diff --git a/_clang-format b/.clang-format similarity index 100% rename from _clang-format rename to .clang-format diff --git a/.clang-tidy b/.clang-tidy index b5f9d549338..b242b140753 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,3 +1,7 @@ +HeaderFilterRegex: '/(examples|include|src|tests)/.*\.hpp' + +FormatStyle: file + Checks: > -*, readability-identifier-naming, diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0cb8002152e..c081eb72997 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2019-2024 Intel Corporation +# Copyright 2019-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,44 +15,48 @@ #=============================================================================== # Default -* @oneapi-src/onednn-arch @intel-innersource/dnn-arch +* @uxlfoundation/onednn-arch # Github automation -/.github/ @oneapi-src/onednn-devops +/.github/ @uxlfoundation/onednn-devops # CPU Engine -/src/cpu/aarch64/ @oneapi-src/onednn-cpu-aarch64 @intel-innersource/dnn-arch -/src/cpu/x64/ @oneapi-src/onednn-cpu-x64 @intel-innersource/dnn-cpu -/src/cpu/rnn/ @oneapi-src/onednn-cpu-x64 @intel-innersource/dnn-cpu +/src/cpu/aarch64/ @uxlfoundation/onednn-cpu-aarch64 +/src/cpu/x64/ @uxlfoundation/onednn-cpu-x64 +/src/cpu/rnn/ @uxlfoundation/onednn-cpu-x64 # GPU Engine -/src/gpu/amd/ @oneapi-src/onednn-gpu-amd @intel-innersource/dnn-arch -/src/gpu/intel/ @oneapi-src/onednn-gpu-intel @intel-innersource/dnn-gpu -/src/gpu/nvidia/ @oneapi-src/onednn-gpu-nvidia @intel-innersource/dnn-arch -/src/gpu/generic/ @oneapi-src/onednn-arch @intel-innersource/dnn-arch @intel-innersource/dnn-gpu -/src/gpu/generic/sycl/ @oneapi-src/onednn-gpu-generic @intel-innersource/dnn-arch @intel-innersource/dnn-gpu +/src/gpu/amd/ @uxlfoundation/onednn-gpu-amd +/src/gpu/intel/ @uxlfoundation/onednn-gpu-intel +/src/gpu/nvidia/ @uxlfoundation/onednn-gpu-nvidia +/src/gpu/generic/ @uxlfoundation/onednn-arch +/src/gpu/generic/sycl/ @uxlfoundation/onednn-gpu-generic # Tests -/tests/benchdnn/inputs/ @oneapi-src/onednn-maintain @intel-innersource/dnn-arch @intel-innersource/dnn-cpu @intel-innersource/dnn-gpu -/tests/benchdnn/graph/ @oneapi-src/onednn-graph @oneapi-src/onednn-arch @intel-innersource/dnn-graph @intel-innersource/dnn-arch -/tests/benchdnn/inputs/graph/ @oneapi-src/onednn-graph @oneapi-src/onednn-arch @intel-innersource/dnn-graph @intel-innersource/dnn-arch -/tests/gtests/graph/ @oneapi-src/onednn-graph @intel-innersource/dnn-graph +/tests/benchdnn/inputs/ @uxlfoundation/onednn-maintain +/tests/benchdnn/graph/ @uxlfoundation/onednn-graph @uxlfoundation/onednn-arch +/tests/benchdnn/inputs/graph/ @uxlfoundation/onednn-graph @uxlfoundation/onednn-arch +/tests/gtests/graph/ @uxlfoundation/onednn-graph # Graph API -/src/graph/ @oneapi-src/onednn-graph @intel-innersource/dnn-graph - -# Graph compiler -/src/graph/backend/graph_compiler/ @intel-innersource/dnn-compiler -/tests/gtests/graph/unit/backend/graph_compiler/ @intel-innersource/dnn-compiler +/src/graph/ @uxlfoundation/onednn-graph # Documentation -*.md @oneapi-src/onednn-doc @oneapi-src/onednn-arch @intel-innersource/dnn-doc @intel-innersource/dnn-arch -/doc/ @oneapi-src/onednn-doc @oneapi-src/onednn-arch @intel-innersource/dnn-doc @intel-innersource/dnn-arch +*.md @uxlfoundation/onednn-doc @uxlfoundation/onednn-arch +/doc/ @uxlfoundation/onednn-doc @uxlfoundation/onednn-arch + +# Third party components +/third-party/ @uxlfoundation/onednn-arch +/third_party/level_zero/ @uxlfoundation/onednn-gpu-intel +/third_party/mdapi/ @uxlfoundation/onednn-gpu-intel +/third_party/ngen/ @uxlfoundation/onednn-gpu-intel +/third_party/xbyak/ @uxlfoundation/onednn-cpu-x64 +/third_party/xbyak_aarch64/ @uxlfoundation/onednn-cpu-aarch64 # Governance and process -/.github/CODEOWNERS @oneapi-src/onednn-maintain -/SECURITY.md @oneapi-src/onednn-maintain -/MAINTAINERS.md @oneapi-src/onednn-maintain -/CONTRIBUTING.md @oneapi-src/onednn-maintain -/CODING_STANDARDS.md @oneapi-src/onednn-maintain -/CODE_OF_CONDUCT.md @oneapi-src/onednn-maintain +/.github/CODEOWNERS @uxlfoundation/onednn-maintain +/SECURITY.md @uxlfoundation/onednn-maintain +/MAINTAINERS.md @uxlfoundation/onednn-maintain +/CONTRIBUTING.md @uxlfoundation/onednn-maintain +/CODING_STANDARDS.md @uxlfoundation/onednn-maintain +/CODE_OF_CONDUCT.md @uxlfoundation/onednn-maintain diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 2770b4545dc..6141ad2ca74 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -12,7 +12,7 @@ factors are considered important to reproduce an issue. # Version Report oneDNN version and githash. Version information is printed to stdout -in [verbose mode](https://oneapi-src.github.io/oneDNN/dev_guide_verbose.html). +in [verbose mode](https://uxlfoundation.github.io/oneDNN/dev_guide_verbose.html). # Environment oneDNN includes hardware-specific optimizations and may behave @@ -28,10 +28,10 @@ the following information to help reproduce the issue: # Steps to reproduce Please check that the issue is reproducible with the latest revision on -master. Include all the steps to reproduce the issue. +main. Include all the steps to reproduce the issue. -You can use [verbose mode](https://oneapi-src.github.io/oneDNN/dev_guide_verbose.html) -and [benchdnn](https://github.com/oneapi-src/oneDNN/tree/master/tests/benchdnn) +You can use [verbose mode](https://uxlfoundation.github.io/oneDNN/dev_guide_verbose.html) +and [benchdnn](https://github.com/uxlfoundation/oneDNN/tree/main/tests/benchdnn) to validate correctness of all primitives the library supports. If this does not work a short C/C++ program or modified unit tests demonstrating the issue will greatly help with the investigation. @@ -40,7 +40,7 @@ will greatly help with the investigation. Document behavior you observe. For performance defects, like performance regressions or a function being slow, provide a log including output generated by your application in -[verbose mode](https://oneapi-src.github.io/oneDNN/dev_guide_verbose.html). +[verbose mode](https://uxlfoundation.github.io/oneDNN/dev_guide_verbose.html). # Expected behavior Document behavior you expect. \ No newline at end of file diff --git a/.github/automation/.azure-pipeline.yml b/.github/automation/.azure-pipeline.yml deleted file mode 100644 index a6ddac46fe9..00000000000 --- a/.github/automation/.azure-pipeline.yml +++ /dev/null @@ -1,132 +0,0 @@ -#! /bin/bash - -#=============================================================================== -# Copyright 2019-2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - -trigger: -- main -- rls-* - -jobs: - - job: 'ClangFormat' - pool: - vmImage: 'ubuntu-20.04' - steps: - - script: | - .github/automation/env/clang.sh 11 - displayName: 'init' - - script: | - .github/automation/clang-format.sh - displayName: 'ClangFormat_Check' - failOnStderr: true - - job: 'Ubuntu20' - timeoutInMinutes: 120 - pool: - vmImage: 'ubuntu-20.04' - strategy: - matrix: - clang: - CC: clang - CXX: clang++ - gcc: - CC: gcc - CXX: g++ - steps: - - script: | - if [ "$(CC)" == "clang" ]; then - .github/automation/env/clang.sh 9 - fi - displayName: "Init_Env" - - script: | - .github/automation/build.sh --threading omp --mode Release --source-dir $(pwd) --build-dir $(pwd)/build - displayName: 'build' - - script: | - .github/automation/test.sh --build-dir $(pwd)/build --report-dir $(pwd)/report - displayName: 'test' - failOnStderr: true - - job: 'Ubuntu22' - timeoutInMinutes: 120 - pool: - vmImage: 'ubuntu-22.04' - strategy: - matrix: - clang: - CC: clang - CXX: clang++ - gcc: - CC: gcc - CXX: g++ - steps: - - script: | - if [ "$(CC)" == "clang" ]; then - .github/automation/env/clang.sh 15 - fi - displayName: "Init_Env" - - script: | - .github/automation/build.sh --threading omp --mode Release --source-dir $(pwd) --build-dir $(pwd)/build - displayName: 'build' - - script: | - .github/automation/test.sh --build-dir $(pwd)/build --report-dir $(pwd)/report - displayName: 'test' - failOnStderr: true - - job: 'macOS12' - timeoutInMinutes: 120 - pool: - vmImage: 'macOS-12' - steps: - - script: | - .github/automation/build.sh --threading omp --mode Release --source-dir $(pwd) --build-dir $(pwd)/build - displayName: 'build' - - script: | - .github/automation/test.sh --build-dir $(pwd)/build --report-dir $(pwd)/report - displayName: 'test' - failOnStderr: true - - job: 'macOS13' - timeoutInMinutes: 120 - pool: - vmImage: 'macOS-13' - steps: - - script: | - .github/automation/build.sh --threading omp --mode Release --source-dir $(pwd) --build-dir $(pwd)/build - displayName: 'build' - - script: | - .github/automation/test.sh --build-dir $(pwd)/build --report-dir $(pwd)/report - displayName: 'test' - failOnStderr: true - - job: 'Windows_Server_2022' - timeoutInMinutes: 120 - pool: - vmImage: 'windows-2022' - steps: - - script: | - .github\automation\build.bat /THREADING omp /MODE Release /VSVERSION vs2022 /SOURCEDIR %CD% /BUILDDIR %CD%\build - displayName: 'build' - - script: | - .github\automation\test.bat /BUILDDIR %CD%\build /MODE Release /REPORTDIR %CD%\report - displayName: 'test' - failOnStderr: true - - job: 'Windows_Server_2019' - timeoutInMinutes: 120 - pool: - vmImage: 'windows-2019' - steps: - - script: | - .github\automation\build.bat /THREADING omp /MODE Release /VSVERSION vs2019 /SOURCEDIR %CD% /BUILDDIR %CD%\build - displayName: 'build' - - script: | - .github\automation\test.bat /BUILDDIR %CD%\build /MODE Release /REPORTDIR %CD%\report - displayName: 'test' - failOnStderr: true diff --git a/.github/automation/aarch64/build.sh b/.github/automation/aarch64/build.sh new file mode 100755 index 00000000000..a3d8d81f26b --- /dev/null +++ b/.github/automation/aarch64/build.sh @@ -0,0 +1,54 @@ +#! /bin/bash + +# ******************************************************************************* +# Copyright 2024 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +# Build oneDNN for aarch64. + +set -o errexit -o pipefail -o noclobber + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" + +# Defines MP, CC, CXX and OS. +source ${SCRIPT_DIR}/common.sh + +export ACL_ROOT_DIR=${ACL_ROOT_DIR:-"${PWD}/ComputeLibrary"} + +CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-"Release"} +ONEDNN_TEST_SET=${ONEDNN_TEST_SET:-"SMOKE"} +ONEDNN_BUILD_GRAPH=${ONEDNN_BUILD_GRAPH:-"ON"} + +if [[ "$ONEDNN_ACTION" == "configure" ]]; then + set -x + cmake \ + -Bbuild -S. \ + -DDNNL_USE_ACL=ON \ + -DONEDNN_BUILD_GRAPH=$ONEDNN_BUILD_GRAPH \ + -DDNNL_CPU_RUNTIME=$ONEDNN_THREADING \ + -DONEDNN_WERROR=ON \ + -DDNNL_BUILD_FOR_CI=ON \ + -DONEDNN_TEST_SET=$ONEDNN_TEST_SET \ + -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE + set +x +elif [[ "$ONEDNN_ACTION" == "build" ]]; then + set -x + cmake --build build + set +x +else + echo "Unknown action: $ONEDNN_ACTION" + exit 1 +fi diff --git a/.github/automation/aarch64/build_acl.sh b/.github/automation/aarch64/build_acl.sh new file mode 100755 index 00000000000..53cc2a825fc --- /dev/null +++ b/.github/automation/aarch64/build_acl.sh @@ -0,0 +1,81 @@ +#! /bin/bash + +# ******************************************************************************* +# Copyright 2020-2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +# Build ACL from github. + +set -o errexit -o pipefail -o noclobber + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" + +# Defines MP, CC, CXX and OS. +source ${SCRIPT_DIR}/common.sh + +ACL_BUILD_TYPE=${ACL_BUILD_TYPE:-"Release"} +ACL_ROOT_DIR=${ACL_ROOT_DIR:-"${PWD}/ComputeLibrary"} +ACL_REPO="https://github.com/ARM-software/ComputeLibrary.git" + +if [[ "$ACL_THREADING" == "OMP" ]]; then + ACL_OPENMP=1 +elif [[ "$ACL_THREADING" == "SEQ" ]]; then + ACL_OPENMP=0 +fi + +if [[ "$OS" == "Linux" ]]; then + ACL_MULTI_ISA_SUPPORT=1 + if [[ "$ACL_THREADING" == "OMP" ]]; then + ACL_OPENMP=1 + elif [[ "$ACL_THREADING" == "SEQ" ]]; then + ACL_OPENMP=0 + fi + ACL_OS="linux" +elif [[ "$OS" == "Darwin" ]]; then + ACL_MULTI_ISA_SUPPORT=0 + ACL_OPENMP=0 + ACL_OS="macos" +else + echo "Unknown OS: $OS" + exit 1 +fi + +if [[ "$ACL_BUILD_TYPE" == "Release" ]]; then + ACL_DEBUG=0 +elif [[ "$ACL_BUILD_TYPE" == "Debug" ]]; then + ACL_DEBUG=1 +else + echo "Unknown build config: $ACL_BUILD_TYPE" + exit 1 +fi + +if [[ "$ACL_ACTION" == "clone" ]]; then + set -x + git clone --branch $ACL_VERSION --depth 1 $ACL_REPO $ACL_ROOT_DIR + set +x +elif [[ "$ACL_ACTION" == "build" ]]; then + set -x + cd $ACL_ROOT_DIR + set -x + scons $MP Werror=0 debug=$ACL_DEBUG neon=1 opencl=0 embed_kernels=0 \ + os=$ACL_OS arch=armv8.2-a build=native multi_isa=$ACL_MULTI_ISA_SUPPORT \ + fixed_format_kernels=1 cppthreads=0 openmp=$ACL_OPENMP examples=0 \ + validation_tests=0 + set +x +else + echo "Unknown action: $ACL_ACTION" + exit 1 +fi diff --git a/.github/automation/aarch64/ci.json b/.github/automation/aarch64/ci.json new file mode 100644 index 00000000000..bdd44eaed7c --- /dev/null +++ b/.github/automation/aarch64/ci.json @@ -0,0 +1,8 @@ +{ + "dependencies": { + "acl": "v25.02", + "gcc": "13", + "clang": "17", + "onednn-base": "v3.7" + } +} diff --git a/.github/automation/aarch64/common.sh b/.github/automation/aarch64/common.sh new file mode 100644 index 00000000000..cfb483eb468 --- /dev/null +++ b/.github/automation/aarch64/common.sh @@ -0,0 +1,46 @@ +#! /bin/bash + +# ******************************************************************************* +# Copyright 2024-2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +# Common variables for aarch64 ci. Exports: +# CC, CXX, OS + +set -o errexit -o pipefail -o noclobber + +export OS=$(uname) + +# Num threads on system. +if [[ "$OS" == "Darwin" ]]; then + export MP="-j$(sysctl -n hw.ncpu)" +elif [[ "$OS" == "Linux" ]]; then + export MP="-j$(nproc)" +fi + +if [[ "$BUILD_TOOLSET" == "gcc" ]]; then + export CC=gcc-${GCC_VERSION} + export CXX=g++-${GCC_VERSION} +elif [[ "$BUILD_TOOLSET" == "clang" ]]; then + export CC=clang + export CXX=clang++ +fi + +# Print every exported variable. +echo "OS: $OS" +echo "Toolset: $BUILD_TOOLSET" +echo "CC: $CC" +echo "CXX: $CXX" diff --git a/.github/automation/aarch64/get_acl.sh b/.github/automation/aarch64/get_acl.sh new file mode 100755 index 00000000000..7745b9b9764 --- /dev/null +++ b/.github/automation/aarch64/get_acl.sh @@ -0,0 +1,95 @@ +#! /bin/bash + +# ******************************************************************************* +# Copyright 2024-2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +set -o errexit -o pipefail -o noclobber + +WORKSPACE=${GITHUB_WORKSPACE:-$(pwd)} +echo "github workspace $GITHUB_WORKSPACE" + +os_type=$(uname) + +ACL_WITH_ASSERTS=${ACL_WITH_ASSERTS:-0} +ACL_VERSION=${ACL_VERSION:-v24.08.1} + +if [[ "$os_type" == "Linux" ]]; then + echo "This machine is running Linux" + ARCHIVE="arm_compute-${ACL_VERSION}-linux-aarch64-cpu-bin.tar.gz" +elif [[ "$os_type" == "Darwin" ]]; then + echo "This machine is running macOS" + ARCHIVE="arm_compute-${ACL_VERSION}-macos-aarch64-cpu-bin.tar.gz" +else + echo "Unknown OS: $os_type" + exit 1 +fi + +# Set version and root directory +export ACL_ROOT_DIR="${WORKSPACE}/ComputeLibrary" + +echo "ACL_VERSION: ${ACL_VERSION}" +echo "ACL_DIR_NAME: ${ACL_DIR_NAME}" +echo "ACL_ROOT_DIR: ${ACL_ROOT_DIR}" +echo "ACL_WITH_ASSERTS: ${ACL_WITH_ASSERTS}" + +# Download the specified Compute Library version +if [[ ! -f $ARCHIVE ]]; then + ACL_URL="https://github.com/ARM-software/ComputeLibrary/releases/download/${ACL_VERSION}/${ARCHIVE}" + echo "Downloading ACL from ${ACL_URL}" + wget ${ACL_URL} +else + echo "$ARCHIVE already exists, skipping download." +fi + +# Function to find the appropriate lib directory +find_acl_lib_dir() { + local dirs=("$ACL_ROOT_DIR"/lib/*/) + local selected_dir="" + + # Select directory based on build type + for dir in "${dirs[@]}"; do + if [[ $ACL_WITH_ASSERTS == 1 ]]; then + [[ "$dir" == *"-asserts/" ]] && selected_dir="$dir" && break + else + [[ "$dir" != *"-asserts/" ]] && selected_dir="$dir" && break + fi + done + + # Return result or exit if not found + if [[ -z "$selected_dir" ]]; then + echo "No matching ACL lib directory found." + exit 1 + else + echo "$selected_dir" + fi +} + +# Extract the tarball if not already extracted +if [[ ! -d $ACL_ROOT_DIR ]]; then + mkdir -p $ACL_ROOT_DIR + tar -xzvf "${ARCHIVE}" -C $ACL_ROOT_DIR --strip-components=1 >/dev/null 2>&1 +else + echo "$ACL_ROOT_DIR directory already exists, skipping extraction." +fi + +# Find the ACL library directory +ACL_LIB_DIR=$(find_acl_lib_dir) +echo "Using ACL lib from ${ACL_LIB_DIR}" +echo "cp contents from ${ACL_LIB_DIR} to ${ACL_ROOT_DIR}/lib" +cp -rf "$ACL_LIB_DIR"* "$ACL_ROOT_DIR/lib/" + +echo "${ACL_VERSION}" >"${ACL_ROOT_DIR}/arm_compute/arm_compute_version.embed" diff --git a/.github/automation/aarch64/skipped-tests.sh b/.github/automation/aarch64/skipped-tests.sh new file mode 100755 index 00000000000..01f01923e76 --- /dev/null +++ b/.github/automation/aarch64/skipped-tests.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +# Test oneDNN for aarch64. + +set -eo pipefail + +OS=${OS:-"Linux"} + +# AArch64 does not officially support graph for now. +SKIPPED_GRAPH_TEST_FAILURES="test_graph_unit_dnnl_sdp_decomp_cpu" +SKIPPED_GRAPH_TEST_FAILURES+="|test_graph_unit_dnnl_mqa_decomp_cpu" + +# described in issue: https://github.com/uxlfoundation/oneDNN/issues/2175 +SKIPPED_TEST_FAILURES="test_benchdnn_modeC_matmul_multidims_cpu" + +# We currently have some OS and config specific test failures. +if [[ "$OS" == "Linux" ]]; then + if [[ "$CMAKE_BUILD_TYPE" == "Debug" ]]; then + # as test_matmul is time consuming , we only run it in release mode to save time. + SKIPPED_TEST_FAILURES+="|test_matmul" + fi + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_binary_ci_cpu" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_binary_different_dt_ci_cpu" + + SKIPPED_GRAPH_TEST_FAILURES+="|test_benchdnn_modeC_graph_ci_cpu" + SKIPPED_GRAPH_TEST_FAILURES+="|cpu-graph-gqa-cpp" + SKIPPED_GRAPH_TEST_FAILURES+="|cpu-graph-mqa-cpp" + SKIPPED_GRAPH_TEST_FAILURES+="|cpu-graph-sdpa-cpp" + SKIPPED_GRAPH_TEST_FAILURES+="|cpu-graph-sdpa-stacked-qkv-cpp" + SKIPPED_GRAPH_TEST_FAILURES+="|test_graph_unit_dnnl_large_partition_cpu" + + # OpenVINO Toolkit OneDNN fork failed tests + SKIPPED_TEST_FAILURES+="|test_batch_normalization" + SKIPPED_TEST_FAILURES+="|test_eltwise" + SKIPPED_TEST_FAILURES+="|test_iface_attr" + SKIPPED_TEST_FAILURES+="|test_lrn" + SKIPPED_TEST_FAILURES+="|test_pooling_forward" + SKIPPED_TEST_FAILURES+="|test_reduction" + SKIPPED_TEST_FAILURES+="|test_api" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_binary_smoke_cpu" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_bnorm_smoke_cpu" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_conv_smoke_cpu" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_deconv_smoke_cpu" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_eltwise_smoke_cpu" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_lrn_smoke_cpu" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_pool_smoke_cpu" + SKIPPED_TEST_FAILURES+="|test_benchdnn_modeC_reduction_smoke_cpu" +fi + +# Nightly failures +SKIPPED_NIGHTLY_TEST_FAILURES="test_benchdnn_modeC_bnorm_all_blocked_cpu" +SKIPPED_NIGHTLY_TEST_FAILURES+="|test_benchdnn_modeC_bnorm_regressions_cpu" +SKIPPED_NIGHTLY_TEST_FAILURES+="|test_benchdnn_modeC_conv_int8_cpu" +SKIPPED_NIGHTLY_TEST_FAILURES+="|test_benchdnn_modeC_graph_fusions_cpu" +SKIPPED_NIGHTLY_TEST_FAILURES+="|test_benchdnn_modeC_matmul_sparse_gpu_cpu" +SKIPPED_NIGHTLY_TEST_FAILURES+="|test_benchdnn_modeC_reorder_all_cpu" + +# * c7g failures. TODO: scope these to c7g only. Better yet, fix them. +SKIPPED_NIGHTLY_TEST_FAILURES+="|test_benchdnn_modeC_binary_all_cpu" +SKIPPED_NIGHTLY_TEST_FAILURES+="|test_benchdnn_modeC_graph_int8_cpu" + +SKIPPED_TEST_FAILURES+="|${SKIPPED_GRAPH_TEST_FAILURES}|${SKIPPED_NIGHTLY_TEST_FAILURES}" + +printf "${SKIPPED_TEST_FAILURES}" diff --git a/.github/automation/aarch64/test.sh b/.github/automation/aarch64/test.sh new file mode 100755 index 00000000000..d26fccb9aea --- /dev/null +++ b/.github/automation/aarch64/test.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +# ******************************************************************************* +# Copyright 2024-2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +# Test oneDNN for aarch64. + +set -o errexit -o pipefail -o noclobber + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" + +export CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-"Release"} + +# Defines MP, CC, CXX and OS. +source ${SCRIPT_DIR}/common.sh + +# Sequential (probably macOS) builds should use num proc parallelism. +if [[ "$ONEDNN_THREADING" == "SEQ" ]]; then + export CTEST_PARALLEL_LEVEL="" +fi + +set -x +ctest --no-tests=error --output-on-failure -E $("${SCRIPT_DIR}"/skipped-tests.sh) +set +x diff --git a/.github/automation/build.sh b/.github/automation/build.sh deleted file mode 100755 index 5a684f6353f..00000000000 --- a/.github/automation/build.sh +++ /dev/null @@ -1,108 +0,0 @@ -#! /bin/bash - -#=============================================================================== -# Copyright 2019-2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - - -while [[ $# -gt 0 ]]; do - key="$1" - - case $key in - --threading) - BUILD_THREADING="$2" - ;; - --mode) - BUILD_MODE="$2" - ;; - --source-dir) - SORUCE_DIR="$2" - ;; - --acl-dir) - ACL_DIR="$2" - ;; - --build-dir) - BUILD_DIR="$2" - ;; - --cmake-opt) - CMAKE_OPT="$2" - ;; - *) - echo "Unknown option: $1" - exit 1 - ;; - esac - shift - shift -done - -CMAKE_OPTIONS="-DCMAKE_BUILD_TYPE=${BUILD_MODE} -DDNNL_BUILD_FOR_CI=ON -DDNNL_WERROR=ON ${CMAKE_OPT}" - -CPU_RUNTIME="NONE" -GPU_RUNTIME="NONE" - -if [ "${BUILD_THREADING}" == "tbb" ]; then - CPU_RUNTIME="TBB" - echo "Info: Setting DNNL_CPU_RUNTIME to TBB..." -elif [ "${BUILD_THREADING}" == "omp" ]; then - echo "Info: Setting DNNL_CPU_RUNTIME to OMP..." - CPU_RUNTIME="OMP" -elif [ "${BUILD_THREADING}" == "ocl" ]; then - echo "Info: Setting DNNL_CPU_RUNTIME to OMP..." - echo "Info: Setting DNNL_GPU_RUNTIME to OCL..." - CPU_RUNTIME="OMP" - GPU_RUNTIME="OCL" -else - echo "Error unknown threading: ${BUILD_THREADING}" - exit 1 -fi - -CMAKE_OPTIONS="${CMAKE_OPTIONS} - -DDNNL_CPU_RUNTIME=${CPU_RUNTIME} - -DDNNL_GPU_RUNTIME=${GPU_RUNTIME} - -DDNNL_TEST_SET=SMOKE - " - -# Enable Compute Library backend if a location for the built library is given -# NOTE: only for AArch64 builds. -if [ ! -z ${ACL_DIR} ]; then - export ACL_ROOT_DIR=$ACL_DIR - CMAKE_OPTIONS="${CMAKE_OPTIONS} -DDNNL_AARCH64_USE_ACL=ON" - echo "Info: Building with Arm Compute Library backend for Aarch64..." -fi - -if [ "$(uname)" == "Linux" ]; then - MAKE_OP="-j$(grep -c processor /proc/cpuinfo)" -else - MAKE_OP="-j$(sysctl -n hw.physicalcpu)" -fi - -cd "${SORUCE_DIR}" -echo "Calling CMake with otions: ${CMAKE_OPTIONS}" -cmake . -B${BUILD_DIR} ${CMAKE_OPTIONS} -err=$? -if [ "$err" != 0 ]; then - if [ -e "${BUILD_DIR}/CMakeFiles/CMakeOutput.log" ]; then - echo "CMakeOutput.log:" - cat ${BUILD_DIR}/CMakeFiles/CMakeOutput.log - fi - if [ -e "${BUILD_DIR}/CMakeFiles/CMakeError.log" ]; then - echo "CMakeError.log:" - cat ${BUILD_DIR}/CMakeFiles/CMakeError.log - fi - exit $err -fi -cd ${BUILD_DIR} && make -k ${MAKE_OP} -exit $? diff --git a/.github/automation/build_acl.sh b/.github/automation/build_acl.sh deleted file mode 100755 index 41c6b0b4a2e..00000000000 --- a/.github/automation/build_acl.sh +++ /dev/null @@ -1,61 +0,0 @@ -#! /bin/bash - -# ******************************************************************************* -# Copyright 2020-2023 Arm Limited and affiliates. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ******************************************************************************* - -# Compute Library build defaults -ACL_VERSION="v23.11" -ACL_DIR="${PWD}/ComputeLibrary" -ACL_ARCH="armv8a" -ACL_MULTI_ISA_SUPPORT=0 - -while [[ $# -gt 0 ]]; do - case $1 in - --version) - ACL_VERSION="v$2" - shift - ;; - --arch) - ACL_ARCH="$2" - shift - ;; - --multi_isa) - ACL_MULTI_ISA_SUPPORT=1 - ;; - --root-dir) - ACL_DIR="$2" - shift - ;; - *) - echo "Unknown option: $1" - exit 1 - ;; - esac - shift -done - -readonly ACL_REPO="https://github.com/ARM-software/ComputeLibrary.git" -MAKE_NP="-j$(grep -c processor /proc/cpuinfo)" - -git clone --branch $ACL_VERSION --depth 1 $ACL_REPO $ACL_DIR -cd $ACL_DIR - -scons --silent $MAKE_NP Werror=0 debug=0 neon=1 opencl=0 embed_kernels=0 \ - os=linux arch=$ACL_ARCH build=native multi_isa=$ACL_MULTI_ISA_SUPPORT \ - fixed_format_kernels=1 - -exit $? diff --git a/.github/automation/clang-format.sh b/.github/automation/clang-format.sh index b38c3877466..28c99331109 100755 --- a/.github/automation/clang-format.sh +++ b/.github/automation/clang-format.sh @@ -16,20 +16,31 @@ # limitations under the License. #=============================================================================== -echo "Using clang-format version: $(clang-format --version)" +CLANG_FORMAT=clang-format-11 + +echo "Checking ${CLANG_FORMAT}" +if ! ${CLANG_FORMAT} --version; then + echo ${CLANG_FORMAT} is not available or not working correctly. + exit 1 +fi + echo "Starting format check..." -for filename in $(find "$(pwd)" -type f | grep -P ".*\.(c|cpp|h|hpp|cl)$"); do clang-format -style=file -i $filename; done +for filename in $(find "$(pwd)" -type f | grep -P ".*\.(c|cpp|h|hpp|cl)$"); do ${CLANG_FORMAT} -style=file -i $filename; done RETURN_CODE=0 -echo $(git status) | grep "nothing to commit" > /dev/null +echo $(git status) | grep "nothing to commit" > /dev/null if [ $? -eq 1 ]; then - echo "Clang-format check FAILED! Found not formatted files!" - echo "$(git status)" + echo "Clang-format check FAILED! The following files must be formatted with ${CLANG_FORMAT}:" + echo "$(git diff --name-only)" + echo + echo "Changes required to pass this check:" + echo "$(git diff)" + echo RETURN_CODE=3 else - echo "Clang-format check PASSED! Not formatted files not found..." + echo "Clang-format check PASSED!" fi -exit ${RETURN_CODE} \ No newline at end of file +exit ${RETURN_CODE} diff --git a/.github/automation/commit-msg-check.py b/.github/automation/commit-msg-check.py new file mode 100755 index 00000000000..aa6ca2cd5f6 --- /dev/null +++ b/.github/automation/commit-msg-check.py @@ -0,0 +1,88 @@ +#!/usr/bin/python3 + +# ******************************************************************************* +# Copyright 2024 Arm Limited and affiliates. +# Copyright 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +import argparse +import subprocess +import re + +# Ensure the scope ends in a colon and that same level scopes are +# comma delimited. +# Current implementation only checks the first level scope as ':' can be used +# in the commit description (ex: TBB::tbb or bf16:bf16). +# TODO: Limit scopes to an acceptable list of tags. +def __scopeCheck(msg: str): + status = "Message scope: " + + if not re.match('^[a-z0-9_]+(, [a-z0-9_]+)*: ', msg): + print(f"{status} FAILED: Commit message must follow the format " + ":[ :] ") + return False + + print(f"{status} OK") + return True + +# Ensure a character limit for the first line. +def __numCharacterCheck(msg: str): + status = "Message length:" + if len(msg) <= 72: + print(f"{status} OK") + return True + else: + # Fixup commits usually include the full name of the commit they are + # fixing, which adds 6 more symbols to the message. Let them in. + if re.match('^fixup: ', msg): + print(f"{status} Fixup message, OK") + return True + else: + print(f"{status} FAILED: Commit message summary must not " + "exceed 72 characters.") + return False + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("head", help="Head commit of PR branch") + parser.add_argument("base", help="Base commit of PR branch") + args = parser.parse_args() + base: str = args.base + head: str = args.head + + commit_range = base + ".." + head + messages = subprocess.run(["git", "rev-list", "--format=oneline", + commit_range], capture_output=True, text=True).stdout + + is_ok = True + for i in messages.splitlines(): + print(i) + commit_msg=i.split(' ', 1)[1] + result = __numCharacterCheck(commit_msg) + is_ok = is_ok and result + result = __scopeCheck(commit_msg) + is_ok = is_ok and result + + if is_ok: + print("All commmit messages are formatted correctly. ") + else: + print("Some commit message checks failed. Please align commit messages " + "with Contributing Guidelines and update the PR.") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/.github/automation/env/qemu.sh b/.github/automation/env/qemu.sh deleted file mode 100755 index 71dc2c636f6..00000000000 --- a/.github/automation/env/qemu.sh +++ /dev/null @@ -1,25 +0,0 @@ -#! /bin/bash - -#=============================================================================== -# Copyright 2020 FUJITSU LIMITED -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - -# Download, build and install QEMU -wget https://download.qemu.org/qemu-5.0.0.tar.xz -tar xJf qemu-5.0.0.tar.xz > /dev/null -cd qemu-5.0.0 -./configure --target-list=aarch64-linux-user > /dev/null -make > /dev/null -make install > /dev/null diff --git a/.github/automation/performance/bench_nightly_performance.sh b/.github/automation/performance/bench_nightly_performance.sh new file mode 100644 index 00000000000..9b66f0a7c01 --- /dev/null +++ b/.github/automation/performance/bench_nightly_performance.sh @@ -0,0 +1,53 @@ +#! /bin/bash + +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +# Usage: bash bench_nightly_performance.sh {baseline_benchdnn_executable} {benchdnn_executable} {baseline_results_file} {new_results_file} + +IFS=$'\n' # Prevents shuffling from using spaces as delimiters + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +TESTS=( + "$1 --matmul --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/matmul_nightly >> $3" + "$2 --matmul --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/matmul_nightly >> $4" + "$1 --conv --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/conv_nightly >> $3" + "$2 --conv --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/conv_nightly >> $4" + "$1 --eltwise --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/eltwise_nightly >> $3" + "$2 --eltwise --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/eltwise_nightly >> $4" + "$1 --reorder --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/reorder_nightly >> $3" + "$2 --reorder --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/reorder_nightly >> $4" + ) + +N=5 + +for i in $( seq $N ) +do + echo "Testing loop ${i} / ${N}..." + + TESTS=( $(shuf -e "${TESTS[@]}") ) + + for test in "${TESTS[@]}" + do + echo "Starting ${test}" + SECONDS=0 + eval $test + duration=$SECONDS + echo "Completed in $((duration / 60)):$((duration % 60))" + done +done diff --git a/.github/automation/performance/bench_pr_performance.sh b/.github/automation/performance/bench_pr_performance.sh new file mode 100755 index 00000000000..48f0eebd643 --- /dev/null +++ b/.github/automation/performance/bench_pr_performance.sh @@ -0,0 +1,52 @@ +#! /bin/bash + +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +# Usage: bash bench_pr_performance.sh {baseline_benchdnn_executable} {benchdnn_executable} {baseline_results_file} {new_results_file} + +IFS=$'\n' # Prevents shuffling from using spaces as delimiters +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +TESTS=( + "$1 --matmul --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/matmul >> $3" + "$2 --matmul --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/matmul >> $4" + "$1 --conv --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/conv >> $3" + "$2 --conv --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/conv >> $4" + "$1 --eltwise --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/eltwise >> $3" + "$2 --eltwise --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/eltwise >> $4" + "$1 --reorder --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/reorder >> $3" + "$2 --reorder --mode=P --perf-template=%prb%,%-time% --batch=${SCRIPT_DIR}/inputs/reorder >> $4" + ) + +N=5 + +for i in $( seq $N ) +do + echo "Testing loop ${i} / ${N}..." + + TESTS=( $(shuf -e "${TESTS[@]}") ) + + for test in "${TESTS[@]}" + do + echo "Starting ${test}" + SECONDS=0 + eval $test + duration=$SECONDS + echo "Completed in $((duration / 60)):$((duration % 60))" + done +done diff --git a/.github/automation/performance/benchdnn_comparison.py b/.github/automation/performance/benchdnn_comparison.py new file mode 100644 index 00000000000..1dba59e4e4b --- /dev/null +++ b/.github/automation/performance/benchdnn_comparison.py @@ -0,0 +1,98 @@ +#!/usr/bin/python3 + +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +import sys +import os +from collections import defaultdict +from scipy.stats import ttest_ind +import warnings +import statistics + + +def compare_two_benchdnn(file1, file2, tolerance=0.05): + """ + Compare two benchdnn output files + """ + with open(file1) as f: + r1 = f.readlines() + + with open(file2) as f: + r2 = f.readlines() + + # Trim non-formatted lines and split the problem from time + r1 = [x.split(",") for x in r1 if x[0:8] == "--mode=P"] + r2 = [x.split(",") for x in r2 if x[0:8] == "--mode=P"] + + if (len(r1) == 0) or (len(r2) == 0): + warnings.warn("One or both of the test results have zero lines") + if len(r1) != len(r2): + warnings.warn("The number of benchdnn runs do not match") + + r1_samples = defaultdict(list) + r2_samples = defaultdict(list) + + for k, v in r1: + r1_samples[k].append(float(v[:-1])) + for k, v in r2: + r2_samples[k].append(float(v[:-1])) + + failed_tests = [] + times = {} + for prb, r1_times in r1_samples.items(): + if prb not in r2_samples: + warnings.warn(f"{prb} exists in {file1} but not {file2}") + continue + + r2_times = r2_samples[prb] + + res = ttest_ind(r2_times, r1_times, alternative='greater') + r1_med = statistics.median(r1_times) + r2_med = statistics.median(r2_times) + times[prb] = (r1_med, r2_med) + times_str = f" {times[prb][0]} vs {times[prb][1]}" + + # pass the test if: + # the t-test passes (i.e. pvalue > 0.05) OR + # both the median time and min time has not + # slowed down by more than 10% + passed = res.pvalue > 0.05 or \ + ((r2_med - r1_med) / r1_med < 0.1 and \ + (min(r2_times) - min(r1_times)) / min(r1_times) < 0.1) + if not passed: + failed_tests.append(prb + times_str) + passed = False + + if "GITHUB_OUTPUT" in os.environ: + with open(os.environ["GITHUB_OUTPUT"], "a") as f: + print(f"pass={not failed_tests}", file=f) + + if not failed_tests: + print("Regression tests passed") + else: + message = "\n----The following regression tests failed:----\n" + \ + "\n".join(failed_tests) + "\n" + if "GITHUB_OUTPUT" in os.environ: + out_message = message.replace("\n", "%0A") + with open(os.environ["GITHUB_OUTPUT"], "a") as f: + print(f'message={out_message}', file=f) + print(message) + raise Exception("Some regression tests failed") + +if __name__ == "__main__": + compare_two_benchdnn(sys.argv[1], sys.argv[2]) diff --git a/.github/automation/performance/inputs/conv b/.github/automation/performance/inputs/conv new file mode 100644 index 00000000000..554fb596465 --- /dev/null +++ b/.github/automation/performance/inputs/conv @@ -0,0 +1,27 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +--reset +--dir=FWD_D +--dt=f32 +mb1_ic64oc256_ih200oh200kh1sh1dh0ph0_iw267ow267kw1sw1dw0pw0 + +--reset +--dir=FWD_D +--dt=f32 +--attr-fpmath=bf16 +mb1_ic64oc256_ih200oh200kh1sh1dh0ph0_iw267ow267kw1sw1dw0pw0 diff --git a/.github/automation/performance/inputs/conv_nightly b/.github/automation/performance/inputs/conv_nightly new file mode 100644 index 00000000000..04699bee0e5 --- /dev/null +++ b/.github/automation/performance/inputs/conv_nightly @@ -0,0 +1,25 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* +--reset +--batch=conv + +--reset +--dt=f32 +--alg=auto +--dir=FWD_D,FWD_B +--attr-fpmath=,bf16 +--batch=shapes_resnet_50 diff --git a/.github/automation/performance/inputs/eltwise b/.github/automation/performance/inputs/eltwise new file mode 100644 index 00000000000..35935ce2629 --- /dev/null +++ b/.github/automation/performance/inputs/eltwise @@ -0,0 +1,23 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* +--reset +--inplace=true +--alg=gelu_erf +--dir=FWD_D +--dt=f32,bf16 +--tag=abc +1536x384 diff --git a/.github/automation/performance/inputs/eltwise_nightly b/.github/automation/performance/inputs/eltwise_nightly new file mode 100644 index 00000000000..801ff66a3ba --- /dev/null +++ b/.github/automation/performance/inputs/eltwise_nightly @@ -0,0 +1,41 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +--reset +--batch=eltwise + +--reset + +--dt=f32 +--tag=abx,axb +--dir=FWD_D +--attr-post-ops=, + +## algs which do not support alpha and beta + relu with alpha=0 +--alpha=0 --beta=0 +--alg=exp,exp_dst,gelu_erf,gelu_tanh,relu_dst,tanh,tanh_dst +384x384 + +## algs which support negative alpha +--alpha=-2 --beta=0 +--alg=elu,relu,swish +384x384 + +## algs which support alpha and beta +--alpha=-2 --beta=3 +--alg=linear +384x384 diff --git a/.github/automation/performance/inputs/matmul b/.github/automation/performance/inputs/matmul new file mode 100644 index 00000000000..f9deaac91f0 --- /dev/null +++ b/.github/automation/performance/inputs/matmul @@ -0,0 +1,32 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* +--reset +--stag=ab +--wtag=any +--dtag=ab +--attr-post-ops=sum +--dt=f32 +1500x1536:1536x384 + +--reset +--stag=ab +--wtag=any +--dtag=ab +--attr-post-ops=sum +--attr-fpmath=bf16 +--dt=f32 +1500x1536:1536x384 diff --git a/.github/automation/performance/inputs/matmul_nightly b/.github/automation/performance/inputs/matmul_nightly new file mode 100644 index 00000000000..9a37d97c7ec --- /dev/null +++ b/.github/automation/performance/inputs/matmul_nightly @@ -0,0 +1,43 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* +--reset +--batch=matmul + +# Plain cases +--reset +--dt=f32,s8:s8:f32 +--bia-dt=f32,undef +--bia_mask=2 +--batch=shapes_2d_ci +--bia_mask=4 +--batch=shapes_3d + +--dt=f32 +--bia-dt=f32,undef +--bia_mask=2 +--attr-fpmath=bf16 +--batch=shapes_2d_ci +--bia_mask=4 +--batch=shapes_3d + +#f16 +--dt=f16:f16:f16 +--bia-dt=undef +--bia_mask=2 +--batch=shapes_2d_ci +--bia_mask=4 +--batch=shapes_3d diff --git a/.github/automation/performance/inputs/reorder b/.github/automation/performance/inputs/reorder new file mode 100644 index 00000000000..38441f10f2c --- /dev/null +++ b/.github/automation/performance/inputs/reorder @@ -0,0 +1,39 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +--reset +--sdt=f32 +--ddt=f32 +--allow-enum-tags-only=0 +--stag=ba +--dtag=Ab4a,Ab8a +384x384 + +--reset +--sdt=f32 +--ddt=bf16 +--allow-enum-tags-only=0 +--stag=ba +--dtag=BA8b4a,BA4b4a +384x384 + +--reset +--sdt=bf16 +--ddt=f32 +--allow-enum-tags-only=0 +--stag=BA8b4a,BA4b4a +384x384 diff --git a/.github/automation/performance/inputs/reorder_nightly b/.github/automation/performance/inputs/reorder_nightly new file mode 100644 index 00000000000..62e5f9ed262 --- /dev/null +++ b/.github/automation/performance/inputs/reorder_nightly @@ -0,0 +1,27 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +--reset +--batch=reorder + +--reset +--sdt=f32,s8 +--ddt=f32,s8 + +--stag=abx,axb,aBx4b,aBx8b +--dtag=abx,axb,aBx4b,aBx8b +4x256x5x5 diff --git a/.github/automation/x64/build_linters.sh b/.github/automation/x64/build_linters.sh new file mode 100755 index 00000000000..58951e313ca --- /dev/null +++ b/.github/automation/x64/build_linters.sh @@ -0,0 +1,40 @@ +# Build oneDNN for PR linter checks. + +set -o errexit -o pipefail -o noclobber + +export CC=clang +export CXX=clang++ + +if [[ "$ONEDNN_ACTION" == "configure" ]]; then + if [[ "$GITHUB_JOB" == "pr-clang-tidy" ]]; then + set -x + cmake \ + -Bbuild -S. \ + -DCMAKE_BUILD_TYPE=debug \ + -DONEDNN_BUILD_GRAPH=ON \ + -DDNNL_EXPERIMENTAL=ON \ + -DDNNL_EXPERIMENTAL_SPARSE=ON \ + -DDNNL_EXPERIMENTAL_PROFILING=ON \ + -DDNNL_EXPERIMENTAL_UKERNEL=ON \ + -DONEDNN_EXPERIMENTAL_LOGGING=ON \ + -DDNNL_CPU_RUNTIME=OMP \ + -DDNNL_GPU_RUNTIME=OCL \ + -DDNNL_WERROR=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON + set +x + elif [[ "$GITHUB_JOB" == "pr-format-tags" ]]; then + set -x + cmake -B../build -S. -DONEDNN_BUILD_GRAPH=OFF -DDNNL_EXPERIMENTAL_SPARSE=ON + set +x + else + echo "Unknown linter job: $GITHUB_JOB" + exit 1 + fi +elif [[ "$ONEDNN_ACTION" == "build" ]]; then + set -x + cmake --build build -j`nproc` + set +x +else + echo "Unknown action: $ONEDNN_ACTION" + exit 1 +fi diff --git a/.github/automation/build.bat b/.github/azure/build.bat similarity index 100% rename from .github/automation/build.bat rename to .github/azure/build.bat diff --git a/.github/azure/build.sh b/.github/azure/build.sh new file mode 100755 index 00000000000..f8a589b8bb1 --- /dev/null +++ b/.github/azure/build.sh @@ -0,0 +1,108 @@ +#! /bin/bash + +#=============================================================================== +# Copyright 2019-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + + +while [[ $# -gt 0 ]]; do + key="$1" + + case $key in + --threading) + BUILD_THREADING="$2" + ;; + --mode) + BUILD_MODE="$2" + ;; + --source-dir) + SORUCE_DIR="$2" + ;; + --acl-dir) + ACL_DIR="$2" + ;; + --build-dir) + BUILD_DIR="$2" + ;; + --cmake-opt) + CMAKE_OPT="$2" + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac + shift + shift +done + +CMAKE_OPTIONS="-DCMAKE_BUILD_TYPE=${BUILD_MODE} -DDNNL_BUILD_FOR_CI=ON -DDNNL_WERROR=ON ${CMAKE_OPT}" + +CPU_RUNTIME="NONE" +GPU_RUNTIME="NONE" + +if [ "${BUILD_THREADING}" == "tbb" ]; then + CPU_RUNTIME="TBB" + echo "Info: Setting DNNL_CPU_RUNTIME to TBB..." +elif [ "${BUILD_THREADING}" == "omp" ]; then + echo "Info: Setting DNNL_CPU_RUNTIME to OMP..." + CPU_RUNTIME="OMP" +elif [ "${BUILD_THREADING}" == "ocl" ]; then + echo "Info: Setting DNNL_CPU_RUNTIME to OMP..." + echo "Info: Setting DNNL_GPU_RUNTIME to OCL..." + CPU_RUNTIME="OMP" + GPU_RUNTIME="OCL" +else + echo "Error unknown threading: ${BUILD_THREADING}" + exit 1 +fi + +CMAKE_OPTIONS="${CMAKE_OPTIONS} + -DDNNL_CPU_RUNTIME=${CPU_RUNTIME} + -DDNNL_GPU_RUNTIME=${GPU_RUNTIME} + -DDNNL_TEST_SET=SMOKE + " + +# Enable Compute Library backend if a location for the built library is given +# NOTE: only for AArch64 builds. +if [ ! -z ${ACL_DIR} ]; then + export ACL_ROOT_DIR=$ACL_DIR + CMAKE_OPTIONS="${CMAKE_OPTIONS} -DDNNL_USE_ACL=ON" + echo "Info: Building with Arm Compute Library backend for Aarch64..." +fi + +if [ "$(uname)" == "Linux" ]; then + MAKE_OP="-j$(grep -c processor /proc/cpuinfo)" +else + MAKE_OP="-j$(sysctl -n hw.physicalcpu)" +fi + +cd "${SORUCE_DIR}" +echo "Calling CMake with otions: ${CMAKE_OPTIONS}" +cmake . -B${BUILD_DIR} ${CMAKE_OPTIONS} +err=$? +if [ "$err" != 0 ]; then + if [ -e "${BUILD_DIR}/CMakeFiles/CMakeOutput.log" ]; then + echo "CMakeOutput.log:" + cat ${BUILD_DIR}/CMakeFiles/CMakeOutput.log + fi + if [ -e "${BUILD_DIR}/CMakeFiles/CMakeError.log" ]; then + echo "CMakeError.log:" + cat ${BUILD_DIR}/CMakeFiles/CMakeError.log + fi + exit $err +fi +cd ${BUILD_DIR} && make -k ${MAKE_OP} +exit $? diff --git a/.github/azure/ci-x64.yml b/.github/azure/ci-x64.yml new file mode 100644 index 00000000000..2ee792dd7cc --- /dev/null +++ b/.github/azure/ci-x64.yml @@ -0,0 +1,111 @@ +#=============================================================================== +# Copyright 2019-2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# Reference: +# https://learn.microsoft.com/en-us/azure/devops/pipelines/yaml-schema + +trigger: + batch: true + branches: + include: [ main, 'rls-*' ] + paths: + include: + - .github/azure + - cmake + - examples + - include + - src + - tests + - CMakeLists.txt + exclude: + - src/gpu + - src/cpu/aarch64 + - src/cpu/ppc64 + - src/cpu/rv64 + - src/cpu/s390x + - src/xpu + +pr: + autoCancel: true + branches: + include: [ main, 'rls-*' ] + paths: + include: + - .github/azure + - cmake + - examples + - include + - src + - tests + - CMakeLists.txt + exclude: + - src/gpu + - src/cpu/aarch64 + - src/cpu/ppc64 + - src/cpu/rv64 + - src/cpu/s390x + - src/xpu + +jobs: + - job: 'Ubuntu22' + timeoutInMinutes: 120 + pool: + vmImage: 'ubuntu-22.04' + strategy: + matrix: + clang: + CC: clang + CXX: clang++ + gcc: + CC: gcc + CXX: g++ + steps: + - script: | + if [ "$(CC)" == "clang" ]; then + .github/azure/env/clang.sh 15 + fi + displayName: "Init_Env" + - script: | + .github/azure/build.sh --threading omp --mode Release --source-dir $(pwd) --build-dir $(pwd)/build + displayName: 'build' + - script: | + .github/azure/test.sh --build-dir $(pwd)/build --report-dir $(pwd)/report + displayName: 'test' + failOnStderr: true + - job: 'macOS14' + timeoutInMinutes: 120 + pool: + vmImage: 'macOS-14' + steps: + - script: | + .github/azure/build.sh --threading omp --mode Release --source-dir $(pwd) --build-dir $(pwd)/build + displayName: 'build' + - script: | + .github/azure/test.sh --build-dir $(pwd)/build --report-dir $(pwd)/report + displayName: 'test' + failOnStderr: true + - job: 'Windows_Server_2022' + timeoutInMinutes: 120 + pool: + vmImage: 'windows-2022' + steps: + - script: | + .github\azure\build.bat /THREADING omp /MODE Release /VSVERSION vs2022 /SOURCEDIR %CD% /BUILDDIR %CD%\build + displayName: 'build' + - script: | + .github\azure\test.bat /BUILDDIR %CD%\build /MODE Release /REPORTDIR %CD%\report + displayName: 'test' + failOnStderr: true diff --git a/.github/automation/env/clang.sh b/.github/azure/env/clang.sh similarity index 100% rename from .github/automation/env/clang.sh rename to .github/azure/env/clang.sh diff --git a/.github/automation/test.bat b/.github/azure/test.bat similarity index 100% rename from .github/automation/test.bat rename to .github/azure/test.bat diff --git a/.github/automation/test.sh b/.github/azure/test.sh similarity index 100% rename from .github/automation/test.sh rename to .github/azure/test.sh diff --git a/.github/labels.yml b/.github/labels.yml index 81e5e89ac72..98ec4fdde80 100644 --- a/.github/labels.yml +++ b/.github/labels.yml @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2024 Intel Corporation +# Copyright 2024-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,6 +41,33 @@ documentation: - changed-files: - any-glob-to-any-file: ['**/*.md', 'doc/**'] +# Common code +component:build: +- changed-files: + - any-glob-to-any-file: + - 'cmake/**' + - 'CMakeLists.txt' + +component:examples: +- changed-files: + - any-glob-to-any-file: 'examples/**' + +component:tests: +- changed-files: + - any-glob-to-any-file: 'tests/**' + +component:api: +- changed-files: + - any-glob-to-any-file: 'include/**' + +component:graph-api: +- changed-files: + - any-glob-to-any-file: + - 'src/graph/**' + - 'tests/benchdnn/graph/**' + - 'tests/gtests/graph/**' + - 'doc/graph/**' + # CPU Engine platform:cpu-aarch64: - changed-files: diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 8dde2c631ea..7e7f8473297 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,6 +1,6 @@ # Description -Please include a summary of the change. Please also include relevant motivation and context. See [contribution guidelines](https://github.com/oneapi-src/oneDNN/blob/master/CONTRIBUTING.md) for more details. If the change fixes an issue not documented in the project's Github issue tracker, please document all steps necessary to reproduce it. +Please include a summary of the change. Please also include relevant motivation and context. See [contribution guidelines](https://github.com/uxlfoundation/oneDNN/blob/main/CONTRIBUTING.md) for more details. If the change fixes an issue not documented in the project's Github issue tracker, please document all steps necessary to reproduce it. Fixes # (github issue) @@ -26,7 +26,7 @@ Fixes # (github issue) - [ ] Have you included information on how to reproduce the issue (either in a github issue or in this PR)? - [ ] Have you added relevant regression tests? -## [RFC](https://github.com/oneapi-src/oneDNN/tree/rfcs) PR +## [RFC](https://github.com/uxlfoundation/oneDNN/tree/rfcs) PR -- [ ] Does RFC document follow the [template](https://github.com/oneapi-src/oneDNN/blob/rfcs/rfcs/template.md#onednn-design-document-rfc)? +- [ ] Does RFC document follow the [template](https://github.com/uxlfoundation/oneDNN/blob/rfcs/rfcs/template.md#onednn-design-document-rfc)? - [ ] Have you added a link to the rendered document? diff --git a/.github/workflows/aarch64-acl.yml b/.github/workflows/aarch64-acl.yml new file mode 100644 index 00000000000..a241b5680da --- /dev/null +++ b/.github/workflows/aarch64-acl.yml @@ -0,0 +1,124 @@ +# ******************************************************************************* +# Copyright 2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +name: "Build ACL cache" + +#* To avoid duplicate jobs running when both push and PR is satisfied, we use this: +#* https://github.com/orgs/community/discussions/26940#discussioncomment-5686753 +on: + workflow_call: + workflow_dispatch: + +# Declare default permissions as read only. +permissions: read-all + +jobs: + # Cache is built sequentially to avoid cache-hit race conditions + build-cache: + strategy: + max-parallel: 1 + matrix: + config: [ + { name: MacOS, label: macos-14, threading: SEQ, toolset: clang, build: Release }, + { name: cb100, label: ubuntu-24.04-arm, threading: OMP, toolset: gcc, build: Release }, + { name: c6g, label: ah-ubuntu_22_04-c6g_2x-50, threading: OMP, toolset: clang, build: Debug }, + { name: c6g, label: ah-ubuntu_22_04-c6g_2x-50, threading: OMP, toolset: gcc, build: Release } + ] + + name: ${{ matrix.config.name }}, ${{ matrix.config.toolset }}, ${{ matrix.config.threading }}, ${{ matrix.config.build }} + runs-on: ${{ matrix.config.label }} + steps: + - name: Checkout oneDNN + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: oneDNN + + - name: Read version file + id: get-versions + run: | + content=`cat ${{ github.workspace }}/oneDNN/.github/automation/aarch64/ci.json` + content="${content//[$'\t\r\n$ ']}" + echo "output=$content" >> $GITHUB_OUTPUT + + - name: Clone ACL + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build_acl.sh + env: + ACL_ACTION: clone + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + ACL_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.acl }} + + - name: Get ACL commit hash for cache key + id: get_acl_commit_hash + run: (cd ${{ github.workspace }}/ComputeLibrary && echo "ACLCommitHash=$(git rev-parse --short HEAD)") >> $GITHUB_OUTPUT + + - name: Get system name + id: get_system_name + run: (echo "SystemName=$(uname)") >> $GITHUB_OUTPUT + + - name: Restore cached ACL + id: cache-acl-restore + uses: actions/cache/restore@v4 + with: + key: ${{ steps.get_system_name.outputs.SystemName }}-acl-${{ matrix.config.toolset }}-${{ matrix.config.build }}-${{ steps.get_acl_commit_hash.outputs.ACLCommitHash }} + path: ${{ github.workspace }}/ComputeLibrary/build + + - name: Install Scons (MacOS) + if: ${{ matrix.config.name == 'MacOS' && (steps.cache-acl-restore.outputs.cache-hit != 'true') }} + run: brew install scons + + - name: Install scons (Linux) + if: ${{ matrix.config.name != 'MacOS' && (steps.cache-acl-restore.outputs.cache-hit != 'true') }} + run: | + sudo apt update -y + sudo apt install -y scons + + - if: ${{ contains(matrix.config.label,'ubuntu') && (matrix.config.threading == 'OMP') && (steps.cache-acl-restore.outputs.cache-hit != 'true') }} + name: Install openmp + run: | + sudo apt install -y libomp-dev + + - if: ${{ contains(matrix.config.label,'ubuntu') && (matrix.config.toolset == 'gcc') && (steps.cache-acl-restore.outputs.cache-hit != 'true') }} + name: Install gcc + run: | + sudo add-apt-repository ppa:ubuntu-toolchain-r/test -y + sudo apt update -y + sudo apt install -y g++-${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + + - if: ${{ contains(matrix.config.label,'ubuntu') && (matrix.config.toolset == 'clang') && (steps.cache-acl-restore.outputs.cache-hit != 'true') }} + name: Install clang + uses: KyleMayes/install-llvm-action@6ba6e2cd3813def9879be378609d87cb3ef3bac3 + with: + version: ${{ fromJson(steps.get-versions.outputs.output).dependencies.clang }} + + - name: Build ACL + if: ${{ steps.cache-acl-restore.outputs.cache-hit != 'true' }} + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build_acl.sh + env: + ACL_ACTION: build + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + ACL_THREADING: ${{ matrix.config.threading }} + BUILD_TOOLSET: ${{ matrix.config.toolset }} + ACL_BUILD_TYPE: ${{ matrix.config.build }} + GCC_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + + - name: Save ACL in cache + id: cache-acl_build-save + if: ${{ steps.cache-acl-restore.outputs.cache-hit != 'true' }} + uses: actions/cache/save@v4 + with: + key: ${{ steps.get_system_name.outputs.SystemName }}-acl-${{ matrix.config.toolset }}-${{ matrix.config.build }}-${{ steps.get_acl_commit_hash.outputs.ACLCommitHash }} + path: ${{ github.workspace }}/ComputeLibrary/build diff --git a/.github/workflows/ci-aarch64.yml b/.github/workflows/ci-aarch64.yml new file mode 100644 index 00000000000..6b0a2923c97 --- /dev/null +++ b/.github/workflows/ci-aarch64.yml @@ -0,0 +1,241 @@ +# ******************************************************************************* +# Copyright 2024-2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +name: "CI AArch64" + +#* To avoid duplicate jobs running when both push and PR is satisfied, we use this: +#* https://github.com/orgs/community/discussions/26940#discussioncomment-5686753 +on: + push: + branches: [main, "rls-*"] + paths: + - ".github/automation/*" + - ".github/automation/aarch64" + - ".github/workflows/aarch64-acl.yml" + - ".github/workflows/ci-aarch64.yml" + - "cmake/**" + - "examples/**" + - "include/**" + - "src/common/**" + - "src/cpu/*" + - "src/cpu/aarch64/**" + - "tests/**" + - "CMakeLists.txt" + pull_request: + types: [opened, synchronize, reopened] + paths: + - ".github/automation/*" + - ".github/automation/aarch64" + - ".github/workflows/aarch64-acl.yml" + - ".github/workflows/ci-aarch64.yml" + - "cmake/**" + - "examples/**" + - "include/**" + - "src/common/**" + - "src/cpu/*" + - "src/cpu/aarch64/**" + - "tests/**" + - "CMakeLists.txt" + #* allow manual trigger of workflow when needed. + workflow_dispatch: + +#* Stop stale workflows when pull requests are updated: https://stackoverflow.com/a/70972844 +#* Does not apply to the main branch. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare default permissions as read only. +permissions: read-all + +jobs: + build-acl-cache: + uses: ./.github/workflows/aarch64-acl.yml + + build-and-test: + needs: build-acl-cache + strategy: + matrix: + config: [ + { name: MacOS, label: macos-14, threading: SEQ, toolset: clang, build: Release, testset: SMOKE }, + { name: cb100, label: ubuntu-24.04-arm, threading: OMP, toolset: gcc, build: Release, testset: SMOKE }, + { name: c6g, label: ah-ubuntu_22_04-c6g_4x-50, threading: OMP, toolset: gcc, build: Release, testset: CI }, + { name: c6g, label: ah-ubuntu_22_04-c6g_2x-50, threading: OMP, toolset: clang, build: Debug, testset: SMOKE }, + { name: c7g, label: ah-ubuntu_22_04-c7g_4x-50, threading: OMP, toolset: gcc, build: Release, testset: CI } + ] + + name: ${{ matrix.config.name }}, ${{ matrix.config.toolset }}, ${{ matrix.config.threading }}, ${{ matrix.config.build }} + runs-on: ${{ matrix.config.label }} + steps: + - name: Checkout oneDNN + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: oneDNN + + - name: Read version file + id: get-versions + run: | + content=`cat ${{ github.workspace }}/oneDNN/.github/automation/aarch64/ci.json` + content="${content//[$'\t\r\n$ ']}" + echo "output=$content" >> $GITHUB_OUTPUT + + # Note: This will create a github actions cache + - name: Get latest CMake and Ninja + uses: lukka/get-cmake@56d043d188c3612951d8755da8f4b709ec951ad6 # v3.31.6 + with: + cmakeVersion: 3.31.0 + ninjaVersion: 1.12.0 + + - if: ${{ (contains(matrix.config.label,'ubuntu') && (matrix.config.threading == 'OMP')) }} + name: Install openmp + run: | + sudo apt install -y libomp-dev + + - if: ${{ (contains(matrix.config.label,'ubuntu') && (matrix.config.toolset == 'gcc')) }} + name: Install gcc + run: | + sudo add-apt-repository ppa:ubuntu-toolchain-r/test -y + sudo apt update -y + sudo apt install -y g++-${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + + - if: ${{ (contains(matrix.config.label,'ubuntu') && (matrix.config.toolset == 'clang')) }} + name: Install clang + uses: KyleMayes/install-llvm-action@6ba6e2cd3813def9879be378609d87cb3ef3bac3 + with: + version: ${{ fromJson(steps.get-versions.outputs.output).dependencies.clang }} + + - name: setup python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install scipy + if: ${{ matrix.config.build == 'Release' }} + run: pip install scipy statistics + + - name: Clone ACL + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build_acl.sh + env: + ACL_ACTION: clone + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + ACL_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.acl }} + + - name: Get ACL commit hash for cache key + id: get_acl_commit_hash + run: (cd ${{ github.workspace }}/ComputeLibrary && echo "ACLCommitHash=$(git rev-parse --short HEAD)") >> $GITHUB_OUTPUT + + - name: Get system name + id: get_system_name + run: (echo "SystemName=$(uname)") >> $GITHUB_OUTPUT + + - name: Restore cached ACL + id: cache-acl-restore + uses: actions/cache/restore@v4 + with: + key: ${{ steps.get_system_name.outputs.SystemName }}-acl-${{ matrix.config.toolset }}-${{ matrix.config.build }}-${{ steps.get_acl_commit_hash.outputs.ACLCommitHash }} + path: ${{ github.workspace }}/ComputeLibrary/build + + - name: Configure oneDNN + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN + env: + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + BUILD_TOOLSET: ${{ matrix.config.toolset }} + CMAKE_BUILD_TYPE: ${{ matrix.config.build }} + CMAKE_GENERATOR: Ninja + GCC_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + ONEDNN_ACTION: configure + ONEDNN_TEST_SET: ${{ matrix.config.testset }} + ONEDNN_THREADING: ${{ matrix.config.threading }} + + - name: Build oneDNN + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN + env: + ONEDNN_ACTION: build + + - name: Run oneDNN tests + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/test.sh + working-directory: ${{ github.workspace }}/oneDNN/build + env: + BUILD_TOOLSET: ${{ matrix.config.toolset }} + CMAKE_BUILD_TYPE: ${{ matrix.config.build }} + CTEST_PARALLEL_LEVEL: 6 + DYLD_LIBRARY_PATH: ${{ github.workspace }}/ComputeLibrary/build + ONEDNN_THREADING: ${{ matrix.config.threading }} + + ## Performance test steps ## + - name: Checkout oneDNN base + if: ${{ github.event_name == 'pull_request' && matrix.config.build == 'Release' && matrix.config.name != 'cb100' }} + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + ref: ${{ github.base_ref }} + path: oneDNN_base + + - name: Configure oneDNN base + if: ${{ github.event_name == 'pull_request' && matrix.config.build == 'Release' && matrix.config.name != 'cb100' }} + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN_base + env: + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + BUILD_TOOLSET: ${{ matrix.config.toolset }} + CMAKE_BUILD_TYPE: ${{ matrix.config.build }} + CMAKE_GENERATOR: Ninja + GCC_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + ONEDNN_ACTION: configure + ONEDNN_TEST_SET: ${{ matrix.config.testset }} + ONEDNN_THREADING: ${{ matrix.config.threading }} + + - name: Build oneDNN base + if: ${{ github.event_name == 'pull_request' && matrix.config.build == 'Release' && matrix.config.name != 'cb100' }} + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN_base + env: + ONEDNN_ACTION: build + + - name: Run performance tests + shell: bash + if: ${{ github.event_name == 'pull_request' && matrix.config.build == 'Release' && matrix.config.name != 'cb100' }} + run: | + OMP_NUM_THREADS=4 bash ${{ github.workspace }}/oneDNN/.github/automation/performance/bench_pr_performance.sh ${{ github.workspace }}/oneDNN_base/build/tests/benchdnn/benchdnn ${{ github.workspace }}/oneDNN/build/tests/benchdnn/benchdnn base_4.txt new_4.txt + OMP_NUM_THREADS=16 bash ${{ github.workspace }}/oneDNN/.github/automation/performance/bench_pr_performance.sh ${{ github.workspace }}/oneDNN_base/build/tests/benchdnn/benchdnn ${{ github.workspace }}/oneDNN/build/tests/benchdnn/benchdnn base_16.txt new_16.txt + env: + DYLD_LIBRARY_PATH: ${{ github.workspace }}/ComputeLibrary/build + + - name: Compare performance test results + if: ${{ github.event_name == 'pull_request' && matrix.config.build == 'Release' && matrix.config.name != 'cb100' }} + id: performance-test + continue-on-error: true + run: | + echo "4 threads:" + python ${{ github.workspace }}/oneDNN/.github/automation/performance/benchdnn_comparison.py base_4.txt new_4.txt + echo "16 threads:" + python ${{ github.workspace }}/oneDNN/.github/automation/performance/benchdnn_comparison.py base_16.txt new_16.txt + + - name: Check performance test failure + if: ${{ steps.performance-test.outputs.pass != 'True' && github.event_name == 'pull_request' && matrix.config.build == 'Release' && matrix.config.name != 'cb100' }} + run: echo "::warning file=.github/workflows/ci-aarch64.yml,line=1,col=1::${{ steps.performance-test.outputs.message }}" + + # This job adds a check named "CI AArch64" that represents overall + # workflow status and can be used in branch rulesets + status: + needs: build-and-test + runs-on: ubuntu-latest + name: "CI AArch64" + steps: + - name: Print success + run: echo Success diff --git a/.github/workflows/clang-tidy.yml b/.github/workflows/clang-tidy.yml new file mode 100644 index 00000000000..f4bd097bb50 --- /dev/null +++ b/.github/workflows/clang-tidy.yml @@ -0,0 +1,85 @@ +name: "Clang-Tidy" + +on: + pull_request: + types: [opened, edited, synchronize, reopened] + paths: + - ".github/automation/x64/**" + - ".github/workflows/clang-tidy.yml" + - "cmake/**" + - "examples/**" + - "include/**" + - "src/common/**" + - "src/cpu/*" + - "src/cpu/gemm/**" + - "src/cpu/matmul/**" + - "src/cpu/reorder/**" + - "src/cpu/rnn/**" + - "src/cpu/x64/**" + - "src/gpu/*" + - "src/gpu/intel/**" + - "src/graph/**" + - "tests/**" + - "CMakeLists.txt" + +## Declare default permissions as read only. +permissions: read-all + +# Kill stale checks +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + pr-clang-tidy: + name: Clang-Tidy + runs-on: ubuntu-latest + steps: + - name: Checkout oneDNN + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + + - name: Install clang + run: | + sudo apt-get update + sudo apt-get install -y clang libomp-dev ocl-icd-libopencl1 ocl-icd-opencl-dev + + - name: Configure oneDNN + run: .github/automation/x64/build_linters.sh + env: + ONEDNN_ACTION: configure + + - name: Check source files + run: | + echo -e "Checking Clang-Tidy $(clang-tidy --version)\n" + touch source-check.log + for file in $(git diff --name-only ${{ github.event.pull_request.head.sha }} ${{ github.event.pull_request.base.sha }} | grep -E '\.cpp$'); + do + if grep -q "$file" "build/compile_commands.json"; then + echo -e "\nAnalyzing $file" + clang-tidy -p build --header-filter='' $file 2>&1 | tee -a source-check.log + else + echo "Skipped $file as it's not built in x64 OpenMP/OpenCL configuration." + fi + done + grep -i -E "warning:|error:" source-check.log | sort -u + grep -q -i -E "warning:|error:" source-check.log && exit 1 || true + + - name: Check header files + if: always() + continue-on-error: true + run: | + echo -e "Checking Clang-Tidy $(clang-tidy --version)\n" + touch headers-check.log + for file in $(git diff --name-only ${{ github.event.pull_request.head.sha }} ${{ github.event.pull_request.base.sha }} | grep -E '\.cpp$'); + do + if grep -q "$file" "build/compile_commands.json"; then + echo -e "\nAnalyzing $file" + clang-tidy -p build $file 2>&1 | tee -a headers-check.log + else + echo "Skipped $file as it's not built in x64 OpenMP/OpenCL configuration." + fi + done + grep -i -E "warning:|error:" headers-check.log | sort -u + grep -q -i -E "warning:|error:" headers-check.log && exit 1 || true diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 09d2f175a7b..64d4471fb87 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -28,7 +28,7 @@ jobs: pull-requests: write steps: - - uses: actions/labeler@v5.0.0 + - uses: actions/labeler@8558fd74291d67161a8a78ce36a881fa63b766a9 # v5.0.0 with: sync-labels: true configuration-path: '.github/labels.yml' diff --git a/.github/workflows/nightly-aarch64.yml b/.github/workflows/nightly-aarch64.yml new file mode 100644 index 00000000000..d8086ae2f66 --- /dev/null +++ b/.github/workflows/nightly-aarch64.yml @@ -0,0 +1,146 @@ +# ******************************************************************************* +# Copyright 2024-2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +name: "Nightly AArch64" + +on: + #* allow manual trigger of workflow when needed. Useful for a nightly. + workflow_dispatch: + schedule: + #* minute (0-59) hour (0-23) day (1-31) month (1-12) day of the week (0 - 6) + #* cron jobs run on the default (main) branch. + #* set to run at 5am UCT + - cron: "0 5 * * *" + +#* Stop stale workflows, though we should never hit this unless it hangs for a whole day. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +# Declare default permissions as read only. +permissions: read-all + +jobs: + build-acl-cache: + uses: ./.github/workflows/aarch64-acl.yml + + test-performance: + uses: ./.github/workflows/performance-aarch64.yml + + build-and-test: + needs: build-acl-cache + strategy: + matrix: + config: [ + { name: c6g, label: ah-ubuntu_22_04-c6g_8x-100, threading: OMP, toolset: gcc, build: Release, testset: NIGHTLY }, + { name: c7g, label: ah-ubuntu_22_04-c7g_8x-100, threading: OMP, toolset: gcc, build: Release, testset: NIGHTLY } + ] + + name: ${{ matrix.config.name }}, ${{ matrix.config.toolset }}, ${{ matrix.config.threading }}, ${{ matrix.config.build }} + runs-on: ${{ matrix.config.label }} + steps: + + - name: Checkout oneDNN + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: oneDNN + + # Note: This will create a github actions cache + - name: Get latest CMake and Ninja + uses: lukka/get-cmake@56d043d188c3612951d8755da8f4b709ec951ad6 # v3.31.6 + with: + cmakeVersion: 3.31.0 + ninjaVersion: 1.12.0 + + - if: ${{ matrix.config.threading == 'OMP' }} + name: Install openmp + run: | + sudo apt install -y libomp-dev + + - name: Read version file + id: get-versions + run: | + content=`cat ${{ github.workspace }}/oneDNN/.github/automation/aarch64/ci.json` + content="${content//[$'\t\r\n$ ']}" + echo "output=$content" >> $GITHUB_OUTPUT + + - name: Install gcc + run: | + sudo add-apt-repository ppa:ubuntu-toolchain-r/test -y + sudo apt update -y + sudo apt install -y g++-${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + + - name: Clone ACL + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build_acl.sh + env: + ACL_ACTION: clone + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + ACL_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.acl }} + + - name: Get ACL commit hash for cache key + id: get_acl_commit_hash + run: (cd ${{ github.workspace }}/ComputeLibrary && echo "ACLCommitHash=$(git rev-parse --short HEAD)") >> $GITHUB_OUTPUT + + - name: Get system name + id: get_system_name + run: (echo "SystemName=$(uname)") >> $GITHUB_OUTPUT + + - name: Restore cached ACL + id: cache-acl-restore + uses: actions/cache/restore@v4 + with: + key: ${{ steps.get_system_name.outputs.SystemName }}-acl-${{ matrix.config.toolset }}-${{ matrix.config.build }}-${{ steps.get_acl_commit_hash.outputs.ACLCommitHash }} + path: ${{ github.workspace }}/ComputeLibrary/build + + - name: Configure oneDNN + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN + env: + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + BUILD_TOOLSET: ${{ matrix.config.toolset }} + CMAKE_BUILD_TYPE: ${{ matrix.config.build }} + CMAKE_GENERATOR: Ninja + GCC_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + ONEDNN_ACTION: configure + ONEDNN_TEST_SET: ${{ matrix.config.testset }} + ONEDNN_THREADING: ${{ matrix.config.threading }} + + - name: Build oneDNN + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN + env: + ONEDNN_ACTION: build + + - name: Run oneDNN tests + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/test.sh + working-directory: ${{ github.workspace }}/oneDNN/build + env: + BUILD_TOOLSET: ${{ matrix.config.toolset }} + CMAKE_BUILD_TYPE: ${{ matrix.config.build }} + CTEST_PARALLEL_LEVEL: 8 + DYLD_LIBRARY_PATH: ${{ github.workspace }}/ComputeLibrary/build + ONEDNN_THREADING: ${{ matrix.config.threading }} + + #* This job adds a check named "Nightly AArch64" that represents overall + #* workflow status and can be used in branch rulesets + status: + needs: build-and-test + runs-on: ubuntu-latest + name: "Nightly AArch64" + steps: + - name: Print success + run: echo Success diff --git a/.github/workflows/openssf-scorecard.yml b/.github/workflows/openssf-scorecard.yml index 60d94920c1b..2b4efc813dc 100644 --- a/.github/workflows/openssf-scorecard.yml +++ b/.github/workflows/openssf-scorecard.yml @@ -41,12 +41,12 @@ jobs: steps: - name: "Checkout code" - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0 + uses: ossf/scorecard-action@f49aabe0b5af0936a0987cfb85d86b75731b0186 # v2.4.1 with: results_file: results.sarif results_format: sarif diff --git a/.github/workflows/performance-aarch64.yml b/.github/workflows/performance-aarch64.yml new file mode 100644 index 00000000000..b142240db7d --- /dev/null +++ b/.github/workflows/performance-aarch64.yml @@ -0,0 +1,171 @@ +# ******************************************************************************* +# Copyright 2024-2025 Arm Limited and affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +name: "Performance AArch64" + +on: + workflow_call: + +#* Stop stale workflows +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-performance + cancel-in-progress: true + +# Declare default permissions as read only. +permissions: read-all + +jobs: + build-acl-cache: + uses: ./.github/workflows/aarch64-acl.yml + + build-and-test-performance: + needs: build-acl-cache + strategy: + matrix: + config: [ + { name: c7g, label: ah-ubuntu_22_04-c7g_m-100, threading: OMP, toolset: gcc, build: Release, testset: NIGHTLY } + ] + + name: ${{ matrix.config.name }}, ${{ matrix.config.toolset }}, ${{ matrix.config.threading }}, ${{ matrix.config.build }} + runs-on: ${{ matrix.config.label }} + steps: + + - name: Checkout oneDNN + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: oneDNN + + # Note: This will create a github actions cache + - name: Get latest CMake and Ninja + uses: lukka/get-cmake@56d043d188c3612951d8755da8f4b709ec951ad6 # v3.31.6 + with: + cmakeVersion: 3.31.0 + ninjaVersion: 1.12.0 + + - if: ${{ matrix.config.threading == 'OMP' }} + name: Install openmp + run: | + sudo apt install -y libomp-dev + + - name: Read version file + id: get-versions + run: | + content=`cat ${{ github.workspace }}/oneDNN/.github/automation/aarch64/ci.json` + content="${content//[$'\t\r\n$ ']}" + echo "output=$content" >> $GITHUB_OUTPUT + + - name: Install gcc + run: | + sudo add-apt-repository ppa:ubuntu-toolchain-r/test -y + sudo apt update -y + sudo apt install -y g++-${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + + - name: setup python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install scipy + if: ${{ matrix.config.build == 'Release' }} + run: pip install scipy statistics + + - name: Clone ACL + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build_acl.sh + env: + ACL_ACTION: clone + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + ACL_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.acl }} + + - name: Get ACL commit hash for cache key + id: get_acl_commit_hash + run: (cd ${{ github.workspace }}/ComputeLibrary && echo "ACLCommitHash=$(git rev-parse --short HEAD)") >> $GITHUB_OUTPUT + + - name: Get system name + id: get_system_name + run: (echo "SystemName=$(uname)") >> $GITHUB_OUTPUT + + - name: Restore cached ACL + id: cache-acl-restore + uses: actions/cache/restore@v4 + with: + key: ${{ steps.get_system_name.outputs.SystemName }}-acl-${{ matrix.config.toolset }}-${{ matrix.config.build }}-${{ steps.get_acl_commit_hash.outputs.ACLCommitHash }} + path: ${{ github.workspace }}/ComputeLibrary/build + + - name: Configure oneDNN + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN + env: + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + BUILD_TOOLSET: ${{ matrix.config.toolset }} + CMAKE_BUILD_TYPE: ${{ matrix.config.build }} + CMAKE_GENERATOR: Ninja + GCC_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + ONEDNN_ACTION: configure + ONEDNN_TEST_SET: ${{ matrix.config.testset }} + ONEDNN_THREADING: ${{ matrix.config.threading }} + + - name: Build oneDNN + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN + env: + ONEDNN_ACTION: build + + - name: Checkout oneDNN base + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + ref: ${{ fromJson(steps.get-versions.outputs.output).dependencies.onednn-base }} + path: oneDNN_base + + - name: Configure oneDNN base + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN_base + env: + ACL_ROOT_DIR: ${{ github.workspace }}/ComputeLibrary + BUILD_TOOLSET: ${{ matrix.config.toolset }} + CMAKE_BUILD_TYPE: ${{ matrix.config.build }} + CMAKE_GENERATOR: Ninja + GCC_VERSION: ${{ fromJson(steps.get-versions.outputs.output).dependencies.gcc }} + ONEDNN_ACTION: configure + ONEDNN_TEST_SET: ${{ matrix.config.testset }} + ONEDNN_THREADING: ${{ matrix.config.threading }} + + - name: Build oneDNN base + run: ${{ github.workspace }}/oneDNN/.github/automation/aarch64/build.sh + working-directory: ${{ github.workspace }}/oneDNN_base + env: + ONEDNN_ACTION: build + + - name: Run performance tests + shell: bash + run: | + OMP_NUM_THREADS=16 bash ${{ github.workspace }}/oneDNN/.github/automation/performance/bench_nightly_performance.sh ${{ github.workspace }}/oneDNN_base/build/tests/benchdnn/benchdnn ${{ github.workspace }}/oneDNN/build/tests/benchdnn/benchdnn base.txt new.txt + env: + DYLD_LIBRARY_PATH: ${{ github.workspace }}/ComputeLibrary/build + + - name: Compare 16 threads performance test results + run: | + python ${{ github.workspace }}/oneDNN/.github/automation/performance/benchdnn_comparison.py base.txt new.txt + + #* This job adds a check named "Nightly Performance AArch64" that represents overall + #* workflow status and can be used in branch rulesets + status: + needs: build-and-test-performance + runs-on: ubuntu-latest + name: "Nightly Performance AArch64" + steps: + - name: Print success + run: echo Success diff --git a/.github/workflows/pr-linter.yml b/.github/workflows/pr-linter.yml new file mode 100644 index 00000000000..4ba07ddaa15 --- /dev/null +++ b/.github/workflows/pr-linter.yml @@ -0,0 +1,98 @@ +# ******************************************************************************* +# Copyright 2024 Arm Limited and affiliates. +# Copyright 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ******************************************************************************* + +name: "PR Linters" + +on: + pull_request: + types: [opened, edited, synchronize, reopened] + +# Declare default permissions as read only. +permissions: read-all + +# Kill stale checks +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + pr-commits: + name: Commit messages + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Check commit messages + run: python3 ./.github/automation/commit-msg-check.py "${{ github.event.pull_request.head.sha }}" "${{ github.event.pull_request.base.sha }}" + + pr-clang-format: + name: Clang-Format + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install clang-format + run: sudo apt update && sudo apt install -y "clang-format-11" + - name: Check code formatting + run: .github/automation/clang-format.sh + + pr-format-tags: + name: Format tags consistency + runs-on: ubuntu-latest + steps: + - name: Checkout oneDNN + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: source + + - name: Install clang + run: | + sudo apt-get update + sudo apt-get install -y clang libomp-dev + + - name: Install castxml package + run: | + python -m venv venv + source venv/bin/activate + python -m pip install --no-cache-dir --disable-pip-version-check castxml + + - name: Configure oneDNN + run: .github/automation/x64/build_linters.sh + working-directory: ${{ github.workspace }}/source + env: + ONEDNN_ACTION: configure + + - name: Check format-tags + run: | + venv/bin/castxml --castxml-cc-gnu-c clang --castxml-output=1 -I${{ github.workspace }}/source/include -I${{ github.workspace }}/build/include ${{ github.workspace }}/source/include/oneapi/dnnl/dnnl_types.h -o ${{ github.workspace }}/types.xml + python ${{ github.workspace }}/source/scripts/generate_dnnl_debug.py ${{ github.workspace }}/types.xml + python ${{ github.workspace }}/source/scripts/generate_format_tags.py + cd ${{ github.workspace }}/source/ + git diff | grep . && exit 1 || true + + pr-status: + name: Formatting + runs-on: ubuntu-latest + needs: [ pr-commits, pr-clang-format, pr-format-tags ] + steps: + - name: Print success + run: echo "Success" \ No newline at end of file diff --git a/.gitignore b/.gitignore index db170326a5d..381a4f20f45 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ #=============================================================================== # Copyright 2019-2021 Intel Corporation -# Copyright 2024 Arm Limited and affiliates. +# Copyright 2024-2025 Arm Limited and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ # limitations under the License. #=============================================================================== -build +/build* external .vs .vscode diff --git a/CITATION.cff b/CITATION.cff index 1598115c9a1..84eb0055376 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -8,8 +8,8 @@ message: >- type: software authors: - name: oneDNN Contributors -repository-code: 'https://github.com/oneapi-src/oneDNN' -url: 'https://oneapi-src.github.io/oneDNN' +repository-code: 'https://github.com/uxlfoundation/oneDNN' +url: 'https://uxlfoundation.github.io/oneDNN' abstract: >- oneAPI Deep Neural Network Library (oneDNN) is an open-source cross-platform performance library of basic building blocks for deep learning applications. @@ -18,4 +18,4 @@ abstract: >- oneDNN has experimental support for the following architectures: NVIDIA GPU, AMD GPU, OpenPOWER Power ISA (PPC64), IBMz (s390x), and RISC-V. license: Apache-2.0 -version: v3.6 +version: v3.8 diff --git a/CMakeLists.txt b/CMakeLists.txt index af2522a0721..a106cea3122 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2016-2019 Intel Corporation +# Copyright 2016-2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,54 +14,26 @@ # limitations under the License. #=============================================================================== -cmake_minimum_required(VERSION 2.8.12) +cmake_minimum_required(VERSION 3.13) -if(POLICY CMP0022) - cmake_policy(SET CMP0022 NEW) -endif() - -# Foo::Bar always refers to an IMPORTED target -if(POLICY CMP0028) - cmake_policy(SET CMP0028 NEW) -endif() - -if(POLICY CMP0054) - cmake_policy(SET CMP0054 NEW) -endif() - -# Enable RPATH on MacOS/OSX -if(POLICY CMP0042) - cmake_policy(SET CMP0042 NEW) -endif() - -# Do not export symbols from executables -if(POLICY CMP0065) - cmake_policy(SET CMP0065 NEW) -endif() - -# Pass linker flags to try_compile -if(POLICY CMP0056) - cmake_policy(SET CMP0056 NEW) -endif() - -# Always link with full path -if(POLICY CMP0060) - cmake_policy(SET CMP0060 NEW) -endif() +# CMake minimum required version enables all policies introduced in minimum +# version and earlier versions. Policies introduced in future versions +# are handled individually in the section below. -# Pass compiler flags to try_compile -if(POLICY CMP0066) - cmake_policy(SET CMP0066 NEW) +# CMake 3.14: Install rules from add_subdirectory() calls are interleaved +# with those in caller. +if(POLICY CMP0082) + cmake_policy(SET CMP0082 NEW) endif() -# Use _ROOT env. variable as a prefix -if(POLICY CMP0074) - cmake_policy(SET CMP0074 NEW) +# CMake 3.27: The FindPythonInterp and FindPythonLibs modules are removed. +if(POLICY CMP0148) + cmake_policy(SET CMP0148 NEW) endif() -# Install rules order -if(POLICY CMP0082) - cmake_policy(SET CMP0082 NEW) +# CMake 3.27: The FindCUDA module is removed. +if(POLICY CMP0146) + cmake_policy(SET CMP0146 OLD) endif() if("${CMAKE_BUILD_TYPE}" STREQUAL "") @@ -69,26 +41,31 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "") set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel RelWithAssert RelWithMDd...") endif() +if (CMAKE_GENERATOR MATCHES "^Visual Studio") + message(STATUS +"oneDNN build configuration is based on the CMAKE_BUILD_TYPE value, but + the CMake generator '${CMAKE_GENERATOR}' does not respect it and requires + using the --config option to choose the build type. Changing the build type + using the --config option requires rerunning CMake from scratch with a + matching CMAKE_BUILD_TYPE value.") +endif() set(PROJECT_NAME "oneDNN") set(PROJECT_FULL_NAME "oneAPI Deep Neural Network Library (oneDNN)") -set(PROJECT_VERSION "3.6.0") +set(PROJECT_VERSION "3.8.0") -if (CMAKE_VERSION VERSION_LESS 3.0) - project(${PROJECT_NAME} C CXX) -else() - cmake_policy(SET CMP0048 NEW) - project(${PROJECT_NAME} VERSION "${PROJECT_VERSION}" LANGUAGES C CXX) -endif() +project(${PROJECT_NAME} VERSION "${PROJECT_VERSION}" LANGUAGES C CXX) if (NOT CMAKE_SIZEOF_VOID_P EQUAL 8) - message(FATAL_ERROR "oneDNN supports 64 bit platforms only") + message(WARNING "oneDNN officially supports 64 bit platforms only") endif() # Set the target architecture. if(NOT DNNL_TARGET_ARCH) if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64.*|AARCH64.*|arm64.*|ARM64.*)") set(DNNL_TARGET_ARCH "AARCH64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(arm.*|ARM.*)") + set(DNNL_TARGET_ARCH "ARM") elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(ppc64.*|PPC64.*|powerpc64.*)") set(DNNL_TARGET_ARCH "PPC64") elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x.*|S390X.*)") @@ -133,30 +110,22 @@ include("cmake/host_compiler.cmake") include("cmake/configuring_primitive_list.cmake") if(UNIX OR MINGW) - if(CMAKE_VERSION VERSION_LESS "3.1.0") - # No CMAKE__STANDARD, so add directly to CMAKE__FLAGS - # (prepended so the user can override) - set(CMAKE_C_FLAGS "-std=c99 ${CMAKE_C_FLAGS}") - # Let SYCL to choose the C++ standard it needs. - if(NOT DNNL_WITH_SYCL) - set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS}") - endif() - else() - # CMAKE__STANDARD support, so set it to our defaults, unless - # overridden by the user - if(NOT DEFINED CMAKE_C_STANDARD) - set(CMAKE_C_STANDARD 99) - endif() - if(NOT DEFINED CMAKE_CXX_STANDARD AND NOT DNNL_WITH_SYCL) - set(CMAKE_CXX_STANDARD 11) - endif() - - # Disable -std=gnuXX and -std=gnu++XX - set(CMAKE_C_EXTENSIONS OFF) - set(CMAKE_CXX_EXTENSIONS OFF) + # CMAKE__STANDARD support, so set it to our defaults, unless + # overridden by the user + if(NOT DEFINED CMAKE_C_STANDARD) + set(CMAKE_C_STANDARD 99) + endif() + if(NOT DEFINED CMAKE_CXX_STANDARD AND NOT DNNL_WITH_SYCL) + set(CMAKE_CXX_STANDARD 11) endif() -endif() + # Disable -std=gnuXX and -std=gnu++XX + set(CMAKE_C_EXTENSIONS OFF) + set(CMAKE_CXX_EXTENSIONS OFF) +endif() +if (ANDROID) + set(CMAKE_CXX_STANDARD 20) +endif() # Handle cases when OpenMP runtime is requested but not found: override CPU # runtime to be sequential if(DNNL_CPU_RUNTIME STREQUAL "OMP" AND @@ -179,7 +148,7 @@ configure_file( "${PROJECT_BINARY_DIR}/README" ) -if(DNNL_INSTALL_MODE MATCHES "^(BUNDLE|BUNDLE_V2)$" AND NOT DEFINED CMAKE_INSTALL_LIBDIR) +if(DNNL_INSTALL_MODE STREQUAL "BUNDLE" AND NOT DEFINED CMAKE_INSTALL_LIBDIR) # define CMAKE_INSTALL_LIBDIR as "lib" in the case of bundle set(CMAKE_INSTALL_LIBDIR "lib") endif() @@ -192,10 +161,6 @@ add_subdirectory(examples) add_subdirectory(tests) if(DNNL_INSTALL_MODE STREQUAL "BUNDLE") - install(FILES LICENSE DESTINATION ${CMAKE_INSTALL_PREFIX}) - install(FILES THIRD-PARTY-PROGRAMS DESTINATION ${CMAKE_INSTALL_PREFIX}) - install(FILES ${PROJECT_BINARY_DIR}/README DESTINATION ${CMAKE_INSTALL_PREFIX}) -elseif(DNNL_INSTALL_MODE STREQUAL "BUNDLE_V2") install(FILES LICENSE DESTINATION ${CMAKE_INSTALL_DATAROOTDIR}/doc/${LIB_PACKAGE_NAME}) install(FILES THIRD-PARTY-PROGRAMS DESTINATION ${CMAKE_INSTALL_DATAROOTDIR}/doc/${LIB_PACKAGE_NAME}) install(FILES ${PROJECT_BINARY_DIR}/README DESTINATION ${CMAKE_INSTALL_DATAROOTDIR}/doc/${LIB_PACKAGE_NAME}) diff --git a/CODING_STANDARDS.md b/CODING_STANDARDS.md index bfe16cae060..4698cebdd3b 100644 --- a/CODING_STANDARDS.md +++ b/CODING_STANDARDS.md @@ -25,7 +25,7 @@ oneDNN uses [clang-tidy](https://clang.llvm.org/extra/clang-tidy/) in order to diagnose and fix common style violations and easy-to-fix issues in the code base. For instructions on how to use `clang-tidy`, please refer to the [clang-tidy -RFC](https://github.com/oneapi-src/oneDNN/blob/rfcs/rfcs/20200813-clang-tidy/README.md). +RFC](https://github.com/uxlfoundation/oneDNN/blob/rfcs/rfcs/20200813-clang-tidy/README.md). The list of clang-tidy checks the oneDNN code base follows is available in the `.clang-tidy` file found in the oneDNN top-level directory. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ea8c718f80e..100d8429486 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,8 +7,8 @@ requests! To get started, see the GitHub You can: - Submit your changes directly with a - [pull request](https://github.com/oneapi-src/oneDNN/pulls) -- Log a bug or feedback with an [issue](https://github.com/oneapi-src/oneDNN/issues) + [pull request](https://github.com/uxlfoundation/oneDNN/pulls) +- Log a bug or feedback with an [issue](https://github.com/uxlfoundation/oneDNN/issues) **See also:** [Contributor Covenant](CODE_OF_CONDUCT.md) code of conduct. @@ -54,7 +54,7 @@ For Comments (RFC) process, which consists of opening, discussing, and accepting (promoting) RFC pull requests. More information about the process can be found in the dedicated -[`rfcs`](https://github.com/oneapi-src/oneDNN/tree/rfcs) branch. +[`rfcs`](https://github.com/uxlfoundation/oneDNN/tree/rfcs) branch. ## Code contribution guidelines @@ -146,7 +146,7 @@ Use the following command to run tests selected by a build configuration: ``` To modify the coverage, use the -[`ONEDNN_TEST_SET`](https://oneapi-src.github.io/oneDNN/dev_guide_build_options.html#onednn-test-set) +[`ONEDNN_TEST_SET`](https://uxlfoundation.github.io/oneDNN/dev_guide_build_options.html#onednn-test-set) build option. More details on how to run benchdnn can be found in diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 212a815977f..c5bc3a05772 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -35,7 +35,7 @@ Privileges: ## Code Owner A Code Owner has responsibility for a specific project component or a functional -area. Code Owners are collectively responsible, with other Code Owners, +area. Code Owners are collectively responsible, with other Code Owners, for developing and maintaining their component or functional areas, including reviewing all changes to their their areas of responsibility and indicating whether those changes are ready to merge. They have a track record of @@ -72,7 +72,7 @@ including name, Github username, and affiliation. 2. At least two specific component Maintainers approve the PR. ## Maintainer -Maintainers are the most established contributors who are responsible for the +Maintainers are the most established contributors who are responsible for the project technical direction and participate in making decisions about the strategy and priorities of the project. @@ -100,7 +100,7 @@ Privileges: * Can recommend Code Owners to become Maintainers. Process of becoming a maintainer: -1. A Maintainer may nominate a current Reviewer to become a new Maintainer by +1. A Maintainer may nominate a current Reviewer to become a new Maintainer by opening a PR against MAINTAINERS.md file. 2. A majority of the current Maintainers must then approve the PR. @@ -108,7 +108,7 @@ opening a PR against MAINTAINERS.md file. ## Core (API, Architecture, Tests) -Team: @oneapi-src/onednn-arch +Team: @uxlfoundation/onednn-arch | Name | Github ID | Affiliation | Role | | ----------------- | --------------------- | ----------------- | ---------- | @@ -117,10 +117,11 @@ Team: @oneapi-src/onednn-arch | Mourad Gouicem | @mgouicem | Intel Corporation | Maintainer | | Vadim Pirogov | @vpirogov | Intel Corporation | Maintainer | | Ankit Manerikar | @avmanerikar | Intel Corporation | Code Owner | +| Stefan Palicki | @spalicki | Intel Corporation | Code Owner | ## Graph API -Team: @oneapi-src/onednn-graph +Team: @uxlfoundation/onednn-graph | Name | Github ID | Affiliation | Role | | ------------------ | --------------------- | ----------------- | ---------- | @@ -130,39 +131,39 @@ Team: @oneapi-src/onednn-graph | Shaojie Cui | @ShanSimu | Intel Corporation | Code Owner | | Yonghao Gu | @gyhintel | Intel Corporation | Code Owner | | Rong Zhang | @rongzha1 | Intel Corporation | Code Owner | -| Zhailong Mu | @muzhailong | Intel Corporation | Code Owner | | Xiang Guo | @xiang1guo | Intel Corporation | Code Owner | -| Jiaming Song | @litchilitchy | Intel Corporation | Code Owner | | Yixin Bao | @ElaineBao | Intel Corporation | Code Owner | ## CPU Engine ### x64 -Team: @oneapi-src/onednn-cpu-x64 +Team: @uxlfoundation/onednn-cpu-x64 | Name | Github ID | Affiliation | Role | | ------------------ | --------------------- | ----------------- | ---------- | | Andrey Kalinin | @ankalinin | Intel Corporation | Maintainer | -| Srinivas Putta | @nivas-x86 | Intel Corporation | Maintainer | | Tatyana Primak | @tprimak | Intel Corporation | Maintainer | +| Alexey Makarevich | @amakarev | Intel Corporation | Code Owner | | David Eberius | @davideberius | Intel Corporation | Code Owner | -| John Karasev | @karashjoh000 | Intel Corporation | Code Owner | | Stefan Palicki | @spalicki | Intel Corporation | Code Owner | | Tomasz Czeszun | @tczeszun | Intel Corporation | Code Owner | -| Xuxin Zen | @xuxinzen | Intel Corporation | Code Owner | +| Xuxin Zeng | @xuxinzen | Intel Corporation | Code Owner | ### AArch64 -Team: @oneapi-src/onednn-cpu-aarch64 +Team: @uxlfoundation/onednn-cpu-aarch64 | Name | Github ID | Affiliation | Role | | ------------------ | --------------------- | ----------------- | ---------- | +| Hamza Butt | @theComputeKid | Arm Ltd | Maintainer | | Crefeda Rodrigues | @cfrod | Arm Ltd | Code Owner | | David Svantesson | @davsva01 | Arm Ltd | Code Owner | -| Johnatan Deakin | @jondea | Arm Ltd | Code Owner | -| Hamza Butt | @theComputeKid | Arm Ltd | Code Owner | +| Jonathan Deakin | @jondea | Arm Ltd | Code Owner | +| Radu Salavat | @Radu2k | Arm Ltd | Code Owner | +| Siddhartha Menon | @Sqvid | Arm Ltd | Code Owner | | Sunita Nadampalli | @snadampal | Amazon.com, Inc. | Code Owner | +| Ryo Suzuki | @Ryo-not-rio | Arm Ltd | Code Owner | ### OpenPOWER (PPC64) @@ -184,7 +185,7 @@ Vacant. Maintained by Core team. ### Intel -Team: @oneapi-src/onednn-gpu-intel +Team: @uxlfoundation/onednn-gpu-intel | Name | Github ID | Affiliation | Role | | ------------------ | --------------------- | ----------------- | ---------- | @@ -192,9 +193,11 @@ Team: @oneapi-src/onednn-gpu-intel | Konstantin Arturov | @karturov | Intel Corporation | Maintainer | | Peter Caday | @petercad | Intel Corporation | Maintainer | | Andy Kassen | @atkassen | Intel Corporation | Code Owner | +| Daniel Youssif | @dyoussif | Intel Corporation | Code Owner | | Haleema Sadia | @h-sadia | Intel Corporation | Code Owner | | Andrey Guskov | @hidefromkgb | Intel Corporation | Code Owner | | Gallagher Pryor | @pv-pterab-s | Intel Corporation | Code Owner | +| Kealan Barbieri | @kealan-barbieri | Intel Corporation | Code Owner | | Roy Oursler | @rjoursler | Intel Corporation | Code Owner | | Simon Ewing | @Simonsays095 | Intel Corporation | Code Owner | | Sergey Kazakov | @skazakov1 | Intel Corporation | Code Owner | @@ -204,42 +207,44 @@ Team: @oneapi-src/onednn-gpu-intel ### NVIDIA, AMD, and generic GPU Teams: -* @oneapi-src/onednn-gpu-nvidia -* @oneapi-src/onednn-gpu-amd -* @oneapi-src/onednn-gpu-generic +* @uxlfoundation/onednn-gpu-nvidia +* @uxlfoundation/onednn-gpu-amd +* @uxlfoundation/onednn-gpu-generic | Name | Github ID | Affiliation | Role | | ------------------ | --------------------- | ----------------- | ---------- | -| Dylan Angus | @dylan-angus-codeplay | Codeplay Software | Code Owner | -| John Osorio | @kala85 | Codeplay Software | Code Owner | +| Anton Mitkov | @ShanoToni | Codeplay Software | Code Owner | +| Atharva Dubey | @AD2605 | Codeplay Software | Code Owner | | Mehdi Goli | @mehdi-goli | Codeplay Software | Code Owner | -| Anton Mitkov | @ShaoToni | Codeplay Software | Code Owner | +| Nicolò Scipione | @s-Nick | Codeplay Software | Code Owner | | Svetlozar Georgiev | @sgeor255 | Codeplay Software | Code Owner | +| Romain Biessy | @Rbiessy | Codeplay Software | Code Owner | ## Support functions ### Documentation -Team: @oneapi-src/onednn-doc +Team: @uxlfoundation/onednn-doc | Name | Github ID | Affiliation | Role | | ------------------ | --------------------- | ----------------- | ---------- | | Vadim Pirogov | @vpirogov | Intel Corporation | Maintainer | -| Deb Taylor | @deb-intel | Intel Corporation | Code Owner | +| Ranu Kundu | @ranukund | Intel Corporation | Code Owner | +| Tao Lv | @TaoLv | Intel Corporation | Code Owner | ### DevOps -Team: @oneapi-src/onednn-devops +Team: @uxlfoundation/onednn-devops | Name | Github ID | Affiliation | Role | | ------------------ | --------------------- | ----------------- | ---------- | | Sergey Razumovskiy | @srazumov | Intel Corporation | Maintainer | | Vadim Pirogov | @vpirogov | Intel Corporation | Maintainer | +| Hamza Butt | @theComputeKid | Arm Ltd | Code Owner | ### Release management | Name | Github ID | Affiliation | Role | | ------------------ | --------------------- | ----------------- | ---------- | -| Harry Mao | @harrymao2022 | Intel Corporation | Maintainer | | Tatyana Primak | @tprimak | Intel Corporation | Maintainer | | Vadim Pirogov | @vpirogov | Intel Corporation | Maintainer | diff --git a/README.binary.in b/README.binary.in index 76d7f7e39a6..1d8ae9f9e03 100644 --- a/README.binary.in +++ b/README.binary.in @@ -13,17 +13,17 @@ developers interested in improving application performance on CPUs and GPUs. This package contains oneDNN v@PROJECT_VERSION@ (@DNNL_VERSION_HASH@). You can find information about the latest version and release notes -at the oneDNN Github (https://github.com/oneapi-src/oneDNN/releases). +at the oneDNN Github (https://github.com/uxlfoundation/oneDNN/releases). Documentation ------------- * Developer guide -(https://oneapi-src.github.io/oneDNN/v@DNNL_VERSION_MAJOR@.@DNNL_VERSION_MINOR@) +(https://uxlfoundation.github.io/oneDNN/v@DNNL_VERSION_MAJOR@.@DNNL_VERSION_MINOR@) explains the programming model, supported functionality, and implementation details, and includes annotated examples. * API reference -(https://oneapi-src.github.io/oneDNN/v@DNNL_VERSION_MAJOR@.@DNNL_VERSION_MINOR@/modules.html) +(https://uxlfoundation.github.io/oneDNN/v@DNNL_VERSION_MAJOR@.@DNNL_VERSION_MINOR@/modules.html) provides a comprehensive reference of the library API. System Requirements @@ -48,7 +48,7 @@ just-in-time (JIT) code generation to deploy the code optimized for the latest supported ISA. Future ISAs may have initial support in the library disabled by default and require the use of run-time controls to enable them. See CPU dispatcher control -(https://oneapi-src.github.io/oneDNN/dev_guide_cpu_dispatcher_control.html) +(https://uxlfoundation.github.io/oneDNN/dev_guide_cpu_dispatcher_control.html) for more details. The library is optimized for the following GPUs: @@ -65,7 +65,7 @@ Support ------- Submit questions, feature requests, and bug reports on the -GitHub issues page (https://github.com/oneapi-src/oneDNN/issues). +GitHub issues page (https://github.com/uxlfoundation/oneDNN/issues). License ------- @@ -102,7 +102,7 @@ govern your use of the third party programs as set forth in the # Security -Security Policy (https://github.com/oneapi-src/oneDNN/blob/main/SECURITY.md) +Security Policy (https://github.com/uxlfoundation/oneDNN/blob/main/SECURITY.md) outlines our guidelines and procedures for ensuring the highest level of Security and trust for our users who consume oneDNN. diff --git a/README.md b/README.md index 6b5c384a069..861de627f76 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ oneAPI Deep Neural Network Library (oneDNN) =========================================== [![OpenSSF Best Practices](https://www.bestpractices.dev/projects/8762/badge)](https://www.bestpractices.dev/projects/8762) -[![OpenSSF Scorecard](https://api.securityscorecards.dev/projects/github.com/oneapi-src/oneDNN/badge)](https://securityscorecards.dev/viewer/?uri=github.com/oneapi-src/oneDNN) +[![OpenSSF Scorecard](https://api.securityscorecards.dev/projects/github.com/uxlfoundation/oneDNN/badge)](https://securityscorecards.dev/viewer/?uri=github.com/uxlfoundation/oneDNN) oneAPI Deep Neural Network Library (oneDNN) is an open-source cross-platform performance library of basic building blocks for deep learning applications. @@ -18,18 +18,33 @@ AMD\* GPU, OpenPOWER\* Power ISA (PPC64), IBMz\* (s390x), and RISC-V. oneDNN is intended for deep learning applications and framework developers interested in improving application performance on CPUs and GPUs. -Deep learning practitioners should use one of the -[applications enabled with oneDNN](#applications-enabled-with-onednn). + +Deep learning practitioners should use one of the applications enabled with oneDNN: + +* [Apache SINGA](https://singa.apache.org) +* [DeepLearning4J\*](https://deeplearning4j.konduit.ai) +* [Flashlight\*](https://github.com/flashlight/flashlight) +* [MATLAB\* Deep Learning Toolbox](https://www.mathworks.com/help/deeplearning) +* [ONNX Runtime](https://onnxruntime.ai) +* [OpenVINO(TM) toolkit](https://github.com/openvinotoolkit/openvino) +* [PaddlePaddle\*](http://www.paddlepaddle.org) +* [PyTorch\*](https://pytorch.org). Intel GPU support and additional +optimizations are available with [Intel® Extension for PyTorch*]. +* [Tensorflow\*](https://www.tensorflow.org). Intel GPU support and additional +optimizations are available with [Intel® Extension for TensorFlow*]. + +[Intel® Extension for PyTorch*]: https://github.com/intel/intel-extension-for-pytorch +[Intel® Extension for TensorFlow*]: https://github.com/intel/intel-extension-for-tensorflow [UXL Foundation]: http://www.uxlfoundation.org -[oneAPI specification]: https://spec.oneapi.io +[oneAPI specification]: https://oneapi-spec.uxlfoundation.org/specifications/oneapi/latest/elements/onednn/source/ # Table of Contents - [Documentation](#documentation) -- [Installation](#installation) - [System Requirements](#system-requirements) -- [Applications Enabled with oneDNN](#applications-enabled-with-onednn) +- [Installation](#installation) +- [Validated Configurations](#validated-configurations) - [Governance](#governance) - [Support](#support) - [Contributing](#contributing) @@ -39,32 +54,18 @@ Deep learning practitioners should use one of the # Documentation -* [Developer Guide] explains the programming model, supported functionality, - and implementation details, and includes annotated examples. -* [API Reference] provides a comprehensive reference of the library API. - -[Developer Guide]: https://oneapi-src.github.io/oneDNN -[API Reference]: https://oneapi-src.github.io/oneDNN/group_dnnl_api.html - -# Installation +* [oneDNN Developer Guide and Reference] explains the programming + model, supported functionality, implementation details, and includes + annotated examples. +* [API Reference] provides a comprehensive reference of the library + API. +* [Release Notes] explains the new features, performance + optimizations, and improvements implemented in each version of + oneDNN. -Binary distribution of this software is available in: -* [Anaconda] -* [Intel oneAPI] - -The packages do not include library dependencies and these need to be resolved -in the application at build time. See the [System Requirements] section below -and the [Build Options] section in the [Developer Guide] for more details on -CPU and GPU runtimes. - -If the configuration you need is not available, you can -[build the library from source][Build from Source]. - -[Anaconda]: https://anaconda.org/conda-forge/onednn -[Intel oneAPI]: https://www.intel.com/content/www/us/en/developer/tools/oneapi/onednn.html -[System Requirements]: #system-requirements -[Build Options]: https://oneapi-src.github.io/oneDNN/dev_guide_build_options.html -[Build from Source]: https://oneapi-src.github.io/oneDNN/dev_guide_build.html +[oneDNN Developer Guide and Reference]: https://uxlfoundation.github.io/oneDNN +[API Reference]: https://uxlfoundation.github.io/oneDNN/group_dnnl_api.html +[Release Notes]: https://github.com/uxlfoundation/oneDNN/releases # System Requirements @@ -119,15 +120,15 @@ The library is optimized for the following GPUs: (formerly Meteor Lake, Arrow Lake and Lunar Lake) * future Intel Arc graphics (code name Battlemage) -[CPU dispatcher control]: https://oneapi-src.github.io/oneDNN/dev_guide_cpu_dispatcher_control.html -[Linking Guide]: https://oneapi-src.github.io/oneDNN/dev_guide_link.html +[CPU dispatcher control]: https://uxlfoundation.github.io/oneDNN/dev_guide_cpu_dispatcher_control.html +[Linking Guide]: https://uxlfoundation.github.io/oneDNN/dev_guide_link.html ## Requirements for Building from Source oneDNN supports systems meeting the following requirements: * Operating system with Intel 64 / Arm 64 / Power / IBMz architecture support * C++ compiler with C++11 standard support -* [CMake] 2.8.12 or later +* [CMake] 3.13 or later The following tools are required to build oneDNN documentation: * [Doxygen] 1.8.5 or later @@ -173,7 +174,7 @@ On a CPU based on Arm AArch64 architecture, oneDNN CPU engine can be built with machine learning applications and provides AArch64 optimized implementations of core functions. This functionality currently requires that ACL is downloaded and built separately. See [Build from Source] section of the Developer Guide for -details. oneDNN only supports Compute Library versions 24.08.1 or later. +details. oneDNN only supports Compute Library versions 24.11.1 or later. [Arm Compute Library (ACL)]: https://github.com/arm-software/ComputeLibrary @@ -239,12 +240,12 @@ is enabled: [timeout detection and recovery]: https://learn.microsoft.com/en-us/windows-hardware/drivers/display/timeout-detection-and-recovery [TdrDelay]: https://learn.microsoft.com/en-us/windows-hardware/drivers/display/tdr-registry-keys#tdrdelay -### Runtime Dependencies +## Runtime Dependencies When oneDNN is built from source, the library runtime dependencies and specific versions are defined by the build environment. -#### Linux +### Linux Common dependencies: * GNU C Library (`libc.so`) @@ -265,7 +266,7 @@ Runtime-specific dependencies: | `DNNL_GPU_RUNTIME=OCL` | any | OpenCL loader (`libOpenCL.so`) | `DNNL_GPU_RUNTIME=SYCL` | Intel oneAPI DPC++ Compiler | Intel oneAPI DPC++ Compiler runtime (`libsycl.so`), OpenCL loader (`libOpenCL.so`), oneAPI Level Zero loader (`libze_loader.so`) -#### Windows +### Windows Common dependencies: * Microsoft Visual C++ Redistributable (`msvcrt.dll`) @@ -281,7 +282,7 @@ Runtime-specific dependencies: | `DNNL_GPU_RUNTIME=OCL` | any | OpenCL loader (`OpenCL.dll`) | `DNNL_GPU_RUNTIME=SYCL` | Intel oneAPI DPC++ Compiler | Intel oneAPI DPC++ Compiler runtime (`sycl.dll`), OpenCL loader (`OpenCL.dll`), oneAPI Level Zero loader (`ze_loader.dll`) -#### macOS +### macOS Common dependencies: * System C/C++ runtime (`libc++.dylib`, `libSystem.dylib`) @@ -293,11 +294,32 @@ Runtime-specific dependencies: | `DNNL_CPU_RUNTIME=OMP` | Intel C/C++ Compiler | Intel OpenMP runtime (`libiomp5.dylib`) | `DNNL_CPU_RUNTIME=TBB` | any | TBB (`libtbb.dylib`) -### Validated Configurations +# Installation -CPU engine was validated on RedHat\* Enterprise Linux 8 with -* GNU Compiler Collection 5.4, 6.1, 7.2, 8.1, 9.1, 11.1, 11.3 -* Clang\* 7.1, 8.0, 9.0, 14.0.6 +You can download and install the oneDNN library using one of the following options: + +- Binary Distribution: You can download pre-built binary packages from + the following sources: + - [conda-forge]: If the configuration you need is not available on + the conda-forge channel, you can build the library using the + Source Distribution. + - Intel oneAPI: + - [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.htm) + - [Intel® oneDNN standalone package](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onednn-download.html) + +- Source Distribution: You can build the library from source by + following the instructions on the [Build from Source] page. + +[conda-forge]: https://anaconda.org/conda-forge/onednn +[System Requirements]: #system-requirements +[Build Options]: https://uxlfoundation.github.io/oneDNN/dev_guide_build_options.html +[Build from Source]: https://uxlfoundation.github.io/oneDNN/dev_guide_build.html + +# Validated Configurations + +x86-64 CPU engine was validated on RedHat\* Enterprise Linux 8 with +* GNU Compiler Collection 8.5, 9.5, 11.1, 11.3 +* Clang\* 11.0, 14.0.6 * [Intel oneAPI DPC++/C++ Compiler] 2024.0 on Windows Server\* 2019 with @@ -307,16 +329,19 @@ on Windows Server\* 2019 with on macOS 11 (Big Sur) with * Apple LLVM version 13.0 -on Ubuntu 20.04 AArch64 with -* GNU Compiler Collection 7.0, 8.0, 9.0, 10.0 -* Clang\* 9.0, 17.0 +AArch64 CPU engine was validated on Ubuntu 22.04 with +* GNU Compiler Collection 10.0, 13.0 +* Clang\* 17.0 * [Arm Compiler for Linux] 24.04 * [Arm Compute Library (ACL)] built for armv8-a arch, latest stable version available at the time of release +on macOS 14 (Sonoma) with +* Apple LLVM version 15.0 + GPU engine was validated on Ubuntu\* 22.04 with -* GNU Compiler Collection 7.2, 8.1, and 9.1 -* Clang 7.1, 8.0, 9.0 +* GNU Compiler Collection 8.5, and 9.5 +* Clang 11.0 * [Intel oneAPI DPC++/C++ Compiler] 2024.0 * [Intel Software for General Purpose GPU capabilities] latest stable version available at the time of release @@ -331,24 +356,6 @@ time of release [Intel Arc & Iris Xe Graphics Driver]: https://www.intel.com/content/www/us/en/download/785597/intel-arc-iris-xe-graphics-windows.html [Arm Compiler for Linux]: https://developer.arm.com/Tools%20and%20Software/Arm%20Compiler%20for%20Linux -# Applications Enabled with oneDNN - -* [Apache\* MXNet](https://mxnet.apache.org) -* [Apache SINGA](https://singa.apache.org) -* [DeepLearning4J\*](https://deeplearning4j.konduit.ai) -* [Flashlight\*](https://github.com/flashlight/flashlight) -* [Korali](https://github.com/cselab/korali) -* [MATLAB\* Deep Learning Toolbox](https://www.mathworks.com/help/deeplearning) -* [ONNX Runtime](https://onnxruntime.ai) -* [OpenVINO(TM) toolkit](https://github.com/openvinotoolkit/openvino) -* [PaddlePaddle\*](http://www.paddlepaddle.org) -* [PyTorch\*](https://pytorch.org). Intel GPU support and additional -optimizations are available with [Intel Extension for PyTorch]. -* [Tensorflow\*](https://www.tensorflow.org). Intel GPU support and additional -optimizations are available with [Intel Extension for Tensorflow]. - -[Intel Extension for PyTorch]: https://github.com/intel/intel-extension-for-pytorch -[Intel Extension for Tensorflow]: https://github.com/intel/intel-extension-for-tensorflow # Support @@ -358,7 +365,7 @@ Submit questions, feature requests, and bug reports on the You can also contact oneDNN developers via [UXL Foundation Slack] using [#onednn] channel. -[Github issues]: https://github.com/oneapi-src/oneDNN/issues +[Github issues]: https://github.com/uxlfoundation/oneDNN/issues [UXL Foundation Slack]: https://slack-invite.uxlfoundation.org/ [#onednn]: https://uxlfoundation.slack.com/channels/onednn @@ -384,37 +391,31 @@ schedule and work already in progress towards future milestones in Github's [Milestones] section. If you are looking for a specific task to start, consider selecting from issues that are marked with the [help wanted] label. -If you have an idea on how to improve the library: -* For changes impacting the public API or library overall, such as adding new -primitives or changes to the architecture, submit an [RFC pull request]. -* Ensure that the changes are consistent with the [code contribution guidelines] -and [coding standards]. -* Ensure that you can build the product and run all the examples with your -patch. -* Submit a [pull request]. - -For additional details, see [contribution guidelines](CONTRIBUTING.md). You can -also contact oneDNN developers and maintainers via [UXL Foundation Slack] using -[#onednn] channel. -This project is intended to be a safe, welcoming space for collaboration, and -contributors are expected to adhere to the +See [contribution guidelines](CONTRIBUTING.md) to start contributing +to oneDNN. You can also contact oneDNN developers and maintainers via +[UXL Foundation Slack] using [#onednn] channel. + +This project is intended to be a safe, welcoming space for +collaboration, and contributors are expected to adhere to the [Contributor Covenant](CODE_OF_CONDUCT.md) code of conduct. -[RFC pull request]: https://github.com/oneapi-src/oneDNN/tree/rfcs +[RFC pull request]: https://github.com/uxlfoundation/oneDNN/tree/rfcs [code contribution guidelines]: CONTRIBUTING.md#code-contribution-guidelines [coding standards]: CONTRIBUTING.md#coding-standards -[pull request]: https://github.com/oneapi-src/oneDNN/pulls -[Milestones]: https://github.com/oneapi-src/oneDNN/milestones -[help wanted]: https://github.com/oneapi-src/oneDNN/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22 +[pull request]: https://github.com/uxlfoundation/oneDNN/pulls +[Milestones]: https://github.com/uxlfoundation/oneDNN/milestones +[help wanted]: https://github.com/uxlfoundation/oneDNN/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22 + # License -oneDNN is licensed under [Apache License Version 2.0](LICENSE). Refer to the -"[LICENSE](LICENSE)" file for the full license text and copyright notice. +oneDNN is licensed under [Apache License Version 2.0](LICENSE). Refer +to the "[LICENSE](LICENSE)" file for the full license text and +copyright notice. -This distribution includes third party software governed by separate license -terms. +This distribution includes third party software governed by separate +license terms. 3-clause BSD license: * [Xbyak](https://github.com/herumi/xbyak) @@ -443,17 +444,17 @@ and OpenCL Driver](https://github.com/intel/compute-runtime) Interface](https://github.com/intel/metrics-discovery) * [spdlog](https://github.com/gabime/spdlog) -This third party software, even if included with the distribution of -the Intel software, may be governed by separate license terms, including -without limitation, third party license terms, other Intel software license -terms, and open source software license terms. These separate license terms -govern your use of the third party programs as set forth in the -"[THIRD-PARTY-PROGRAMS](THIRD-PARTY-PROGRAMS)" file. +This third-party software, even if included with the distribution of +the Intel software, may be governed by separate license terms, +including without limitation,third party license terms, other Intel +software license terms, and open source software license terms. These +separate license terms govern your use of the third party programs as +set forth in the "[THIRD-PARTY-PROGRAMS](THIRD-PARTY-PROGRAMS)" file. # Security [Security Policy](SECURITY.md) outlines our guidelines and procedures -for ensuring the highest level of Security and trust for our users +for ensuring the highest level of security and trust for our users who consume oneDNN. # Trademark Information diff --git a/SECURITY.md b/SECURITY.md index 0613b2e7703..279574c78fc 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -64,6 +64,6 @@ If you have any suggestions on how this Policy could be improved, please submit an issue or a pull request to this repository. Please **do not** report potential vulnerabilities or security flaws via a pull request. -[1]: https://github.com/oneapi-src/oneDNN/releases/latest -[2]: https://github.com/oneapi-src/oneDNN/security/advisories/new -[3]: https://github.com/oneapi-src/oneDNN/security/advisories +[1]: https://github.com/uxlfoundation/oneDNN/releases/latest +[2]: https://github.com/uxlfoundation/oneDNN/security/advisories/new +[3]: https://github.com/uxlfoundation/oneDNN/security/advisories diff --git a/THIRD-PARTY-PROGRAMS b/THIRD-PARTY-PROGRAMS index c377e234ed9..fa47ab926ed 100644 --- a/THIRD-PARTY-PROGRAMS +++ b/THIRD-PARTY-PROGRAMS @@ -496,7 +496,7 @@ limitations under the License. END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- -6. Boost C++ Libraries (src/common/primitive_hashing.hpp, src/graph/backend/graph_compiler/core/src/util/hash_utils.hpp) +6. Boost C++ Libraries (src/common/primitive_hashing.hpp) Copyright 2005-2014 Daniel James. Boost Software License - Version 1.0 - August 17th, 2003 @@ -610,227 +610,3 @@ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -10. LLVM (src/graph/backend/graph_compiler/core/src/util/array_ref.hpp) -============================================================================== -The LLVM Project is under the Apache License v2.0 with LLVM Exceptions: -============================================================================== - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - ----- LLVM Exceptions to the Apache 2.0 License ---- - -As an exception, if, as a result of your compiling your source code, portions -of this Software are embedded into an Object form of such source code, you -may redistribute such embedded portions in such Object form without complying -with the conditions of Sections 4(a), 4(b) and 4(d) of the License. - -In addition, if you combine or link compiled forms of this Software with -software that is licensed under the GPLv2 ("Combined Software") and if a -court of competent jurisdiction determines that the patent provision (Section -3), the indemnity provision (Section 9) or other Section of the License -conflicts with the conditions of the GPLv2, you may retroactively and -prospectively choose to deem waived or otherwise exclude such Section(s) of -the License, but only in their entirety and only with respect to the Combined -Software. diff --git a/cmake/ACL.cmake b/cmake/ACL.cmake index b185f7ba340..d619e6f9226 100644 --- a/cmake/ACL.cmake +++ b/cmake/ACL.cmake @@ -21,17 +21,17 @@ endif() set(acl_cmake_included true) include("cmake/options.cmake") -if(NOT DNNL_TARGET_ARCH STREQUAL "AARCH64") +if(NOT DNNL_TARGET_ARCH MATCHES "^(AARCH64|ARM)$") return() endif() -if(NOT DNNL_AARCH64_USE_ACL) +if(NOT DNNL_USE_ACL) return() endif() find_package(ACL REQUIRED) -set(ACL_MINIMUM_VERSION "24.08.1") +set(ACL_MINIMUM_VERSION "24.11.1") if(ACL_FOUND) file(GLOB_RECURSE ACL_VERSION_FILE ${ACL_INCLUDE_DIR}/*/arm_compute_version.embed) @@ -67,7 +67,7 @@ if(ACL_FOUND) message(STATUS "Arm Compute Library: ${ACL_LIBRARIES}") message(STATUS "Arm Compute Library headers: ${ACL_INCLUDE_DIRS}") - add_definitions(-DDNNL_AARCH64_USE_ACL) + add_definitions(-DDNNL_USE_ACL) set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_EXTENSIONS "OFF") endif() diff --git a/cmake/Doxygen.cmake b/cmake/Doxygen.cmake index 5d27d650a9e..a9409985be8 100644 --- a/cmake/Doxygen.cmake +++ b/cmake/Doxygen.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2016-2024 Intel Corporation +# Copyright 2016-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ if(DOXYGEN_FOUND) COMMENT "Generating API documentation in .xml format with Doxygen" VERBATIM) add_custom_target(doc_doxygen DEPENDS ${DOXYGEN_STAMP_FILE}) - if(NOT DNNL_INSTALL_MODE MATCHES "BUNDLE|BUNDLE_V2") + if(NOT DNNL_INSTALL_MODE STREQUAL "BUNDLE") install( DIRECTORY ${DOXYGEN_OUTPUT_DIR} DESTINATION share/doc/${LIB_PACKAGE_NAME} OPTIONAL) diff --git a/cmake/FindMIOpen.cmake b/cmake/FindMIOpen.cmake index 3928ce0dbce..727c16730af 100644 --- a/cmake/FindMIOpen.cmake +++ b/cmake/FindMIOpen.cmake @@ -34,6 +34,7 @@ list(APPEND EXTRA_SHARED_LIBS amd_comgr) # Prioritize MIOPENROOT list(APPEND miopen_root_hints + $ENV{ROCM_PATH} ${MIOPENROOT} $ENV{MIOPENROOT} "/opt/rocm" @@ -68,6 +69,10 @@ if(EXISTS "${MIOpen_INCLUDE_DIR}/miopen/version.h") "${MIOpen_MAJOR_VERSION}.${MIOpen_MINOR_VERSION}.${MIOpen_PATCH_VERSION}" ) + if(${MIOpen_MAJOR_VERSION} LESS 3) + add_definitions(-DMIOPEN_HAS_INT8X4=1) + endif() + unset(MIOpen_VERSION_CONTENT) else() message(WARNING "MIOpen version couldn't be identified.") diff --git a/cmake/FindOpenCL.cmake b/cmake/FindOpenCL.cmake index de876351714..711850959ba 100644 --- a/cmake/FindOpenCL.cmake +++ b/cmake/FindOpenCL.cmake @@ -47,18 +47,18 @@ function(_FIND_OPENCL_VERSION) set(CMAKE_REQUIRED_QUIET ${OpenCL_FIND_QUIETLY}) CMAKE_PUSH_CHECK_STATE() - foreach(VERSION "2_2" "2_1" "2_0" "1_2" "1_1" "1_0") + foreach(VERSION "3_0" "2_2" "2_1" "2_0" "1_2" "1_1" "1_0") set(CMAKE_REQUIRED_INCLUDES "${OpenCL_INCLUDE_DIR}") if(APPLE) CHECK_SYMBOL_EXISTS( CL_VERSION_${VERSION} - "${OpenCL_INCLUDE_DIR}/Headers/cl.h" + "Headers/cl.h" OPENCL_VERSION_${VERSION}) else() CHECK_SYMBOL_EXISTS( CL_VERSION_${VERSION} - "${OpenCL_INCLUDE_DIR}/CL/cl.h" + "CL/cl.h" OPENCL_VERSION_${VERSION}) endif() diff --git a/cmake/FindcublasLt.cmake b/cmake/FindcublasLt.cmake new file mode 100644 index 00000000000..bb7d4a3d5df --- /dev/null +++ b/cmake/FindcublasLt.cmake @@ -0,0 +1,48 @@ +# =============================================================================== +# Copyright 2020-2025 Intel Corporation +# Copyright 2020-2024 Codeplay Software Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# =============================================================================== + +find_package(CUDA 10.0 REQUIRED) +find_package(Threads REQUIRED) + +find_path( + CUBLASLT_INCLUDE_DIR "cublasLt.h" + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES include) + +find_library(CUDA_DRIVER_LIBRARY cuda) + +find_library( + CUBLASLT_LIBRARY cublasLt + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 bin) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + cublasLt REQUIRED_VARS CUBLASLT_INCLUDE_DIR CUDA_INCLUDE_DIRS CUBLASLT_LIBRARY + CUDA_LIBRARIES CUDA_DRIVER_LIBRARY) + +if(NOT TARGET cublasLt::cublasLt) + add_library(cublasLt::cublasLt SHARED IMPORTED) + set_target_properties( + cublasLt::cublasLt + PROPERTIES IMPORTED_LOCATION ${CUBLASLT_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES + "${CUBLASLT_INCLUDE_DIR};${CUDA_INCLUDE_DIRS}" + INTERFACE_LINK_LIBRARIES + "Threads::Threads;${CUDA_DRIVER_LIBRARY};${CUDA_LIBRARIES}" + INTERFACE_COMPILE_DEFINITIONS CUDA_NO_HALF) +endif() diff --git a/cmake/FindrocBLAS.cmake b/cmake/FindrocBLAS.cmake index c36baa8b473..45743c28873 100644 --- a/cmake/FindrocBLAS.cmake +++ b/cmake/FindrocBLAS.cmake @@ -19,21 +19,23 @@ find_package(Threads REQUIRED) # Prioritize ROCBLASROOT list(APPEND rocblas_root_hints + $ENV{ROCM_PATH} ${ROCBLASROOT} $ENV{ROCBLASROOT} "/opt/rocm" - "/opt/rocm/rocblas") + "/opt/rocm/rocblas" + "/opt/rocm/lib") find_path( rocBLAS_INCLUDE_DIR "rocblas.h" HINTS ${rocblas_root_hints} - PATH_SUFFIXES include + PATH_SUFFIXES include include/rocblas ) find_library( rocBLAS_LIBRARY rocblas HINTS ${rocblas_root_hints} - PATH_SUFFIXES lib + PATH_SUFFIXES lib lib/rocblas ) if(EXISTS "${rocBLAS_INCLUDE_DIR}/internal/rocblas-version.h") diff --git a/cmake/OpenMP.cmake b/cmake/OpenMP.cmake index 9484c268506..75aeba8a467 100644 --- a/cmake/OpenMP.cmake +++ b/cmake/OpenMP.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2017-2024 Intel Corporation +# Copyright 2017-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ endif() set(OpenMP_cmake_included true) include("cmake/Threading.cmake") -if (APPLE AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") +if (APPLE AND CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") # OSX Clang doesn't have OpenMP by default. # But we still want to build the library. set(_omp_severity "WARNING") @@ -31,19 +31,6 @@ else() set(_omp_severity "FATAL_ERROR") endif() -macro(set_openmp_values_for_old_cmake) - #newer version for findOpenMP (>= v. 3.9) - if(CMAKE_VERSION VERSION_LESS "3.9" AND OPENMP_FOUND) - if(${CMAKE_MAJOR_VERSION} VERSION_LESS "3" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "Intel") - # Override FindOpenMP flags for Intel Compiler (otherwise deprecated) - set(OpenMP_CXX_FLAGS "-fopenmp") - set(OpenMP_C_FLAGS "-fopenmp") - endif() - set(OpenMP_C_FOUND true) - set(OpenMP_CXX_FOUND true) - endif() -endmacro() - if(DPCPP_HOST_COMPILER_KIND STREQUAL "DEFAULT") # XXX: workaround: when -fsycl is specified the compiler doesn't define # _OPENMP macro causing `find_package(OpenMP)` to fail. @@ -51,10 +38,7 @@ if(DPCPP_HOST_COMPILER_KIND STREQUAL "DEFAULT") # the -fsycl option by default so it has to be explicitly disabled. set(_omp_original_cmake_cxx_flags "${CMAKE_CXX_FLAGS}") string(REGEX REPLACE "-fsycl" "-fno-sycl" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") - find_package(OpenMP) - set_openmp_values_for_old_cmake() - set(CMAKE_CXX_FLAGS "${_omp_original_cmake_cxx_flags}") endif() @@ -68,13 +52,6 @@ if(NOT OpenMP_CXX_FOUND AND MSVC AND CMAKE_CXX_COMPILER_ID MATCHES "(Clang|Intel # The ICX driver doesn't link OpenMP library even if `/Qopenmp` # was specified. set(OpenMP_FLAGS "/Qopenmp -Xclang --dependent-lib=libiomp5md") - else() - if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "10.0") - # version < 10 can't pass cl-style `/openmp` flag - set(OpenMP_FLAGS "-Xclang -fopenmp") - # ... and requires explicit linking against omp library - set(OpenMP_CXX_LIBRARIES "libomp.lib") - endif() endif() set(OpenMP_C_FLAGS ${OpenMP_FLAGS}) set(OpenMP_CXX_FLAGS ${OpenMP_FLAGS}) diff --git a/cmake/SDL.cmake b/cmake/SDL.cmake index cf8a7d61f51..10953c021af 100644 --- a/cmake/SDL.cmake +++ b/cmake/SDL.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2017-2024 Intel Corporation +# Copyright 2017-2025 Intel Corporation # Copyright 2021 FUJITSU LIMITED # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,16 +30,11 @@ macro(sdl_unix_common_ccxx_flags var) append(${var} "-fPIC -Wformat -Wformat-security") endmacro() -macro(sdl_gnu_common_ccxx_flags var) - if(DPCPP_HOST_COMPILER_KIND STREQUAL "GNU") - # GNU compiler 7.4 or newer is required for host compiler - append(${var} "-fstack-protector-strong") - else() - if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 4.9) - append(${var} "-fstack-protector-all") - else() - append(${var} "-fstack-protector-strong") - endif() +macro(sdl_gnu_common_ccxx_flags var gnu_version) + append(${var} "-fstack-protector-strong") + if(NOT (${gnu_version} VERSION_LESS 8.0) + AND (DNNL_TARGET_ARCH STREQUAL "X64")) + append(${var} "-fcf-protection=full") endif() endmacro() @@ -49,44 +44,84 @@ endmacro() # only. To prevent warnings on users' side who use the library and turn # this warning on, let's use it too. Applicable for the library sources # and interfaces only (tests currently rely on that fact heavily) -macro(sdl_gnu_src_ccxx_flags var) +macro(sdl_unix_src_ccxx_flags var) append(${var} "-Wmissing-field-initializers") endmacro() -macro(sdl_gnu_example_ccxx_flags var) +macro(sdl_unix_example_ccxx_flags var) # At this point the flags for src and examples are the same - sdl_gnu_src_ccxx_flags(${var}) + sdl_unix_src_ccxx_flags(${var}) endmacro() -if(UNIX) - set(CMAKE_CCXX_FLAGS) +set(ONEDNN_SDL_COMPILER_FLAGS) +set(ONEDNN_SDL_LINKER_FLAGS) - sdl_unix_common_ccxx_flags(CMAKE_CCXX_FLAGS) - append(CMAKE_CXX_FLAGS_RELEASE "-D_FORTIFY_SOURCE=2") - append(CMAKE_C_FLAGS_RELEASE "-D_FORTIFY_SOURCE=2") +if(UNIX) + sdl_unix_common_ccxx_flags(ONEDNN_SDL_COMPILER_FLAGS) + sdl_unix_src_ccxx_flags(CMAKE_SRC_CCXX_FLAGS) + sdl_unix_example_ccxx_flags(CMAKE_EXAMPLE_CCXX_FLAGS) + if(UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "RELEASE") + append(ONEDNN_SDL_COMPILER_FLAGS "-D_FORTIFY_SOURCE=2") + endif() if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - sdl_gnu_common_ccxx_flags(CMAKE_CCXX_FLAGS) - sdl_gnu_src_ccxx_flags(CMAKE_SRC_CCXX_FLAGS) - sdl_gnu_example_ccxx_flags(CMAKE_EXAMPLE_CCXX_FLAGS) - elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + sdl_gnu_common_ccxx_flags(ONEDNN_SDL_COMPILER_FLAGS + CMAKE_CXX_COMPILER_VERSION) + elseif(CMAKE_CXX_COMPILER_ID MATCHES "(Apple)?[Cc]lang") get_filename_component(CXX_CMD_NAME ${CMAKE_CXX_COMPILER} NAME) # Fujitsu CXX compiler does not support "-fstack-protector-all". if(NOT CXX_CMD_NAME STREQUAL "FCC") - append(CMAKE_CCXX_FLAGS "-fstack-protector-all") + append(ONEDNN_SDL_COMPILER_FLAGS "-fstack-protector-all") endif() elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") - append(CMAKE_CXX_FLAGS "-fstack-protector") + append(ONEDNN_SDL_COMPILER_FLAGS "-fstack-protector") endif() - append(CMAKE_C_FLAGS "${CMAKE_CCXX_FLAGS}") - append(CMAKE_CXX_FLAGS "${CMAKE_CCXX_FLAGS}") if(APPLE) - append(CMAKE_SHARED_LINKER_FLAGS "-Wl,-bind_at_load") - append(CMAKE_EXE_LINKER_FLAGS "-Wl,-bind_at_load") + append(ONEDNN_SDL_LINKER_FLAGS "-Wl,-bind_at_load") else() + # Only applies to executables. append(CMAKE_EXE_LINKER_FLAGS "-pie") - append(CMAKE_SHARED_LINKER_FLAGS "-Wl,-z,noexecstack -Wl,-z,relro -Wl,-z,now") - append(CMAKE_EXE_LINKER_FLAGS "-Wl,-z,noexecstack -Wl,-z,relro -Wl,-z,now") + append(ONEDNN_SDL_LINKER_FLAGS "-Wl,-z,noexecstack -Wl,-z,relro -Wl,-z,now") + endif() +elseif(WIN32) + if(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + append(ONEDNN_SDL_COMPILER_FLAGS "/GS /Gy /guard:cf /DYNAMICBASE /sdl") + append(ONEDNN_SDL_LINKER_FLAGS "/NXCOMPAT /LTCG") + elseif(CMAKE_BASE_NAME STREQUAL "icx") + append(ONEDNN_SDL_COMPILER_FLAGS "/GS /Gy /guard:cf /Wformat /Wformat-security") + append(ONEDNN_SDL_LINKER_FLAGS "/link /NXCOMPAT") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + append(ONEDNN_SDL_COMPILER_FLAGS "-Wformat -Wformat-security") + if(UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "RELEASE") + append(ONEDNN_SDL_COMPILER_FLAGS "-D_FORTIFY_SOURCE=2") + endif() + get_filename_component(CXX_CMD_NAME ${CMAKE_CXX_COMPILER} NAME) + # Fujitsu CXX compiler does not support "-fstack-protector-all". + if(NOT CXX_CMD_NAME STREQUAL "FCC") + append(ONEDNN_SDL_COMPILER_FLAGS "-fstack-protector-all") + endif() + append(ONEDNN_SDL_LINKER_FLAGS "-Xlinker /NXCOMPAT -Xlinker /LTCG") + endif() + + if(NOT MINGW) + # For a Windows build, a malicious DLL can be injected because of the + # uncontrolled search order for load-time linked libraries defined for a + # Windows setting. The following cmake flags change the search order so that + # DLLs are loaded from the current working directory only if it is under a path + # in the Safe Load List. + if(CMAKE_BASE_NAME STREQUAL "icx") + # add ICX-style linker flags + append(ONEDNN_SDL_LINKER_FLAGS "/link /DEPENDENTLOADFLAG:0x2000") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + # add Clang-style linker flags + append(ONEDNN_SDL_LINKER_FLAGS "-Xlinker /DEPENDENTLOADFLAG:0x2000") + else() + # Default to MSVC-style definition + append(ONEDNN_SDL_LINKER_FLAGS "/DEPENDENTLOADFLAG:0x2000") + endif() endif() -elseif(MSVC AND ${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) - set(CMAKE_CCXX_FLAGS "/guard:cf") endif() + +append(CMAKE_C_FLAGS "${ONEDNN_SDL_COMPILER_FLAGS}") +append(CMAKE_CXX_FLAGS "${ONEDNN_SDL_COMPILER_FLAGS}") +append(CMAKE_SHARED_LINKER_FLAGS "${ONEDNN_SDL_LINKER_FLAGS}") +append(CMAKE_EXE_LINKER_FLAGS "${ONEDNN_SDL_LINKER_FLAGS}") diff --git a/cmake/SYCL.cmake b/cmake/SYCL.cmake index bfaf25e53a1..5ca1c1c1beb 100644 --- a/cmake/SYCL.cmake +++ b/cmake/SYCL.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2019-2024 Intel Corporation +# Copyright 2019-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -88,12 +88,13 @@ endmacro() if(DNNL_SYCL_CUDA) suppress_warnings_for_nvidia_target() find_package(cuBLAS REQUIRED) + find_package(cublasLt REQUIRED) find_package(cuDNN REQUIRED) - adjust_headers_priority("cuBLAS::cuBLAS;cuDNN::cuDNN") + adjust_headers_priority("cuBLAS::cuBLAS;cuDNN::cuDNN;cublasLt::cublasLt") add_definitions_with_host_compiler("-DCUDA_NO_HALF") - list(APPEND EXTRA_SHARED_LIBS cuBLAS::cuBLAS cuDNN::cuDNN) + list(APPEND EXTRA_SHARED_LIBS cuBLAS::cuBLAS cuDNN::cuDNN cublasLt::cublasLt) message(STATUS "DPC++ support is enabled (CUDA)") elseif(DNNL_SYCL_HIP) find_package(HIP REQUIRED) @@ -135,14 +136,7 @@ endif() # #pragma message("The Intel extensions have been moved into cl_ext.h. # Please include cl_ext.h directly.") if(NOT WIN32) - if(${CMAKE_VERSION} VERSION_LESS "3.1.0") - # Prior to CMake 3.1 the Makefile generators did not escape # correctly - # inside make variable assignments used in generated makefiles, causing - # them to be treated as comments. This is a workaround. - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-\\#pragma-messages") - else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-#pragma-messages") - endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-#pragma-messages") endif() add_definitions_with_host_compiler("-DCL_TARGET_OPENCL_VERSION=300") diff --git a/cmake/Sphinx.cmake b/cmake/Sphinx.cmake index 99b7de2868f..ed1e17a41f0 100644 --- a/cmake/Sphinx.cmake +++ b/cmake/Sphinx.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2021 Intel Corporation +# Copyright 2021-2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +22,9 @@ if(Sphinx_cmake_included) endif() set(Sphinx_cmake_included true) -find_package(PythonInterp 2.7) +find_package(Python 3.7 COMPONENTS Interpreter) find_package(Sphinx) -if (PYTHONINTERP_FOUND AND SPHINX_FOUND) +if (Python_FOUND AND SPHINX_FOUND) set(SPHINX_GENERATOR "html" CACHE STRING "specifies generator for Sphinx") set(SPHINX_OUTPUT_DIR @@ -52,7 +52,7 @@ if (PYTHONINTERP_FOUND AND SPHINX_FOUND) COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/doc/sphinx/_static ${SPHINX_SOURCE_DIR}/_static - COMMAND ${PYTHON_EXECUTABLE} + COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/cleanup.py ${SPHINX_SOURCE_DIR} COMMAND ${SPHINX_EXECUTABLE} -b ${SPHINX_GENERATOR} -D release=v${PROJECT_VERSION} -j auto rst ${SPHINX_OUTPUT_DIR} @@ -60,4 +60,4 @@ if (PYTHONINTERP_FOUND AND SPHINX_FOUND) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/reference COMMENT "Generating API documentation with Sphinx" VERBATIM) add_custom_target(doc_sphinx DEPENDS ${SPHINX_STAMP_FILE} doc_doxyrest) -endif(PYTHONINTERP_FOUND AND SPHINX_FOUND) +endif(Python_FOUND AND SPHINX_FOUND) diff --git a/cmake/TBB.cmake b/cmake/TBB.cmake index 7c82c428b41..ef6669da8a9 100644 --- a/cmake/TBB.cmake +++ b/cmake/TBB.cmake @@ -26,7 +26,10 @@ include("cmake/Threading.cmake") macro(handle_tbb_target) if(TBB_FOUND) set_property(TARGET TBB::tbb PROPERTY "MAP_IMPORTED_CONFIG_RELWITHMDD" "DEBUG") - include_directories_with_host_compiler(${_tbb_include_dirs}) + foreach(inc_dir ${_tbb_include_dirs}) + include_directories(BEFORE SYSTEM ${inc_dir}) + append_host_compiler_options(CMAKE_CXX_FLAGS "-I${inc_dir}") + endforeach() list(APPEND EXTRA_SHARED_LIBS TBB::tbb) # Print TBB location @@ -59,7 +62,7 @@ macro(handle_tbb_target) add_definitions(-DTBB_PREVIEW_TASK_ARENA_CONSTRAINTS_EXTENSION=1) endmacro() -if(NOT DNNL_CPU_THREADING_RUNTIME STREQUAL "TBB") +if(NOT "${DNNL_CPU_THREADING_RUNTIME}" MATCHES "^(TBB|TBB_AUTO)$") return() endif() diff --git a/cmake/Threading.cmake b/cmake/Threading.cmake index 5ad3d903b07..cdbef190c7f 100644 --- a/cmake/Threading.cmake +++ b/cmake/Threading.cmake @@ -39,22 +39,12 @@ list(APPEND EXTRA_SHARED_LIBS "${CMAKE_THREAD_LIBS_INIT}") # A macro to avoid code duplication macro(find_package_tbb) - # Try to find TBB using a TBB-provided CMake config file. - find_package(TBB QUIET COMPONENTS tbb) - # If the previous `find_package` call failed then try to - # use a TBB CMake config file that is maintained by oneDNN. - # The reason the previous call may fail is that TBB package is - # very old and doesn't provide a CMake config file. - if(NOT TBB_FOUND) - message(STATUS "TBB-provided CMake config either failed or was not found. Trying to use a custom one.") - set(_cmake_proj_dir "${PROJECT_SOURCE_DIR}/cmake") - if(WIN32) - find_package(TBB ${ARGN} COMPONENTS tbb HINTS ${_cmake_proj_dir}/win) - elseif(APPLE) - find_package(TBB ${ARGN} COMPONENTS tbb HINTS ${_cmake_proj_dir}/mac) - elseif(UNIX) - find_package(TBB ${ARGN} COMPONENTS tbb HINTS ${_cmake_proj_dir}/lnx) - endif() + if(WIN32) + find_package(TBB ${ARGN} COMPONENTS tbb) + elseif(APPLE) + find_package(TBB ${ARGN} COMPONENTS tbb) + elseif(UNIX) + find_package(TBB ${ARGN} COMPONENTS tbb) endif() if(TBB_FOUND) diff --git a/cmake/config.cmake.in b/cmake/config.cmake.in index 24a35b5d4bf..0cdd6f754e3 100644 --- a/cmake/config.cmake.in +++ b/cmake/config.cmake.in @@ -21,6 +21,8 @@ set(DNNL_GPU_RUNTIME "@DNNL_GPU_RUNTIME@") set(DNNL_BLAS_VENDOR "@DNNL_BLAS_VENDOR@") +set(DNNL_GPU_VENDOR "@DNNL_GPU_VENDOR@") + if(DNNL_CPU_THREADING_RUNTIME STREQUAL "TBB") # Try to find TBB using a TBB-provided CMake config file. find_package(TBB QUIET COMPONENTS tbb) @@ -62,6 +64,14 @@ check_required_components("@LIB_PACKAGE_NAME@") if(DNNL_CPU_RUNTIME STREQUAL "SYCL" OR DNNL_CPU_RUNTIME STREQUAL "DPCPP" OR DNNL_GPU_RUNTIME STREQUAL "SYCL" OR DNNL_GPU_RUNTIME STREQUAL "DPCPP") + if(DNNL_GPU_VENDOR STREQUAL "NVIDIA") + set(DNNL_ORIGINAL_CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}) + list(INSERT CMAKE_MODULE_PATH 0 ${PACKAGE_PREFIX_DIR}/@LIB_CONFIG_INSTALL_DIR@) + find_package(cuDNN REQUIRED) + find_package(cuBLAS REQUIRED) + find_package(cublasLt REQUIRED) + set(CMAKE_MODULE_PATH ${DNNL_ORIGINAL_CMAKE_MODULE_PATH}) + endif() set(DNNL_COMPILE_FLAGS "-fsycl") @HANDLE_BUNDLE_DEBUG_SYCL_CONFIGURATION@ endif() diff --git a/cmake/configuring_primitive_list.cmake b/cmake/configuring_primitive_list.cmake index 3524f171070..55fc83b33e2 100644 --- a/cmake/configuring_primitive_list.cmake +++ b/cmake/configuring_primitive_list.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2021-2024 Intel Corporation +# Copyright 2021-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ else() foreach(impl ${DNNL_ENABLE_PRIMITIVE}) string(TOUPPER ${impl} uimpl) if(NOT "${uimpl}" MATCHES - "^(BATCH_NORMALIZATION|BINARY|CONCAT|CONVOLUTION|DECONVOLUTION|ELTWISE|INNER_PRODUCT|LAYER_NORMALIZATION|LRN|MATMUL|POOLING|PRELU|REDUCTION|REORDER|RESAMPLING|RNN|SDPA|SHUFFLE|SOFTMAX|SUM)$") + "^(BATCH_NORMALIZATION|BINARY|CONCAT|CONVOLUTION|DECONVOLUTION|ELTWISE|GROUP_NORMALIZATION|INNER_PRODUCT|LAYER_NORMALIZATION|LRN|MATMUL|POOLING|PRELU|REDUCTION|REORDER|RESAMPLING|RNN|SDPA|SHUFFLE|SOFTMAX|SUM)$") message(FATAL_ERROR "Unsupported primitive: ${uimpl}") endif() set(BUILD_${uimpl} TRUE) @@ -58,7 +58,7 @@ if (DNNL_ENABLE_PRIMITIVE_GPU_ISA STREQUAL "ALL") else() foreach(isa ${DNNL_ENABLE_PRIMITIVE_GPU_ISA}) string(TOUPPER ${isa} uisa) - if(NOT "${uisa}" MATCHES "^(GEN9|GEN11|XELP|XEHP|XEHPG|XEHPC|XE2)$") + if(NOT "${uisa}" MATCHES "^(GEN9|GEN11|XELP|XEHP|XEHPG|XEHPC|XE2|XE3)$") message(FATAL_ERROR "Unsupported primitive GPU ISA: ${uisa}") endif() set(BUILD_${uisa} TRUE) diff --git a/cmake/coverage.cmake b/cmake/coverage.cmake index ef0d06eed83..ce1799ed120 100644 --- a/cmake/coverage.cmake +++ b/cmake/coverage.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2019-2020 Intel Corporation +# Copyright 2019-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,11 +36,7 @@ if("${DNNL_CODE_COVERAGE}" STREQUAL "GCOV") message(FATAL_ERROR "GCOV not found in path") endif() - if("${CMAKE_CXX_COMPILER_ID}" MATCHES "(Apple)?[Cc]lang") - if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 3) - message(FATAL_ERROR "Clang version must be 3.0.0 or greater! Aborting...") - endif() - elseif(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if(NOT CMAKE_CXX_COMPILER_ID MATCHES "(Apple)?[Cc]lang|GNU") message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") endif() @@ -49,7 +45,7 @@ if("${DNNL_CODE_COVERAGE}" STREQUAL "GCOV") if(NOT CMAKE_BUILD_TYPE MATCHES "[Dd]ebug") message(WARNING "Code coverage results with an optimised (non-Debug) build may be misleading") - endif() + endif() if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") link_libraries(gcov) diff --git a/cmake/dnnl_compat.cmake b/cmake/dnnl_compat.cmake index 17e3d0192e1..7e2fc043a1a 100644 --- a/cmake/dnnl_compat.cmake +++ b/cmake/dnnl_compat.cmake @@ -35,6 +35,8 @@ endmacro() set(COMPAT_CACHE_BOOL_VARS "EXPERIMENTAL" "EXPERIMENTAL_SPARSE" + "EXPERIMENTAL_UKERNEL" + "EXPERIMENTAL_LOGGING" "VERBOSE" "ENABLE_CONCURRENT_EXEC" "ENABLE_PRIMITIVE_CACHE" diff --git a/cmake/gen_gpu_kernel.cmake b/cmake/gen_gpu_kernel.cmake index 672c88ef877..dfc5feb1d9b 100644 --- a/cmake/gen_gpu_kernel.cmake +++ b/cmake/gen_gpu_kernel.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2019-2024 Intel Corporation +# Copyright 2019-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,18 +22,22 @@ file(READ ${CL_FILE} cl_file_lines) -# Remove C++ style comments -string(REGEX REPLACE "//[^\n]*\n" "\n" cl_file_lines "${cl_file_lines}") -# Remove repeated whitespaces -string(REGEX REPLACE " +" " " cl_file_lines "${cl_file_lines}") -# Remove leading whitespaces -string(REGEX REPLACE "\n " "\n" cl_file_lines "${cl_file_lines}") -# Remove empty lines -string(REGEX REPLACE "\n+" "\n" cl_file_lines "${cl_file_lines}") +string(LENGTH "${cl_file_lines}" len) +if(MINIFY OR len GREATER 65535) + # Remove C++ style comments + string(REGEX REPLACE "//[^\n]*\n" "\n" cl_file_lines "${cl_file_lines}") + # Remove repeated whitespaces + string(REGEX REPLACE " +" " " cl_file_lines "${cl_file_lines}") + # Remove leading whitespaces + string(REGEX REPLACE "\n " "\n" cl_file_lines "${cl_file_lines}") + # Remove empty lines + string(REGEX REPLACE "\n+" "\n" cl_file_lines "${cl_file_lines}") +endif() string(LENGTH "${cl_file_lines}" len) if(len GREATER 65535) - message(WARNING "Windows requires string literals to fit in 65535 bytes. Please split ${CL_FILE}.") + message(FATAL_ERROR + "Windows requires string literals to fit in 65535 bytes. Please split ${CL_FILE}.") endif() get_filename_component(cl_file_name ${CL_FILE} NAME_WE) diff --git a/cmake/gen_gpu_kernel_list.cmake b/cmake/gen_gpu_kernel_list.cmake index 02f8cacb9bb..f64f90d6259 100644 --- a/cmake/gen_gpu_kernel_list.cmake +++ b/cmake/gen_gpu_kernel_list.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2020-2021 Intel Corporation +# Copyright 2020-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,6 +46,11 @@ endfunction() function(gen_gpu_kernel_list ker_list_templ ker_list_src ker_sources headers) set(_sources "${SOURCES}") + set(MINIFY "ON") + if(DNNL_DEV_MODE OR CMAKE_BUILD_TYPE STREQUAL "Debug") + set(MINIFY "OFF") + endif() + set(KER_LIST_EXTERN) set(KER_LIST_ENTRIES) set(KER_HEADERS_EXTERN) @@ -62,6 +67,7 @@ function(gen_gpu_kernel_list ker_list_templ ker_list_src ker_sources headers) COMMAND ${CMAKE_COMMAND} -DCL_FILE="${header_path}" -DGEN_FILE="${gen_file}" + -DMINIFY="${MINIFY}" -P ${PROJECT_SOURCE_DIR}/cmake/gen_gpu_kernel.cmake DEPENDS ${header_path} ) @@ -81,6 +87,7 @@ function(gen_gpu_kernel_list ker_list_templ ker_list_src ker_sources headers) COMMAND ${CMAKE_COMMAND} -DCL_FILE="${ker_path}" -DGEN_FILE="${gen_file}" + -DMINIFY="${MINIFY}" -P ${PROJECT_SOURCE_DIR}/cmake/gen_gpu_kernel.cmake DEPENDS ${ker_path} ) diff --git a/cmake/host_compiler.cmake b/cmake/host_compiler.cmake index 22b5ed60bd3..7d64edbbe8a 100644 --- a/cmake/host_compiler.cmake +++ b/cmake/host_compiler.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2021-2024 Intel Corporation +# Copyright 2021-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,6 +35,8 @@ if(DPCPP_HOST_COMPILER_KIND MATCHES "^(GNU|CLANG)$") platform_unix_and_mingw_common_cxx_flags(DPCPP_HOST_COMPILER_OPTS) sdl_unix_common_ccxx_flags(DPCPP_HOST_COMPILER_OPTS) + sdl_unix_src_ccxx_flags(DPCPP_SRC_COMPILER_OPTS) + sdl_unix_example_ccxx_flags(DPCPP_EXAMPLE_COMPILER_OPTS) # SYCL uses C++17 features in headers hence C++17 support should be enabled # for host compiler. @@ -78,9 +80,7 @@ if(DPCPP_HOST_COMPILER_KIND MATCHES "^(GNU|CLANG)$") if(DPCPP_HOST_COMPILER_KIND STREQUAL "GNU") platform_gnu_nowarn_ccxx_flags(DPCPP_CXX_NOWARN_FLAGS ${DPCPP_HOST_COMPILER_MAJOR_VER}.${DPCPP_HOST_COMPILER_MINOR_VER}) - sdl_gnu_common_ccxx_flags(DPCPP_HOST_COMPILER_OPTS) - sdl_gnu_src_ccxx_flags(DPCPP_SRC_CXX_FLAGS) - sdl_gnu_example_ccxx_flags(DPCPP_EXAMPLE_CXX_FLAGS) + sdl_gnu_common_ccxx_flags(DPCPP_HOST_COMPILER_OPTS DPCPP_HOST_COMPILER_VER) # SYCL headers contain some comments that trigger warning with GNU compiler append(DPCPP_HOST_COMPILER_OPTS "-Wno-comment") @@ -100,6 +100,11 @@ if(DPCPP_HOST_COMPILER_KIND MATCHES "^(GNU|CLANG)$") # Affects both, GNU and CLANG kinds. append(CMAKE_CXX_FLAGS "-Wno-unused-command-line-argument") + # Option `-fsycl-unnamed-lambda` is enabled by default, but not compatible + # with `-fsycl-host-compiler`. While icpx driver adds + # `-fno-sycl-unnamed-lambda` to avoid build issues clang++ does not do that. + append(CMAKE_CXX_FLAGS "-fno-sycl-unnamed-lambda") + append(CMAKE_CXX_FLAGS "-fsycl-host-compiler=${DPCPP_HOST_COMPILER}") append_host_compiler_options(CMAKE_CXX_FLAGS "${DPCPP_HOST_COMPILER_OPTS}") endif() diff --git a/cmake/host_compiler_id.cmake b/cmake/host_compiler_id.cmake index 2a74d703c4f..966c553023c 100644 --- a/cmake/host_compiler_id.cmake +++ b/cmake/host_compiler_id.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2024 Intel Corporation +# Copyright 2024-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -98,13 +98,13 @@ message(STATUS "Host compiler version: ${DPCPP_HOST_COMPILER_MAJOR_VER}.${DPCPP_ # Check the version of the provided host compiler. if(DPCPP_HOST_COMPILER_KIND STREQUAL "GNU") - if((DPCPP_HOST_COMPILER_MAJOR_VER LESS 7) OR (DPCPP_HOST_COMPILER_MAJOR_VER EQUAL 7 AND DPCPP_HOST_COMPILER_MINOR_VER LESS 4)) - message(FATAL_ERROR "The minimum version of ${DPCPP_HOST_COMPILER_KIND} host compiler is 7.4.") + if(DPCPP_HOST_COMPILER_MAJOR_VER LESS 8) + message(FATAL_ERROR "The minimum version of ${DPCPP_HOST_COMPILER_KIND} host compiler is 8.0.") endif() endif() if(DPCPP_HOST_COMPILER_KIND STREQUAL "CLANG") - if(DPCPP_HOST_COMPILER_MAJOR_VER LESS 8) - message(FATAL_ERROR "The minimum version of ${DPCPP_HOST_COMPILER_KIND} host compiler is 8.0.") + if(DPCPP_HOST_COMPILER_MAJOR_VER LESS 11) + message(FATAL_ERROR "The minimum version of ${DPCPP_HOST_COMPILER_KIND} host compiler is 11.0.") endif() endif() diff --git a/cmake/lnx/TBBConfig.cmake b/cmake/lnx/TBBConfig.cmake deleted file mode 100644 index bedbff68e39..00000000000 --- a/cmake/lnx/TBBConfig.cmake +++ /dev/null @@ -1,183 +0,0 @@ -#=============================================================================== -# Copyright 2017-2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - -# TBB_FOUND should not be set explicitly. It is defined automatically by CMake. -# Handling of TBB_VERSION is in TBBConfigVersion.cmake. - -if (NOT TBB_FIND_COMPONENTS) - set(TBB_FIND_COMPONENTS "tbb;tbbmalloc;tbbmalloc_proxy") - foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(TBB_FIND_REQUIRED_${_tbb_component} 1) - endforeach() -endif() - -# Add components with internal dependencies: tbbmalloc_proxy -> tbbmalloc -list(FIND TBB_FIND_COMPONENTS tbbmalloc_proxy _tbbmalloc_proxy_ix) -if (NOT _tbbmalloc_proxy_ix EQUAL -1) - list(FIND TBB_FIND_COMPONENTS tbbmalloc _tbbmalloc_ix) - if (_tbbmalloc_ix EQUAL -1) - list(APPEND TBB_FIND_COMPONENTS tbbmalloc) - set(TBB_FIND_REQUIRED_tbbmalloc ${TBB_FIND_REQUIRED_tbbmalloc_proxy}) - endif() -endif() - -# oneDNN changes: use TBBROOT to locate Intel TBB -# get_filename_component(_tbb_root "${CMAKE_CURRENT_LIST_FILE}" PATH) -# get_filename_component(_tbb_root "${_tbb_root}" PATH) -if (NOT TBBROOT) - if(DEFINED ENV{TBBROOT}) - set (TBBROOT $ENV{TBBROOT}) - endif() -endif() - -set(_tbb_root ${TBBROOT}) - -set(_tbb_x32_subdir ia32) -set(_tbb_x64_subdir intel64) - -if (CMAKE_SIZEOF_VOID_P EQUAL 8) - set(_tbb_arch_subdir ${_tbb_x64_subdir}) -else() - set(_tbb_arch_subdir ${_tbb_x32_subdir}) -endif() - -if (CMAKE_CXX_COMPILER_LOADED) - set(_tbb_compiler_id ${CMAKE_CXX_COMPILER_ID}) - set(_tbb_compiler_ver ${CMAKE_CXX_COMPILER_VERSION}) -elseif (CMAKE_C_COMPILER_LOADED) - set(_tbb_compiler_id ${CMAKE_C_COMPILER_ID}) - set(_tbb_compiler_ver ${CMAKE_C_COMPILER_VERSION}) -endif() - -# For non-GCC compilers try to find version of system GCC to choose right compiler subdirectory. -if (NOT _tbb_compiler_id STREQUAL "GNU") - execute_process(COMMAND gcc --version OUTPUT_VARIABLE _tbb_gcc_ver_output ERROR_QUIET) - string(REGEX REPLACE ".*gcc.* ([0-9]+\\.[0-9]+)\\.[0-9]+.*" "\\1" _tbb_compiler_ver "${_tbb_gcc_ver_output}") - if (NOT _tbb_compiler_ver) - message(FATAL_ERROR "This Intel TBB package is intended to be used only environment with available 'gcc'") - endif() - unset(_tbb_gcc_ver_output) -endif() - -if (EXISTS "${_tbb_root}/lib/${_tbb_arch_subdir}") - set(_tbb_lib ${_tbb_root}/lib/${_tbb_arch_subdir}) - set(_tbb_inc ${_tbb_root}/include) - - file(GLOB _tbb_gcc_versions_available RELATIVE ${_tbb_lib} ${_tbb_lib}/*) - # shall we check _tbb_gcc_versions_available is not empty? - foreach (_tbb_gcc_version ${_tbb_gcc_versions_available}) - string(SUBSTRING ${_tbb_gcc_version} 3 -1 _tbb_gcc_version_number) - if (NOT _tbb_compiler_ver VERSION_LESS _tbb_gcc_version_number) - set(_tbb_compiler_subdir ${_tbb_gcc_version}) - endif() - endforeach() -else() - if (TBBROOT) - set(__tbb_hint_path "${TBBROOT}") - else() - set(__tbb_hint_path "/non/existing/path") - endif() - - # try to find TBB in the system - find_library(_tbb_lib NAMES tbb - HINTS "${__tbb_hint_path}" - PATH_SUFFIXES lib lib64) - find_path(_tbb_inc NAMES tbb.h - HINTS "${__tbb_hint_path}" - PATH_SUFFIXES include tbb include/tbb) - unset(__tbb_hint_path) - - if (NOT _tbb_lib OR NOT _tbb_inc) - message("FATAL_ERROR" "Cannot find TBB") - endif() - - get_filename_component(_tbb_lib "${_tbb_lib}" PATH) - get_filename_component(_tbb_inc "${_tbb_inc}" PATH) - - set(_tbb_arch_subdir "") - set(_tbb_compiler_subdir "") -endif() - -unset(_tbb_gcc_version_number) -unset(_tbb_compiler_id) -unset(_tbb_compiler_ver) - -# Now we check that all the needed component are present -get_filename_component(_tbb_lib_path "${_tbb_lib}/${_tbb_compiler_subdir}" ABSOLUTE) - -if (TBB_FOUND) - return() -endif() - -foreach (_tbb_soversion 2 12) -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(_tbb_release_lib - "${_tbb_lib_path}/lib${_tbb_component}.so.${_tbb_soversion}") - set(_tbb_debug_lib - "${_tbb_lib_path}/lib${_tbb_component}_debug.so.${_tbb_soversion}") - - # oneDNN change: check library existence (BUILD_MODE related only, not both) - string(TOUPPER "${CMAKE_BUILD_TYPE}" UPPERCASE_CMAKE_BUILD_TYPE) - if (UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "DEBUG") - if (EXISTS "${_tbb_debug_lib}") - set(_lib_exists TRUE) - elseif (EXISTS "${_tbb_release_lib}") - message(FATAL_ERROR - "Intel TBB release library is found here: ${_tbb_release_lib}. " - "But the debug library - (lib${_tbb_component}_debug.so.${_tbb_soversion}) is missing.") - endif() - else() - if (EXISTS "${_tbb_release_lib}") - set(_lib_exists TRUE) - endif() - endif() - - if (_lib_exists) - if (NOT TARGET TBB::${_tbb_component}) - add_library(TBB::${_tbb_component} SHARED IMPORTED) - set_target_properties(TBB::${_tbb_component} PROPERTIES - IMPORTED_CONFIGURATIONS "RELEASE;DEBUG" - IMPORTED_LOCATION_RELEASE "${_tbb_release_lib}" - IMPORTED_LOCATION_DEBUG "${_tbb_debug_lib}" - INTERFACE_INCLUDE_DIRECTORIES "${_tbb_inc}") - - # Add internal dependencies for imported targets: TBB::tbbmalloc_proxy -> TBB::tbbmalloc - if (_tbb_component STREQUAL tbbmalloc_proxy) - set_target_properties(TBB::tbbmalloc_proxy PROPERTIES INTERFACE_LINK_LIBRARIES TBB::tbbmalloc) - endif() - - list(APPEND TBB_IMPORTED_TARGETS TBB::${_tbb_component}) - set(TBB_${_tbb_component}_FOUND 1) - endif() - break() - endif() -endforeach() -endforeach() - -if (NOT _lib_exists AND TBB_FIND_REQUIRED AND TBB_FIND_REQUIRED_${_tbb_component}) - message(FATAL_ERROR "Missed required Intel TBB component: ${_tbb_component}") -endif() - -unset(_tbb_x32_subdir) -unset(_tbb_x64_subdir) -unset(_tbb_arch_subdir) -unset(_tbb_compiler_subdir) -unset(_tbbmalloc_proxy_ix) -unset(_tbbmalloc_ix) -unset(_tbb_lib_path) -unset(_tbb_release_lib) -unset(_tbb_debug_lib) diff --git a/cmake/mac/TBBConfig.cmake b/cmake/mac/TBBConfig.cmake deleted file mode 100644 index 7bb9af865e2..00000000000 --- a/cmake/mac/TBBConfig.cmake +++ /dev/null @@ -1,127 +0,0 @@ -#=============================================================================== -# Copyright 2017-2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - -# TBB_FOUND should not be set explicitly. It is defined automatically by CMake. -# Handling of TBB_VERSION is in TBBConfigVersion.cmake. - -if (NOT TBB_FIND_COMPONENTS) - set(TBB_FIND_COMPONENTS "tbb;tbbmalloc;tbbmalloc_proxy") - foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(TBB_FIND_REQUIRED_${_tbb_component} 1) - endforeach() -endif() - -# Add components with internal dependencies: tbbmalloc_proxy -> tbbmalloc -list(FIND TBB_FIND_COMPONENTS tbbmalloc_proxy _tbbmalloc_proxy_ix) -if (NOT _tbbmalloc_proxy_ix EQUAL -1) - list(FIND TBB_FIND_COMPONENTS tbbmalloc _tbbmalloc_ix) - if (_tbbmalloc_ix EQUAL -1) - list(APPEND TBB_FIND_COMPONENTS tbbmalloc) - set(TBB_FIND_REQUIRED_tbbmalloc ${TBB_FIND_REQUIRED_tbbmalloc_proxy}) - endif() -endif() - -# oneDNN changes: use TBBROOT to locate Intel TBB -# get_filename_component(_tbb_root "${CMAKE_CURRENT_LIST_FILE}" PATH) -# get_filename_component(_tbb_root "${_tbb_root}" PATH) -if (NOT TBBROOT) - if(DEFINED ENV{TBBROOT}) - set (TBBROOT $ENV{TBBROOT}) - else() - message("FATAL_ERROR" "TBBROOT is unset") - endif() -endif() - -set(_tbb_root ${TBBROOT}) - -set(_tbb_x32_subdir .) -set(_tbb_x64_subdir .) - -if (CMAKE_SIZEOF_VOID_P EQUAL 8) - set(_tbb_arch_subdir ${_tbb_x64_subdir}) -else() - set(_tbb_arch_subdir ${_tbb_x32_subdir}) -endif() - -set(_tbb_compiler_subdir .) - -get_filename_component(_tbb_lib_path "${_tbb_root}/lib/${_tbb_arch_subdir}/${_tbb_compiler_subdir}" ABSOLUTE) - -if (TBB_FOUND) - return() -endif() - -foreach (_tbb_lib_version .12 "") -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(_tbb_release_lib "${_tbb_lib_path}/lib${_tbb_component}${_tbb_lib_version}.dylib") - set(_tbb_debug_lib "${_tbb_lib_path}/lib${_tbb_component}_debug${_tbb_lib_version}.dylib") - - # oneDNN change: check library existence (BUILD_MODE related only, not both) - string(TOUPPER "${CMAKE_BUILD_TYPE}" UPPERCASE_CMAKE_BUILD_TYPE) - if (UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "DEBUG") - if (EXISTS "${_tbb_debug_lib}") - set(_lib_exists TRUE) - elseif (EXISTS "${_tbb_release_lib}") - message(FATAL_ERROR - "Intel TBB release library is found here: ${_tbb_release_lib}. " - "But the debug library - (lib${_tbb_component}_debug${_tbb_lib_version}.dylib) is missing.") - endif() - else() - if (EXISTS "${_tbb_release_lib}") - set(_lib_exists TRUE) - endif() - endif() - - if (_lib_exists) - if (NOT TARGET TBB::${_tbb_component}) - add_library(TBB::${_tbb_component} SHARED IMPORTED) - set_target_properties(TBB::${_tbb_component} PROPERTIES - IMPORTED_CONFIGURATIONS "RELEASE;DEBUG" - IMPORTED_LOCATION_RELEASE "${_tbb_release_lib}" - IMPORTED_LOCATION_DEBUG "${_tbb_debug_lib}" - INTERFACE_INCLUDE_DIRECTORIES "${_tbb_root}/include") - - # Add internal dependencies for imported targets: TBB::tbbmalloc_proxy -> TBB::tbbmalloc - if (_tbb_component STREQUAL tbbmalloc_proxy) - set_target_properties(TBB::tbbmalloc_proxy PROPERTIES INTERFACE_LINK_LIBRARIES TBB::tbbmalloc) - endif() - - list(APPEND TBB_IMPORTED_TARGETS TBB::${_tbb_component}) - set(TBB_${_tbb_component}_FOUND 1) - endif() - break() - endif() -endforeach() -endforeach() - -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - if (NOT TARGET TBB::${_tbb_component} AND TBB_FIND_REQUIRED AND TBB_FIND_REQUIRED_${_tbb_component}) - message(FATAL_ERROR "Missed required Intel TBB component: ${_tbb_component}") - endif() -endforeach() - -unset(_tbb_x32_subdir) -unset(_tbb_x64_subdir) -unset(_tbb_arch_subdir) -unset(_tbb_compiler_subdir) -unset(_tbbmalloc_proxy_ix) -unset(_tbbmalloc_ix) -unset(_tbb_lib_path) -unset(_tbb_release_lib) -unset(_tbb_debug_lib) -unset(_tbb_lib_version) -unset(_lib_exists) diff --git a/cmake/options.cmake b/cmake/options.cmake index 0bb963ae24b..b73128d759d 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2018-2024 Intel Corporation +# Copyright 2018-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -81,12 +81,11 @@ set(DNNL_TEST_SET "CI" CACHE STRING semicolon separated string, e.g., DNNL_TEST_SET=CI;NO_CORR.") set(DNNL_INSTALL_MODE "DEFAULT" CACHE STRING - "specifies installation mode; supports DEFAULT, BUNDLE and BUNDLE_V2. + "specifies installation mode; supports DEFAULT and BUNDLE. - When BUNDLE or BUNDLE_V2 option is set oneDNN will be installed as a bundle - which contains examples and benchdnn. The difference between BUNDLE and - BUNDLE_V2 is in the directory layout.") -if (NOT "${DNNL_INSTALL_MODE}" MATCHES "^(DEFAULT|BUNDLE|BUNDLE_V2)$") + When BUNDLE option is set oneDNN will be installed as a bundle + which contains examples and benchdnn.") +if (NOT "${DNNL_INSTALL_MODE}" MATCHES "^(DEFAULT|BUNDLE)$") message(FATAL_ERROR "Unsupported install mode: ${DNNL_INSTALL_MODE}") endif() @@ -123,9 +122,9 @@ set(DNNL_ENABLE_PRIMITIVE "ALL" CACHE STRING - ALL (the default). Includes all primitives to be enabled. - . Includes only the selected primitive to be enabled. Possible values are: BATCH_NORMALIZATION, BINARY, CONCAT, CONVOLUTION, - DECONVOLUTION, ELTWISE, INNER_PRODUCT, LAYER_NORMALIZATION, LRN, MATMUL, - POOLING, PRELU, REDUCTION, REORDER, RESAMPLING, RNN, SDPA, SHUFFLE, - SOFTMAX, SUM. + DECONVOLUTION, ELTWISE, GROUP_NORMALIZATION, INNER_PRODUCT, + LAYER_NORMALIZATION, LRN, MATMUL, POOLING, PRELU, REDUCTION, REORDER, + RESAMPLING, RNN, SDPA, SHUFFLE, SOFTMAX, SUM. - ;;... Includes only selected primitives to be enabled at build time. This is treated as CMake string, thus, semicolon is a mandatory delimiter between names. This is the way to specify several @@ -147,7 +146,7 @@ set(DNNL_ENABLE_PRIMITIVE_GPU_ISA "ALL" CACHE STRING implementations will always be available. Valid values: - ALL (the default). Includes all ISA to be enabled. - ;;... Includes only selected ISA to be enabled. - Possible values are: GEN9, GEN11, XELP, XEHP, XEHPG, XEHPC, XE2.") + Possible values are: GEN9, GEN11, XELP, XEHP, XEHPG, XEHPC, XE2, XE3.") set(ONEDNN_ENABLE_GEMM_KERNELS_ISA "ALL" CACHE STRING "Specifies an ISA set of GeMM kernels residing in x64/gemm folder to be @@ -224,13 +223,6 @@ option(DNNL_EXPERIMENTAL_LOGGING independently from DNNL_EXPERIMENTAL." OFF) # disabled by default -option(ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND - "builds oneDNN Graph API graph-compiler backend" OFF) -set(ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_LLVM_CONFIG "AUTO" CACHE STRING - "graph-compiler's llvm-config path") -set(ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT "builtin" CACHE STRING - "the optional JIT backends for graph-compiler: llvm;c;builtin") - # ====================== # Profiling capabilities # ====================== @@ -262,7 +254,7 @@ set(DNNL_CPU_RUNTIME "OMP" CACHE STRING To use Threading Building Blocks (TBB) one should also set TBBROOT (either environment variable or CMake option) to the library location.") -if(NOT "${DNNL_CPU_RUNTIME}" MATCHES "^(NONE|OMP|TBB|SEQ|THREADPOOL|DPCPP|SYCL)$") +if(NOT "${DNNL_CPU_RUNTIME}" MATCHES "^(NONE|OMP|TBB|TBB_AUTO|SEQ|THREADPOOL|DPCPP|SYCL)$") message(FATAL_ERROR "Unsupported CPU runtime: ${DNNL_CPU_RUNTIME}") endif() @@ -381,6 +373,8 @@ set(DNNL_USE_CLANG_TIDY "NONE" CACHE STRING - NONE (default) Clang-tidy is disabled. - CHECK + Enables checks from .clang-tidy for source code + - CHECK_ALL Enables checks from .clang-tidy. - FIX Enables checks from .clang-tidy and fix found issues. @@ -419,8 +413,11 @@ set(DNNL_BLAS_VENDOR "NONE" CACHE STRING # AArch64 optimizations with Arm Compute Library # ============================================== -option(DNNL_AARCH64_USE_ACL "Enables use of AArch64 optimised functions +option(DNNL_USE_ACL "Enables use of ARM optimised functions from Arm Compute Library. This is only supported on AArch64 builds and assumes there is a functioning Compute Library build available at the location specified by the environment variable ACL_ROOT_DIR." OFF) + +option(DNNL_XBYAK_NO_EXCEPTION + "Enables XBYAK_NO_EXCEPTION" ON) # enabled by default diff --git a/cmake/platform.cmake b/cmake/platform.cmake index fc8a7c13e35..aa06aaef7ca 100644 --- a/cmake/platform.cmake +++ b/cmake/platform.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2016-2024 Intel Corporation +# Copyright 2016-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -78,11 +78,12 @@ macro(platform_gnu_nowarn_ccxx_flags var gnu_version) append(${var} "-Wno-strict-overflow") # suppress false positive warnings about uninitialized variables append(${var} "-Wno-maybe-uninitialized") - # suppress false positive warnings with 10.x: GCC Bugzilla – Bug 96963 + # suppress false positive warnings with 9.x+: GCC Bugzilla – Bug 96963 # assume 0.0 is unknown version - always suppress the warning if(${gnu_version} VERSION_EQUAL 0.0 OR - (${gnu_version} VERSION_GREATER 10.0 AND ${gnu_version} VERSION_LESS 11.0)) + ${gnu_version} VERSION_GREATER 9.0) append(${var} "-Wno-stringop-overflow") + append(${var} "-Wno-array-bounds") endif() endmacro() @@ -119,12 +120,26 @@ endif() if(MSVC) set(USERCONFIG_PLATFORM "x64") append_if(DNNL_WERROR CMAKE_CCXX_FLAGS "/WX") + + # Generating frame pointers for easier performance profiling + if(DNNL_TARGET_ARCH STREQUAL "X64") + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + append(CMAKE_CCXX_FLAGS "-fno-omit-frame-pointer -mno-omit-leaf-frame-pointer") + else() + append(CMAKE_CCXX_FLAGS "/Oy-") + endif() + endif() + if(${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) append(CMAKE_CCXX_FLAGS "/MP") # increase number of sections in obj file append(CMAKE_CCXX_FLAGS "/bigobj") # make preprocessor standard compliant append(CMAKE_CCXX_FLAGS "/Zc:preprocessor") + # Set UTF-8 as default encoding to be consistent with other compilers + append(CMAKE_CCXX_FLAGS "/utf-8") + # Enable __cplusplus macro to align behavior with other compilers + append(CMAKE_CCXX_FLAGS "/Zc:__cplusplus") # int64_t -> int (tent) append(CMAKE_CCXX_NOWARN_FLAGS "/wd4244") # workaround: macro outputs defined token in msvs header @@ -152,7 +167,7 @@ if(MSVC) # disable: icpc deprecation notice append(CMAKE_CXX_FLAGS_DEBUG "-Qdiag-disable:10441") endif() - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + if(CMAKE_CXX_COMPILER_ID MATCHES "(Apple)?[Cc]lang") append(CMAKE_CCXX_NOEXCEPT_FLAGS "-fno-exceptions") # Clang cannot vectorize some loops with #pragma omp simd and gets # very upset. Tell it that it's okay and that we love it @@ -227,14 +242,22 @@ elseif(UNIX OR MINGW) append(CMAKE_CCXX_NOWARN_FLAGS "-Wno-recommended-option") # Older compiler versions may not support "-Wno-recommended-option". append(CMAKE_CCXX_FLAGS "-Wno-unknown-warning-option") + + # Align with GCC -Wall + append(CMAKE_CCXX_FLAGS "-Wsign-compare") + endif() + + # Generating frame pointers for easier performance profiling + if(DNNL_TARGET_ARCH STREQUAL "X64") + append(CMAKE_CCXX_FLAGS "-fno-omit-frame-pointer -mno-omit-leaf-frame-pointer") endif() platform_unix_and_mingw_common_ccxx_flags(CMAKE_CCXX_FLAGS) platform_unix_and_mingw_common_cxx_flags(CMAKE_CXX_FLAGS) platform_unix_and_mingw_noexcept_ccxx_flags(CMAKE_CMAKE_CCXX_NOEXCEPT_FLAGS) # compiler specific settings - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - if(DNNL_TARGET_ARCH STREQUAL "AARCH64") + if(CMAKE_CXX_COMPILER_ID MATCHES "(Apple)?[Cc]lang") + if(DNNL_TARGET_ARCH MATCHES "^(AARCH64|ARM)$") if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug") set(DEF_ARCH_OPT_FLAGS "-O3") endif() @@ -276,8 +299,11 @@ elseif(UNIX OR MINGW) if(DNNL_USE_CLANG_SANITIZER STREQUAL "MemoryWithOrigin") append(CMAKE_CCXX_SANITIZER_FLAGS "-fsanitize-memory-track-origins=2") - append(CMAKE_CCXX_SANITIZER_FLAGS - "-fno-omit-frame-pointer") + # Already enabled for x64 + if(NOT DNNL_TARGET_ARCH STREQUAL "X64") + append(CMAKE_CCXX_SANITIZER_FLAGS + "-fno-omit-frame-pointer") + endif() endif() set(DNNL_ENABLED_CLANG_SANITIZER "${DNNL_USE_CLANG_SANITIZER}") elseif(DNNL_USE_CLANG_SANITIZER STREQUAL "Undefined") @@ -302,25 +328,35 @@ elseif(UNIX OR MINGW) message(STATUS "Using Clang ${DNNL_ENABLED_CLANG_SANITIZER} " "sanitizer (experimental!)") - append(CMAKE_CCXX_SANITIZER_FLAGS "-g -fno-omit-frame-pointer") + append(CMAKE_CCXX_SANITIZER_FLAGS "-g") + # Already enabled for x64 + if(NOT DNNL_TARGET_ARCH STREQUAL "X64") + append(CMAKE_CCXX_SANITIZER_FLAGS "-fno-omit-frame-pointer") + endif() + # Blacklist to ignore false-positive cases. Each case may be # assigned to a specific sanitizer. See online doc for help. append(CMAKE_CCXX_SANITIZER_FLAGS "-fsanitize-blacklist=${PROJECT_SOURCE_DIR}/.clang-ignorelist") endif() - if (DNNL_USE_CLANG_TIDY MATCHES "(CHECK|FIX)" AND ${CMAKE_VERSION} VERSION_LESS "3.6.0") - message(FATAL_ERROR "Using clang-tidy requires CMake 3.6.0 or newer") - elseif(DNNL_USE_CLANG_TIDY MATCHES "(CHECK|FIX)") + if(DNNL_USE_CLANG_TIDY MATCHES "(CHECK|CHECK_ALL|FIX)") find_program(CLANG_TIDY NAMES clang-tidy) if(NOT CLANG_TIDY) message(FATAL_ERROR "Clang-tidy not found") else() + # FIXME: Remove --header-filter option once clang-tidy warnings + # are addressed if(DNNL_USE_CLANG_TIDY STREQUAL "CHECK") + set(CMAKE_CXX_CLANG_TIDY ${CLANG_TIDY} + --header-filter='') + message(STATUS "Using clang-tidy to run checks for source") + elseif(DNNL_USE_CLANG_TIDY STREQUAL "CHECK_ALL") set(CMAKE_CXX_CLANG_TIDY ${CLANG_TIDY}) - message(STATUS "Using clang-tidy to run checks") + message(STATUS "Using clang-tidy to run checks for source and headers") elseif(DNNL_USE_CLANG_TIDY STREQUAL "FIX") - set(CMAKE_CXX_CLANG_TIDY ${CLANG_TIDY} -fix) + set(CMAKE_CXX_CLANG_TIDY ${CLANG_TIDY} + -fix) message(STATUS "Using clang-tidy to run checks and fix found issues") endif() endif() @@ -333,13 +369,7 @@ elseif(UNIX OR MINGW) append(CMAKE_CCXX_FLAGS "-Wno-ignored-attributes") endif() - # XXX: Suppress an erroneous warning of nested lambda visibility - # exceeding that of the containing class (GCC Bugzilla - Bug 80947). - if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8 AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 6.0) - append(CMAKE_CCXX_FLAGS "-Wno-attributes") - endif() - - if(DNNL_TARGET_ARCH STREQUAL "AARCH64") + if(DNNL_TARGET_ARCH MATCHES "^(AARCH64|ARM)$") if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug") set(DEF_ARCH_OPT_FLAGS "-O3") endif() @@ -418,8 +448,7 @@ if(DNNL_ARCH_OPT_FLAGS STREQUAL "HostOpts") set(DNNL_ARCH_OPT_FLAGS "${DEF_ARCH_OPT_FLAGS}") endif() -append(CMAKE_C_FLAGS "${CMAKE_CCXX_FLAGS} ${DNNL_ARCH_OPT_FLAGS}") -append(CMAKE_CXX_FLAGS "${CMAKE_CCXX_FLAGS} ${DNNL_ARCH_OPT_FLAGS}") +append(CMAKE_CCXX_FLAGS "${DNNL_ARCH_OPT_FLAGS}") if(APPLE) set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) @@ -435,8 +464,18 @@ endif() if (DNNL_TARGET_ARCH STREQUAL "RV64") # Check if the RVV Intrinsics can be compiled with the current toolchain and flags include(CheckCXXSourceCompiles) - check_cxx_source_compiles("#include - int main() { return 0; };" + check_cxx_source_compiles("#if !defined(__riscv) || !defined(__riscv_v) + #error \"RISC-V or vector extension(RVV) is not supported by the compiler\" + #endif + + #if defined(__riscv_v_intrinsic) && __riscv_v_intrinsic < 12000 + #error \"RISC-V intrinsics v0.12 or higher is required\" + #endif + + #include + int main() { + return 0; + };" CAN_COMPILE_RVV_INTRINSICS ) # set CAN_COMPILE_RVV_INTRINSICS to TRUE / FALSE instead of 1 / "" (Undefined) @@ -454,3 +493,6 @@ if (DNNL_TARGET_ARCH STREQUAL "RV64") message(STATUS "Can compile RVV Intrinsics: ${CAN_COMPILE_RVV_INTRINSICS}") message(STATUS "DNNL_RISCV_USE_RVV_INTRINSICS: ${DNNL_RISCV_USE_RVV_INTRINSICS}") endif() + +append(CMAKE_C_FLAGS "${CMAKE_CCXX_FLAGS}") +append(CMAKE_CXX_FLAGS "${CMAKE_CCXX_FLAGS}") diff --git a/cmake/testing.cmake b/cmake/testing.cmake index a002920fcbe..8d9b3f7fd16 100644 --- a/cmake/testing.cmake +++ b/cmake/testing.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2020-2024 Intel Corporation +# Copyright 2020-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ set(DNNL_TEST_SET_COVERAGE "0") set(DNNL_TEST_SET_COVERAGE_STR "") set(DNNL_TEST_SET_HAS_NO_CORR "0") set(DNNL_TEST_SET_HAS_ADD_BITWISE "0") +set(DNNL_TEST_SET_HAS_GRAPH_EXE "0") function(check_consistency entry) if(NOT DNNL_TEST_SET_COVERAGE EQUAL 0) @@ -57,6 +58,8 @@ foreach(entry ${DNNL_TEST_SET}) set(DNNL_TEST_SET_HAS_NO_CORR "1") elseif(entry STREQUAL "ADD_BITWISE") set(DNNL_TEST_SET_HAS_ADD_BITWISE "1") + elseif(entry STREQUAL "GRAPH_EXE") + set(DNNL_TEST_SET_HAS_GRAPH_EXE "1") elseif(entry STREQUAL "CI_NO_CORR") # Left here for compatibility till v4.0 set(DNNL_TEST_SET_COVERAGE ${DNNL_TEST_SET_CI}) set(DNNL_TEST_SET_COVERAGE_STR "CI") @@ -68,7 +71,7 @@ foreach(entry ${DNNL_TEST_SET}) message(FATAL_ERROR "The DNNL_TEST_SET entry ${entry} is not recognized. " "Supported values are:" - "NIGHTLY, CI, SMOKE, NO_CORR, ADD_BITWISE.") + "NIGHTLY, CI, SMOKE, NO_CORR, ADD_BITWISE, GRAPH_EXE.") endif() endforeach() @@ -79,3 +82,6 @@ endif() if(DNNL_TEST_SET_HAS_ADD_BITWISE EQUAL 1) message(STATUS "Enabled testing modifier: Add bitwise validation") endif() +if(DNNL_TEST_SET_HAS_GRAPH_EXE EQUAL 1) + message(STATUS "Enabled testing modifier: Use graph execution") +endif() diff --git a/cmake/utils.cmake b/cmake/utils.cmake index fdd6b2c95dc..05d55f5ccc9 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -142,35 +142,16 @@ macro(append_to_windows_path_list path_list path) endif() endmacro() -function(target_link_libraries_build target list) - # Foreach is required for compatibility with 2.8.11 ways - foreach(lib ${list}) - target_link_libraries(${target} LINK_PUBLIC - "$") - endforeach(lib) -endfunction() - +# Strip paths from libraries before populating INSTALL_INTERFACE function(target_link_libraries_install target list) - # Foreach is required for compatibility with 2.8.11 ways foreach(lib ${list}) get_filename_component(base "${lib}" NAME) - target_link_libraries(${target} LINK_PUBLIC - "$") + target_link_libraries(${target} PUBLIC "$") endforeach(lib) endfunction() function(find_libm var) - # This is to account for the linker cache in OSX11. might work - # with lower than 3.9.4, but was not able to test with anything - # between 2.8 and 3.9. See here for more details: - # https://gitlab.kitware.com/cmake/cmake/-/issues/20863 - if (APPLE AND (${CMAKE_HOST_SYSTEM_VERSION} VERSION_GREATER "20.0.0") - AND (${CMAKE_VERSION} VERSION_LESS "3.9.4")) - message(INFO "Using OSX11 and above with CMAKE older than 3.18 can cause linking issues.") - set(OSX11_AND_OLDER_CMAKE TRUE) - endif() - - if(UNIX AND (NOT (APPLE AND OSX11_AND_OLDER_CMAKE))) + if(UNIX) find_library(${var} m REQUIRED) endif() endfunction() diff --git a/cmake/win/TBBConfig.cmake b/cmake/win/TBBConfig.cmake deleted file mode 100644 index 623147f53ac..00000000000 --- a/cmake/win/TBBConfig.cmake +++ /dev/null @@ -1,164 +0,0 @@ -#=============================================================================== -# Copyright 2017-2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#=============================================================================== - -# TBB_FOUND should not be set explicitly. It is defined automatically by CMake. -# Handling of TBB_VERSION is in TBBConfigVersion.cmake. - -if (NOT TBB_FIND_COMPONENTS) - set(TBB_FIND_COMPONENTS "tbb;tbbmalloc;tbbmalloc_proxy") - foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(TBB_FIND_REQUIRED_${_tbb_component} 1) - endforeach() -endif() - -# Add components with internal dependencies: tbbmalloc_proxy -> tbbmalloc -list(FIND TBB_FIND_COMPONENTS tbbmalloc_proxy _tbbmalloc_proxy_ix) -if (NOT _tbbmalloc_proxy_ix EQUAL -1) - list(FIND TBB_FIND_COMPONENTS tbbmalloc _tbbmalloc_ix) - if (_tbbmalloc_ix EQUAL -1) - list(APPEND TBB_FIND_COMPONENTS tbbmalloc) - set(TBB_FIND_REQUIRED_tbbmalloc ${TBB_FIND_REQUIRED_tbbmalloc_proxy}) - endif() -endif() - -# oneDNN changes: use TBBROOT to locate Intel TBB -# get_filename_component(_tbb_root "${CMAKE_CURRENT_LIST_FILE}" PATH) -# get_filename_component(_tbb_root "${_tbb_root}" PATH) -if (NOT TBBROOT) - if(DEFINED ENV{TBBROOT}) - set (TBBROOT $ENV{TBBROOT}) - else() - message("FATAL_ERROR" "TBBROOT is unset") - endif() -endif() - -set(_tbb_root ${TBBROOT}) - -set(_tbb_x32_subdir ia32) -set(_tbb_x64_subdir intel64) - -if (CMAKE_SIZEOF_VOID_P EQUAL 8) - set(_tbb_arch_subdir ${_tbb_x64_subdir}) -else() - set(_tbb_arch_subdir ${_tbb_x32_subdir}) -endif() - -# Workaround: 3.19.0 and 3.19.1 versions don't define MSVC_VERSION. -# The workaround is to assume that vc14 is used. -set(_tbb_detect_msvc_version FALSE) -if (NOT ${CMAKE_VERSION} VERSION_EQUAL "3.19.0" AND NOT ${CMAKE_VERSION} VERSION_EQUAL "3.19.1") - set(_tbb_detect_msvc_version TRUE) -endif() - -# Detect the most relevant MSVC subdirectory -set(_tbb_msvc_1700_subdir vc11) -set(_tbb_msvc_1800_subdir vc12) -set(_tbb_msvc_1900_subdir vc14) - -# oneDNN changes: if the project is not with MSVC, try to use MSVC 1900 -set(_tbb_msvc_ver 1900) - -if (_tbb_detect_msvc_version) - if (MSVC) - set(_tbb_msvc_ver ${MSVC_VERSION}) - endif() - if (MSVC_VERSION VERSION_LESS 1700) - message(FATAL_ERROR "This Intel TBB package is intended to be used only in the project with MSVC version 1700 (vc11) or higher") - elseif (MSVC_VERSION VERSION_GREATER 1900) - set(_tbb_msvc_ver 1900) - endif() -endif() -set(_tbb_compiler_subdir ${_tbb_msvc_${_tbb_msvc_ver}_subdir}) -unset(_tbb_msvc_1700_subdir) -unset(_tbb_msvc_1800_subdir) -unset(_tbb_msvc_1900_subdir) - -if (WINDOWS_STORE) - set(_tbb_compiler_subdir ${_tbb_compiler_subdir}_ui) -endif() - -#set conveniance variable to locate TBB files (these are used for a PSXE install) -get_filename_component(_tbb_lib_path "${_tbb_root}/lib/${_tbb_arch_subdir}/${_tbb_compiler_subdir}" ABSOLUTE) -get_filename_component(_tbb_inc_path "${_tbb_root}/include/" ABSOLUTE) - -if (TBB_FOUND) - return() -endif() - -foreach (_tbb_lib_version 12 "") -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - set(_tbb_release_lib "${_tbb_lib_path}/${_tbb_component}${_tbb_lib_version}.lib") - set(_tbb_debug_lib "${_tbb_lib_path}/${_tbb_component}${_tbb_lib_version}_debug.lib") - - # oneDNN change: check library existence (BUILD_MODE related only, not both) - string(TOUPPER "${CMAKE_BUILD_TYPE}" UPPERCASE_CMAKE_BUILD_TYPE) - if (UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "DEBUG") - if (EXISTS "${_tbb_debug_lib}") - set(_lib_exists TRUE) - elseif (EXISTS "${_tbb_release_lib}") - message(FATAL_ERROR - "Intel TBB release library is found here: ${_tbb_release_lib}. " - "But the debug library - (lib${_tbb_component}${tbb_lib_version}_debug.lib) is missing.") - endif() - else() - if (EXISTS "${_tbb_release_lib}") - set(_lib_exists TRUE) - endif() - endif() - - if (_lib_exists) - if (NOT TARGET TBB::${_tbb_component}) - add_library(TBB::${_tbb_component} SHARED IMPORTED) - set_target_properties(TBB::${_tbb_component} PROPERTIES - IMPORTED_CONFIGURATIONS "RELEASE;DEBUG" - IMPORTED_LOCATION_RELEASE "${_tbb_release_lib}" - IMPORTED_LOCATION_DEBUG "${_tbb_debug_lib}" - INTERFACE_INCLUDE_DIRECTORIES "${_tbb_inc_path}" - IMPORTED_IMPLIB_RELEASE "${_tbb_release_lib}" - IMPORTED_IMPLIB_DEBUG "${_tbb_debug_lib}" - INTERFACE_COMPILE_DEFINITIONS "__TBB_NO_IMPLICIT_LINKAGE=1") - - # Add internal dependencies for imported targets: TBB::tbbmalloc_proxy -> TBB::tbbmalloc - if (_tbb_component STREQUAL tbbmalloc_proxy) - set_target_properties(TBB::tbbmalloc_proxy PROPERTIES INTERFACE_LINK_LIBRARIES TBB::tbbmalloc) - endif() - - list(APPEND TBB_IMPORTED_TARGETS TBB::${_tbb_component}) - set(TBB_${_tbb_component}_FOUND 1) - endif() - break() - endif() -endforeach() -endforeach() - -foreach (_tbb_component ${TBB_FIND_COMPONENTS}) - if (NOT TARGET TBB::${_tbb_component} AND TBB_FIND_REQUIRED AND TBB_FIND_REQUIRED_${_tbb_component}) - message(FATAL_ERROR "Missed required Intel TBB component: ${_tbb_component}") - endif() -endforeach() - -unset(_tbb_x32_subdir) -unset(_tbb_x64_subdir) -unset(_tbb_arch_subdir) -unset(_tbb_compiler_subdir) -unset(_tbbmalloc_proxy_ix) -unset(_tbbmalloc_ix) -unset(_tbb_lib_path) -unset(_tbb_release_lib) -unset(_tbb_debug_lib) -unset(_tbb_lib_version) -unset(_lib_exists) diff --git a/doc/advanced/experimental.md b/doc/advanced/experimental.md index 0f55dfc0243..b3464c75871 100644 --- a/doc/advanced/experimental.md +++ b/doc/advanced/experimental.md @@ -22,14 +22,14 @@ Both kinds of experimental features can be enabled simultaneously. | Environment variable | Description | |:-----------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------| -| ONEDNN_EXPERIMENTAL_BNORM_STATS_ONE_PASS | Calculate mean and variance in batch normalization(BN) in single pass ([RFC](https://github.com/oneapi-src/oneDNN/tree/rfcs/rfcs/20210519-single-pass-bnorm)). | +| ONEDNN_EXPERIMENTAL_BNORM_STATS_ONE_PASS | Calculate mean and variance in batch normalization(BN) in single pass ([RFC](https://github.com/uxlfoundation/oneDNN/tree/rfcs/rfcs/20210519-single-pass-bnorm)). | +| ONEDNN_EXPERIMENTAL_GPU_CONV_V2 | Enable shapeless GPU convolution implementation (the feature is under development). | | Build time option | Description | |:-------------------------------------------|:-------------------------------------------------------------------| | ONEDNN_EXPERIMENTAL_SPARSE | Enable experimental API and functionality for sparse domain. | | ONEDNN_EXPERIMENTAL_UKERNEL | Enable experimental microkernel APIs and functionalities. | | ONEDNN_EXPERIMENTAL_PROFILING | Enable experimental profiling API. | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND | Enable experimental graph compiler backend of the graph component. | | ONEDNN_EXPERIMENTAL_LOGGING | Enable experimental logging support for oneDNN verbose mode. | ## Features details @@ -55,25 +55,29 @@ of buffers. The order of the buffers in the vector matters and should correspond the buffers' indices. oneDNN also introduces a new format kind dnnl::memory::format_kind::sparse. -Sparse encoding (a.k.a. sparse format) is an -enumeration type that specifies how data is encoded. Currently, oneDNN -supports CSR (Compressed Sparse Row) and PACKED sparse encodings -(dnnl::memory::sparse_encoding::csr, dnnl::memory::sparse_encoding_packed). +Sparse encoding (a.k.a. sparse format) is an enumeration type that specifies +how data is encoded. Currently, oneDNN supports Compressed Sparse Row (CSR), +Sorted Co-ordinate (COO) Sparse Format, and PACKED sparse encodings +(dnnl::memory::sparse_encoding::csr, dnnl::memory::sparse_encoding::coo, +dnnl::memory::sparse_encoding::packed) for CPU engine, and, only sorted +COO (Co-ordinate Sparse Format) for GPU engine. The memory descriptor has dedicated static member functions for creating memory descriptors for different sparse encodings. Each encoding defines the number and meaning of the buffers. -| Sparse encoding | Buffers | -|:----------------|:----------------------------------------| -| CSR | 0 - values, 1 - indices, 2 - pointers | -| PACKED | The meaning and content are unspecified | +| Sparse encoding | Buffers | +|:----------------|:---------------------------------------------------------------------------| +| CSR | 0 - values, 1 - indices, 2 - pointers | +| Sorted COO | 0 - values, 1 to *ndims* - indices (*ndims* - number of tensor dimensions) | +| PACKED | The meaning and content are unspecified | -The pseudo-code below demonstrates how to create a memory object -for CSR sparse encoding and use the new API to work with the +The pseudocode below demonstrates how to create a memory object +for the CSR and COO sparse encodings and use the new API to work with the underlying handles. +###### CSR Encoding: ~~~cpp using namespace dnnl; const memory::dim M = 4, N = 6; @@ -119,6 +123,49 @@ underlying handles. assert(pointers_handle == (void *)csr_pointers.data()); ~~~ +###### Sorted COO Encoding: +~~~cpp + using namespace dnnl; + const memory::dim M = 4, N = 6; + const memory::dim nnz = 5; + const auto values_dt = memory::data_type::f32; + const auto indices_dt = memory::data_type::s32; + + // Create a memory descriptor for COO sparse encoding. + const auto coo_md = memory::desc::coo( + {M, N}, // Dimensions + values_dt, // Data type of values + nnz, // Number of non-zero entries + indices_dt); // Data type of indices (metadata) + + // A sparse matrix represented in the COO format. + std::vector coo_values = {2.5f, 1.5f, 1.5f, 2.5f, 2.0f}; + std::vector coo_row_indices = {0, 1, 2, 2, 3}; + std::vector coo_col_indices = {0, 2, 0, 5, 1}; + + // Create a memory object for the given buffers with values and metadata. + memory coo_mem(coo_md, engine, { + coo_values.data(), // Buffer with values + coo_row_indices.data(), // Buffer with row indices (metadata) + coo_col_indices.data() // Buffer with column indices (metadata) + }); + + const auto values_sz = coo_mem.get_size(0); + const auto indices_sz = coo_mem.get_size(1); + + assert(values_sz == coo_values.size() * sizeof(float)); + assert(indices_sz == coo_row_indices.size() * sizeof(int32_t)); + assert(indices_sz == coo_col_indices.size() * sizeof(int32_t)); + + void *values_handle = coo_mem.get_data_handle(0); + void *row_indices_handle = coo_mem.get_data_handle(1); + void *col_indices_handle = coo_mem.get_data_handle(2); + + assert(values_handle == (void *)coo_values.data()); + assert(row_indices_handle == (void *)coo_row_indices.data()); + assert(col_indices_handle == (void *)coo_col_indices.data()); +~~~ + A memory descriptor created for the sparse encoding PACKED cannot be used to create a memory object. It can only be used to create a primitive descriptor to query the actual memory descriptor @@ -132,14 +179,15 @@ This option enables the matmul primitive that can work with sparse input tensors. ###### CSR encoding -Only one of the input tensors is allowed to be sparse. The -output tensor is always dense. +Supported only for the CPU engine. Only one of the input tensors can be sparse. +The output tensor is always dense. -The following data types combinations are supported: +The following data type combinations are supported: -| Values | Indices | Pointers | -|:-------|:--------|:---------| -| f32 | s32 | s32 | +| Values (src, weight, dst) | Indices | +|:----------------------------|:---------| +| f16, f16, f16 | s32 | +| f32, f32, f32 | s32 | The following format tags are supported for dense input/output tensors: @@ -154,6 +202,34 @@ Benchdnn can be used to test matmul with a CSR input tensor as follows: For the case above, the number of non-zero elements for the source tensor is calculated as max(4 * 1000000 * (1 - 0.99), 1). +###### COO encoding +Supported only for the CPU and GPU engines. Only one of the input tensors can +be sparse. The output tensor is always dense. + +The following data type combinations are supported: + +| Values (src, weight, dst) | Indices | +|:----------------------------|:---------| +| f16, f16, f16 | s32 | +| f32, f32, f32 | s32 | + +The following format tags are supported for dense weights tensor: + +* ab +* ba + +The following format tags are supported for dense destination tensor: + +* ab + +See the example [here](@ref cpu_matmul_coo_cpp). + +Benchdnn can be used to test matmul with a COO input tensor as follows: +`./benchdnn --matmul --encoding=coo+0.99:: --wtag=ab --dtag=ab 4x1000000:1000000x128` + +For the case above, the number of non-zero elements for the source tensor is +calculated as max(4 * 1000000 * (1 - 0.99), 1). + ###### PACKED encoding Only the weights tensor is allowed to be sparse. The other tensors @@ -164,6 +240,7 @@ scales, zero-points, etc) that is supported for the dense weights should also work for the sparse weights. Currently, matmul has the following limitations for the PACKED encoding: +* Supported only for the CPU engine * Only Intel Advanced Matrix Extensions (Intel AMX) instruction set architecture (ISA) is supported * Only `s8` data type for the weights is supported @@ -188,11 +265,10 @@ In general, it is expected that all reorder-related functionality destination tensor should also work for the sparse one. #### Common Limitations -* This functionality is not supported for SYCL and OpenCL runtimes -* The interoperability API for sparse memory is not provided +* The interoperability API to get/set data handles is not supported. Use the +runtime agnostic API to do that. * Sparse memory and memory descriptor can only be used with the Matrix -Multiplication and Reorder primitives -* Sparse memory can be created only for a CPU engine +Multiplication and Reorder primitives. ### ONEDNN_EXPERIMENTAL_UKERNEL @@ -257,11 +333,6 @@ user-provided queue. * Only Intel vendor is supported for SYCL runtime * Out-of-order queue is not supported -### ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND -This option extends the coverage scope of the graph API to cover larger fusion -patterns apart from primitive patterns. Refer to -[Graph Compiler](@ref dev_guide_graph_compiler) for more details. - @warning - Enabling some experimental features does not guarantee that the library will utilize them - Enabling some experimental features might change the accuracy of oneDNN primitives diff --git a/doc/advanced/understanding_memory_formats.md b/doc/advanced/understanding_memory_formats.md index 30d2f191179..c23cfeb124e 100644 --- a/doc/advanced/understanding_memory_formats.md +++ b/doc/advanced/understanding_memory_formats.md @@ -115,9 +115,9 @@ in this example. One can create memory with **NCHW** data layout using #dnnl_nchw of the enum type #dnnl_format_tag_t defined in -[dnnl_types.h](https://github.com/oneapi-src/oneDNN/blob/master/include/oneapi/dnnl/dnnl_types.h) +[dnnl_types.h](https://github.com/uxlfoundation/oneDNN/blob/main/include/oneapi/dnnl/dnnl_types.h) for the C API, and dnnl::memory::format_tag::nchw defined in -[dnnl.hpp](https://github.com/oneapi-src/oneDNN/blob/master/include/oneapi/dnnl/dnnl.hpp) +[dnnl.hpp](https://github.com/uxlfoundation/oneDNN/blob/main/include/oneapi/dnnl/dnnl.hpp) for the C++ API. diff --git a/doc/build/build.md b/doc/build/build.md index 6fb58bd6610..cdd1ea0c2d6 100644 --- a/doc/build/build.md +++ b/doc/build/build.md @@ -3,16 +3,16 @@ Build from Source {#dev_guide_build} ## Download the Source Code -Download [oneDNN source code](https://github.com/oneapi-src/oneDNN/archive/master.zip) -or clone [the repository](https://github.com/oneapi-src/oneDNN.git). +Download [oneDNN source code](https://github.com/uxlfoundation/oneDNN/archive/main.zip) +or clone [the repository](https://github.com/uxlfoundation/oneDNN.git). ~~~sh -git clone https://github.com/oneapi-src/oneDNN.git +git clone https://github.com/uxlfoundation/oneDNN.git ~~~ ## Build the Library -Ensure that all [software dependencies](https://github.com/oneapi-src/oneDNN#requirements-for-building-from-source) +Ensure that all [software dependencies](https://github.com/uxlfoundation/oneDNN#requirements-for-building-from-source) are in place and have at least the minimal supported version. The oneDNN build system is based on CMake. Use @@ -51,7 +51,7 @@ cmake .. - Build the library ~~~sh -make -j +make -j$(nproc) ~~~ #### Intel oneAPI DPC++/C++ Compiler with SYCL runtime @@ -86,7 +86,7 @@ it is installed in a custom location. - Build the library ~~~sh -make -j +make -j$(nproc) ~~~ #### GCC targeting AArch64 on x64 host @@ -106,7 +106,7 @@ cmake .. \ - Build the library ~~~sh -make -j +make -j$(nproc) ~~~ #### GCC with Arm Compute Library (ACL) on AArch64 host @@ -117,13 +117,13 @@ make -j ~~~sh export ACL_ROOT_DIR= cmake .. \ - -DDNNL_AARCH64_USE_ACL=ON \ + -DDNNL_USE_ACL=ON \ ~~~ - Build the library ~~~sh -make -j +make -j$(nproc) ~~~ ### Windows @@ -142,9 +142,13 @@ cmake -G "Visual Studio 16 2019" .. cmake --build . --config=Release ~~~ -@note CMake's Microsoft Visual Studio generator does not respect `CMAKE_BUILD_TYPE` option. -Solution file supports both Debug and Release builds with Debug being the default. -You can choose specific build type with `--config` option. +@note Currently, the oneDNN build system has limited support for multi-config + generators. Build configuration is based on the `CMAKE_BUILD_TYPE` option + (`Release` by default), and CMake must be rerun from scratch every time + the build type changes to apply the new build configuration. You can choose + a specific build type with the `--config` option (the solution file supports + both `Debug` and `Release` builds), but it must refer to the same build type + (`Release`, `Debug`, etc.) as selected with the `CMAKE_BUILD_TYPE` option. @note You can also open `oneDNN.sln` to build the project from the Microsoft Visual Studio IDE. diff --git a/doc/build/build_options.md b/doc/build/build_options.md index 2bcdede9ce2..a98310fb367 100644 --- a/doc/build/build_options.md +++ b/doc/build/build_options.md @@ -13,7 +13,6 @@ oneDNN supports the following build-time options. | ONEDNN_BUILD_TESTS | **ON**, OFF | Controls building the tests | | ONEDNN_BUILD_GRAPH | **ON**, OFF | Controls building graph component | | ONEDNN_ENABLE_GRAPH_DUMP | ON, **OFF** | Controls dumping graph artifacts | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND | ON, **OFF** | Enables the [graph compiler backend](@ref dev_guide_graph_compiler) of the graph component (experimental)| | ONEDNN_ARCH_OPT_FLAGS | *compiler flags* | Specifies compiler optimization flags (see warning note below) | | ONEDNN_ENABLE_CONCURRENT_EXEC | ON, **OFF** | Disables sharing a common scratchpad between primitives in #dnnl::scratchpad_mode::library mode | | ONEDNN_ENABLE_JIT_PROFILING | **ON**, OFF | Enables [integration with performance profilers](@ref dev_guide_profilers) | @@ -87,14 +86,14 @@ dependencies for forward propagation kind part. #### ONEDNN_ENABLE_PRIMITIVE This option supports several values: `ALL` (the default) which enables all primitives implementations or a set of `BATCH_NORMALIZATION`, `BINARY`, -`CONCAT`, `CONVOLUTION`, `DECONVOLUTION`, `ELTWISE`, `INNER_PRODUCT`, -`LAYER_NORMALIZATION`, `LRN`, `MATMUL`, `POOLING`, `PRELU`, `REDUCTION`, -`REORDER`, `RESAMPLING`, `RNN`, `SDPA`, `SHUFFLE`, `SOFTMAX`, `SUM`. When a set -is used, only those selected primitives implementations will be available. -Attempting to use other primitive implementations will end up returning an -unimplemented status when creating primitive descriptor. In order to specify a -set, a CMake-style string should be used, with semicolon delimiters, as in this -example: +`CONCAT`, `CONVOLUTION`, `DECONVOLUTION`, `ELTWISE`, `GROUP_NORMALIZATION`, +`INNER_PRODUCT`, `LAYER_NORMALIZATION`, `LRN`, `MATMUL`, `POOLING`, `PRELU`, +`REDUCTION`, `REORDER`, `RESAMPLING`, `RNN`, `SDPA`, `SHUFFLE`, `SOFTMAX`, +`SUM`. When a set is used, only those selected primitives implementations will +be available. Attempting to use other primitive implementations will end up +returning an unimplemented status when creating primitive descriptor. In order +to specify a set, a CMake-style string should be used, with semicolon +delimiters, as in this example: ``` -DONEDNN_ENABLE_PRIMITIVE=CONVOLUTION;MATMUL;REORDER ``` @@ -118,7 +117,7 @@ Example that enables SSE41 and AVX2 sets: #### ONEDNN_ENABLE_PRIMITIVE_GPU_ISA This option supports several values: `ALL` (the default) which enables all ISA implementations or any set of `GEN9`, `GEN11`, `XELP`, `XEHP`, `XEHPG`, -`XEHPC`, and `XE2`. Selected ISA will enable correspondent parts in +`XEHPC`, `XE2`, and `XE3`. Selected ISA will enable correspondent parts in just-in-time kernel generation based implementations. OpenCL based kernels and implementations will always be available. Example that enables XeLP and XeHP set: @@ -303,7 +302,7 @@ $ cmake -DONEDNN_BLAS_VENDOR=ARMPL .. Additional options available for development/debug purposes. These options are subject to change without notice, see -[`cmake/options.cmake`](https://github.com/oneapi-src/oneDNN/blob/master/cmake/options.cmake) +[`cmake/options.cmake`](https://github.com/uxlfoundation/oneDNN/blob/main/cmake/options.cmake) for details. ## GPU Options @@ -335,20 +334,3 @@ CMake error. |:------------------------|:-------------------| | ONEDNN_GPU_VENDOR | NVIDIA | | ONEDNN_ENABLE_PRIMITIVE | PRIMITIVE_NAME | - -## Graph Compiler Backend Limitations - -As a backend of the graph component, besides the options described in -[Graph component limitations](@ref component_limitation), graph compiler -backend has some extra limitations. Specifying unsupported build options will -lead to a CMake error. - -| CMake Option | Unsupported Values | -| :-----------------------| :------------------| -| ONEDNN_CPU_RUNTIME | THREADPOOL, SYCL | -| ONEDNN_GPU_RUNTIME | OCL, SYCL | - -Besides, the instructions contained in the kernels generated by the graph -compiler backend are [AVX512_CORE](@ref dev_guide_cpu_dispatcher_control) or -above, so these kernels will not be dispatched on systems that do not have -corresponding instruction sets support. diff --git a/doc/graph/experimental_graph_compiler.md b/doc/graph/experimental_graph_compiler.md deleted file mode 100644 index 487de765695..00000000000 --- a/doc/graph/experimental_graph_compiler.md +++ /dev/null @@ -1,161 +0,0 @@ -Graph Compiler {#dev_guide_graph_compiler} -========================================== - -oneDNN Graph Compiler is an experimental backend for oneDNN Graph API. It can -generate optimized implementations for complex computational graphs including -multi-head attention (MHA), multi-layer perceptron (MLP), and convolution -residual blocks over typical data types for both inference and training. It -also brings improved performance by providing more flexible operator fusion. - -Use of oneDNN Graph Compiler is transparent for applications, as it does not -involve API or programming model changes. - -## Build-Time Controls -The following build time options only work when -`ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND` is ON. - -| CMake Option | Supported values (defaults in bold) | Description | -| :--- | :--- | :--- | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT | llvm, c, **builtin** | Selects the CPU codegen and JIT to be built by graph compiler backend. Multiple codegen approaches can be used simultaneously. See the [example](@ref jit_options) for setting multiple codegen methods. | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_LLVM_CONFIG | **AUTO**, *path to llvm-config binary* | Defines the method for detecting and configuring LLVM. | - -@anchor jit_options -### Codegen and JIT Options -Graph compiler backend supports several different codegen and JIT options -including C, LLVM, and builtin (xbyak). Users can choose to build a subset of -available options by setting the `ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT` -option. - -~~~bash -cmake .. -DONEDNN_BUILD_GRAPH=ON -DONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND=ON -DONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT="c;builtin" -~~~ - -This will only build `c` and `builtin` codegen options. - -~~~bash -cmake .. -DONEDNN_BUILD_GRAPH=ON -DONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND=ON -DONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT="llvm;c;builtin" -~~~ - -This will build all three codegen options. - -#### C -C codegen generates temporary cpp files and adopts `g++` to compile them into -the executable. It can be used for debugging purposes as the generated code is -more friendly and readable to developers. - -#### LLVM -LLVM codegen generates LLVM-IR in memory. It provides the best performance -among all supported codegen methods. When LLVM codegen is chosen, extra LLVM -dependency is required. If LLVM does not exist in this case, a CMake error will -occur. - -Users can set `ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_LLVM_CONFIG` to specify -the LLVM to be integrated. By default, -`ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_LLVM_CONFIG` is set to `AUTO`, which -auto-detects existing LLVM in the environment. If auto-detection fails or user -wants to explicitly specify the version of LLVM, a specific path to -*llvm-config binary* shall be set. - -Users can follow the [guidelines](https://llvm.org/docs/GettingStarted.html#getting-the-source-code-and-building-llvm) -to build and install LLVM from source, or download and install the pre-built -binary from [here](https://apt.llvm.org/). - -@note **LLVM 10.0 or above** is required to enable LLVM codegen. - -#### Builtin -Builtin codegen and JIT method is implemented with xbyak technology inside. -Compared with C or LLVM codegen, it has no extra dependency. - -## Environment Variables -The following environment variables are introduced by the graph compiler -backend. - -| Environment Variable | Value | Description | -| :--- | :--- |:--- | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT | **llvm** | Uses LLVM as codegen and JIT method | -| | builtin | Uses builtin as codegen and JIT method | -| | c | Uses C as codegen and JIT method | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_OPT_LEVEL | 0 | Turns off optimization passes and sets the compilation optimization level to be 0 in C and LLVM JIT | -| | 1,2,**3** | Sets the compilation optimization level of C and LLVM JIT | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_KERNEL_TRACE | **0** | No kernel execution trace output | -| | 1,*stderr or filename.json* | Generates kernel execution trace to the file specified by the given filename with chrome tracing format | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_PRINT_PASS_RESULT | **0** | No IR output after each graph or tensor IR pass | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_PRINT_PASS_RESULT | 1 | Prints the output IR of each graph and tensor IR passes | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_VERBOSE | **0** | No verbose output | -| | 1 | Prints warning messages during compilation | -| | 2 | Prints warning messages and info logs (e.g. fusion-related information) during compilation | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_DUMP_GENCODE | *path_to_dump* | Dumps the generated kernel in C | -| ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_C_INCLUDE | *path_to_c_codegen_header* | Specifies the C codegen header for JIT compilation | - -### Enable Tracing - -~~~bash -ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_KERNEL_TRACE=1 ./application -~~~ - -This will produce a kernel execution trace in JSON format that will be -stored to the default destination: `./sctrace.json`. - -~~~bash -ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_KERNEL_TRACE=1,stderr ./application -~~~ - -This will dump a kernel execution trace to the *stderr* stream. - -~~~bash -ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_KERNEL_TRACE=1,/tmp/filename.json ./application -~~~ - -This will produce a kernel execution trace in JSON format that will be stored -to the user specified path `/tmp/filename.json`. - -### Switch Between Different Codegen Methods -By default, codegen methods have priorities ranked from higher to lower as -`llvm`, `c`, `builtin`. When multiple codegen and JIT methods are enabled at -build stage, the method with the highest priority is adopted at runtime by -default. - -Users can switch to a different codegen method at runtime by setting -`ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT`. - -~~~bash -ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT=builtin ./application -~~~ - -This will switch the CPU codegen and JIT method to `builtin` (xbyak). - -~~~bash -ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_CPU_JIT=c ./application -~~~ - -This will switch the CPU codegen and JIT method to `c`. - -When using C codegen option, the generated C code will rely on existing runtime -function declarations in `cpu_include.hpp`. -`ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_C_INCLUDE` environment variable is used to -specify the corresponding include path. -Normally, the include path is automatically set at CMake build stage. But if -the following error message occurs -`environment variable ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_C_INCLUDE is not set`, -users shall manually set `ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_C_INCLUDE` to -`/path_to_onednn_repo/src/graph/backend/graph_compiler/core/src`. - -@warning The specified codegen method must be built. Otherwise, the default -codegen method would be used. - -### Enable Code Dumping -Users can use `ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_DUMP_GENCODE` variable to -generate offline C kernels. - -~~~bash -ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_DUMP_GENCODE="./dump_code" ./application -~~~ - -This will dump the generated C kernels to `dump_code` folder. - -@warning `ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_DUMP_GENCODE` works under both LLVM -and C codegen. - -@warning The user specified `ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_DUMP_GENCODE` -path shall be an existing folder. Otherwise the code dumping will not be in -effect. diff --git a/doc/graph/fusion_patterns/gated_mlp.md b/doc/graph/fusion_patterns/gated_mlp.md new file mode 100644 index 00000000000..73611ab6e1f --- /dev/null +++ b/doc/graph/fusion_patterns/gated_mlp.md @@ -0,0 +1,123 @@ +Gated Multi-Layer Perceptron (Gated-MLP) {#dev_guide_graph_gated_mlp} +===================================================================== + +## Overview + +Gated Multi-Layer Perceptron (Gated-MLP) is a variant of MLP which is widely +used as the Feed Forward Network (FFN) in many Transformer-based Large Language +Models (LLMs). + +Typically, the FFN in Transformer architecture [1] is defined as a two layer MLP +with a ReLU activation in between which can be replaced with other activations. + +\f[ + + FFN(src,W,V) = ReLU(src \cdot W) \cdot V + +\f] + +Gated Linear Unit (GLU) is adopted to replace the first linear layer to +improve the quality of Transformer-based models [2]: + +\f[ + + GLU(src,W_1,W_2) = (src \cdot W_1) \otimes Sigmoid(src \cdot W_2) \\ + + FFN(src,W_1,W_2,V) = GLU(src,W_1,W_2) \cdot V + +\f] + +Where the \f$ src \cdot W_1 \f$ is usually called "FC (fully-connected) up", +\f$ src \cdot W_2 \f$ is called "FC gate", and the last linear is called +"FC down". + +Swish activation is further adopted to replace Sigmoid in the GLU to form +swiGLU. + +\f[ + + Swish(x) = x \otimes Sigmoid(x) \\ + + swiGLU(src,W_1,W_2) = (src \cdot W_1) \otimes Swish(src \cdot W_2) \\ + + FFN(src,W_1,W_2,V) = swiGLU(src,W_1,W_2) \cdot V + +\f] + +The Gated-MLP based on swiGLU is also adopted in LLMs like LLaMA [3], Qwen [4], +etc. + +## Gated-MLP patterns + +oneDNN supports Gated-MLP and its optimization through Graph API [5] by defining +the graph, getting partition from the graph, and optimizing the kernels +underneath. In general, a Gated-MLP pattern is defined as a directional acyclic +graph (DAG) using oneDNN Graph API. + +### Floating-point Gated-MLP + +oneDNN defines floating-point (f32, bf16, and f16) Gated-MLP as follows. The blue +nodes are required when defining a Gated-MLP pattern while the brown nodes are +optional. + +![Gated-MLP pattern](images/fp-gated-mlp.png) + +1. The first MatMul on the top left calculates "FC up": \f$ src \cdot W_1 \f$. + See [MatMul](@ref dev_guide_op_matmul) operation in Graph API. +2. The second MatMul on the top right calculates "FC gate": \f$ src \cdot W_2 \f$. +3. The Activation node is optional. If required, it can be constructed with the + activation operations in Graph API, for example, [ReLU](@ref dev_guide_op_relu), + [GELU](@ref dev_guide_op_gelu), [Sigmoid](@ref dev_guide_op_sigmoid), and so on. + For Swish activation, the node can be constructed with the [Sigmoid](@ref dev_guide_op_sigmoid) + and [Multiply](@ref dev_guide_op_multiply) as below. You can also refer the + [Gated-MLP example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gated_mlp.cpp) + for Swish definition. + + ![Swish Activation](images/gated-mlp-swish.png) + +4. The last MatMul on the bottom performs the "FC down" operation between the + GLU output and \f$V\f$. + +## Data Types + +oneDNN supports the floating-point Gated-MLP pattern with data types f32, bf16, +and f16. You can specify the data type via the input and output data type fields +of logical tensors for each operation. oneDNN does not support mixing different +floating data types in a floating-point Gated-MLP pattern. + +The definition of the data types and support status on different CPU and GPU +platforms follow the general description in @ref dev_guide_data_types. + +## Implementation limitations + +1. oneDNN primitive-based Gated-MLP is implemented as the reference + implementation on both Intel Architecture Processors and Intel Graphics + Products. In this case, floating-point Gated-MLP patterns are usually + implemented with three f32, bf16, or f16 matmul (with binary or eltwise + post-ops) primitives. +2. The Gated-MLP patterns functionally supports all input shapes meeting the + shape requirements of each operation in the graph. For example, the `MatMul` + operation requires shape consistency for `k` dimension. The `Multiply` + operation requires the input tensors to have the same shape or the shapes can + be properly broadcasted based on the operation attribute. + +## Examples + +oneDNN provides a [Gated-MLP +example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gated_mlp.cpp) +demonstrating how to construct a typical floating-point Gated-MLP pattern with +oneDNN Graph API on CPU and GPU with different runtimes. + +For applications where the weights of FC up and FC gate are combined as a single +tensor, oneDNN also provides an +[example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gated_mlp_wei_combined.cpp) +demonstrating how to create the weight tensors for the pattern with the offsets +and strides from the combined weight tensor. + +## References + +1. Attention is all you need, https://arxiv.org/abs/1706.03762v7 +2. GLU Variants Improve Transformer, https://arxiv.org/abs/2002.05202 +3. LLaMA: Open and Efficient Foundation Language Models, https://arxiv.org/abs/2302.13971 +4. Qwen Technical Report, https://arxiv.org/abs/2309.16609 +5. oneDNN Graph API documentation, https://uxlfoundation.github.io/oneDNN/graph_extension.html diff --git a/doc/graph/fusion_patterns/gqa.md b/doc/graph/fusion_patterns/gqa.md new file mode 100644 index 00000000000..84d846924cb --- /dev/null +++ b/doc/graph/fusion_patterns/gqa.md @@ -0,0 +1,106 @@ +Grouped Query Attention (GQA) {#dev_guide_graph_gqa} +==================================================== + +## Overview + +In a typical Scaled Dot-Product Attention (SDPA) [1], the input Query, Key, and +Value tensors have the same head number. It becomes a performance bottleneck to +load the Key and Value tensors in each generation step, especially when the +sentence length gets longer. + +To reduce the memory bandwidth overhead of loading the Key and Value tensors, +Multi-Query Attention (MQA) [2] is created by reducing the head number of Key +and Value tensors to one which means multiple Queries will map to the same +single Key and Value tensor. However, MQA may lead to model quality degradation +and training instability. Therefore, Grouped-Query Attention (GQA) [3], an +interpolation between the typical SDPA and MQA, is proposed with single Key and +Value head per a subgroup of Query heads. The head number of Key and Value +equals to the group number of Query heads. + +The notations used in the document: + +- N: the mini-batch size. +- H_q: the head number of Query. +- H_kv: the head number of Key or Value. +- N_rep: H_q / H_kv, indicates how many Query heads are mapped to one Key head. +- S: the sequence length. +- D: the size of each head. + +## GQA Pattern + +Similar to how SDPA is supported, the GQA pattern is also defined as a +directional acyclic graph (DAG) using oneDNN Graph API. oneDNN extends the +[SDPA pattern](@ref dev_guide_graph_sdpa) to support floating-point (f32, bf16, +and f16) GQA as follows. The blue nodes are required when defining a GQA pattern +while the brown nodes are optional. + +![GQA pattern](images/gqa.png) + +Compared to a typical SDPA pattern, there are a few differences in the GQA +pattern: + +1. The input Query has shape (N, H_q, S, D). It will be reshaped to (N, H_kv, + N_rep, S, D) by splitting H_q dimension into H_kv and N_rep. The reshaping + can be constructed using the [StaticReshape](@ref dev_guide_op_staticreshape) + operation in Graph API. +2. Similarly, the input Key and Value have shape (N, H_kv, S, D). They will be + reshaped to (N, H_kv, 1, S, D) to meet the input shape requirement of + [MatMul](@ref dev_guide_op_matmul) operation. +3. The second MatMul calculates the dot products between the probabilities after + SoftMax and Value nodes and generates output with shape (N, H_kv, N_rep, S, D). +4. Another StaticReshape operation is applied to the output of the second MatMul + to convert the shape into (N, H_q, S, D) by combining H_kv and N_rep + dimensions. +5. The input scale factor and mask in the pattern also need to meet the + operations' shape requirement which can be achieved through StaticReshape + similarly. Besides that, they have the same definition as described in the + typical SDPA pattern. + +## Data Types + +oneDNN supports the floating-point GQA pattern with data types f32, bf16, and +f16. You can specify the data type via the input and output data type fields of +logical tensors for each operation. oneDNN does not support mixing different +floating data types in a floating-point GQA pattern. + +The definition of the data types and support status on different CPU and GPU +platforms follow the general description in @ref dev_guide_data_types. + +## Implementation Limitations + +1. oneDNN primitive-based GQA is implemented as the reference implementation on + both Intel Architecture Processors and Intel Graphics Products. The reference + implementation requires memory to store the intermediate results of the dot + products between Query and Key which takes \f$O(S^2)\f$ memory. It may lead + to Out-of-Memory error when computing long sequence length input on platforms with + limited memory. +2. The GQA patterns functionally support all input shapes meeting the shape + requirements of each operation in the graph. +3. CPU + - Optimized implementation is available for 4D Q/K/V tensors with shape + defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value. + - Optimized implementation is available for OpenMP runtime and Threadpool + runtime on Intel Architecture Processors. + - Specifically for OpenMP runtime, the optimized implementation requires `N * + H_q > 2 * thread number` to get enough parallelism. +4. GPU + - Optimized implementation is available for 4D Q/K/V tensors with shape + defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value. + - Optimized implementation is available for floating-point GQA with `f16` + data type and `D <= 256` on Intel Graphics Products with Intel(R) Xe Matrix + Extensions (Intel(R) XMX) support. + +## Example + +oneDNN provides a [GQA +example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gqa.cpp) +demonstrating how to construct a floating-point GQA pattern with oneDNN Graph +API on CPU and GPU with different runtimes. + +## References + +[1] Attention is all you need, https://arxiv.org/abs/1706.03762v7 + +[2] Fast Transformer Decoding: One Write-Head is All You Need, https://arxiv.org/abs/1911.02150 + +[3] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints, https://arxiv.org/abs/2305.13245 diff --git a/doc/graph/fusion_patterns/images/compressed_sdpa_pattern.png b/doc/graph/fusion_patterns/images/compressed_sdpa_pattern.png new file mode 100644 index 00000000000..b0563e7fb0e Binary files /dev/null and b/doc/graph/fusion_patterns/images/compressed_sdpa_pattern.png differ diff --git a/doc/graph/fusion_patterns/images/fp-gated-mlp.png b/doc/graph/fusion_patterns/images/fp-gated-mlp.png new file mode 100644 index 00000000000..a52952ce87b Binary files /dev/null and b/doc/graph/fusion_patterns/images/fp-gated-mlp.png differ diff --git a/doc/graph/fusion_patterns/images/gated-mlp-swish.png b/doc/graph/fusion_patterns/images/gated-mlp-swish.png new file mode 100644 index 00000000000..2050ee8d871 Binary files /dev/null and b/doc/graph/fusion_patterns/images/gated-mlp-swish.png differ diff --git a/doc/graph/fusion_patterns/images/gqa.png b/doc/graph/fusion_patterns/images/gqa.png new file mode 100644 index 00000000000..0871903bcda Binary files /dev/null and b/doc/graph/fusion_patterns/images/gqa.png differ diff --git a/doc/graph/images/sdpa-mask-1.png b/doc/graph/fusion_patterns/images/sdpa-mask-1.png similarity index 100% rename from doc/graph/images/sdpa-mask-1.png rename to doc/graph/fusion_patterns/images/sdpa-mask-1.png diff --git a/doc/graph/images/sdpa-mask-2.png b/doc/graph/fusion_patterns/images/sdpa-mask-2.png similarity index 100% rename from doc/graph/images/sdpa-mask-2.png rename to doc/graph/fusion_patterns/images/sdpa-mask-2.png diff --git a/doc/graph/fusion_patterns/images/sdpa-mask-3.png b/doc/graph/fusion_patterns/images/sdpa-mask-3.png new file mode 100644 index 00000000000..339a9589122 Binary files /dev/null and b/doc/graph/fusion_patterns/images/sdpa-mask-3.png differ diff --git a/doc/graph/images/sdpa-reorder.png b/doc/graph/fusion_patterns/images/sdpa-reorder.png similarity index 100% rename from doc/graph/images/sdpa-reorder.png rename to doc/graph/fusion_patterns/images/sdpa-reorder.png diff --git a/doc/graph/fusion_patterns/images/sdpa.png b/doc/graph/fusion_patterns/images/sdpa.png new file mode 100644 index 00000000000..07add3d2afe Binary files /dev/null and b/doc/graph/fusion_patterns/images/sdpa.png differ diff --git a/doc/graph/fusion_patterns/sdpa.md b/doc/graph/fusion_patterns/sdpa.md new file mode 100644 index 00000000000..75528dc56ba --- /dev/null +++ b/doc/graph/fusion_patterns/sdpa.md @@ -0,0 +1,157 @@ +Scaled Dot-Product Attention (SDPA) {#dev_guide_graph_sdpa} +=========================================================== + +## Overview + +Scaled Dot-Product Attention (SDPA) is introduced in [1] as the core operation +of Transformer block which now becomes the backbone of many language models and +generative models (BERT, Stable Diffusion, GPT, etc.). + +The input of SDPA consists of query (Q), key (K), and value (V). The attention +output is computed as: + +\f[ + + attention(Q,K,V) = V \cdot softmax(\frac{QK^T}{\sqrt{d_k}}) + +\f] + +\f$d_k\f$ is the dimension size of K. Other notations used in the document: + +- N: the mini-batch size. +- H: the number of multi-head. +- S: the sequence length. +- D: the size of each head. + +## SDPA patterns + +oneDNN supports SDPA and its optimization through Graph API [2] by defining the +SDPA graph, getting partition from the graph, and optimizing the kernels +underneath. In general, an SDPA pattern is defined as a directional acyclic +graph (DAG) using oneDNN Graph API. + +### Floating-point SDPA + +oneDNN defines floating-point (f32, bf16, or f16) SDPA as follows. The blue +nodes are required when defining an SDPA pattern while the brown parts are +optional. + +![SDPA pattern](images/sdpa.png) + +1. The first MatMul calculates the dot products between Query and Key. See + [MatMul](@ref dev_guide_op_matmul) operation in Graph API. +2. The Scale node is optional and is used to scale the output of the first + MatMul with a scaling factor. It can be constructed by [Multiply](@ref dev_guide_op_multiply) + or [Divide](@ref dev_guide_op_divide) operation in Graph API. The scaling + factor is given by users as an input of SDPA. \f$\sqrt{d_k}\f$ in the formula + is not considered as a part of the SDPA pattern because it is a constant. +3. The Mask node is optional and is used to apply an attention mask to the + output of the previous Scale node. There are two types of masks that can + be applied: + + 1. Explicit user-generated mask: You can explicitly create a mask tensor + and pass it to the library for the computation of SDPA. In this case, mask + can be constructed by [Add](@ref dev_guide_op_add) + or [Select](@ref dev_guide_op_select) operation in Graph API for different + mask policies (for example, causal mask or padding mask). When the + Add operation is used to apply the mask, the input mask is usually an upper + triangular matrix with all the elements above the diagonal filled with + `-inf` and zeroes elsewhere. The `-inf` entries will become zero probability + after Softmax is applied in the next step. + Alternatively, a Select operation may be used. In this case, the + input is a boolean tensor (for example, with the boolean value set to `true` + on and below the diagonal, and `false` above the diagonal). + A `false` element in the mask forces the corresponding element of the scaled + output to `-inf`, while a `true` element leaves it unchanged. + + ![SDPA-mask-1](images/sdpa-mask-1.png) ![SDPA-mask-2](images/sdpa-mask-2.png) + + 2. Implicit library-generated mask: You can use the operations in the library + to generate a mask by constructing a subgraph. Currently, Graph API supports + generating an implicit causal mask (top-left aligned) using operations of + [GenIndex](@ref dev_guide_op_genindex), [GreaterEqual](@ref dev_guide_op_greaterequal) + and [Select](@ref dev_guide_op_select). + + ![SDPA-mask-3](images/sdpa-mask-3.png) + +4. The SoftMax operation takes the masked output and transforms it into + probabilities between 0 and 1. See [SoftMax](@ref dev_guide_op_softmax) + operation in Graph API. +5. The second MatMul calculates the dot products between the probabilities after + SoftMax and Value. +6. The Reorder node is optional and used to reshape or transpose the attention + output for cases where the attention output is transformed from shape (N, H, + S, D) to (N, S, H, D) or (N, S, H * D). The node can be constructed by the + combinations of [StaticTranspose](@ref dev_guide_op_statictranspose) and + [StaticReshape](@ref dev_guide_op_staticreshape) operation in Graph API. + + ![SDPA-Reorder](images/sdpa-reorder.png) + + +## Data Types + +oneDNN supports the floating-point SDPA pattern with data types f32, bf16, and +f16. You can specify the data type via the input and output logical tensors' +data type fields for each operation. + +oneDNN supports bf16 or f16 SDPA with f32 intermediate type, which means the +Q/K/V tensors have bf16 or f16 data type while the output of the first MatMul, +Scale, Mask, and the input of SoftMax are in f32 data type. + +oneDNN supports the quantized SDPA pattern with int8-f32 mixed precision, +int8-bf16 mixed precision, and int8-f16 mixed precision data types. + +The definition of the data types and support status on different CPU and GPU +platforms follow the general description in @ref dev_guide_data_types. + +## Implementation limitations + +1. oneDNN primitive-based SDPA is implemented as the reference implementation on + both Intel Architecture Processors and Intel Graphics Products. In this case, + floating-point SDPA patterns are usually implemented with f32, bf16, or f16 + matmul (with post-ops) and softmax primitives, while quantized SDPA patterns + are implemented with int8 matmul (with post-ops) and f32, bf16, or f16 + softmax primitives. The reference implementation requires memory to store the + intermediate results of the dot products between Query and Key which takes + \f$O(S^2)\f$ memory. It may lead to out-of-memory error when computing long + sequence length input on platforms with limited memory. For an implicit + causal mask, the reference implementation is only available on CPU. +2. The SDPA patterns functionally supports all input shapes meeting the shape + requirements of each operation in the graph. For example, Add, Multiply, + Divide, and Select operations require the input tensors to have the same + shape or the shapes can be properly broadcasted based on the operation + attribute. +3. CPU + - Optimized implementation is available for 4D Q/K/V tensors with shape + defined as (N, H, S, D). + - Optimized implementation is available for OpenMP runtime and Threadpool + runtime on Intel Architecture Processors. + - Specifically for OpenMP runtime, the optimized implementation requires `N * + H > 2 * thread number` to get enough parallelism. +4. GPU + - Optimized implementation is available for 4D Q/K/V tensors with shape + defined as (N, H, S, D). + - Optimized implementation is available for `f16` or `bf16` SDPA with `f32` + intermediate data type and `D <= 256` on Intel Graphics Products with + Intel(R) Xe Matrix Extensions (Intel(R) XMX) support. + +## Example + +oneDNN provides an [SDPA +example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/sdpa.cpp) +demonstrating how to construct a typical floating-point SDPA pattern with oneDNN +Graph API on CPU and GPU with different runtimes. + +oneDNN also provides a [MQA (Multi-Query Attention) +example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/mqa.cpp) [3] +demonstrating how to construct a floating-point MQA pattern with the same +pattern structure as in the SDPA example but different head number in Key and +Value tensors. In MQA, the head number of Key and Value is always one. + +## References + +[1] Attention is all you need, https://arxiv.org/abs/1706.03762v7 + +[2] oneDNN Graph API documentation, https://uxlfoundation.github.io/oneDNN/graph_extension.html + +[3] Fast Transformer Decoding: One Write-Head is All You Need, https://arxiv.org/abs/1911.02150 diff --git a/doc/graph/fusion_patterns/sdpa_with_compressed_kv.md b/doc/graph/fusion_patterns/sdpa_with_compressed_kv.md new file mode 100644 index 00000000000..e7a55ef571c --- /dev/null +++ b/doc/graph/fusion_patterns/sdpa_with_compressed_kv.md @@ -0,0 +1,119 @@ +SDPA with Compressed Key and Value {#dev_guide_graph_sdpa_compressed_kv} +======================================================================== + +## Overview + +int4 and int8 compressions for Key and Value are exploited in fused Scaled +Dot-Product Attention (SDPA)[1] to reduce the memory footprint of generative +inference of LLM, especially when KV cache mechanism is adopted. Specifically, +Key and Value tensors are stored using lower precision data types like int4 and +int8 to reduce memory usage, and are subsequently de-quantized to wider floating +point data types such as f16 and bf16 for computation. + +Note that grouped quantization is required to improve the model accuracy, +especially for int4 data types. In this case, group size is needed as an +attribute for quantization, which indicates the number of elements that share +the same scaling factor and zero-points in each quantization group. + +The notations used in this topic are: + +- N: The mini-batch size. +- H: The head number. +- S: The sequence length. +- D: The size of each head. +- G: The group size. + +## SDPA Pattern + +The SDPA pattern with compressed Key and Value is defined as a directional +acyclic graph (DAG) using oneDNN Graph API. oneDNN extends +[SDPA pattern](@ref dev_guide_graph_sdpa) to support the following three kinds +of compressed SDPA patterns: + +1. SDPA with compressed Key and Value. +2. SDPA with floating-point Key and compressed Value. +3. SDPA with compressed Key and floating-point Value. + +The floating-point data types include f32, f16 and bf16, and the compressed +data type refers to low-precision integral data types, including int4 (u4/s4) +and int8 (u8/s8) data types. + +In oneDNN Graph API, we support quantization through a pattern with quantization +operations such as [DynamicDequantize](@ref dev_guide_op_dynamicdequantize) and +[DynamicQuantize](@ref dev_guide_op_dynamicquantize). The supported pattern is +as follows. The blue nodes are required while the brown nodes are optional. + +![compressed SDPA pattern](images/compressed_sdpa_pattern.png) + +Compared to a typical SDPA pattern, there are a few differences: + +1. Two additional DynamicDequantize operations are applied to the input Key and +Value to convert the integral values to floating-point values. +2. Apart from the Query, Key and Value inputs, the pattern requires additional +quantization information such as scale and zero-points for the dequantization of +Key and Value tensors. Currently, oneDNN only supports grouped quantization +on one dimension; specifically, the shapes of scale and zero-points for Key and +Value de-quantization should be (N, H, S, D/G). +3. Additionally, the `group_shape` attribute of the quantization operations must +be specified as (1, 1, 1, G) for Key and Value dequantization. + +## Data Types + +oneDNN supports the following combinations of data types for Query, Key, Value, +output, scale for Key, zero-points for Key, scale for Value and zero-points for +Value: + +| Query | Key | Scale_K | Zp_K | Value | Scale_V | Zp_V | Output | +|:--------|:--------|:--------|:----------------|:-------|:--------|:----------------|:-------| +| dt_fp | dt_int | dt_fp | u4,s4,u8,s8,s32 | dt_int | dt_fp | u4,s4,u8,s8,s32 | dt_fp | +| dt_fp | dt_int | dt_fp | u4,s4,u8,s8,s32 | dt_fp | N/A | N/A | dt_fp | +| dt_fp | dt_fp | N/A | N/A | dt_int | dt_fp | u4,s4,u8,s8,s32 | dt_fp | + +Notes: +- dt_fp can be: f16, bf16 or f32. +- dt_int can be: u8, s8, u4 or s4. +- zero-point inputs are optional. + +You can specify the data type via the input and output data type fields of +logical tensors for each operation. The definition of the data types and support +status on different CPU and GPU platforms follow the general description in +@ref dev_guide_data_types. + +### Floating-point Math Mode + +You should set the floating-point math mode +(@ref dev_guide_attributes_fpmath_mode) when using SDPA with compressed Key and +Value. Generally, the math mode should align with the data type of the Query, +which indicates the computation data type. Additionally, the second boolean +flag, `apply_to_int`, should be set to true. You can configure these attribute +values using the `set_fpmath_mode` API +(@ref dnnl::graph::graph::set_fpmath_mode) on the graph object. + +## Implementation Limitations + +- oneDNN primitive-based SDPA with compressed Key and Value is implemented as +a reference implementation on both Intel Architecture Processors and Intel +Graphics Products. The reference implementation requires memory to store the +intermediate results of the dot products between Query and Key which takes +\f$O(S^2)\f$ memory. It may lead to Out-of-Memory error when computing long +sequence length inputs on platforms with limited memory. +- The compressed SDPA patterns functionally support all input shapes meeting +the shape requirements of each operation in the graph. +- CPU + - oneDNN does not provide optimized implementation on CPU currently. All + executions will be implemented with the primitive-based reference + computation. +- GPU + - Optimized implementation is available for 4D Q/K/V tensors with the shape + defined as (N, H, S, D) for Query and Value, (N, H, D, S) for Key, + (N, H, D/G, S) for scales and zero-points of Key (if available) and + (N, H, S, D/G) for scales and zero-points of Value (if available). + - Optimized implementation is available for compressed SDPA with `f16` + computation data type on Intel Graphics Products with Intel(R) Xe Matrix + Extensions (Intel(R) XMX) support. + - If int4 zero-points are specified, optimized implementation will be only + available when the group size equals 16. + +## References + +[1] Attention is all you need, https://arxiv.org/abs/1706.03762v7 diff --git a/doc/graph/images/sdpa.png b/doc/graph/images/sdpa.png deleted file mode 100644 index 87f4443bf49..00000000000 Binary files a/doc/graph/images/sdpa.png and /dev/null differ diff --git a/doc/graph/operations/Add.md b/doc/graph/operations/Add.md index 6f5342b382c..5fef1ab7d7e 100644 --- a/doc/graph/operations/Add.md +++ b/doc/graph/operations/Add.md @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is Add operation supports the following data type combinations. -| Src_0 / Src_1 | Dst | -|:--------------|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src_0 | Src_1 | Dst | +|:----------|:----------|:-----| +| f32 | f32 | f32 | +| bf16 | bf16 | bf16 | +| f16 | f16 | f16 | +| f32 | bf16, f16 | f32 | +| bf16, f16 | f32 | f32 | diff --git a/doc/graph/operations/Divide.md b/doc/graph/operations/Divide.md index 8c4ab535544..11689c9b7eb 100644 --- a/doc/graph/operations/Divide.md +++ b/doc/graph/operations/Divide.md @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is Divide operation supports the following data type combinations. -| Src_0 / Src_1 | Dst | -|:--------------|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src_0 | Src_1 | Dst | +|:----------|:----------|:-----| +| f32 | f32 | f32 | +| bf16 | bf16 | bf16 | +| f16 | f16 | f16 | +| f32 | bf16, f16 | f32 | +| bf16, f16 | f32 | f32 | diff --git a/doc/graph/operations/DynamicDequantize.md b/doc/graph/operations/DynamicDequantize.md index 9e730f1fb54..46aba4667f9 100644 --- a/doc/graph/operations/DynamicDequantize.md +++ b/doc/graph/operations/DynamicDequantize.md @@ -3,11 +3,11 @@ DynamicDequantize {#dev_guide_op_dynamicdequantize} ## General -DynamicDequantize operation converts a quantized (s8 or u8) tensor to a f32 -tensor. It supports both per-tensor and per-channel asymmetric linear -de-quantization. Rounding mode is library-implementation defined. Unlike the -@ref dev_guide_op_dequantize, DynamicDequantize takes scales and zero-points as -operator src tensors. +The Dynamic Dequantize operation converts a quantized (s4, u4, s8, or u8) tensor +to an bf16, f16 or f32 tensor. It supports per-tensor, per-channel, and per-group asymmetric +linear de-quantization. The rounding mode is defined by the library +implementation. Unlike the @ref dev_guide_op_dequantize, Dynamic Dequantize takes +scales and zero-points as operator src tensors. For per-tensor de-quantization @@ -16,12 +16,23 @@ For per-tensor de-quantization For per-channel de-quantization, taking channel axis = 1 as an example: \f[ {dst}_{\cdots,i,\cdots,\cdots} = (src_{\cdots,i,\cdots,\cdots} - zps_i)*scales_i,i\in [0,channelNum-1] \f] +For per-group de-quantization, let's take group shape = Gx1 as an example. It +indicates that one scaling factor will de adopted for G elements in the src +tensor. On the dimensions where group quantization is adopted, make channelNum +equal to the dimension of src and groupNum equal to channelNum/group size: + \f[ {dst}_{i,\cdots} = (src_{i,\cdots} - zps_j)*scales_j,i\in [0,channelNum-1],j\in [0,groupNum-1] \f] +Where: + \f[ i = j*groupSize+k,k\in [0,groupSize-1] \f] +On other dimensions: + \f[ {dst}_{i,\cdots} = (src_{i,\cdots} - zps_i)*scales_i,i\in [0,channelNum-1] \f] + ## Operation attributes | Attribute Name | Description | Value Type | Supported Values | Required or Optional | |:-------------------------------------------|:---------------------------------------------------------------------|:-----------|:------------------------------------------------------------------------------------------------------------------------------------------------|:---------------------| | [qtype](@ref dnnl::graph::op::attr::qtype) | Specifies which de-quantization type is used. | string | `per_tensor` (default), `per_channel` | Optional | -| [axis](@ref dnnl::graph::op::attr::axis) | Specifies dimension on which per-channel de-quantization is applied. | s64 | A s64 value in the range of [-r, r-1] where r = rank(src), `1` by default. Negative value means counting the dimension backwards from the end. | Optional | +| [axis](@ref dnnl::graph::op::attr::axis) | Specifies dimension on which per-channel de-quantization is applied. | s64 | An s64 value in the range of [-r, r-1] where r = rank(src), `1` by default. Negative values mean counting the dimension backwards from the end. | Optional | +| [group_shape](@ref dnnl::graph::op::attr::group_shape) | Specifies the group shape of an operation. | s64 | An s64 list indicates the group size on the dimensions where grouped quantization is adopted. | Optional | ## Execution arguments @@ -36,15 +47,23 @@ constructing an operation. | 1 | `scales` | Required | | 2 | `zps` | Optional | -@note `scales` is a f32 1D tensor to be applied to the de-quantization formula. -For `qtype` = `per-tensor`, there should be only one element in the scales -tensor. For `qtype` = `per-channel`, the element number should be equal to the -element number of src tensor along the dimension axis. - -@note `zps` is a 1D tensor with offset values that map to zero. For `qtype` = -`per-tensor`, there should be only one element in the zps tensor. For `qtype` = +@note `scales` is a bf16/f16/f32 tensor to be applied to the de-quantization +formula. For `qtype` = `per-tensor`, there should be only one element in the +`scales` tensor. For `qtype` = `per-channel`, the element number should be equal +to the element number of the src tensor along the dimension axis. For +`qtype` = `per-gropup`, the `scale` tensor should have the same number of +dimension as the `src` tensor. On the dimensions where grouped quantization is +applied, the dimension should be the number of groups, which equals to +`src_dim` / `group_size`, while other dimensions should match the `src` tensor. + +@note `zps` is a tensor with offset values that map to zero. For `qtype` = +`per-tensor`, there should be only one element in the `zps` tensor. For `qtype` = `per-channel`, the element number should be equal to the element number of input -tensor along the dimension axis. If omitted, zps values are assumed to be zero. +tensor along the dimension axis. For `qtype` = `per-group`, the `zps` tensor +should have the same number of dimensions as the `src` tensor. On the dimensions +where grouped quantization is applied, the dimension should be the number of +groups, which equals to `src_dim` / `group_size`, while other dimensions should +match the `src` tensor. If omitted, the `zps` values are assumed to be zero. ### Outputs @@ -58,5 +77,9 @@ DynamicDequantize operation supports the following data type combinations. | Src | Dst | Scales | Zps | |:-- -|:----|:-------|:------------| -| s8 | f32 | f32 | s8, u8, s32 | -| u8 | f32 | f32 | s8, u8, s32 | +| s8 | f16, bf16, f32 | f16, bf16, f32 | s8, u8, s32 | +| u8 | f16, bf16, f32 | f16, bf16, f32 | s8, u8, s32 | +| s4 | f16, bf16, f32 | f16, bf16, f32 | s4, u4, s32 | +| u4 | f16, bf16, f32 | f16, bf16, f32 | s4, u4, s32 | + +It's expected that the data types of scales and dst should be the same. diff --git a/doc/graph/operations/GenIndex.md b/doc/graph/operations/GenIndex.md new file mode 100644 index 00000000000..ff3306633dc --- /dev/null +++ b/doc/graph/operations/GenIndex.md @@ -0,0 +1,39 @@ +GenIndex{#dev_guide_op_genindex} +================================ + +## General + +The GenIndex operation creates an index tensor along a specified axis of +an input tensor. The resulting index tensor has the same shape as the +input tensor, with each element representing the index along the +specified axis. + +## Operation Attributes + +| Attribute Name | Description | Value Type | Supported Values | Required or Optional | +|:------------------------------------------|:----------------------------------------------------------------|:-----------|:-----------------------------------------------------------|:---------------------| +| [axis] (@ref dnnl::graph::op::attr::axis) | Specifies the dimension along which index values are generated. | s64 | An s64 value in the range of [-r, r-1] where r = rank(src) | Required | + +## Execution Arguments + +### Input + +| Index | Argument Name | Required or Optional | +|:------|:--------------|:---------------------| +| 0 | `src` | Required | + +### Output + +| Index | Argument Name | Required or Optional | +|:------|:--------------|:---------------------| +| 0 | `dst` | Required | + +## Supported Data Types + +The GenIndex operation supports the following data type combinations. + +| Src | Dst | +|:-------|:-------| +| f32 | s32 | +| bf16 | s32 | +| f16 | s32 | diff --git a/doc/graph/operations/GreaterEqual.md b/doc/graph/operations/GreaterEqual.md new file mode 100644 index 00000000000..912b066b45c --- /dev/null +++ b/doc/graph/operations/GreaterEqual.md @@ -0,0 +1,49 @@ +GreaterEqual{#dev_guide_op_greaterequal} +======================================== + +## General + +The GreaterEqual operation performs an element-wise greater-than-or-equal +comparison between two given tensors. This operation applies +the multi-directional broadcast rules to ensure compatibility between +the tensors of different shapes. + +\f[ dst = \begin{cases} true & \text{if}\ src_0 \ge src_1 \\ + false & \text{if}\ src_0 < src_1 \end{cases} \f] + +## Operation Attributes + +| Attribute Name | Description | Value Type | Supported Values | Required or Optional | +|:-------------------------------------------------------------|:-----------------------------------------------------------|:-----------|:-------------------------|:---------------------| +| [auto_broadcast](@ref dnnl::graph::op::attr::auto_broadcast) | Specifies rules used for auto-broadcasting of src tensors. | string | `none`,`numpy` (default) | Optional | + +## Execution Arguments + +### Input + +| Index | Argument Name | Required or Optional | +|:------|:--------------|:---------------------| +| 0 | `src_0` | Required | +| 1 | `src_1` | Required | + +@note Both src shapes should match and no auto-broadcasting is allowed if +the `auto_broadcast` attribute is `none`. `src_0` and `src_1` shapes can be +different and auto-broadcasting is allowed if the `auto_broadcast` attribute +is `numpy`. Broadcasting is performed according to the `auto_broadcast` value. + +### Output + +| Index | Argument Name | Required or Optional | +|:------|:--------------|:---------------------| +| 0 | `dst` | Required | + +## Supported Data Types + +The GreaterEqual operation supports the following data type combinations. + +| Src_0 / Src_1 | Dst | +|:--------------|:---------| +| f32 | boolean | +| bf16 | boolean | +| f16 | boolean | +| s32 | boolean | diff --git a/doc/graph/operations/MatMul.md b/doc/graph/operations/MatMul.md index d2b4cc89b0f..7879393969a 100644 --- a/doc/graph/operations/MatMul.md +++ b/doc/graph/operations/MatMul.md @@ -61,8 +61,8 @@ constructing an operation. MatMul operation supports the following data type combinations. -| Src | Weights | Bias | Dst | -|:-----|:--------|:-----|:-----| -| f32 | f32 | f32 | f32 | -| bf16 | bf16 | bf16 | bf16 | -| f16 | f16 | f16 | f16 | +| Src | Weights | Bias | Dst | +|:-----|:--------|:-----|:----------| +| f32 | f32 | f32 | f32 | +| bf16 | bf16 | bf16 | f32, bf16 | +| f16 | f16 | f16 | f32, f16 | diff --git a/doc/graph/operations/Multiply.md b/doc/graph/operations/Multiply.md index 625bfea10d2..24e09881e10 100644 --- a/doc/graph/operations/Multiply.md +++ b/doc/graph/operations/Multiply.md @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is Multiply operation supports the following data type combinations. -| Src_0 / Src_1 | Dst | -|:--------------|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src_0 | Src_1 | Dst | +|:----------|:----------|:-----| +| f32 | f32 | f32 | +| bf16 | bf16 | bf16 | +| f16 | f16 | f16 | +| f32 | bf16, f16 | f32 | +| bf16, f16 | f32 | f32 | diff --git a/doc/graph/operations/Softmax.md b/doc/graph/operations/Softmax.md index 6655eb218d6..467634b1d05 100644 --- a/doc/graph/operations/Softmax.md +++ b/doc/graph/operations/Softmax.md @@ -36,8 +36,8 @@ constructing an operation. SoftMax operation supports the following data type combinations. -| Src | Dst | -|:-----|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src | Dst | +|:-----|:----------------| +| f32 | f32, bf16, f16 | +| bf16 | bf16 | +| f16 | f16 | diff --git a/doc/graph/operations/Subtract.md b/doc/graph/operations/Subtract.md index 28138271a5a..bca45816cc8 100644 --- a/doc/graph/operations/Subtract.md +++ b/doc/graph/operations/Subtract.md @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is Subtract operation supports the following data type combinations. -| Src_0 / Src_1 | Dst | -|:--------------|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src_0 | Src_1 | Dst | +|:----------|:----------|:-----| +| f32 | f32 | f32 | +| bf16 | bf16 | bf16 | +| f16 | f16 | f16 | +| f32 | bf16, f16 | f32 | +| bf16, f16 | f32 | f32 | diff --git a/doc/graph/programming_model/graph_basic_concepts.md b/doc/graph/programming_model/graph_basic_concepts.md index 5ee5eaa558b..b2f75349d1f 100644 --- a/doc/graph/programming_model/graph_basic_concepts.md +++ b/doc/graph/programming_model/graph_basic_concepts.md @@ -41,13 +41,19 @@ tensor as the edge between them. ## Graph `Graph` (@ref dnnl::graph::graph) contains a set of operations. A graph object -is associated to a specific engine kind (@ref dnnl::engine::kind). Multiple -operations can be added (@ref dnnl::graph::graph::add_op) along with input and -output logical tensors to a graph. After finishing adding operations, -finalization API (@ref dnnl::graph::graph::finalize) can be called to indicate -that the graph is ready for partitioning. By calling partitioning API (@ref -dnnl::graph::graph::get_partitions), a group of partitions from the graph will -be returned . +is associated to a specific engine kind (@ref dnnl::engine::kind). In addition, +you can set the graph-level floating-point math mode through the setter API +( @ref dnnl::graph::graph::set_fpmath_mode ) or in the constructor. The API +accepts two paramters, the given floating point math mode and a optional boolean +flag to indicate whether to use floating-point arithmetic for integral +operations. + +Multiple operations can be added (@ref dnnl::graph::graph::add_op) along with +input and output logical tensors to a graph. After finishing adding the +operations, finalization API (@ref dnnl::graph::graph::finalize) can be called +to indicate that the graph is ready for partitioning. By calling partitioning +API (@ref dnnl::graph::graph::get_partitions), a group of partitions from the +graph will be returned. ## Partition diff --git a/doc/graph/programming_model/low_precision.md b/doc/graph/programming_model/low_precision.md index 35118771b96..83b7744ba25 100644 --- a/doc/graph/programming_model/low_precision.md +++ b/doc/graph/programming_model/low_precision.md @@ -52,7 +52,6 @@ Graph operations support bf16 and f16 data types. A TypeCast operation performing down conversion should be inserted clearly to indicate the use of low numeric precision. oneDNN Graph implementation fully -honors the API-specified numeric precision and only performs the computation -using the API-specified or higher numeric precision. +honors the API-specified numeric precision. @img{bf16_programming.jpg,Figure 2: Overview of bf16 programming model.,80%,} diff --git a/doc/graph/rst/graph_fusion_patterns.rst b/doc/graph/rst/graph_fusion_patterns.rst new file mode 100644 index 00000000000..ce4bca97f7f --- /dev/null +++ b/doc/graph/rst/graph_fusion_patterns.rst @@ -0,0 +1,171 @@ +Fusion Patterns +############### + +.. default-role:: math +.. toctree:: + :maxdepth: 1 + :hidden: + + dev_guide_graph_gated_mlp + dev_guide_graph_gqa + dev_guide_graph_sdpa_compressed_kv + dev_guide_graph_sdpa + + +The following fusion patterns are subgraphs that the oneDNN Graph API +recognizes as candidates for fusion. The patterns are described using +oneDNN Graph operation (op) names with the following convention. + +.. note:: + oneDNN Graph performs limited input validation to minimize + the performance overheads. The application is responsible for + sanitizing inputs passed to the library. Because large ``u8`` or + ``s8`` inputs may lead to accumulator overflow, you can use + floating-point patterns instead of quantized patterns. + +``"+"`` describes a chain of two ops. The preceding op produces an +output tensor, which is consumed by the following op as its first +operand. + +``"[]"`` describes a component of the overall pattern description. For +example, it could include a subgraph or all the op choices within the +bracket. + +``"|"`` describes choices of multiple operations, say A+[B|C] means the +graph partition contains A followed by B or C. + +``","`` describes a graph composed of multiple subgraphs, each subgraph +marks its output tensor explicitly, which is consumed by other +subgraphs. + +``Superscript`` denotes the numbers of repetition pattern. For example, +A+[B|C] `^{3}` means the graph partition +contains A followed by three ops, each of them is either B or C. The +superscript could be a range of number meaning allowing a range of +repetition. If the range is between 0 and 1, we use superscript ``"?"``. + +``Subscript`` denotes the input and output tensors which need to +explicitly mark the producer and consumer relation within one graph +partition. For example, +A `_{>t1}` +B+C `_{"`` refers to the output tensor, and ``"<"`` for input tensor. Input +and output tensors between neighbor ops are not explicitly marked, for +example, B consumes t1 implicitly in the example above. + +Subscript ``"out"`` marks the output tensor of a certain op to be the +output of a graph partition. For example, in +A `_{>t1}` +B `_{>out}`\ +C `_{out}`, +B's output and C's output are marked as output tensors. + +Subscript ``"in"`` marks the input tensor of a certain op to be the +input of a graph partition. For example, in +A `_{t1}`\ +B+C `_{`_ for more details. + * - Grouped Query Attention + - Refer to `Grouped Query Attention (GQA) `_ for more details. + * - Scaled Dot-Product Attention with Compressed Key/Value + - Refer to `Scaled Dot-Product Attention with Compressed Key/Value `_ for more details. + * - Gated Multi-Layer Perceptron (Gated-MLP) + - Refer to `Gated Multi-Layer Perceptron (Gated-MLP) `_ for more details. + * - Convolution + BiasAdd `^?` + BatchNormInference `^?` + [Unary \| Binary] `^{0-3}` `_{>out}` + - This pattern is widely used in Convolution Neural Networks, for example ResNet, ResNext, SSD, etc. + * - ConvTranspose + BiasAdd `^?` + [Unary \| Binary] `^{0-3}` `_{>out}` + - This pattern is widely used in Generative Adversarial Networks. + * - Interpolate + [Unary \| Binary] `^{0-3}` `_{>out}` + - This pattern is widely used for image processing. + * - MatMul + BiasAdd `^?` + [Unary \| Binary] `^{0-3}` + Select `^?` `_{>out}` + - This pattern is widely used in language models and recommendation models, for example BERT, DLRM, etc. + * - Reduction + [Unary \| Binary] `^{0-3}` `_{>out}` + - This pattern is widely used for data processing, for example loss reduction. + * - Unary + Binary `^{0-3}` `_{>out}` + - This pattern is widely used in Convolution Neural Networks. + * - Binary + [Unary \| Binary] `^{0-3}` `_{>out}` + - This pattern is widely used in language models and recommendation models, for example BERT, DLRM, etc. + * - [AvgPool \| MaxPool] + Binary `^{0-3}` `_{>out}` + - This pattern is widely used in Convolution Neural Networks. + * - BatchNormInference + ReLU `_{>out}` + - This pattern is widely used in Convolution Neural Networks, for example DenseNet. + * - Reciprocal + Multiply `_{>out}` + - N/A + * - Reorder + Add `_{>out}` + - N/A + + + + +Quantized Patterns +^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :widths: 75 25 + :header-rows: 1 + + * - Pattern + - Description + * - Quantize `^?` + Dequantize `_{>t1}`, Dequantize `_{>t2}` `^{0-3}`, Dequantize + Convolution `_{out}` + - N/A + * - Quantize `^?` + Dequantize `_{>t1}`, Dequantize `_{>t2}` `^{0-3}`, Dequantize + ConvTranspose `_{out}` + - N/A + * - Quantize `^?` + Dequantize `_{>t1}`, Dequantize `_{>t2}` `^{0-3}`, Dequantize + MatMul `_{out}` + - N/A + * - Dequantize + [AvgPool \| MaxPool] + Quantize `_{>out}`` + - N/A + * - Dequantize `_{>t1}`, Dequantize + [AvgPool \| MaxPool] + Add `_{out}` + - N/A + * - Dequantize + Reorder + Quantize `_{>out}` + - This pattern is widely used in Generative Adversarial Networks. + * - Dequantize `_{>t1}`, Dequantize + Reorder + Add `_{out}` + - This pattern is widely used for image processing. + * - [SoftMax \| LayerNorm \| GroupNorm] + [Unary \| Binary `_{out}` + - This pattern is used in SmoothQuant to fuse scales and quantization into previous layers. + +Training +~~~~~~~~ + +.. list-table:: + :widths: 75 25 + :header-rows: 1 + + * - Pattern + - Description + * - ConvolutionBackwardWeights + BiasAddBackward `_{>out}` + - N/A + * - ReLUBackward + BatchNormTrainingBackward `_{>out}` + - N/A diff --git a/doc/graph/sdpa.md b/doc/graph/sdpa.md deleted file mode 100644 index 1b0864a5c76..00000000000 --- a/doc/graph/sdpa.md +++ /dev/null @@ -1,128 +0,0 @@ -Scaled Dot-Product Attention (SDPA) {#dev_guide_graph_sdpa} -=========================================================== - -## Background - -Scaled Dot-Product Attention (SDPA) was introduced in [1] as the core operation -of Transformer block which now becomes the backbone of many language models and -generative models (BERT, Stable Diffusion, GPT, etc.). - -The input of SDPA consists of query (Q), key (K), and value (V). The attention -output is computed as: - -\f[ - - attention(Q,K,V) = V \cdot softmax(\frac{QK^T}{\sqrt{d_k}}) - -\f] - -\f$d_k\f$ is the dimension size of K. Other notations used in the document: - -- N: the mini-batch size. -- H: the number of multi-head. -- S: the sequence length. -- D: the size of each head. - -## SDPA patterns - -oneDNN supports SDPA and its optimization through Graph API [2] by defining the -SDPA graph, getting partition from the graph, and optimizing the kernels -underneath. In general, an SDPA pattern is defined as a directional acyclic -graph (DAG) using oneDNN Graph API. - -### Floating point SDPA - -oneDNN defines floating point (f32, bf16, or f16) SDPA as follows. The blue -nodes are required when defining an SDPA pattern while the brown parts are -optional. - -![SDPA pattern](images/sdpa.png) - -1. The first MatMul calculates the dot products between Query and Key. See - [MatMul](@ref dev_guide_op_matmul) operation in Graph API. -2. The Scale node is optional and is used to scale the output of the first - MatMul with a scaling factor. It can be constructed by [Multiply](@ref dev_guide_op_multiply) - or [Divide](@ref dev_guide_op_divide) operation in Graph API. The scaling - factor is given by users as an input of SDPA. \f$\sqrt{d_k}\f$ in the formula - is not considered as part of the SDPA pattern as it is constant. -3. The Mask node is optional and is used to apply an attention mask to the - output of the previous Scale node. It can be constructed by [Add](@ref dev_guide_op_add) - or [Select](@ref dev_guide_op_select) operation in Graph API for different - mask policies (eg. causal mask or padding mask). When Add operation is used - to apply the mask, the input mask is usually an upper triangular matrix with - all the elements above the diagonal filled with `-inf` and zeroes elsewhere. - The `-inf` entries will become zero probability after Softmax is applied in - the next step. Alternately, a Select operation may be used. In this case, the - input is a boolean tensor (for example, with `true` on and below the - diagonal, and `false` above the diagonal). A `false` element in the mask - forces the corresponding element of the scaled output to `-inf`, while a - `true` element leaves it unchanged. - - ![SDPA-mask-1](images/sdpa-mask-1.png) ![SDPA-mask-2](images/sdpa-mask-2.png) - -4. The SoftMax operation takes the masked output and transforms it into - probabilities between 0 and 1. See [SoftMax](@ref dev_guide_op_softmax) - operation in Graph API. -5. The second MatMul calculates the dot products between the probabilities after - SoftMax and Value. -6. The Reorder node is optional and used to reshape or transpose the attention - output for cases where the attention output is transformed from shape (N, H, - S, D) to (N, S, H, D) or (N, S, H * D). The node can be constructed by the - combinations of [StaticTranspose](@ref dev_guide_op_statictranspose) and - [StaticReshape](@ref dev_guide_op_staticreshape) operation in Graph API. - - ![SDPA-Reorder](images/sdpa-reorder.png) - - -## Data types - -oneDNN supports the floating point SDPA pattern with data types f32, bf16, and -f16. oneDNN users can specify the data type via the input and output logical -tensors' data type fields for each operation. oneDNN does not support mixing -different floating data types in a floating point SDPA pattern. - -oneDNN supports the quantized SDPA pattern with int8-f32 mixed precision, -int8-bf16 mixed precision, and int8-f16 mixed precision data types. - -The definition of the data types and support status on different CPU and GPU -platforms follow the general description in @ref dev_guide_data_types. - -## Implementation limitations - -1. oneDNN primitive-based SDPA is implemented as the reference implementation on - both Intel Architecture Processors and Intel Graphics Products. In this case, - floating point SDPA patterns are usually implemented with f32/bf16/f16 matmul - (with post-ops) and softmax primitives, while quantized SDPA patterns are - implemented with int8 matmul (with post-ops) and f32/bf16/f16 softmax - primitives. -2. The SDPA patterns functionally supports all input shapes meeting the shape - requirements of each operation in the graph. For example, Add, Multiply, - Divide, and Select operations require the input tensors to have the same - shape or the shapes can be properly broadcasted based on the operation - attribute. -3. CPU - - Optimized implementation is available for 4D Q/K/V tensors with shape - defined as (N, H, S, D). - - Optimized implementation is available for OpenMP runtime and Threadpool - runtime on Intel Architecture Processors. - - Specifically for OpenMP runtime, the optimized implementation requires `N * - H > 2 * thread number` to get enough parallelism. -4. GPU - - Optimized implementation is available for 4D Q/K/V tensors with shape - defined as (N, H, S, D). - - Optimized implementation is available for floating point SDPA with `f16` - data type and `D <= 256` on Intel Graphics Products with Intel(R) Xe Matrix - Extensions (Intel(R) XMX) support. - -## Example - -oneDNN provides an [SDPA -example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/sdpa.cpp) -demonstrating how to construct a typical floating point SDPA pattern with oneDNN -Graph API on CPU and GPU with different runtimes. - -## References - -[1] Attention is all you need, https://arxiv.org/abs/1706.03762v7 - -[2] oneDNN Graph API documentation, https://oneapi-src.github.io/oneDNN/graph_extension.html diff --git a/doc/graph/supported_patterns.md b/doc/graph/supported_patterns.md deleted file mode 100644 index 6118a088929..00000000000 --- a/doc/graph/supported_patterns.md +++ /dev/null @@ -1,159 +0,0 @@ -Supported Fusion Patterns {#dev_guide_graph_fusion_patterns} -============================================================ - -@anchor fusion_patterns -## Fusion Patterns - -The following fusion patterns are subgraphs that the oneDNN Graph API recognizes -as candidate for fusion. The patterns are described using oneDNN Graph -operation (op) names with the following convention. - -@note oneDNN Graph performs limited input validation to minimize the performance -overheads. The application is responsible for sanitizing inputs passed to the -library. For large u8 or s8 inputs may lead to accumulator overflow, you can use -floating point patterns instead of quantized patterns. - -`"+"` describes a chain of two ops. The preceding op produces an output tensor, -which is consumed by the following op as its first operand. - -`"[]"` describes a component of the overall pattern description. For example, -it could include a subgraph or all the op choices within the bracket. - -`"|"` describes choices of multiple operations, say A+[B|C] means the graph -partition contains A followed by B or C. - -`","` describes a graph composed of multiple subgraphs, each subgraph marks its -output tensor explicitly, which is consumed by other subgraphs. - -`Superscript` denotes the numbers of repetition pattern. For example, -A+[B|C]\f$^{3}\f$ means the graph partition contains A followed by three ops, -each of them is either B or C. The superscript could be a range of number -meaning allowing a range of repetition. If the range is between 0 and 1, we use -superscript `"?"`. - -`Subscript` denotes the input and output tensors which need to explicitly mark -the producer and consumer relation within one graph partition. For example, -A\f$_{>t1}\f$+B+C\f$_{"` refers to the output -tensor, and `"<"` for input tensor. Input and output tensor between neighbor -ops are not explicitly marked, for example, B consumes t1 implicitly in the -example above. - -Subscript `"out"` marks the output tensor of a certain op to be the output of -a graph partition. For example, in -A\f$_{>t1}\f$+B\f$_{>out}\f$+C\f$_{out}\f$, B's output and C's output -are marked as output tensors. - -Subscript `"in"` marks the input tensor of a certain op to be the input of a -graph partition. For example, in A\f$_{t1}\f$+B+C\f$_{out}\f$ | This pattern is widely used in Convolution Neural Networks, for example ResNet, ResNext, SSD, etc. | -| ConvTranspose + BiasAdd\f$^?\f$ + [Unary \| Binary]\f$^{0-3}\f$\f$_{>out}\f$ | This pattern is widely used in Generative Adversarial Networks. | -| Interpolate + [Unary \| Binary]\f$^{0-3}\f$\f$_{>out}\f$ | This pattern is widely used for image processing. | -| MatMul + BiasAdd\f$^?\f$ + [Unary \| Binary]\f$^{0-3}\f$ + Select\f$^?\f$\f$_{>out}\f$ | This pattern is widely used in language models and recommendation models, for example BERT, DLRM, etc. | -| Reduction + [Unary \| Binary]\f$^{0-3}\f$\f$_{>out}\f$ | This pattern is widely used for data processing, for example loss reduction. | -| Unary + Binary\f$^{0-3}\f$\f$_{>out}\f$ | This pattern is widely used in Convolution Neural Networks. | -| Binary + [Unary \| Binary]\f$^{0-3}\f$\f$_{>out}\f$ | This pattern is widely used in Generative Adversarial Networks, for example ParallelWaveGAN. | -| [AvgPool \| MaxPool] + Binary\f$^{0-3}\f$\f$_{>out}\f$ | This pattern is widely used in Convolution Neural Networks. | -| BatchNormInference + ReLU\f$_{>out}\f$ | This pattern is widely used in Convolution Neural Networks, for example DenseNet. | -| Reciprocal + Multiply\f$_{>out}\f$ | N/A | -| Reorder + Add\f$_{>out}\f$ | N/A | -| Scaled Dot-Product Attention | Refer to @ref dev_guide_graph_sdpa for more details. | - -#### Quantized Patterns - -| Pattern | Description | -|:--------|:-----------------------------| -| Quantize\f$^?\f$ + Dequantize\f$_{>t1}\f$, Dequantize\f$_{>t2}\f$\f$^{0-3}\f$, Dequantize + Convolution\f$_{out}\f$ | N/A | -| Quantize\f$^?\f$ + Dequantize\f$_{>t1}\f$, Dequantize\f$_{>t2}\f$\f$^{0-3}\f$, Dequantize + ConvTranspose\f$_{out}\f$ |N/A | -| Quantize\f$^?\f$ + Dequantize\f$_{>t1}\f$, Dequantize\f$_{>t2}\f$\f$^{0-3}\f$, Dequantize + MatMul\f$_{out}\f$ |N/A | -| Dequantize + [AvgPool \| MaxPool] + Quantize\f$_{>out}\f$ |N/A | -| Dequantize\f$_{>t1}\f$, Dequantize + [AvgPool \| MaxPool] + Add\f$_{out}\f$ |N/A | -| Dequantize + Reorder + Quantize\f$_{>out}\f$ |N/A | -| Dequantize\f$_{>t1}\f$, Dequantize + Reorder + Add\f$_{out}\f$ |N/A | -| [SoftMax \| LayerNorm \| GroupNorm] + [Unary \| Binary\f$_{out}\f$ | This pattern is used in SmoothQuant to fuse scales and quantization into previous layers | - -### Training - -| Pattern | Description | -|:--------|:-----------------------------| -| ConvolutionBackwardWeights + BiasAddBackward\f$_{>out}\f$ | N/A | -| ReLUBackward + BatchNormTrainingBackward\f$_{>out}\f$ |N/A | - -All the above fusion patterns are supported by default. - -## Aggressive Fusion Patterns -Aggressive fusion patterns also follow the pattern description convention -defined in the [Fusion Patterns](@ref fusion_patterns) section. - -@note Aggressive fusion patterns are only supported when -[Graph Compiler](@ref dev_guide_graph_compiler) is enabled. - -The following categories will also be used to describe aggressive fusion -patterns. - -- ReshapeTranspose = [StaticReshape + StaticTranspose\f$^{1-2}\f$] - -- Activation = [ReLU \| Sigmoid \| GELU] - -- ActivationBackward = [ReLUBackward \| SigmoidBackward \| GELUBackward] - -### Inference - -#### Floating Point Patterns - -| Pattern | Description | -|:--------|:-----------------------------| -| MatMul + [Multiply \| Divide] + Add + Softmax + MatMul + StaticTranspose + Reorder\f$_{>out}\f$ | Multi-head Attention. This pattern is widely used in models containing encoder-decoder structures, for example BERT. | -| ReshapeTranspose\f$_{>t1}\f$, ReshapeTranspose\f$_{>t2}\f$, ReshapeTranspose + MatMul\f$_{out}\f$ | Multi-head Attention. | -| MatMul + Activation\f$_{>t1}\f$, [MatMul\f$_{t1}\f$]\f$^{0-4}\f$, MatMul\f$_{out}\f$ | Multi-layer Perceptron. This pattern is widely used in recommendation models, for example DLRM. | -| [Convolution + BiasAdd\f$^{?}\f$ + ReLU]\f$^{1-3}\f$ + Convolution + BiasAdd\f$^{?}\f$ + Add + ReLU\f$_{>out}\f$ | Identical Bottleneck. Enabled only in single thread runtime scenario. This pattern is widely used in Convolution Neural Networks, for example ResNet. | -| Convolution + BiasAdd\f$^{?}\f$\f$_{>t1}\f$, [Convolution + BiasAdd\f$^{?}\f$ + ReLU]\f$^{1-3}\f$ + Convolution + BiasAdd\f$^{?}\f$ + Add\f$_{out}\f$ | Convolutional Bottleneck. Enabled only in single thread runtime scenario. This pattern is widely used in Convolution Neural Networks, for example ResNet. | - -#### Quantized Patterns - -| Pattern | Description | -|:--------|:-----------------------------| -| Dequantize\f$_{>t1}\f$, Dequantize\f$_{>t2}\f$, Dequantize + MatMul\f$_{out}\f$ | Quantized Multi-head Attention. | -| Dequantize + ReshapeTranspose\f$_{>t1}\f$, Dequantize + ReshapeTranspose\f$_{>t2}\f$, Dequantize + MatMul\f$_{out}\f$ | Quantized Multi-head Attention. | -| Dequantize\f$_{>t1}\f$, Dequantize + MatMul\f$_{t2}\f$, [Dequantize\f$_{>t3}\f$, Dequantize\f$_{t2}\f$]\f$^{0-4}\f$, Dequantize\f$_{>t4}\f$, Dequantize\f$_{out}\f$ | Quantized Multi-layer Perceptron. | -| Dequantize\f$_{>t2}\f$, Dequantize\f$_{>t3}\f$, [Dequantize\f$_{>t1}\f$, Dequantize + Convolution\f$_{out}\f$ | Quantized Identical Bottleneck. Enabled only in single thread runtime scenario. | -| [Dequantize\f$_{>t1}\f$, Dequantize + Convolution\f$_{t2}\f$, Dequantize\f$_{>t4}\f$, [Dequantize\f$_{>t3}\f$, Dequantize + Convolution\f$_{out}\f$ | Quantized Convolutional Bottleneck. Enabled only in single thread runtime scenario. | - -### Training - -| Pattern | Description | -|:--------|:-----------------------------| -| Dequantize\f$_{>t1}\f$, Dequantize\f$_{>t2}\f$, Dequantize + MatMul\f$_{out}\f$ | Multi-head Attention Training Forward Pattern. | -| StaticReshape + StaticTranspose\f$_{>t1}\f$ + MatMul + Multiply\f$_{>t2}\f$ + Subtract\f$_{t4}\f$ + MatMul\f$_{>out1}\f$, Multiply\f$_{t3}\f$, MatMul\f$_{out2}\f$, MatMul\f$_{out3}\f$ | Multi-head Attention Training Backward Pattern. | -| MatMul\f$_{>out1}\f$ + Activation\f$_{>t1,>out2}\f$, [MatMul\f$_{out3}\f$ + Activation\f$_{>t1,>out4}\f$]\f$^{0-4}\f$, MatMul\f$_{out5}\f$ + Activation\f$_{>out6}\f$ | Multi-layer Perceptron Training Forward Pattern. | -| StaticTranspose\f$^{?}\f$\f$_{>t0}\f$, ActivationBackward\f$_{>t2}\f$ + MatMul\f$_{t1}\f$, ReduceSum\f$^{?}\f$\f$_{out1}\f$, StaticTranspose\f$^{?}\f$ + MatMul\f$_{out2}\f$, [StaticTranspose\f$^{?}\f$\f$_{>t3}\f$, ActivationBackward\f$_{>t4,t1}\f$, ReduceSum\f$^{?}\f$\f$_{out3}\f$, StaticTranspose\f$^{?}\f$ + MatMul\f$_{out4}\f$]\f$^{0-4}\f$, StaticTranspose\f$^{?}\f$\f$_{>t5}\f$, ActivationBackward\f$_{>t6,out5}\f$, ReduceSum\f$^{?}\f$\f$_{out6}\f$, StaticTranspose\f$^{?}\f$ + MatMul\f$_{out7}\f$ | Multi-layer Perceptron Training Backward Pattern. | -| Convolution\f$_{>out1}\f$ + BatchNormForwardTraining\f$_{>out2}\f$ + ReLU\f$_{>out3}\f$ + Convolution\f$_{>out4}\f$ + BatchNormForwardTraining\f$_{>out5}\f$ + ReLU\f$_{>out6}\f$ + Convolution\f$_{>out7}\f$ + BatchNormForwardTraining\f$_{>out8}\f$ + Add + ReLU\f$_{>out9}\f$ | Identical Bottleneck Training Forward Pattern. | -| Convolution\f$_{>out1}\f$ + BatchNormForwardTraining\f$_{>t1,>out2}\f$, Convolution\f$_{>out3}\f$ + BatchNormForwardTraining\f$_{>out4}\f$ + ReLU\f$_{>out5}\f$ + Convolution\f$_{>out6}\f$ + BatchNormForwardTraining\f$_{>out7}\f$ + ReLU\f$_{>out8}\f$ + Convolution\f$_{>out9}\f$ + BatchNormForwardTraining\f$_{>out10}\f$ + Add\f$_{out11}\f$ | Convolutional Bottleneck Training Forward Pattern. | -| ReLUBackward\f$_{>t1}\f$ + BatchNormTrainingBackward\f$_{>t2,>out1}\f$ + ConvolutionBackwardData + ReLUBackward + BatchNormTrainingBackward\f$_{>t3,>out2}\f$ + ConvolutionBackwardData + ReLUBackward + BatchNormTrainingBackward\f$_{>t4,>out3}\f$ + ConvolutionBackwardData + Add\f$_{out4}\f$, ConvolutionBackwardWeights\f$_{out5}\f$, ConvolutionBackwardWeights\f$_{out6}\f$, ConvolutionBackwardWeights\f$_{out7}\f$ | Identical Bottleneck Training Backward Pattern. | -| ReLUBackward\f$_{>t1}\f$ + BatchNormTrainingBackward\f$_{>t2,>out1}\f$ + ConvolutionBackwardData + ReLUBackward + BatchNormTrainingBackward\f$_{>t3,>out2}\f$ + ConvolutionBackwardData + ReLUBackward + BatchNormTrainingBackward\f$_{>t4,>out3}\f$ + ConvolutionBackwardData + Add\f$_{out4}\f$, BatchNormTrainingBackward\f$_{t5,>out5}\f$ + ConvolutionBackwardData\f$_{>t6}\f$, ConvolutionBackwardWeights\f$_{out6}\f$, ConvolutionBackwardWeights\f$_{out7}\f$, ConvolutionBackwardWeights\f$_{out8}\f$, ConvolutionBackwardWeights\f$_{out9}\f$ | Convolutional Bottleneck Training Backward Pattern. | diff --git a/doc/performance_considerations/benchdnn.md b/doc/performance_considerations/benchdnn.md index 36d8124d059..56fc9628e1c 100644 --- a/doc/performance_considerations/benchdnn.md +++ b/doc/performance_considerations/benchdnn.md @@ -4,4 +4,4 @@ Benchmarking Performance {#dev_guide_benchdnn} oneDNN has a built-in benchmarking program called benchdnn. For a complete description of the available options and working examples, see -the [benchdnn readme](https://github.com/oneapi-src/oneDNN/blob/master/tests/benchdnn/README.md#benchdnn). +the [benchdnn readme](https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/README.md#benchdnn). diff --git a/doc/performance_considerations/verbose.md b/doc/performance_considerations/verbose.md index ca1f02e95e5..9696125237b 100644 --- a/doc/performance_considerations/verbose.md +++ b/doc/performance_considerations/verbose.md @@ -27,6 +27,7 @@ the type of tracing information to display. |:---------------------------|:--------------------|:--------------------------------------------------| | `ONEDNN_VERBOSE` | `none` | no messages printed | | \ | **`error`** | **error messages** (default) | +| \ | `warn` | warning messages | | \ | `check` | primitive creation parameter checking information | | \ | `profile_create` | primitive creation timings | | \ | `profile_exec` | primitive execution timings | @@ -150,7 +151,7 @@ Above, we can see that the highest performance implementations were not dispatched either because they required a higher ISA, or because they did not support that datatype configuration. A complete list of verbose messages encountered in the dispatch mode -can be found [here](https://oneapi-src.github.io/oneDNN/dev_guide_verbose_table.html) along with their explanation. +can be found [here](https://uxlfoundation.github.io/oneDNN/dev_guide_verbose_table.html) along with their explanation. ### Enable ONEDNN_VERBOSE with timestamps @@ -195,6 +196,7 @@ Each subsequent line of primitive verbose information is formatted as a comma-separated list and contains the following, in order of appearance in the line from left to right: * `onednn_verbose` marker string +* verbose mode version: `v0` or `v1` * if `ONEDNN_VERBOSE_TIMESTAMP=1` is specified, start time of the call. On Linux this number represents amount of milliseconds since Unix epoch. On Windows this number represents amount of milliseconds since the last system start. @@ -238,7 +240,7 @@ primitive execution. @note When oneDNN verbose mode is enabled for builds with -[Compute Library for the Arm architecture](https://oneapi-src.github.io/oneDNN/dev_guide_build.html#gcc-with-arm-compute-library-acl-on-aarch64-host), +[Compute Library for the Arm architecture](https://uxlfoundation.github.io/oneDNN/dev_guide_build.html#gcc-with-arm-compute-library-acl-on-aarch64-host), any failures in the validation of Compute Library primitives will be detailed in the verbose output. diff --git a/doc/performance_considerations/verbose_table.md b/doc/performance_considerations/verbose_table.md index 34d678aca08..ace5d59afd1 100644 --- a/doc/performance_considerations/verbose_table.md +++ b/doc/performance_considerations/verbose_table.md @@ -33,6 +33,7 @@ The following catalogue lists verbose messages, explanations, and additional inf |` has a bad number of dimensions ` |`t`- tensor, `ndims`- number of tensor dimensions | all | Tensor data has bad or invalid number of dimensions for the current primitive operation. **Example**: The `convolution` primitive expects only 1D-, 2D- or 3D-spatial tensors for operations and prints this message for any other data with higher dimensions. | |`bad dimensions :` |`t`- tensor, `axis`- axis | all | Tensor `` has an invalid dimension along the specified axis. **Example**: The `concat` primitive prints this message when the destination tensor dimension along the concatenated axis does not match the sum of the dimensions of the concatenated tensors. | |`dimension : is inconsistent with :` |`t0, t1` - tensors, `a0, a1` - tensor axes | all | Tensors `t0, t1` have inconsistent dimensions along axes `a0` and `a1` respectively. **Example**: This is encountered for the `matmul` primitive when the input matrices have mismatching dimensions. | +|`out-of-range dimensions for ` |`t` - tensor | all | One of the dimensions of tensor `t` is beyond the maximum range that can be processed by the current implementation. | |`tensors and have inconsistent number of dimensions` |`t0, t1` - tensors | all | Tensors `t0, t1` have inconsistent dimensions for primitive operation. | |`tensors and have inconsistent datatypes` |`t0, t1` - tensors | all | Tensors `t0, t1` have inconsistent data types for primitive operation. | |**Unsupported Combinations** | | | | @@ -53,22 +54,25 @@ The following catalogue lists verbose messages, explanations, and additional inf |`alpha and beta parameters are not properly set` | | `eltwise` | Alpha and beta parameters are not properly set for the elementwise algorithm. | |`large shapes fall back` | | `gemm` | Heuristic to skip current implementation for large tensor shapes for better performance. | |`only trivial strides are supported` | | `gemm`, `rnn` | Current implementation for the primitive does not process non-trivial stride values. | -|`unsupported fpmath mode` | | `matmul` | [Floating-point math mode](https://oneapi-src.github.io/oneDNN/group_dnnl_api_fpmath_mode.html?highlight=math%20mode) is not supported by the current primitive implementation. | +|`unsupported fpmath mode` | | `matmul` | [Floating-point math mode](https://uxlfoundation.github.io/oneDNN/group_dnnl_api_fpmath_mode.html?highlight=math%20mode) is not supported by the current primitive implementation. | |`small shapes fall back` | | `matmul` | Heuristic to skip current implementation for small tensor shapes for better performance. | |`incompatible gemm format` | | `matmul`, `ip` | Specified GeMM format is incompatible with the current primitive implementation. | |`unsupported tensor layout` |`t` - tensor | `reorder` | The data layout for the source/destination tensor is not supported by the current implementation. | |`bad axis` | | `softmax`, `shuffle` | Bad or invalid axis specified for softmax/shuffle operation. | |`unsupported architecture` | `d` - `dnnl::engine::kind` | `gemm` | Unsupported architecture for specified device-type. Typically encountered when current GPU device does not support the primitive. | |**Miscellaneous** | | | | -|`failed to create nested primitive ` |`pm` - `dnnl::primitive` | all | Descriptor initialization for the nested primitive implementation was unsuccessful. | +|`failed to create nested primitive` |`pm` - `dnnl::primitive` | all | Descriptor initialization for the nested primitive implementation was unsuccessful. | |`failed to create descriptor` |`pm` -`dnnl::primitive`, `dnnl::memory` | all | Descriptor initialization for the primitive or memory object was unsuccessful. | -|`bad accumulation mode` | | all | Bad or invalid [accumulation mode](https://oneapi-src.github.io/oneDNN/enum_dnnl_accumulation_mode.html) specified for primitive attribute `dnnl::primitive_attr`. | +|`bad accumulation mode` | | all | Bad or invalid [accumulation mode](https://uxlfoundation.github.io/oneDNN/enum_dnnl_accumulation_mode.html) specified for primitive attribute `dnnl::primitive_attr`. | |`unsupported md flag` |`t` - tensor | all | Bad or unsupported flags specified for the memory descriptor `dnnl::memory::desc`. | |`problem is not mathematically consistent` | | all | *(self-explanatory)* | |`workspace mismatch between forward and backward primitive descriptors`| | all | *(self-explanatory)* | -|`workspace initialization failed` | | all | [Workspace](https://oneapi-src.github.io/oneDNN/dev_guide_inference_and_training_aspects.html?highlight=workspace#workspace) descriptor initialization was unsuccessful during primitive creation. | +|`workspace initialization failed` | | all | [Workspace](https://uxlfoundation.github.io/oneDNN/dev_guide_inference_and_training_aspects.html?highlight=workspace#workspace) descriptor initialization was unsuccessful during primitive creation. | |`invalid datatype for ` |`t` - tensor | all | The data type for the tensor/data processed by the primitive is invalid. **Example**: This is encountered when an undefined data type `data_type::undef` is specified for the accumulator. | -|`failed to run kernel deterministically` | | all | failed to run application in the [deterministic mode](https://oneapi-src.github.io/oneDNN/dev_guide_attributes_deterministic.html?highlight=deterministic). | +|`failed to run kernel deterministically` | | all | failed to run application in the [deterministic mode](https://uxlfoundation.github.io/oneDNN/dev_guide_attributes_deterministic.html?highlight=deterministic). | +|`skipping or dispatching to another implementation` | | all | *(self-explanatory)* | +|`failed to create kernel` |`k` - kernel name | all | *(self-explanatory)* | + ## Engine Creation @@ -76,13 +80,13 @@ The following catalogue lists verbose messages, explanations, and additional inf |:-----------------------------------------------------|:----------|:------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------| |`bad engine kind` | | all | Invalid value for `dnnl::engine::kind` encountered during engine creation. | |`invalid device in environment: index ` |`d` - `dnnl::engine::kind`, `i` - device index | all | Device of type `dnnl::engine::kind` and index `i` is invalid for the current environment. | -|`no device is available` |`d` - `dnnl::engine::kind` | all | No device of type `dnnl::engine::kind` was found during engine creation. | -|` devices are available but was queried` |`d` - `dnnl::engine::kind`, `n` - number of `d` devices, `i` - queried device index | all | Queried index is out-of-range for device of type `dnnl::engine::kind`. | +|`no device is available` |`d` - `dnnl::engine::kind`, `k` - `dnnl::impl::runtime_kind` | all | No device of type `dnnl::engine::kind` was found during engine creation. | +|` devices are available but device index was queried` |`d` - `dnnl::engine::kind`, `n` - number of `d` devices, `i` - queried device index | all | Queried index is out-of-range for device of type `dnnl::engine::kind`. | |`device not found in the given context` | | all | *(self-explanatory)* | |`unsupported platform (expected got )` |`d` - `dnnl::engine::kind`, `d0` - queried platform, `d1` - available platform | `sycl`, `opencl` | Unsupported device platform encountered during engine creation. | -|`failed to create engine with index ` |`d` - `dnnl::engine::kind`, `i` - device index |all | Engine creation was unsuccessful for specified device index and kind. | +|`failed to create engine with index ` |`d` - `dnnl::engine::kind`, `i` - device index |all | Engine creation was unsuccessful for the specified device index and kind. | |`unsupported backend` |`d` - `dnnl::engine::kind` | `sycl` | *(self-explanatory)* | -|`profiling capabilities are not supported` | | all | Experimental profiling ([ONEDNN_EXPERIMENTAL_PROFILING](https://oneapi-src.github.io/oneDNN/dev_guide_experimental.html?highlight=profiling#onednn-experimental-profiling)) is not enabled for the application. | +|`profiling capabilities are not supported` | | all | Experimental profiling ([ONEDNN_EXPERIMENTAL_PROFILING](https://uxlfoundation.github.io/oneDNN/dev_guide_experimental.html?highlight=profiling#onednn-experimental-profiling)) is not enabled for the application. | ## Memory Creation and Related Operations @@ -92,6 +96,6 @@ The following catalogue lists verbose messages, explanations, and additional inf |`bad arguments for memory descriptor` | Bad or unsupported values passed to the memory descriptor `dnnl::memory::desc` during memory object creation. | |`invalid memory index` | An out-of-range value encountered for memory handle during data mapping. | |`unsupported memory stride` | Memory descriptor initialization failed due to unsupported value for memory strides. | -|`scratchpad memory limit exceeded` | [Scratchpad](https://oneapi-src.github.io/oneDNN/dev_guide_attributes_scratchpad.html?highlight=scratchpad) space is exhausted during GEMM kernel initialization. | +|`scratchpad memory limit exceeded` | [Scratchpad](https://uxlfoundation.github.io/oneDNN/dev_guide_attributes_scratchpad.html?highlight=scratchpad) space is exhausted during GEMM kernel initialization. | |`scratchpad initialization unsuccessful` | *(self-explanatory)* | diff --git a/doc/primitives/batch_normalization.md b/doc/primitives/batch_normalization.md index a85f6ac442a..6fddd954fab 100644 --- a/doc/primitives/batch_normalization.md +++ b/doc/primitives/batch_normalization.md @@ -103,8 +103,8 @@ requires different inputs and outputs. For clarity, a summary is shown below. | #dnnl_normalization_flags_none | *Inputs*: \src

*Outputs*: \dst | *Inputs*: \src

*Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$ | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$

*Outputs*: \diffsrc | Same as for #dnnl_backward | | #dnnl_use_global_stats | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$

*Outputs*: \dst | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$

*Outputs*: \dst | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$

*Outputs*: \diffsrc | Same as for #dnnl_backward | | #dnnl_use_scale | *Inputs*: \src, \f$\gamma\f$

*Outputs*: \dst | *Inputs*: \src, \f$\gamma\f$

*Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$ | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$

*Outputs*: \diffsrc, \f$\diffgamma\f$ | Not supported | -| #dnnl_use_shift | *Inputs*: \src, \f$\beta\f$

*Outputs*: \dst | *Inputs*: \src, \f$\beta\f$

*Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$ | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\beta\f$

*Outputs*: \diffsrc, \f$\diffbeta\f$ | Not supported | -| #dnnl_use_global_stats \| #dnnl_use_scale \| #dnnl_use_shift | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$

*Outputs*: \dst | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$

*Outputs*: \dst | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$

*Outputs*: \diffsrc, \f$\diffgamma\f$, \f$\diffbeta\f$ | Not supported | +| #dnnl_use_shift | *Inputs*: \src, \f$\beta\f$

*Outputs*: \dst | *Inputs*: \src, \f$\beta\f$

*Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$ | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$

*Outputs*: \diffsrc, \f$\diffbeta\f$ | Not supported | +| #dnnl_use_global_stats \| #dnnl_use_scale \| #dnnl_use_shift | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$

*Outputs*: \dst | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$

*Outputs*: \dst | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$

*Outputs*: \diffsrc, \f$\diffgamma\f$, \f$\diffbeta\f$ | Not supported | | `flags` \| #dnnl_fuse_norm_relu | *Inputs*: same as with `flags`

*Outputs*: same as with `flags` | *Inputs*: same as with `flags`

*Outputs*: same as with `flags`, [Workspace](@ref dev_guide_inference_and_training_aspects_workspace) | *Inputs*: same as with `flags`, [Workspace](@ref dev_guide_inference_and_training_aspects_workspace)

*Outputs*: same as with `flags` | Same as for #dnnl_backward if `flags` do not contain #dnnl_use_scale or #dnnl_use_shift; not supported otherwise | | `flags` \| #dnnl_fuse_norm_add_relu | *Inputs*: same as with `flags` and \f$\src_1\f$ for fused binary addition

*Outputs*: same as with `flags` | *Inputs*: same as with `flags` and \f$\src_1\f$ for fused binary addition

*Outputs*: same as with `flags`, [Workspace](@ref dev_guide_inference_and_training_aspects_workspace) | *Inputs*: same as with `flags`, [Workspace](@ref dev_guide_inference_and_training_aspects_workspace)

*Outputs*: same as with `flags` and \f$\diffsrc_1\f$ for fused binary addition | Same as for #dnnl_backward if `flags` do not contain #dnnl_use_scale or #dnnl_use_shift; not supported otherwise | @@ -193,7 +193,7 @@ If #dnnl_use_scale or #dnnl_use_shift are used, the scale (\f$\gamma\f$) and shift (\f$\beta\f$) are separate 1D tensors of shape \f$C\f$. -The format of the corresponding memory object must be #dnnl_nc (#dnnl_ab). +The format of the corresponding memory object must be #dnnl_a. #### Source, Destination, and Their Gradients diff --git a/doc/primitives/binary.md b/doc/primitives/binary.md index 81b8bed6157..47f50248d31 100644 --- a/doc/primitives/binary.md +++ b/doc/primitives/binary.md @@ -16,9 +16,18 @@ between tensors source 0 and source 1 (the variable names follow the standard \src_0(\overline{x}) \mathbin{op} \src_1(\overline{x}), \f] -where \f$op\f$ is one of addition, subtraction, multiplication, division, -greater than or equal to, greater than, less than or equal to, less than, -equal to, not equal to, get maximum value, and get minimum value. +where \f$op\f$ is one of the following operators: addition (\f$+\f$), +subtraction (\f$-\f$), multiplication (\f$\times\f$), division (\f$\div\f$), +greater than or equal to (\f$\geq\f$), greater than (\f$>\f$), +less than or equal to (\f$\leq\f$), less than (\f$<\f$), equal to (\f$=\f$), +not equal to (\f$\neq\f$), get maximum value (\f$\max(\cdot)\f$), +get minimum value (\f$\min(\cdot)\f$), and conditional select operation. +For the conditional select operation, the binary primitive uses a third input +tensor \f$src_2\f$ to select between the two source tensors: + +\f[ + \dst[i] = \src_2[i] ? \src_0[i] : \src_1[i] +\f] The binary primitive does not have a notion of forward or backward propagations. @@ -31,6 +40,7 @@ argument index as specified by the following table. |-----------------------------|---------------------------------------------------------------------------| | \f$\src_0\f$ | DNNL_ARG_SRC_0 | | \f$\src_1\f$ | DNNL_ARG_SRC_1 | +| \f$\src_2\f$ | DNNL_ARG_SRC_2 | | \dst | DNNL_ARG_DST | | \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 | | \f$binary scale0\f$ | DNNL_ARG_ATTR_SCALES \| DNNL_ARG_SRC_0 | @@ -65,6 +75,10 @@ argument index as specified by the following table. be overwritten. In-place mode requires the \dst and source 0 data types to be the same. Different data types will unavoidably lead to correctness issues. + * For the binary select operation, broadcast semantics are not supported for + the third conditional input tensor. For this case, the dimensions and layout + of the conditional input tensor must match that of the source 0 tensor. + ### Post-Ops and Attributes The following attributes are supported: @@ -80,6 +94,8 @@ The following attributes are supported: The source and destination tensors may have `f32`, `bf16`, `f16`, `s32` or `s8/u8` data types. +For the binary select operation, the conditional input tensor can only be +of `s8` data type. The binary primitive supports the following combinations of data types: | Source 0 / 1 | Destination | @@ -106,7 +122,6 @@ meaning associated with any of tensors dimensions. 2. **GPU** - Only tensors of 6 or fewer dimensions are supported. - - s32 data type is not supported. ## Performance Tips diff --git a/doc/primitives/convolution.md b/doc/primitives/convolution.md index 069c781cc03..205ba5d240e 100644 --- a/doc/primitives/convolution.md +++ b/doc/primitives/convolution.md @@ -100,6 +100,12 @@ Here: - \f$OW = \left\lfloor{\frac{IW - DKW + PW_L + PW_R}{SW}} \right\rfloor + 1,\f$ where \f$DKW = 1 + (KW - 1) \cdot (DW + 1)\f$. +@note In oneDNN, convolution without dilation is defined by setting the dilation +parameters to `0`. This differs from PyTorch and TensorFlow, where a non-dilated +case corresponds to a dilation value of `1`. As a result, the PyTorch and +TensorFlow dilation parameters need to be adjusted by subtracting `1` (for example, +\f$DH_onednn = DH_torch - 1\f$, and \f$DW_onednn = DW_torch - 1\f$). + #### Deconvolution (Transposed Convolution) Deconvolutions (also called fractionally strided convolutions or transposed @@ -160,22 +166,22 @@ N/A. Convolution primitive supports the following combination of data types for source, destination, and weights memory objects: -| Propagation | Source | Weights | Destination | Bias | -|:---------------|:----------|:-------------|:----------------------------|:----------------------------| -| forward | f32 | f32 | f32, u8, s8 | f32 | -| forward | f16 | f16 | f16, f32, u8, s8 | f16, f32 | -| forward | u8, s8 | s8 | u8, s8, s32, f32, f16, bf16 | u8, s8, s32, f32, f16, bf16 | -| forward | bf16 | bf16 | f32, bf16 | f32, bf16 | -| forward | f8_e5m2 | f8_e5m2 | f8_e5m2, f32, f16, bf16 | f32 | -| forward | f64 | f64 | f64 | f64 | -| backward | f32, bf16 | bf16 | bf16 | | -| backward | f32, f16 | f16 | f16 | | -| backward | f8_e5m2 | f8_e5m2 | f8_e5m2 | | -| backward | f32 | f32 | f32 | f32 | -| backward | f64 | f64 | f64 | f64 | -| weights update | bf16 | f32, bf16 | bf16, s8, u8 | f32, bf16 | -| weights update | f16 | f32, f16 | f16 | f32, f16 | -| weights update | f8_e5m2 | f32, f8_e5m2 | f8_e5m2 | f32 | +| Propagation | Source | Weights | Destination | Bias | +|:---------------|:-----------------|:----------------------|:---------------------------------|:----------------------------| +| forward | f32 | f32 | f32, u8, s8 | f32 | +| forward | f16 | f16 | f16, f32, u8, s8 | f16, f32 | +| forward | u8, s8 | s8 | u8, s8, s32, f32, f16, bf16 | u8, s8, s32, f32, f16, bf16 | +| forward | bf16 | bf16 | f32, bf16 | f32, bf16 | +| forward | f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3, f32, f16, bf16 | f32 | +| forward | f64 | f64 | f64 | f64 | +| backward | f32, bf16 | bf16 | bf16 | | +| backward | f32, f16 | f16 | f16 | | +| backward | f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3 | | +| backward | f32 | f32 | f32 | f32 | +| backward | f64 | f64 | f64 | f64 | +| weights update | bf16 | f32, bf16 | bf16, s8, u8 | f32, bf16 | +| weights update | f16 | f32, f16 | f16 | f32, f16 | +| weights update | f8_e5m2, f8_e4m3 | f32, f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3 | f32 | @warning There might be hardware and/or implementation specific restrictions. @@ -432,8 +438,8 @@ of Winograd algorithm implementations. 3. **GPU** - Depthwise post-op is not supported - - Only reference support is available for f8_e4m3. Optimized implementation - is available for f8_e5m2 on Intel(R) Data Center GPU Max Series only. + - `f8` iplementation uses Intel XMX cores only on Intel GPUs based on + Xe-HPC and Xe2-LPG, and Xe2-HPG uArch. 4. **CPU** - Only reference support for fp8 data types (f8_e5m2, f8_e4m3) is diff --git a/doc/primitives/lrn.md b/doc/primitives/lrn.md index 7e4ca54a8fc..24e472af655 100644 --- a/doc/primitives/lrn.md +++ b/doc/primitives/lrn.md @@ -14,7 +14,7 @@ The LRN primitive performs a forward or backward local response normalization. The LRN operation is defined by the following formulas (the variable names follow the standard @ref dev_guide_conventions): -LRN [across channels](#dnnl_lrn_across_channels): +LRN across channels: \f[ \dst(n, c, h, w) = @@ -26,7 +26,7 @@ LRN [across channels](#dnnl_lrn_across_channels): \src(n, c, h, w), \f] -LRN [within channel](#dnnl_lrn_within_channel): +LRN within a single channel: \f[ \dst(n, c, h, w) = diff --git a/doc/primitives/matmul.md b/doc/primitives/matmul.md index a8073c68090..1a47dbc8437 100644 --- a/doc/primitives/matmul.md +++ b/doc/primitives/matmul.md @@ -67,7 +67,7 @@ argument index as specified by the following table. user must pass fully specified memory objects so that the primitive is able to perform the computations. Note that the less information about shapes or format is available at the creation stage, the less performant execution - will be. In particular, if the shape is not known at creation stage, one + will be. In particular, if the shape is not known at the creation stage, you cannot use the special format tag #dnnl::memory::format_tag::any to enable an implementation to choose the most appropriate memory format for the corresponding input or output shapes. On the other hand, run-time specified @@ -80,13 +80,13 @@ argument index as specified by the following table. invalid. 3. The broadcasting shape consistency check is not done for the dimensions with - #DNNL_RUNTIME_DIM_VAL. It is user responsibility to make sure the dimensions + #DNNL_RUNTIME_DIM_VAL. Make sure the dimensions for the tensors are valid. 4. Multiple batch dimensions and broadcasting of batch dimensions of `src` and `weights` are supported for both CPU and GPU engines. - Please check tutorials below to see #DNNL_RUNTIME_DIM_VAL support in use. + Check the tutorials below to see #DNNL_RUNTIME_DIM_VAL support in use. ### Data Types @@ -94,14 +94,17 @@ The MatMul primitive supports the following combinations of data types for source, destination, weights, and bias tensors: -| Source | Weights | Destination | Bias | -|:---------------|:----------|:----------------------------|:----------------------------| -| f32 | f32 | f32 | f32 | -| f16 | f16 | f16, u8, s8 | f16, f32 | -| bf16 | bf16 | f32, bf16 | bf16, f32 | -| f32, bf16, f16 | u8, s8 | f32, bf16, f16 | f32, bf16, f16 | -| u8, s8 | s8 | u8, s8, s32, f32, f16, bf16 | u8, s8, s32, f32, f16, bf16 | -| f8_e5m2 | f8_e5m2 | f32, f16, bf16, f8_e5m2 | f32, bf16, f16 | +| Source | Weights | Destination | Bias | +|:-----------------|:---------------------|:---------------------------------|:----------------------------| +| f64 | f64 | f64 | f64, f32, f16, bf16, s8, u8 | +| f32 | f32 | f32 | f32, bf16, f16, u8, s8 | +| f16 | f16, u8, s8, u4, s4 | f16, u8, s8 | f32 | +| f16 | f16, u8, s8 | f32 | f32, f16 | +| bf16 | bf16, u8, s8, u4, s4 | f32, bf16 | f32, bf16 | +| f32, bf16, f16 | u8, s8 | f32, bf16, f16 | f32, bf16, f16 | +| f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3 | f32, f16, bf16, f8_e5m2, f8_e4m3 | f32, bf16, f16 | +| u8, s8 | s8 | u8, s8, s32, f32, f16, bf16 | u8, s8, s32, f32, f16, bf16 | + ### Data Representation @@ -178,8 +181,8 @@ memory buffer that shares its shape with the destination buffer). - Sum post-op doesn't support data type other than destination data type. - Bias of bf16 data type is supported for configuration with bf16 source data type and weights bf16 data type, and up to three dimensional matrices. - - Only reference support is available for f8_e4m3. Optimized implementation - for f8_e5m2 is available only on Intel(R) Data Center GPU Max Series. + - Optimized implementations for fp8 data type are available only on Intel(R) + Data Center GPU Max Series and Intel(R) Xe2 Graphics. - Configuration with int8 source data type, s8 weight data type and bf16 destination data type don't support: * Destination zero point. @@ -187,13 +190,12 @@ memory buffer that shares its shape with the destination buffer). * Three and higher dimensional matrices. - The layout of dropout mask has to be exactly the same as that of dst. + 3. **CPU** - Configuration with int8 source data type, s8 weight data type and f16 destination data type isn't supported. - Configuration with floating point source data type, integer weights data type and floating point destination data type is not optimized. - - Only reference support for fp8 data types (f8_e5m2, f8_e4m3) is - is available on CPU. - The layout of dropout mask has to be exactly the same as that of dst. ## Performance Tips diff --git a/doc/primitives/prelu.md b/doc/primitives/prelu.md index 52c9669b097..3418cc7c3b9 100644 --- a/doc/primitives/prelu.md +++ b/doc/primitives/prelu.md @@ -62,7 +62,7 @@ For no broadcast case, results are calculated using formula: \diffdst(n, c, h, w) \cdot \weights(n, c, h, w) & \mbox{if } \src(n, c, h, w) \leq 0 \end{cases}\\\\ - \diff_weights(n, c, h, w) &= + \diffweights(n, c, h, w) &= \min(\src(n, c, h, w), 0) \cdot \diffdst(n, c, h, w) \f] diff --git a/doc/primitives/reorder.md b/doc/primitives/reorder.md index 76a50405afa..16a8310ec15 100644 --- a/doc/primitives/reorder.md +++ b/doc/primitives/reorder.md @@ -115,15 +115,16 @@ would lead to the following operation: \f[ \dst(\overline{x}) = - scale_{src} \cdot \src(\overline{x} - shift_{src}) + + scale_{src} \cdot (\src(\overline{x}) - shift_{src}) + \beta \cdot \dst(\overline{x}) + shift_{dst} \f] @note * The intermediate operations are being done using single precision floating point data type. - * \f$scale_{src}\f$ and \f$scale_{dst}\f$ must be passed during execution runtime - as a separate memory argument. Using \f$scale_{src}\f$ argument will lead to + * \f$scale_{src}\f$, \f$shift_{src}\f$, \f$scale_{dst}\f$, and + \f$shift_{dst}\f$ must be passed during execution runtime as a separate + memory arguments. Using \f$scale_{src}\f$ argument will lead to multiplication of tensor values by a scale value. Using \f$scale_{dst}\f$ argument will lead to division of tensor values by a scale value. diff --git a/doc/primitives/softmax.md b/doc/primitives/softmax.md index 1ec1a4961e1..cbe15452e6d 100644 --- a/doc/primitives/softmax.md +++ b/doc/primitives/softmax.md @@ -100,12 +100,28 @@ argument index as specified by the following table. Attributes enable you to modify the behavior of the softmax primitive. The following attributes are supported by the softmax primitive: -| Propagation | Type | Operation | Description | Restrictions | -|:------------|:----------|:-----------------------------------------------------|:--------------------------------------------------------------|:-----------------------------------------------------------------------| -| forward | attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the corresponding tensor by the given scale factor(s). | Supported only for int8 softmax and one scale per tensor is supported. | -| forward | post-op | [Binary](@ref dnnl::post_ops::append_binary) | Applies a @ref dnnl_api_binary operation to the result | General binary post-op restrictions | -| forward | Post-op | [Eltwise](@ref dnnl::post_ops::append_eltwise) | Applies an @ref dnnl_api_eltwise operation to the result. | | - +| Propagation | Type | Operation | Description | Restrictions | +|:------------|:----------|:----------------------------------------------------------------------|:--------------------------------------------------------------|:-----------------------------------------------------------------------| +| forward | attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the corresponding tensor by the given scale factor(s). | Supported only for int8 softmax and one scale per tensor is supported. | +| forward | post-op | [Binary](@ref dnnl::post_ops::append_binary) | Applies a @ref dnnl_api_binary operation to the result | General binary post-op restrictions | +| forward | Post-op | [Eltwise](@ref dnnl::post_ops::append_eltwise) | Applies an @ref dnnl_api_eltwise operation to the result. | | +| forward | attribute | [Accumulation mode](@ref dnnl::primitive_attr::set_accumulation_mode) | Defines the implementation's accumulation arithmetic. | Only the values `strict`, `relaxed`, and `any` are supported. | + +#### Accumulation Mode + +You can optimize performance of the forward operation when the source and +destination floating-point data types of the operation are equal and different +from `f32`. When the destination data type is different from `f32`, additional +memory will be used to accumulate data and store it in the destination memory +buffer for a requested data type. Using the additional memory can be opted-out +with an accumulation mode setting set to +[relaxed](@ref dnnl::accumulation_mode::relaxed) or +[any](@ref dnnl::accumulation_mode::any), which will use the precision of +destination data type to accumulate intermediate results directly into the +destination memory buffer. This performance optimization, however, results in +in a minor decrease in accuracy. Depending on the actual data, the difference +between `strict` and `relaxed` accumulation can reach several units in the last +piece (ulps). ### Data Type Support diff --git a/doc/programming_model/data_types.md b/doc/programming_model/data_types.md index 9166a2ff2d8..eaee2f2dfb1 100644 --- a/doc/programming_model/data_types.md +++ b/doc/programming_model/data_types.md @@ -7,8 +7,7 @@ to be the golden standard in deep learning applications and is supported in all the library functions. The purpose of low precision data types support is to improve performance of compute intensive operations, such as convolutions, inner product, and recurrent neural network cells -in comparison to fp32. Boolean data type is used for Graph Compiler to optimize -operations which take bool as inputs and/or outputs data type. +in comparison to fp32. | Data type | Description | |:----------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| @@ -21,32 +20,27 @@ operations which take bool as inputs and/or outputs data type. | boolean | bool (size is C++ implementation defined) | | f8\_e5m2 | [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) with 5 exponent and 2 mantissa bits | | f8\_e4m3 | [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) with 4 exponent and 3 mantissa bits | +| e8m0 | [MX standard 8-bit scaling type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) | +| f4\_e2m1 | [MX standard 4-bit floating-point](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2 exponent and 1 mantissa bits | +| f4\_e3m0 | 4-bit floating-point with 3 exponent bits and no mantissa bit | -@note - Boolean is only supported in the Graph Compiler in CPU engines. No - primitives support boolean during primitive computation. - ## Inference and Training oneDNN supports training and inference with the following data types: -| Usage mode | CPU | GPU | -|:-----------|:---------------------------------------------------------|:----------------------------------------------| -| Inference | f32, bf16, f16, f8\_e5m2/f8\_e4m3, s8/u8, s4/u4, boolean | f32, bf16, f16, f8\_e5m2/f8\_e4m3, s8/u8, f64 | -| Training | f32, bf16, f16 | f32, bf16, f16, f64 | +| Usage mode | CPU | GPU | +|:-----------|:-----------------------------------------------------------------------------|:----------------------------------------------| +| Inference | f32, bf16, f16, f8\_e5m2/f8\_e4m3, f4\_e2m1, f4\_e3m0, s8/u8, s4/u4, boolean | f32, bf16, f16, f8\_e5m2/f8\_e4m3, s8/u8, f64 | +| Training | f32, bf16, f16, f8\_e5m2/f8\_e4m3 | f32, bf16, f16, f8\_e5m2/f8\_e4m3, f64 | @note Using lower precision arithmetic may require changes in the deep learning model implementation. @note - f64 is only supported for convolution, reorder, layer normalization and - pooling primitives, on the GPU engine. - -@note - Boolean is only supported by the oneDNN graph API when the graph compiler - backend is enabled. + f64 is supported only for matmul, convolution, reorder, layer normalization, and + pooling primitives on the GPU engine. @note s4/u4 data types are only supported as a storage data type for weights argument @@ -75,12 +69,12 @@ post-ops). The following formula governs the datatypes dynamic during a primitive computation: \f[ -\operatorname{convert_{dst\_dt}} ( \operatorname{dst\_zero\_point_{f32}} + \operatorname{postops_{f32}} (\operatorname{oscale_{f32}} * \operatorname{convert_{f32}} (\operatorname{Op}(\operatorname{src_{src\_dt}}, \operatorname{weights_{wei\_dt}}, ...)))) +\operatorname{convert_{dst\_dt}} ( \operatorname{zp_{dst}} + 1/\operatorname{scale_{dst}} * \operatorname{postops_{f32}} (\operatorname{convert_{f32}} (\operatorname{Op}(\operatorname{src_{src\_dt}}, \operatorname{weights_{wei\_dt}}, ...)))) \f] The `Op` output datatype depends on the datatype of its inputs: - if `src`, `weights`, ... are floating-point datatype (f32, f16, - bf16, f8\_e5m2, f8\_e4m3), then the `Op` outputs f32 elements. + bf16, f8\_e5m2, f8\_e4m3, f4\_e2m1, f4\_e3m0), then the `Op` outputs f32 elements. - if `src`, `weights`, ... are integral datatypes (s8, u8, s32), then the `Op` outputs s32 elements. - if the primitive allows to mix input datatypes, the `Op` outputs @@ -96,7 +90,15 @@ No downconversions are allowed by default, but can be enabled using the floating-point math controls described in @ref dev_guide_attributes_fpmath_mode. - +The \f$convert_{dst\_dt}\f$ conversion is guaranteed to be faithfully +rounded but not guaranteed to be correctly rounded (the returned value +is not always the closest one but one of the two closest representable +value). In particular, some hardware platforms have no direct +conversion instructions from f32 data type to low-precision data types +such as fp8 or fp4, and will perform conversion through an +intermediate data type (for example f16 or bf16), which may result in +[double +rounding](https://en.wikipedia.org/wiki/Rounding#Double_rounding). ### Rounding mode and denormal handling @@ -111,8 +113,11 @@ the floating-point environment can control: @note For CPU devices, the default floating-point environment is defined by - the C and C++ standards in the fenv.h header. Rounding mode can be - changed globally using the fesetround() C function. + the C and C++ standards in the following header: +~~~cpp +#include +~~~ + Rounding mode can be changed globally using the `fesetround()` C function. @note Most DNN applications do not require precise computations with denormal @@ -164,7 +169,8 @@ types that oneDNN recognizes. | bf16 | Intel DL Boost with bfloat16 support | | f16 | Intel AVX512-FP16 | | boolean | Intel AVX2 | -| f8\_e5m2, f8\_e4m3 | TBA. | +| f8\_e5m2, f8\_e4m3 | Intel AVX512-FP16 | +| f4\_e2m1, f4\_e3m0 | TBA | @note See @ref dev_guide_int8_computations in the Developer Guide for additional @@ -205,30 +211,33 @@ library: * Intel(R) Data Center GPU Flex Series (formerly Arctic Sound) * Xe-HPC (accelerated f16, bf16, u8, and s8 support via DPAS and f64 support via MAD) * Intel(R) Data Center GPU Max Series (formerly Ponte Vecchio) + * Xe2-LPG + * Intel(R) Graphics for Intel(R) Core(TM) Ultra processors (Series 2) (formerly Lunar Lake) + * Xe2-HPG + * Intel(R) Arc(TM) B-Series Graphics (formerly Battlemage) The following table indicates the data types with performant compute primitives for each uArch supported by oneDNN. Unless otherwise noted, all data types have reference support on all architectures. -| uArch | Supported Data types | -|:-------|:-------------------------------------------------| -| Xe-LP | f32, f16, s8, u8 | -| Xe-HPG | f32, f16, bf16, s8, u8 | -| Xe-HPC | f64, f32, bf16, f16, s8, u8 | -| TBA | f64, f32, bf16, f16, s8, u8, f8\_e5m2, f8\_e4m3 | +| uArch | Supported Data types | +|:--------|:--------------------------------------------------------------------| +| Xe-LP | f32, f16, s8, u8 | +| Xe-HPG | f32, f16, bf16, s8, u8 | +| Xe-HPC | f64, f32, bf16, f16, s8, u8 | +| Xe2-LPG | f64, f32, bf16, f16, s8, u8 | +| Xe2-HPG | f64, f32, bf16, f16, s8, u8 | +| TBA | f64, f32, bf16, f16, s8, u8, f8\_e5m2, f8\_e4m3, f4\_e2m1, f4\_e3m0 | + @note f64 configurations are only supported on GPU engines with HW capability for double-precision floating-point. @note - f8\_e5m2 compute operations have limited performance through upconversion on - Xe-HPC. + f8\_e5m2 and f8\_e4m3 compute operations have limited performance through upconversion on + Xe-HPC and Xe2 GPUs. @note f16 operations may be faster with f16 accumulation on GPU architectures older than Xe-HPC. Newer architectures accumulate to f32. - -@note - Boolean is only supported by the oneDNN graph API when the graph compiler - backend is enabled. The graph compiler backend only supports the CPU engine. diff --git a/doc/rst/graph_extension.rst b/doc/rst/graph_extension.rst index 5d835236442..cd681e1d9c2 100644 --- a/doc/rst/graph_extension.rst +++ b/doc/rst/graph_extension.rst @@ -6,7 +6,6 @@ Graph Extension graph_programming_model graph_supported_operations - dev_guide_graph_fusion_patterns + graph_fusion_patterns dev_guide_graph_dump dev_guide_constant_tensor_cache - dev_guide_graph_compiler diff --git a/doc/rst/index.rst b/doc/rst/index.rst index 8cdad3d559d..9a0e3d232c1 100644 --- a/doc/rst/index.rst +++ b/doc/rst/index.rst @@ -1,5 +1,5 @@ -oneAPI Deep Neural Network Library Developer Guide and Reference -======================================================================= +oneAPI Deep Neural Network Library (oneDNN) Developer Guide and Reference +========================================================================= .. toctree:: :maxdepth: 1 diff --git a/doc/rst/orphans.rst b/doc/rst/orphans.rst index db05dbc3653..636e6a6d210 100644 --- a/doc/rst/orphans.rst +++ b/doc/rst/orphans.rst @@ -37,6 +37,7 @@ Orphans example_convolution.cpp.rst example_cpu_cnn_training_f32.c.rst example_cpu_matmul_csr.cpp.rst + example_cpu_matmul_coo.cpp.rst example_cpu_matmul_quantization.cpp.rst example_cpu_matmul_weights_compression.cpp.rst example_cpu_rnn_inference_f32.cpp.rst @@ -79,6 +80,7 @@ Orphans page_convolution_example_cpp.rst page_convolution_example_cpp_short.rst page_cpu_matmul_csr_cpp + page_cpu_matmul_coo_cpp page_cpu_matmul_weights_compression_cpp page_cpu_matmul_quantization_cpp.rst page_cpu_matmul_quantization_cpp_short.rst diff --git a/doc/sphinx/_static/favicons.png b/doc/sphinx/_static/favicons.png new file mode 100644 index 00000000000..f450376b19e Binary files /dev/null and b/doc/sphinx/_static/favicons.png differ diff --git a/doc/sphinx/_static/oneAPI-rgb-rev-100.png b/doc/sphinx/_static/oneAPI-rgb-rev-100.png new file mode 100644 index 00000000000..58d2d5c54e5 Binary files /dev/null and b/doc/sphinx/_static/oneAPI-rgb-rev-100.png differ diff --git a/doc/sphinx/conf.py b/doc/sphinx/conf.py index 976056daff7..a512c72b941 100644 --- a/doc/sphinx/conf.py +++ b/doc/sphinx/conf.py @@ -51,7 +51,7 @@ def whereis(binary): # -- Project information ----------------------------------------------------- project = 'oneDNN' -copyright = '2016-2024 Intel Corporation' +copyright = '2016-2025 Intel Corporation' author = '' # -- General configuration --------------------------------------------------- @@ -116,6 +116,8 @@ def whereis(binary): # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] +source_suffix = '.rst' + # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. @@ -135,11 +137,16 @@ def whereis(binary): html_static_path = ['_static'] #html_js_files = [('dnnl.js', {'defer': 'defer'})] +html_logo = '_static/oneAPI-rgb-rev-100.png' +html_favicon = '_static/favicons.png' + html_theme_options = { - "repository_url": "https://github.com/oneapi-src/oneDNN", - "repository_branch": "master", + "repository_url": "https://github.com/uxlfoundation/oneDNN", + "repository_branch": "main", "use_repository_button": True, - "use_download_button": False + "use_download_button": True, + "path_to_docs": "doc", + "use_issues_button": True } mathjax3_config = { diff --git a/doc/ukernel/operations/brgemm.md b/doc/ukernel/operations/brgemm.md index acefec29b2f..7e4520fb032 100644 --- a/doc/ukernel/operations/brgemm.md +++ b/doc/ukernel/operations/brgemm.md @@ -44,14 +44,14 @@ The BRGeMM ukernel supports the following combinations of data-types. Because of hardware restrictions, the BRGeMM ukernel requires a specific data layout. For x86-64 architecture this layout applies to a B matrix. It is -expressed through @ref dnnl::ukernel::pack_type which can be queried by -@ref dnnl::ukernel::brgemm::get_B_pack_type call. If the query returns -@ref dnnl::ukernel::brgemm::pack_type::no_trans, then packing is not required. +expressed through #dnnl::ukernel::pack_type which can be queried by +#dnnl::ukernel::brgemm::get_B_pack_type call. If the query returns +#dnnl::ukernel::pack_type::no_trans, then packing is not required. Otherwise, the user is responsible for packing the data appropriately before -calling @ref dnnl::ukernel::brgemm::execute, either with custom code, or by -using a dedicated set of APIs: @ref dnnl::ukernel::transform::generate for +calling #dnnl::ukernel::brgemm::execute, either with custom code, or by +using a dedicated set of APIs: #dnnl::ukernel::transform::generate for generating a kernel of a transform routine and -@ref dnnl::ukernel::transform::execute to run the generated kernel. +#dnnl::ukernel::transform::execute to run the generated kernel. ## Attributes diff --git a/doc/ukernel/operations/transform.md b/doc/ukernel/operations/transform.md index 1e0ecb742ff..be3f24d8455 100644 --- a/doc/ukernel/operations/transform.md +++ b/doc/ukernel/operations/transform.md @@ -2,16 +2,26 @@ Data transformation {#dev_guide_ukernel_transform} ======================================= > -> [API Reference](@ref dnnl_api_ukernel_brgemm) +> [API Reference](@ref dnnl::ukernel::transform) > ## General -The transform ukernel allows users to convert data from one format to the other, -similar to what reorder primitive provides functionally. +The [BRGeMM ukernel](@ref dev_guide_ukernel_brgemm) might require the B tensor +in a specific memory layout depending on target data types and the machine +architecture. Check the requirement by calling the +[get_B_pack_type()](@ref dnnl::ukernel::brgemm::get_B_pack_type) function. If it +returns the [pack32](@ref dnnl::ukernel::pack_type::pack32) type, it implies +that packing is required, otherwise, packing is not required. + +The transform ukernel allows the conversion of data from the original layout, +which is described as either +[non-transposed](@ref dnnl::ukernel::pack_type::no_trans) or +[transposed](@ref dnnl::ukernel::pack_type::trans) to the layout requested by +the BRGeMM ukernel. + +The only supported output packing type is `pack32`. -The only output data format supported by this routine is packed format, which is -required by B matrices in [BRGeMM ukernel](@ref dev_guide_ukernel_brgemm). This is an out-of-place operation. ## Data Types @@ -34,7 +44,9 @@ No attribute is supported for transform ukernel. ## Implementation limitations -- Destination leading dimension only supported values are: 16, 32, 48, or 64. +- Destination leading dimension, or `out_ld`, must be one of the following + values: `16`, `32`, `48`, or `64`. This is the implementation limitation, + there are no efficient kernels supported for other leading dimension values. ## Examples diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 2d848af454a..56d6a48787b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2016-2024 Intel Corporation +# Copyright 2016-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -57,6 +57,7 @@ file(GLOB_RECURSE headers *.hpp *.h) if(NOT DNNL_EXPERIMENTAL_SPARSE) list(REMOVE_ITEM sources ${CMAKE_CURRENT_SOURCE_DIR}/cpu_matmul_csr.cpp) + list(REMOVE_ITEM sources ${CMAKE_CURRENT_SOURCE_DIR}/cpu_matmul_coo.cpp) list(REMOVE_ITEM sources ${CMAKE_CURRENT_SOURCE_DIR}/cpu_matmul_weights_compression.cpp) endif() @@ -74,7 +75,10 @@ if(DNNL_SYCL_CUDA) ${CMAKE_CURRENT_SOURCE_DIR}/primitives/lstm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives/layer_normalization.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives/reorder.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives/shuffle.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/shuffle.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/group_normalization.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/vanilla_rnn.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/lbr_gru.cpp) endif() # Remove examples for Graph API if graph component is not enabled @@ -90,9 +94,26 @@ if(NOT ONEDNN_BUILD_GRAPH) ${CMAKE_CURRENT_SOURCE_DIR}/graph/mqa.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph/sdpa_stacked_qkv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph/gqa.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/graph/gated_mlp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/graph/gated_mlp_wei_combined.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/graph/gated_mlp_int4.cpp ) endif() +if(DNNL_SYCL_GENERIC) + list(REMOVE_ITEM sources + # XXX: Enable when InnerProduct is implemented + ${CMAKE_CURRENT_SOURCE_DIR}/cnn_inference_f32.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/inner_product.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rnn_training_f32.cpp + # XXX: Enable when Reduction is implemented + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/reduction.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/group_normalization.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/lbr_gru.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/lstm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives/vanilla_rnn.cpp) +endif() + if(DNNL_SYCL_HIP) # Build examples for supported primitives that support required features. set(sources) @@ -183,12 +204,8 @@ foreach(src ${sources}) endif() endforeach() -if (DNNL_INSTALL_MODE STREQUAL "BUNDLE" OR DNNL_INSTALL_MODE STREQUAL "BUNDLE_V2") - if(DNNL_INSTALL_MODE STREQUAL "BUNDLE") - set(BUNDLE_EXAMPLES_DIR "examples") - else() - set(BUNDLE_EXAMPLES_DIR "${CMAKE_INSTALL_DATAROOTDIR}/doc/${LIB_PACKAGE_NAME}/examples") - endif() +if (DNNL_INSTALL_MODE STREQUAL "BUNDLE") + set(BUNDLE_EXAMPLES_DIR "${CMAKE_INSTALL_DATAROOTDIR}/doc/${LIB_PACKAGE_NAME}/examples") configure_file(CMakeLists.txt.in CMakeLists.txt @ONLY) install(FILES diff --git a/examples/CMakeLists.txt.in b/examples/CMakeLists.txt.in index ed408567f85..4ac638719fc 100644 --- a/examples/CMakeLists.txt.in +++ b/examples/CMakeLists.txt.in @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2019-2024 Intel Corporation +# Copyright 2019-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. #=============================================================================== -cmake_minimum_required(VERSION 2.8.12) +cmake_minimum_required(VERSION 3.13) if("${CMAKE_BUILD_TYPE}" STREQUAL "") message(STATUS "CMAKE_BUILD_TYPE is unset, defaulting to Release") @@ -28,18 +28,9 @@ project (DNNL_EXAMPLES) set(DNNL_CPU_RUNTIME "@DNNL_CPU_RUNTIME@") set(DNNL_GPU_RUNTIME "@DNNL_GPU_RUNTIME@") -if(POLICY CMP0015) - cmake_policy(SET CMP0015 NEW) -endif() - -# Use _ROOT env. variable as a prefix -if(POLICY CMP0074) - cmake_policy(SET CMP0074 NEW) -endif() - set(DNNL_INSTALL_MODE "@DNNL_INSTALL_MODE@") set(IS_NEW_DIR_LAYOUT FALSE) -if(DNNL_INSTALL_MODE STREQUAL "BUNDLE_V2") +if(DNNL_INSTALL_MODE STREQUAL "BUNDLE") set(IS_NEW_DIR_LAYOUT TRUE) endif() @@ -84,25 +75,8 @@ if(CMAKE_BASE_NAME MATCHES "^(icx|icpx)$") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-recommended-option -Wno-unknown-warning-option") endif() -function(find_libm var) - # This is to account for the linker cache in OSX11. might work - # with lower than 3.9.4, but was not able to test with anything - # between 2.8 and 3.9. See here for more details: - # https://gitlab.kitware.com/cmake/cmake/-/issues/20863 - if (APPLE AND (${CMAKE_HOST_SYSTEM_VERSION} VERSION_GREATER "20.0.0") - AND (${CMAKE_VERSION} VERSION_LESS "3.9.4")) - message(INFO "Using OSX11 and above with CMAKE older than 3.18 can cause linking issues.") - set(OSX11_AND_OLDER_CMAKE TRUE) - endif() - - if(UNIX AND (NOT (APPLE AND OSX11_AND_OLDER_CMAKE))) - find_library(${var} m REQUIRED) - endif() -endfunction() - - if(UNIX OR MINGW) - find_libm(LIBM) + find_library(LIBM m REQUIRED) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") if(NOT DNNL_WITH_SYCL) @@ -116,22 +90,11 @@ if(UNIX OR MINGW) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") endif() -if(WIN32 AND ${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) +if(${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) add_definitions(/Qpar) add_definitions(/openmp) else() find_package(OpenMP) - #newer version for findOpenMP (>= v. 3.9) - if(CMAKE_VERSION VERSION_LESS "3.9" AND OPENMP_FOUND) - if(${CMAKE_MAJOR_VERSION} VERSION_LESS "3" AND - ${CMAKE_CXX_COMPILER_ID} STREQUAL "Intel") - # Override FindOpenMP flags for Intel Compiler (otherwise deprecated) - set(OpenMP_CXX_FLAGS "-fopenmp") - set(OpenMP_C_FLAGS "-fopenmp") - endif() - set(OpenMP_C_FOUND true) - set(OpenMP_CXX_FOUND true) - endif() if(OpenMP_C_FOUND) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") endif() @@ -190,7 +153,7 @@ elseif(APPLE) set(CTESTCONFIG_PATH "${DNNLROOT}/lib") endif() -# Common configuration for tests / test cases on Windows and Apple +# Configuration for tests / test cases on Windows function(maybe_configure_test name kind) if(WIN32) string(REPLACE ";" "\;" PATH "${CTESTCONFIG_PATH};$ENV{PATH}") @@ -198,14 +161,6 @@ function(maybe_configure_test name kind) if(CMAKE_GENERATOR MATCHES "Visual Studio") configure_file(template.vcxproj.user ${name}.vcxproj.user @ONLY) endif() - elseif(APPLE) - # When LIBRARY_PATH is set (e.g. when using compiler env. scripts) - # cmake may stop passing `rpath` linker option. The hack below adds the - # LIBRARY_PATH to DYLD_LIBRARY_PATH to make the executable find its - # dependencies. - # TODO: the problem may be in older version of cmake (2.8.11), revisit. - set_property(${kind} ${name} PROPERTY ENVIRONMENT - "DYLD_LIBRARY_PATH=${CTESTCONFIG_PATH}:$ENV{LIBRARY_PATH}:$ENV{DYLD_LIBRARY_PATH}") endif() endfunction() diff --git a/examples/bnorm_u8_via_binary_postops.cpp b/examples/bnorm_u8_via_binary_postops.cpp index e72c852239a..eab8bf18635 100644 --- a/examples/bnorm_u8_via_binary_postops.cpp +++ b/examples/bnorm_u8_via_binary_postops.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,9 +46,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void bnorm_u8_via_binary_postops(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -102,12 +99,18 @@ void bnorm_u8_via_binary_postops(dnnl::engine::kind engine_kind) { oscale_data.begin(), oscale_data.end(), []() { return 0.5f; }); // Create descriptors. - auto src_md = memory::desc(src_dims, dt::u8, tag::nhwc); - auto mean_md = memory::desc(params_dims, dt::f32, tag::nhwc); - auto variance_md = memory::desc(params_dims, dt::f32, tag::nhwc); - auto scale_md = memory::desc(params_dims, dt::f32, tag::nhwc); - auto shift_md = memory::desc(params_dims, dt::f32, tag::nhwc); - auto oscale_md = memory::desc(params_dims, dt::f32, tag::nhwc); + auto src_md = memory::desc( + src_dims, memory::data_type::u8, memory::format_tag::nhwc); + auto mean_md = memory::desc( + params_dims, memory::data_type::f32, memory::format_tag::nhwc); + auto variance_md = memory::desc( + params_dims, memory::data_type::f32, memory::format_tag::nhwc); + auto scale_md = memory::desc( + params_dims, memory::data_type::f32, memory::format_tag::nhwc); + auto shift_md = memory::desc( + params_dims, memory::data_type::f32, memory::format_tag::nhwc); + auto oscale_md = memory::desc( + params_dims, memory::data_type::f32, memory::format_tag::nhwc); // Create src memory objects. auto src_mem = memory(src_md, engine); diff --git a/examples/cnn_inference_f32.cpp b/examples/cnn_inference_f32.cpp index 355f0c11561..24c39eca166 100644 --- a/examples/cnn_inference_f32.cpp +++ b/examples/cnn_inference_f32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2022 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,9 +51,6 @@ using namespace dnnl; void simple_net(engine::kind engine_kind, int times = 100) { - using tag = memory::format_tag; - using dt = memory::data_type; - /// Initialize an engine and stream. The last parameter in the call represents /// the index of the engine. /// @snippet cnn_inference_f32.cpp Initialize engine and stream @@ -91,33 +88,43 @@ void simple_net(engine::kind engine_kind, int times = 100) { std::vector conv1_bias(product(conv1_bias_tz)); //[Allocate buffers] - /// Create memory that describes data layout in the buffers. This example uses - /// tag::nchw (batch-channels-height-width) for input data and tag::oihw - /// for weights. + /// Create memory that describes data layout in the buffers. This example + /// uses dnnl::memory::format_tag::nchw (batch-channels-height-width) + /// for input data and dnnl::memory::format_tag::oihw for weights. /// @snippet cnn_inference_f32.cpp Create user memory //[Create user memory] - auto user_src_memory = memory({{conv1_src_tz}, dt::f32, tag::nchw}, eng); + auto user_src_memory = memory( + {{conv1_src_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); write_to_dnnl_memory(user_src.data(), user_src_memory); auto user_weights_memory - = memory({{conv1_weights_tz}, dt::f32, tag::oihw}, eng); + = memory({{conv1_weights_tz}, memory::data_type::f32, + memory::format_tag::oihw}, + eng); write_to_dnnl_memory(conv1_weights.data(), user_weights_memory); - auto conv1_user_bias_memory - = memory({{conv1_bias_tz}, dt::f32, tag::x}, eng); + auto conv1_user_bias_memory = memory( + {{conv1_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv1_bias.data(), conv1_user_bias_memory); //[Create user memory] - /// Create memory descriptors with layout tag::any. The `any` format enables - /// the convolution primitive to choose the data format that will result in - /// best performance based on its input parameters (convolution kernel - /// sizes, strides, padding, and so on). If the resulting format is different - /// from `nchw`, the user data must be transformed to the format required for - /// the convolution (as explained below). + /// Create memory descriptors with layout dnnl::memory::format_tag::any. + /// The `any` format enables the convolution primitive to choose the data + /// format that will result in best performance based on its input + /// parameters (convolution kernel sizes, strides, padding, and so on). + /// If the resulting format is different from `nchw`, the user data must be + /// transformed to the format required for the convolution (as explained + /// below). /// @snippet cnn_inference_f32.cpp Create convolution memory descriptors //[Create convolution memory descriptors] - auto conv1_src_md = memory::desc({conv1_src_tz}, dt::f32, tag::any); - auto conv1_bias_md = memory::desc({conv1_bias_tz}, dt::f32, tag::any); - auto conv1_weights_md = memory::desc({conv1_weights_tz}, dt::f32, tag::any); - auto conv1_dst_md = memory::desc({conv1_dst_tz}, dt::f32, tag::any); + auto conv1_src_md = memory::desc( + {conv1_src_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv1_bias_md = memory::desc( + {conv1_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv1_weights_md = memory::desc({conv1_weights_tz}, + memory::data_type::f32, memory::format_tag::any); + auto conv1_dst_md = memory::desc( + {conv1_dst_tz}, memory::data_type::f32, memory::format_tag::any); //[Create convolution memory descriptors] /// Create a convolution primitive descriptor by specifying engine, @@ -136,9 +143,9 @@ void simple_net(engine::kind engine_kind, int times = 100) { conv1_strides, conv1_padding, conv1_padding); //[Create convolution primitive descriptor] - /// Check whether data and weights formats required by convolution is different - /// from the user format. In case it is different change the layout using - /// reorder primitive. + /// Check whether data and weights formats required by convolution is + /// different from the user format. In case it is different change the + /// layout using reorder primitive. /// @snippet cnn_inference_f32.cpp Reorder data and weights //[Reorder data and weights] auto conv1_src_memory = user_src_memory; @@ -180,7 +187,8 @@ void simple_net(engine::kind engine_kind, int times = 100) { /// Create the relu primitive. For better performance, keep the input data /// format for ReLU (as well as for other operation primitives until another /// convolution or inner product is encountered) the same as the one chosen - /// for convolution. Also note that ReLU is done in-place by using conv1 memory. + /// for convolution. Also note that ReLU is done in-place by using conv1 + /// memory. /// @snippet cnn_inference_f32.cpp Create relu primitive //[Create relu primitive] auto relu1_prim_desc @@ -224,11 +232,12 @@ void simple_net(engine::kind engine_kind, int times = 100) { memory::dims pool_dilation = {0, 0}; memory::dims pool_padding = {0, 0}; - auto pool1_dst_md = memory::desc({pool1_dst_tz}, dt::f32, tag::any); + auto pool1_dst_md = memory::desc( + {pool1_dst_tz}, memory::data_type::f32, memory::format_tag::any); /// For training execution, pooling requires a private workspace memory - /// to perform the backward pass. However, pooling should not use 'workspace' - /// for inference, because this is detrimental to performance. + /// to perform the backward pass. However, pooling should not use + /// 'workspace' for inference, because this is detrimental to performance. /// @snippet cnn_inference_f32.cpp Create pooling primitive /// /// The example continues to create more layers according @@ -260,17 +269,24 @@ void simple_net(engine::kind engine_kind, int times = 100) { // create memory for user data auto conv2_user_weights_memory - = memory({{conv2_weights_tz}, dt::f32, tag::goihw}, eng); + = memory({{conv2_weights_tz}, memory::data_type::f32, + memory::format_tag::goihw}, + eng); write_to_dnnl_memory(conv2_weights.data(), conv2_user_weights_memory); - auto conv2_user_bias_memory - = memory({{conv2_bias_tz}, dt::f32, tag::x}, eng); + auto conv2_user_bias_memory = memory( + {{conv2_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv2_bias.data(), conv2_user_bias_memory); // create memory descriptors for convolution data w/ no specified format - auto conv2_src_md = memory::desc({conv2_src_tz}, dt::f32, tag::any); - auto conv2_bias_md = memory::desc({conv2_bias_tz}, dt::f32, tag::any); - auto conv2_weights_md = memory::desc({conv2_weights_tz}, dt::f32, tag::any); - auto conv2_dst_md = memory::desc({conv2_dst_tz}, dt::f32, tag::any); + auto conv2_src_md = memory::desc( + {conv2_src_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv2_bias_md = memory::desc( + {conv2_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv2_weights_md = memory::desc({conv2_weights_tz}, + memory::data_type::f32, memory::format_tag::any); + auto conv2_dst_md = memory::desc( + {conv2_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a convolution auto conv2_prim_desc = convolution_forward::primitive_desc(eng, @@ -348,7 +364,8 @@ void simple_net(engine::kind engine_kind, int times = 100) { memory::dims pool2_dilation = {0, 0}; memory::dims pool2_padding = {0, 0}; - auto pool2_dst_md = memory::desc({pool2_dst_tz}, dt::f32, tag::any); + auto pool2_dst_md = memory::desc( + {pool2_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a pooling auto pool2_pd = pooling_forward::primitive_desc(eng, @@ -377,17 +394,24 @@ void simple_net(engine::kind engine_kind, int times = 100) { // create memory for user data auto conv3_user_weights_memory - = memory({{conv3_weights_tz}, dt::f32, tag::oihw}, eng); + = memory({{conv3_weights_tz}, memory::data_type::f32, + memory::format_tag::oihw}, + eng); write_to_dnnl_memory(conv3_weights.data(), conv3_user_weights_memory); - auto conv3_user_bias_memory - = memory({{conv3_bias_tz}, dt::f32, tag::x}, eng); + auto conv3_user_bias_memory = memory( + {{conv3_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv3_bias.data(), conv3_user_bias_memory); // create memory descriptors for convolution data w/ no specified format - auto conv3_src_md = memory::desc({conv3_src_tz}, dt::f32, tag::any); - auto conv3_bias_md = memory::desc({conv3_bias_tz}, dt::f32, tag::any); - auto conv3_weights_md = memory::desc({conv3_weights_tz}, dt::f32, tag::any); - auto conv3_dst_md = memory::desc({conv3_dst_tz}, dt::f32, tag::any); + auto conv3_src_md = memory::desc( + {conv3_src_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv3_bias_md = memory::desc( + {conv3_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv3_weights_md = memory::desc({conv3_weights_tz}, + memory::data_type::f32, memory::format_tag::any); + auto conv3_dst_md = memory::desc( + {conv3_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a convolution auto conv3_prim_desc = convolution_forward::primitive_desc(eng, @@ -450,17 +474,24 @@ void simple_net(engine::kind engine_kind, int times = 100) { // create memory for user data auto conv4_user_weights_memory - = memory({{conv4_weights_tz}, dt::f32, tag::goihw}, eng); + = memory({{conv4_weights_tz}, memory::data_type::f32, + memory::format_tag::goihw}, + eng); write_to_dnnl_memory(conv4_weights.data(), conv4_user_weights_memory); - auto conv4_user_bias_memory - = memory({{conv4_bias_tz}, dt::f32, tag::x}, eng); + auto conv4_user_bias_memory = memory( + {{conv4_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv4_bias.data(), conv4_user_bias_memory); // create memory descriptors for convolution data w/ no specified format - auto conv4_src_md = memory::desc({conv4_src_tz}, dt::f32, tag::any); - auto conv4_bias_md = memory::desc({conv4_bias_tz}, dt::f32, tag::any); - auto conv4_weights_md = memory::desc({conv4_weights_tz}, dt::f32, tag::any); - auto conv4_dst_md = memory::desc({conv4_dst_tz}, dt::f32, tag::any); + auto conv4_src_md = memory::desc( + {conv4_src_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv4_bias_md = memory::desc( + {conv4_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv4_weights_md = memory::desc({conv4_weights_tz}, + memory::data_type::f32, memory::format_tag::any); + auto conv4_dst_md = memory::desc( + {conv4_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a convolution auto conv4_prim_desc = convolution_forward::primitive_desc(eng, @@ -522,17 +553,24 @@ void simple_net(engine::kind engine_kind, int times = 100) { // create memory for user data auto conv5_user_weights_memory - = memory({{conv5_weights_tz}, dt::f32, tag::goihw}, eng); + = memory({{conv5_weights_tz}, memory::data_type::f32, + memory::format_tag::goihw}, + eng); write_to_dnnl_memory(conv5_weights.data(), conv5_user_weights_memory); - auto conv5_user_bias_memory - = memory({{conv5_bias_tz}, dt::f32, tag::x}, eng); + auto conv5_user_bias_memory = memory( + {{conv5_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv5_bias.data(), conv5_user_bias_memory); // create memory descriptors for convolution data w/ no specified format - auto conv5_src_md = memory::desc({conv5_src_tz}, dt::f32, tag::any); - auto conv5_weights_md = memory::desc({conv5_weights_tz}, dt::f32, tag::any); - auto conv5_bias_md = memory::desc({conv5_bias_tz}, dt::f32, tag::any); - auto conv5_dst_md = memory::desc({conv5_dst_tz}, dt::f32, tag::any); + auto conv5_src_md = memory::desc( + {conv5_src_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv5_weights_md = memory::desc({conv5_weights_tz}, + memory::data_type::f32, memory::format_tag::any); + auto conv5_bias_md = memory::desc( + {conv5_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv5_dst_md = memory::desc( + {conv5_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a convolution auto conv5_prim_desc = convolution_forward::primitive_desc(eng, @@ -591,7 +629,8 @@ void simple_net(engine::kind engine_kind, int times = 100) { std::vector pool5_dst(product(pool5_dst_tz)); - auto pool5_dst_md = memory::desc({pool5_dst_tz}, dt::f32, tag::any); + auto pool5_dst_md = memory::desc( + {pool5_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a pooling auto pool5_pd = pooling_forward::primitive_desc(eng, @@ -618,16 +657,24 @@ void simple_net(engine::kind engine_kind, int times = 100) { // create memory for user data auto fc6_user_weights_memory - = memory({{fc6_weights_tz}, dt::f32, tag::oihw}, eng); + = memory({{fc6_weights_tz}, memory::data_type::f32, + memory::format_tag::oihw}, + eng); write_to_dnnl_memory(fc6_weights.data(), fc6_user_weights_memory); - auto fc6_user_bias_memory = memory({{fc6_bias_tz}, dt::f32, tag::x}, eng); + auto fc6_user_bias_memory = memory( + {{fc6_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(fc6_bias.data(), fc6_user_bias_memory); // create memory descriptors for convolution data w/ no specified format - auto fc6_src_md = memory::desc({fc6_src_tz}, dt::f32, tag::any); - auto fc6_bias_md = memory::desc({fc6_bias_tz}, dt::f32, tag::any); - auto fc6_weights_md = memory::desc({fc6_weights_tz}, dt::f32, tag::any); - auto fc6_dst_md = memory::desc({fc6_dst_tz}, dt::f32, tag::any); + auto fc6_src_md = memory::desc( + {fc6_src_tz}, memory::data_type::f32, memory::format_tag::any); + auto fc6_bias_md = memory::desc( + {fc6_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto fc6_weights_md = memory::desc( + {fc6_weights_tz}, memory::data_type::f32, memory::format_tag::any); + auto fc6_dst_md = memory::desc( + {fc6_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a inner_product auto fc6_prim_desc = inner_product_forward::primitive_desc(eng, @@ -667,17 +714,23 @@ void simple_net(engine::kind engine_kind, int times = 100) { std::vector fc7_bias(product(fc7_bias_tz)); // create memory for user data - auto fc7_user_weights_memory - = memory({{fc7_weights_tz}, dt::f32, tag::nc}, eng); + auto fc7_user_weights_memory = memory( + {{fc7_weights_tz}, memory::data_type::f32, memory::format_tag::nc}, + eng); write_to_dnnl_memory(fc7_weights.data(), fc7_user_weights_memory); - auto fc7_user_bias_memory = memory({{fc7_bias_tz}, dt::f32, tag::x}, eng); + auto fc7_user_bias_memory = memory( + {{fc7_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(fc7_bias.data(), fc7_user_bias_memory); // create memory descriptors for convolution data w/ no specified format - auto fc7_bias_md = memory::desc({fc7_bias_tz}, dt::f32, tag::any); - auto fc7_weights_md = memory::desc({fc7_weights_tz}, dt::f32, tag::any); - auto fc7_dst_md = memory::desc({fc7_dst_tz}, dt::f32, tag::any); + auto fc7_bias_md = memory::desc( + {fc7_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto fc7_weights_md = memory::desc( + {fc7_weights_tz}, memory::data_type::f32, memory::format_tag::any); + auto fc7_dst_md = memory::desc( + {fc7_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a inner_product auto fc7_prim_desc = inner_product_forward::primitive_desc(eng, @@ -709,18 +762,26 @@ void simple_net(engine::kind engine_kind, int times = 100) { std::vector fc8_bias(product(fc8_bias_tz)); // create memory for user data - auto fc8_user_weights_memory - = memory({{fc8_weights_tz}, dt::f32, tag::nc}, eng); + auto fc8_user_weights_memory = memory( + {{fc8_weights_tz}, memory::data_type::f32, memory::format_tag::nc}, + eng); write_to_dnnl_memory(fc8_weights.data(), fc8_user_weights_memory); - auto fc8_user_bias_memory = memory({{fc8_bias_tz}, dt::f32, tag::x}, eng); + auto fc8_user_bias_memory = memory( + {{fc8_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(fc8_bias.data(), fc8_user_bias_memory); - auto user_dst_memory = memory({{fc8_dst_tz}, dt::f32, tag::nc}, eng); + auto user_dst_memory = memory( + {{fc8_dst_tz}, memory::data_type::f32, memory::format_tag::nc}, + eng); write_to_dnnl_memory(user_dst.data(), user_dst_memory); // create memory descriptors for convolution data w/ no specified format - auto fc8_bias_md = memory::desc({fc8_bias_tz}, dt::f32, tag::any); - auto fc8_weights_md = memory::desc({fc8_weights_tz}, dt::f32, tag::any); - auto fc8_dst_md = memory::desc({fc8_dst_tz}, dt::f32, tag::any); + auto fc8_bias_md = memory::desc( + {fc8_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto fc8_weights_md = memory::desc( + {fc8_weights_tz}, memory::data_type::f32, memory::format_tag::any); + auto fc8_dst_md = memory::desc( + {fc8_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a inner_product auto fc8_prim_desc = inner_product_forward::primitive_desc(eng, diff --git a/examples/cnn_inference_int8.cpp b/examples/cnn_inference_int8.cpp index 1565f18ebca..7cfbe9b0d1b 100644 --- a/examples/cnn_inference_int8.cpp +++ b/examples/cnn_inference_int8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,9 +33,6 @@ using namespace dnnl; void simple_net_int8(engine::kind engine_kind) { - using tag = memory::format_tag; - using dt = memory::data_type; - auto eng = engine(engine_kind, 0); stream s(eng); @@ -89,12 +86,18 @@ void simple_net_int8(engine::kind engine_kind) { /// The user data will be in its original 32-bit floating point format. /// @snippet cnn_inference_int8.cpp Allocate buffers //[Allocate buffers] - auto user_src_memory = memory({{conv_src_tz}, dt::f32, tag::nchw}, eng); + auto user_src_memory = memory( + {{conv_src_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); write_to_dnnl_memory(user_src.data(), user_src_memory); auto user_weights_memory - = memory({{conv_weights_tz}, dt::f32, tag::oihw}, eng); + = memory({{conv_weights_tz}, memory::data_type::f32, + memory::format_tag::oihw}, + eng); write_to_dnnl_memory(conv_weights.data(), user_weights_memory); - auto user_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); + auto user_bias_memory = memory( + {{conv_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv_bias.data(), user_bias_memory); //[Allocate buffers] @@ -112,10 +115,14 @@ void simple_net_int8(engine::kind engine_kind) { /// > Bias does not support quantization. /// @snippet cnn_inference_int8.cpp Create convolution memory descriptors //[Create convolution memory descriptors] - auto conv_src_md = memory::desc({conv_src_tz}, dt::u8, tag::any); - auto conv_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any); - auto conv_weights_md = memory::desc({conv_weights_tz}, dt::s8, tag::any); - auto conv_dst_md = memory::desc({conv_dst_tz}, dt::u8, tag::any); + auto conv_src_md = memory::desc( + {conv_src_tz}, memory::data_type::u8, memory::format_tag::any); + auto conv_bias_md = memory::desc( + {conv_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv_weights_md = memory::desc( + {conv_weights_tz}, memory::data_type::s8, memory::format_tag::any); + auto conv_dst_md = memory::desc( + {conv_dst_tz}, memory::data_type::u8, memory::format_tag::any); //[Create convolution memory descriptors] /// Configuring int8-specific parameters in an int8 primitive is done @@ -129,7 +136,8 @@ void simple_net_int8(engine::kind engine_kind) { conv_attr.set_scales_mask(DNNL_ARG_DST, dst_mask); // Prepare dst scales - auto dst_scale_md = memory::desc({1}, dt::f32, tag::x); + auto dst_scale_md + = memory::desc({1}, memory::data_type::f32, memory::format_tag::x); auto dst_scale_memory = memory(dst_scale_md, eng); write_to_dnnl_memory(dst_scales.data(), dst_scale_memory); //[Configure scaling] @@ -194,7 +202,8 @@ void simple_net_int8(engine::kind engine_kind) { auto conv_src_memory = memory(conv_prim_desc.src_desc(), eng); primitive_attr src_attr; src_attr.set_scales_mask(DNNL_ARG_DST, src_mask); - auto src_scale_md = memory::desc({1}, dt::f32, tag::x); + auto src_scale_md + = memory::desc({1}, memory::data_type::f32, memory::format_tag::x); auto src_scale_memory = memory(src_scale_md, eng); write_to_dnnl_memory(src_scales.data(), src_scale_memory); auto src_reorder_pd @@ -208,7 +217,8 @@ void simple_net_int8(engine::kind engine_kind) { auto conv_weights_memory = memory(conv_prim_desc.weights_desc(), eng); primitive_attr weight_attr; weight_attr.set_scales_mask(DNNL_ARG_DST, weight_mask); - auto wei_scale_md = memory::desc({1}, dt::f32, tag::x); + auto wei_scale_md + = memory::desc({1}, memory::data_type::f32, memory::format_tag::x); auto wei_scale_memory = memory(wei_scale_md, eng); write_to_dnnl_memory(weight_scales.data(), wei_scale_memory); auto weight_reorder_pd @@ -251,7 +261,9 @@ void simple_net_int8(engine::kind engine_kind) { /// computation output data. /// @snippet cnn_inference_int8.cpp Dequantize the result ///[Dequantize the result] - auto user_dst_memory = memory({{conv_dst_tz}, dt::f32, tag::nchw}, eng); + auto user_dst_memory = memory( + {{conv_dst_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); write_to_dnnl_memory(user_dst.data(), user_dst_memory); primitive_attr dst_attr; dst_attr.set_scales_mask(DNNL_ARG_SRC, dst_mask); diff --git a/examples/cnn_training_bf16.cpp b/examples/cnn_training_bf16.cpp index 9ef4f8a4d1b..0bcee7201f2 100644 --- a/examples/cnn_training_bf16.cpp +++ b/examples/cnn_training_bf16.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2022 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,9 +38,6 @@ using namespace dnnl; void simple_net(engine::kind engine_kind) { - using tag = memory::format_tag; - using dt = memory::data_type; - auto eng = engine(engine_kind, 0); stream s(eng); @@ -79,27 +76,36 @@ void simple_net(engine::kind engine_kind) { conv_bias[i] = sinf((float)i); // create memory for user data - auto conv_user_src_memory - = memory({{conv_src_tz}, dt::f32, tag::nchw}, eng); + auto conv_user_src_memory = memory( + {{conv_src_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); write_to_dnnl_memory(net_src.data(), conv_user_src_memory); auto conv_user_weights_memory - = memory({{conv_weights_tz}, dt::f32, tag::oihw}, eng); + = memory({{conv_weights_tz}, memory::data_type::f32, + memory::format_tag::oihw}, + eng); write_to_dnnl_memory(conv_weights.data(), conv_user_weights_memory); - auto conv_user_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); + auto conv_user_bias_memory = memory( + {{conv_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv_bias.data(), conv_user_bias_memory); // create memory descriptors for bfloat16 convolution data w/ no specified // format tag(`any`) // tag `any` lets a primitive(convolution in this case) // chose the memory format preferred for best performance. - auto conv_src_md = memory::desc({conv_src_tz}, dt::bf16, tag::any); - auto conv_weights_md = memory::desc({conv_weights_tz}, dt::bf16, tag::any); - auto conv_dst_md = memory::desc({conv_dst_tz}, dt::bf16, tag::any); + auto conv_src_md = memory::desc( + {conv_src_tz}, memory::data_type::bf16, memory::format_tag::any); + auto conv_weights_md = memory::desc({conv_weights_tz}, + memory::data_type::bf16, memory::format_tag::any); + auto conv_dst_md = memory::desc( + {conv_dst_tz}, memory::data_type::bf16, memory::format_tag::any); // here bias data type is set to bf16. // additionally, f32 data type is supported for bf16 convolution. - auto conv_bias_md = memory::desc({conv_bias_tz}, dt::bf16, tag::any); + auto conv_bias_md = memory::desc( + {conv_bias_tz}, memory::data_type::bf16, memory::format_tag::any); // create a convolution primitive descriptor @@ -225,11 +231,13 @@ void simple_net(engine::kind engine_kind) { memory::dims pool_padding = {0, 0}; // create memory for pool dst data in user format - auto pool_user_dst_memory - = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng); + auto pool_user_dst_memory = memory( + {{pool_dst_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); // create pool dst memory descriptor in format any for bfloat16 data type - auto pool_dst_md = memory::desc({pool_dst_tz}, dt::bf16, tag::any); + auto pool_dst_md = memory::desc( + {pool_dst_tz}, memory::data_type::bf16, memory::format_tag::any); // create a pooling primitive descriptor auto pool_pd = pooling_forward::primitive_desc(eng, prop_kind::forward, @@ -269,14 +277,17 @@ void simple_net(engine::kind engine_kind) { net_diff_dst[i] = sinf((float)i); // create memory for user diff dst data stored in float data type - auto pool_user_diff_dst_memory - = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng); + auto pool_user_diff_dst_memory = memory( + {{pool_dst_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); write_to_dnnl_memory(net_diff_dst.data(), pool_user_diff_dst_memory); // Backward pooling // create memory descriptors for pooling - auto pool_diff_src_md = memory::desc({lrn_data_tz}, dt::bf16, tag::any); - auto pool_diff_dst_md = memory::desc({pool_dst_tz}, dt::bf16, tag::any); + auto pool_diff_src_md = memory::desc( + {lrn_data_tz}, memory::data_type::bf16, memory::format_tag::any); + auto pool_diff_dst_md = memory::desc( + {pool_dst_tz}, memory::data_type::bf16, memory::format_tag::any); // backward primitive descriptor needs to hint forward descriptor auto pool_bwd_pd = pooling_backward::primitive_desc(eng, @@ -305,7 +316,8 @@ void simple_net(engine::kind engine_kind) { {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); // Backward lrn - auto lrn_diff_dst_md = memory::desc({lrn_data_tz}, dt::bf16, tag::any); + auto lrn_diff_dst_md = memory::desc( + {lrn_data_tz}, memory::data_type::bf16, memory::format_tag::any); const auto &lrn_diff_src_md = lrn_diff_dst_md; // create backward lrn primitive descriptor @@ -335,8 +347,10 @@ void simple_net(engine::kind engine_kind) { {DNNL_ARG_WORKSPACE, lrn_workspace_memory}}); // Backward relu - auto relu_diff_src_md = memory::desc({relu_data_tz}, dt::bf16, tag::any); - auto relu_diff_dst_md = memory::desc({relu_data_tz}, dt::bf16, tag::any); + auto relu_diff_src_md = memory::desc( + {relu_data_tz}, memory::data_type::bf16, memory::format_tag::any); + auto relu_diff_dst_md = memory::desc( + {relu_data_tz}, memory::data_type::bf16, memory::format_tag::any); auto relu_src_md = conv_pd.dst_desc(); // create backward relu primitive_descriptor @@ -367,14 +381,20 @@ void simple_net(engine::kind engine_kind) { // create user format diff weights and diff bias memory for float data type auto conv_user_diff_weights_memory - = memory({{conv_weights_tz}, dt::f32, tag::nchw}, eng); - auto conv_diff_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); + = memory({{conv_weights_tz}, memory::data_type::f32, + memory::format_tag::nchw}, + eng); + auto conv_diff_bias_memory = memory( + {{conv_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); // create memory descriptors for bfloat16 convolution data - auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::bf16, tag::any); - auto conv_diff_weights_md - = memory::desc({conv_weights_tz}, dt::bf16, tag::any); - auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::bf16, tag::any); + auto conv_bwd_src_md = memory::desc( + {conv_src_tz}, memory::data_type::bf16, memory::format_tag::any); + auto conv_diff_weights_md = memory::desc({conv_weights_tz}, + memory::data_type::bf16, memory::format_tag::any); + auto conv_diff_dst_md = memory::desc( + {conv_dst_tz}, memory::data_type::bf16, memory::format_tag::any); // use diff bias provided by the user auto conv_diff_bias_md = conv_diff_bias_memory.get_desc(); diff --git a/examples/cnn_training_f32.cpp b/examples/cnn_training_f32.cpp index a5aaa5e2b4e..89668569d75 100644 --- a/examples/cnn_training_f32.cpp +++ b/examples/cnn_training_f32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2022 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,9 +35,6 @@ using namespace dnnl; void simple_net(engine::kind engine_kind) { - using tag = memory::format_tag; - using dt = memory::data_type; - auto eng = engine(engine_kind, 0); stream s(eng); @@ -75,23 +72,32 @@ void simple_net(engine::kind engine_kind) { conv_bias[i] = sinf((float)i); // create memory for user data - auto conv_user_src_memory - = memory({{conv_src_tz}, dt::f32, tag::nchw}, eng); + auto conv_user_src_memory = memory( + {{conv_src_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); write_to_dnnl_memory(net_src.data(), conv_user_src_memory); auto conv_user_weights_memory - = memory({{conv_weights_tz}, dt::f32, tag::oihw}, eng); + = memory({{conv_weights_tz}, memory::data_type::f32, + memory::format_tag::oihw}, + eng); write_to_dnnl_memory((void *)conv_weights.data(), conv_user_weights_memory); - auto conv_user_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); + auto conv_user_bias_memory = memory( + {{conv_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv_bias.data(), conv_user_bias_memory); // create memory descriptors for convolution data w/ no specified // format tag(`any`) // tag `any` lets a primitive(convolution in this case) // chose the memory format preferred for best performance. - auto conv_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any); - auto conv_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any); - auto conv_weights_md = memory::desc({conv_weights_tz}, dt::f32, tag::any); - auto conv_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::any); + auto conv_src_md = memory::desc( + {conv_src_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv_bias_md = memory::desc( + {conv_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv_weights_md = memory::desc( + {conv_weights_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv_dst_md = memory::desc( + {conv_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a convolution primitive descriptor auto conv_pd = convolution_forward::primitive_desc(eng, prop_kind::forward, @@ -189,12 +195,14 @@ void simple_net(engine::kind engine_kind) { memory::dims pool_padding = {0, 0}; // create memory for pool dst data in user format - auto pool_user_dst_memory - = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng); + auto pool_user_dst_memory = memory( + {{pool_dst_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); write_to_dnnl_memory(net_dst.data(), pool_user_dst_memory); // create pool dst memory descriptor in format any - auto pool_dst_md = memory::desc({pool_dst_tz}, dt::f32, tag::any); + auto pool_dst_md = memory::desc( + {pool_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create a pooling primitive descriptor auto pool_pd = pooling_forward::primitive_desc(eng, prop_kind::forward, @@ -233,14 +241,17 @@ void simple_net(engine::kind engine_kind) { net_diff_dst[i] = sinf((float)i); // create memory for user diff dst data - auto pool_user_diff_dst_memory - = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng); + auto pool_user_diff_dst_memory = memory( + {{pool_dst_tz}, memory::data_type::f32, memory::format_tag::nchw}, + eng); write_to_dnnl_memory(net_diff_dst.data(), pool_user_diff_dst_memory); // Backward pooling // create memory descriptors for pooling - auto pool_diff_src_md = memory::desc({lrn_data_tz}, dt::f32, tag::any); - auto pool_diff_dst_md = memory::desc({pool_dst_tz}, dt::f32, tag::any); + auto pool_diff_src_md = memory::desc( + {lrn_data_tz}, memory::data_type::f32, memory::format_tag::any); + auto pool_diff_dst_md = memory::desc( + {pool_dst_tz}, memory::data_type::f32, memory::format_tag::any); // backward primitive descriptor needs to hint forward descriptor auto pool_bwd_pd = pooling_backward::primitive_desc(eng, @@ -269,7 +280,8 @@ void simple_net(engine::kind engine_kind) { {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); // Backward lrn - auto lrn_diff_dst_md = memory::desc({lrn_data_tz}, dt::f32, tag::any); + auto lrn_diff_dst_md = memory::desc( + {lrn_data_tz}, memory::data_type::f32, memory::format_tag::any); const auto &lrn_diff_src_md = lrn_diff_dst_md; // create backward lrn primitive descriptor @@ -299,8 +311,10 @@ void simple_net(engine::kind engine_kind) { {DNNL_ARG_WORKSPACE, lrn_workspace_memory}}); // Backward relu - auto relu_diff_src_md = memory::desc({relu_data_tz}, dt::f32, tag::any); - auto relu_diff_dst_md = memory::desc({relu_data_tz}, dt::f32, tag::any); + auto relu_diff_src_md = memory::desc( + {relu_data_tz}, memory::data_type::f32, memory::format_tag::any); + auto relu_diff_dst_md = memory::desc( + {relu_data_tz}, memory::data_type::f32, memory::format_tag::any); auto relu_src_md = conv_pd.dst_desc(); // create backward relu primitive_descriptor @@ -333,18 +347,25 @@ void simple_net(engine::kind engine_kind) { std::vector conv_diff_bias_buffer(product(conv_bias_tz)); auto conv_user_diff_weights_memory - = memory({{conv_weights_tz}, dt::f32, tag::nchw}, eng); + = memory({{conv_weights_tz}, memory::data_type::f32, + memory::format_tag::nchw}, + eng); write_to_dnnl_memory(conv_user_diff_weights_buffer.data(), conv_user_diff_weights_memory); - auto conv_diff_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); + auto conv_diff_bias_memory = memory( + {{conv_bias_tz}, memory::data_type::f32, memory::format_tag::x}, + eng); write_to_dnnl_memory(conv_diff_bias_buffer.data(), conv_diff_bias_memory); // create memory descriptors - auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any); - auto conv_diff_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any); - auto conv_diff_weights_md - = memory::desc({conv_weights_tz}, dt::f32, tag::any); - auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::any); + auto conv_bwd_src_md = memory::desc( + {conv_src_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv_diff_bias_md = memory::desc( + {conv_bias_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv_diff_weights_md = memory::desc( + {conv_weights_tz}, memory::data_type::f32, memory::format_tag::any); + auto conv_diff_dst_md = memory::desc( + {conv_dst_tz}, memory::data_type::f32, memory::format_tag::any); // create backward convolution primitive descriptor auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(eng, diff --git a/examples/cpu_matmul_coo.cpp b/examples/cpu_matmul_coo.cpp new file mode 100644 index 00000000000..e16411015ea --- /dev/null +++ b/examples/cpu_matmul_coo.cpp @@ -0,0 +1,108 @@ +/******************************************************************************* +* Copyright 2024-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @example cpu_matmul_coo.cpp +/// > Annotated version: @ref cpu_matmul_coo_cpp +/// +/// This C++ API example demonstrates how to create and execute a +/// [MatMul](@ref dev_guide_matmul) primitive that uses a source tensor +/// encoded with the COO sparse encoding. +/// +/// @page cpu_matmul_coo_cpp MatMul Primitive Example +/// +/// @include cpu_matmul_coo.cpp + +#include +#include +#include +#include +#include + +#include "dnnl.hpp" +#include "example_utils.hpp" + +using namespace dnnl; + +bool check_result(dnnl::memory dst_mem) { + // clang-format off + const std::vector expected_result = {8.750000, 11.250000, 2.500000, + 6.000000, 2.250000, 3.750000, + 19.000000, 15.500000, 5.250000, + 4.000000, 7.000000, 3.000000}; + // clang-format on + std::vector dst_data(expected_result.size()); + read_from_dnnl_memory(dst_data.data(), dst_mem); + return expected_result == dst_data; +} + +void sparse_matmul() { + dnnl::engine engine(engine::kind::cpu, 0); + + const memory::dim M = 4; + const memory::dim N = 3; + const memory::dim K = 6; + + // A sparse matrix represented in the COO format. + std::vector src_coo_values = {2.5f, 1.5f, 1.5f, 2.5f, 2.0f}; + std::vector src_coo_row_indices = {0, 1, 2, 2, 3}; + std::vector src_coo_col_indices = {0, 2, 0, 5, 1}; + + // clang-format off + std::vector weights_data = {3.5f, 4.5f, 1.0f, + 2.0f, 3.5f, 1.5f, + 4.0f, 1.5f, 2.5f, + 3.5f, 5.5f, 4.5f, + 1.5f, 2.5f, 5.5f, + 5.5f, 3.5f, 1.5f}; + // clang-format on + + const int nnz = static_cast(src_coo_values.size()); + + // Create a memory descriptor for COO format by providing information + // about number of non-zero entries and data types of metadata. + const auto src_coo_md = memory::desc::coo( + {M, K}, memory::data_type::f32, nnz, memory::data_type::s32); + const auto wei_md = memory::desc( + {K, N}, memory::data_type::f32, memory::format_tag::oi); + const auto dst_md = memory::desc( + {M, N}, memory::data_type::f32, memory::format_tag::nc); + + // This memory is created for the given values and metadata of COO format. + memory src_coo_mem(src_coo_md, engine, + {src_coo_values.data(), src_coo_row_indices.data(), + src_coo_col_indices.data()}); + memory wei_mem(wei_md, engine, weights_data.data()); + memory dst_mem(dst_md, engine); + + dnnl::stream stream(engine); + + auto sparse_matmul_pd + = matmul::primitive_desc(engine, src_coo_md, wei_md, dst_md); + auto sparse_matmul_prim = matmul(sparse_matmul_pd); + + std::unordered_map sparse_matmul_args; + sparse_matmul_args.insert({DNNL_ARG_SRC, src_coo_mem}); + sparse_matmul_args.insert({DNNL_ARG_WEIGHTS, wei_mem}); + sparse_matmul_args.insert({DNNL_ARG_DST, dst_mem}); + + sparse_matmul_prim.execute(stream, sparse_matmul_args); + stream.wait(); + if (!check_result(dst_mem)) throw std::runtime_error("Unexpected output."); +} + +int main(int argc, char **argv) { + return handle_example_errors({engine::kind::cpu}, sparse_matmul); +} diff --git a/examples/cpu_matmul_csr.cpp b/examples/cpu_matmul_csr.cpp index 7033f4aef81..e7823b685bd 100644 --- a/examples/cpu_matmul_csr.cpp +++ b/examples/cpu_matmul_csr.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ /// > Annotated version: @ref cpu_matmul_csr_cpp /// /// This C++ API example demonstrates how to create and execute a -/// [MatMul](@ref dev_guide_matmul) primitive that uses a weights tensor +/// [MatMul](@ref dev_guide_matmul) primitive that uses a source tensor /// encoded with the CSR sparse encoding. /// /// @page cpu_matmul_csr_cpp MatMul Primitive Example @@ -36,9 +36,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - bool check_result(dnnl::memory dst_mem) { // clang-format off const std::vector expected_result = {8.750000, 11.250000, 2.500000, @@ -77,10 +74,12 @@ void sparse_matmul() { // Create a memory descriptor for CSR format by providing information // about number of non-zero entries and data types of metadata. - const auto src_csr_md - = memory::desc::csr({M, K}, dt::f32, nnz, dt::s32, dt::s32); - const auto wei_md = memory::desc({K, N}, dt::f32, tag::oi); - const auto dst_md = memory::desc({M, N}, dt::f32, tag::nc); + const auto src_csr_md = memory::desc::csr({M, K}, memory::data_type::f32, + nnz, memory::data_type::s32, memory::data_type::s32); + const auto wei_md = memory::desc( + {K, N}, memory::data_type::f32, memory::format_tag::oi); + const auto dst_md = memory::desc( + {M, N}, memory::data_type::f32, memory::format_tag::nc); // This memory is created for the given values and metadata of CSR format. memory src_csr_mem(src_csr_md, engine, diff --git a/examples/cpu_matmul_weights_compression.cpp b/examples/cpu_matmul_weights_compression.cpp index 1169838b6e5..4bbc772f8c9 100644 --- a/examples/cpu_matmul_weights_compression.cpp +++ b/examples/cpu_matmul_weights_compression.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,9 +37,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void matmul_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. dnnl::engine engine(engine_kind, 0); @@ -79,22 +76,31 @@ void matmul_example(dnnl::engine::kind engine_kind) { const memory::dim nnz = std::count_if(weights_data.begin(), weights_data.end(), [](float v) { return v != 0.0f; }); - auto src_md = memory::desc(src_dims, dt::f32, tag::ab); - auto dst_md = memory::desc(dst_dims, dt::f32, tag::ab); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::ab); + auto dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::ab); auto src_mem = memory(src_md, engine); auto dst_mem = memory(dst_md, engine); - auto user_src_mem = memory({src_dims, dt::f32, tag::ab}, engine); - auto user_weights_mem = memory({weights_dims, dt::f32, tag::ab}, engine); - auto user_dst_mem = memory({dst_dims, dt::f32, tag::ab}, engine); + auto user_src_mem = memory( + {src_dims, memory::data_type::f32, memory::format_tag::ab}, engine); + auto user_weights_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::ab}, + engine); + auto user_dst_mem = memory( + {dst_dims, memory::data_type::f32, memory::format_tag::ab}, engine); write_to_dnnl_memory(src_data.data(), src_mem); write_to_dnnl_memory(weights_data.data(), user_weights_mem); - auto matmul_src_md = memory::desc(src_dims, dt::u8, tag::any); - auto matmul_weights_md = memory::desc::packed(weights_dims, dt::s8, nnz); - auto matmul_dst_md = memory::desc(dst_dims, dt::u8, tag::any); + auto matmul_src_md = memory::desc( + src_dims, memory::data_type::u8, memory::format_tag::any); + auto matmul_weights_md + = memory::desc::packed(weights_dims, memory::data_type::s8, nnz); + auto matmul_dst_md = memory::desc( + dst_dims, memory::data_type::u8, memory::format_tag::any); matmul::primitive_desc matmul_pd; try { diff --git a/examples/example_utils.hpp b/examples/example_utils.hpp index 136dfe6147f..8ff0676dc77 100644 --- a/examples/example_utils.hpp +++ b/examples/example_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,9 @@ * limitations under the License. *******************************************************************************/ +/// @file +/// Examples C++ Utility Functions + #ifndef EXAMPLE_UTILS_HPP #define EXAMPLE_UTILS_HPP @@ -76,7 +79,7 @@ inline void finalize() { #endif } -dnnl::engine::kind validate_engine_kind(dnnl::engine::kind akind) { +inline dnnl::engine::kind validate_engine_kind(dnnl::engine::kind akind) { // Checking if a GPU exists on the machine if (akind == dnnl::engine::kind::gpu) { if (dnnl::engine::get_count(dnnl::engine::kind::gpu) == 0) { @@ -91,6 +94,7 @@ dnnl::engine::kind validate_engine_kind(dnnl::engine::kind akind) { // Exception class to indicate that the example uses a feature that is not // available on the current systems. It is not treated as an error then, but // just notifies a user. +// NOLINTNEXTLINE(readability-identifier-naming) struct example_allows_unimplemented : public std::exception { example_allows_unimplemented(const char *message) noexcept : message(message) {} @@ -104,7 +108,7 @@ inline const char *engine_kind2str_upper(dnnl::engine::kind kind); // Returns `0` on success, `1` or oneDNN error, and `2` on example error. inline int handle_example_errors( std::initializer_list engine_kinds, - std::function example) { + const std::function &example) { int exit_code = 0; try { diff --git a/examples/graph/gated_mlp.cpp b/examples/graph/gated_mlp.cpp new file mode 100644 index 00000000000..fd7547a486c --- /dev/null +++ b/examples/graph/gated_mlp.cpp @@ -0,0 +1,275 @@ +/******************************************************************************* +* Copyright 2024-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_graph.hpp" + +#include "graph_example_utils.hpp" + +using namespace dnnl; + +using namespace dnnl::graph; +using layout_type = logical_tensor::layout_type; +using dim = logical_tensor::dim; +using dims = logical_tensor::dims; + +struct mlp_dims_t { + dim mb; + dim ic; + dim oc; +}; + +static const int min_runs = 4; + +// this is changed from the fill_random() function in matmul_perf.cpp. +void fill_random(std::vector &out) { + static std::vector random_data_f; + constexpr size_t nrand = 1037; + + if (random_data_f.empty()) { + std::mt19937 generator; + std::uniform_real_distribution dist_f(-1.0f, 1.0f); + + random_data_f.resize(nrand); + for (auto &d : random_data_f) + d = dist_f(generator); + } + + for (size_t i = 0; i < out.size(); i += nrand) { + size_t chunk = std::min(nrand, out.size() - i); + std::memcpy(&out[i], random_data_f.data(), chunk * sizeof(float)); + } +} + +const char *get_type_string(logical_tensor::data_type dt) { + const char *type_string = "unknown"; + +#define TYPE_CASE(T) \ + if (dt == logical_tensor::data_type::T) type_string = #T; + TYPE_CASE(f16); + TYPE_CASE(f32); + TYPE_CASE(bf16); +#undef TYPE_CASE + + return type_string; +} + +void print_test_case(logical_tensor::data_type dt, const mlp_dims_t &p) { + std::cout << '[' << std::setw(4) << get_type_string(dt); + std::cout << " mb = " << p.mb << ", ic = " << p.ic << ", oc = " << p.oc; + std::cout << "] " << std::flush; +} + +void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, + const mlp_dims_t &p, double time_limit = 0.) { + const bool quick_test = (time_limit == 0.); + print_test_case(dt, p); + + allocator alloc = create_allocator(ekind); + + // Create execution dnnl::engine. + dnnl::engine eng = make_engine_with_allocator(ekind, 0, alloc); + // Create dnnl::stream. + dnnl::stream strm(eng); + + // input shape + const dims src_sz = {p.mb, p.ic}; + // weight0/weight1 shape: fc_gate and fc_up + const dims wei0_sz = {p.ic, p.oc}; + // hidden shape + const dims hd_sz = {p.mb, p.oc}; + // weight2 shape: fc_down + const dims wei2_sz = {p.oc, p.ic}; + // output shape + const dims out_sz = {p.mb, p.ic}; + + // Incremental IDs used to create logical tensors and operations. + size_t id = 0; + + // fc_gate + auto src = logical_tensor(id++, dt, src_sz, layout_type::strided); + auto wei0 = logical_tensor(id++, dt, wei0_sz, layout_type::strided); + auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto fc_gate = op(id++, op::kind::MatMul, "fc_gate"); + fc_gate.add_inputs({src, wei0}); + fc_gate.add_outputs({out0}); + + // fc_up + auto wei1 = logical_tensor(id++, dt, wei0_sz, layout_type::strided); + auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto fc_up = op(id++, op::kind::MatMul, "fc_up"); + fc_up.add_inputs({src, wei1}); + fc_up.add_outputs({out1}); + + // activation swish: sigmoid + auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid"); + swi_sig.add_inputs({out0}); + swi_sig.add_outputs({out2}); + + // activation swish: multiply + auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply"); + swi_mul.add_inputs({out0, out2}); + swi_mul.add_outputs({out3}); + + // multiplication + auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto mul = op(id++, op::kind::Multiply, "mul"); + mul.add_inputs({out3, out1}); + mul.add_outputs({out4}); + + // fc_down + auto wei2 = logical_tensor(id++, dt, wei2_sz, layout_type::strided); + auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided); + auto fc_down = op(id++, op::kind::MatMul, "fc_down"); + fc_down.add_inputs({out4, wei2}); + fc_down.add_outputs({dst}); + + // Construct a gated mlp graph with engine kind and operations. + dnnl::graph::graph mlp(ekind); + mlp.add_op(fc_gate); + mlp.add_op(fc_up); + mlp.add_op(swi_sig); + mlp.add_op(swi_mul); + mlp.add_op(mul); + mlp.add_op(fc_down); + mlp.finalize(); + + // Get partitions from the mlp graph. + std::vector partitions = mlp.get_partitions(); + // This is just for oneDNN testing purpose. + if (partitions.size() != 1) { + std::cout << "unsupported mlp" << std::endl; + return; + } + + // Compile the partition with inputs, outputs, and an engine. + compiled_partition cp + = partitions[0].compile({src, wei0, wei1, wei2}, {dst}, eng); + + // Create tensor objects + auto ts_src = tensor(src, eng); + auto ts_wei0 = tensor(wei0, eng); + auto ts_wei1 = tensor(wei1, eng); + auto ts_wei2 = tensor(wei2, eng); + auto ts_dst = tensor(dst, eng); + + // Allocate user data. + std::vector src_data(product(src_sz)); + std::vector wei0_data(product(wei0_sz)); + std::vector wei1_data(product(wei0_sz)); + std::vector wei2_data(product(wei2_sz)); + + fill_random(src_data); + fill_random(wei0_data); + fill_random(wei1_data); + fill_random(wei2_data); + + // Write data to tensor object's handle. + write_to_dnnl_tensor(src_data.data(), ts_src); + write_to_dnnl_tensor(wei0_data.data(), ts_wei0); + write_to_dnnl_tensor(wei1_data.data(), ts_wei1); + write_to_dnnl_tensor(wei2_data.data(), ts_wei2); + + // Warmup run. + // Execute the compiled partition of mqa. + cp.execute(strm, {ts_src, ts_wei0, ts_wei1, ts_wei2}, {ts_dst}); + + // Wait for the computation to finish. + strm.wait(); + + // First run. + auto start_first = std::chrono::steady_clock::now(); + cp.execute(strm, {ts_src, ts_wei0, ts_wei1, ts_wei2}, {ts_dst}); + strm.wait(); + auto end_first = std::chrono::steady_clock::now(); + std::chrono::duration dur_first + = end_first - start_first; + + if (quick_test) return; + + // Timing runs. + const int runs = std::max(min_runs, int(time_limit / dur_first.count())); + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i <= runs; i++) { + cp.execute(strm, {ts_src, ts_wei0, ts_wei1, ts_wei2}, {ts_dst}); + } + strm.wait(); + auto end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - start; + + // Display the results. + double avg_time = (duration.count() - dur_first.count()) / runs; + std::cout << "graph runs: " << runs + 1 << "; "; + std::cout << "avg_time: " << avg_time << " ms" << std::endl; +} + +void bad_args() { + std::cerr << "Usage: graph-gated-mlp-cpp [cpu|gpu]\n" + " graph-gated-mlp-cpp [cpu|gpu] \n\n"; + throw std::invalid_argument("Incorrect input arguments."); +} + +void bench(engine::kind ekind, dnnl_data_type_t dt, const mlp_dims_t &p, + double time_limit = 0.) { + try { + bench_gated_mlp(ekind, static_cast(dt), p, + time_limit); + get_mem_pool().clear(); + } catch (dnnl::error &e) { + // Catch and report unimplemented cases. + if (e.status == dnnl_unimplemented) { + std::cout << "unsupported mlp" << std::endl; + } else + throw; + } +} + +void mlp_perf(engine::kind ekind, int argc, char **argv) { + // default testing parameters + mlp_dims_t params = {1, 4096, 14336}; + + if (argc > 2) { + if (argc == 5) { + params.mb = std::atoi(argv[2]); + params.ic = std::atoi(argv[3]); + params.oc = std::atoi(argv[4]); + } else { + bad_args(); + } + + if (params.mb <= 0 || params.ic <= 0 || params.oc <= 0) { bad_args(); } + } + + bench(ekind, dnnl_f32, params, 2000.0 /*ms*/); + bench(ekind, dnnl_bf16, params, 2000.0 /*ms*/); + bench(ekind, dnnl_f16, params, 2000.0 /*ms*/); +} + +int main(int argc, char **argv) { + return handle_example_errors( + mlp_perf, parse_engine_kind(argc, argv, 3), argc, argv); +} diff --git a/examples/graph/gated_mlp_int4.cpp b/examples/graph/gated_mlp_int4.cpp new file mode 100644 index 00000000000..2910dba8712 --- /dev/null +++ b/examples/graph/gated_mlp_int4.cpp @@ -0,0 +1,356 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_graph.hpp" + +#include "graph_example_utils.hpp" + +using namespace dnnl; + +using namespace dnnl::graph; +using data_type = logical_tensor::data_type; +using layout_type = logical_tensor::layout_type; +using dim = logical_tensor::dim; +using dims = logical_tensor::dims; + +struct mlp_dims_t { + dim mb; + dim ic; + dim oc; + dim gr; // group size for int4 group quantization +}; + +static const int min_runs = 4; + +// this is changed from the fill_random() function in matmul_perf.cpp. +void fill_random(std::vector &out) { + static std::vector random_data_f; + constexpr size_t nrand = 1037; + + if (random_data_f.empty()) { + std::mt19937 generator; + std::uniform_real_distribution dist_f(-1.0f, 1.0f); + + random_data_f.resize(nrand); + for (auto &d : random_data_f) + d = dist_f(generator); + } + + for (size_t i = 0; i < out.size(); i += nrand) { + size_t chunk = std::min(nrand, out.size() - i); + std::memcpy(&out[i], random_data_f.data(), chunk * sizeof(float)); + } +} + +const char *get_type_string(logical_tensor::data_type dt) { + const char *type_string = "unknown"; + +#define TYPE_CASE(T) \ + if (dt == logical_tensor::data_type::T) type_string = #T; + TYPE_CASE(f16); + TYPE_CASE(f32); + TYPE_CASE(bf16); +#undef TYPE_CASE + + return type_string; +} + +void print_test_case(logical_tensor::data_type dt, const mlp_dims_t &p) { + std::cout << '[' << std::setw(4) << get_type_string(dt); + std::cout << " mb = " << p.mb << ", ic = " << p.ic << ", oc = " << p.oc + << ", group size = " << p.gr; + std::cout << "] " << std::flush; +} + +void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, + const mlp_dims_t &p, double time_limit = 0.) { + const bool quick_test = (time_limit == 0.); + print_test_case(dt, p); + + // input shape + const dims src_sz = {p.mb, p.ic}; + // weight0/weight1 shape: fc_gate and fc_up + const dims wei0_sz = {p.ic, p.oc}; + const dims wei0_scales_sz = {p.ic, p.oc / p.gr}; + // hidden shape + const dims hd_sz = {p.mb, p.oc}; + // weight2 shape: fc_down + const dims wei2_sz = {p.oc, p.ic}; + const dims wei2_scales_sz = {p.oc, p.ic / p.gr}; + // output shape + const dims out_sz = {p.mb, p.ic}; + + allocator alloc = create_allocator(ekind); + + // Create execution dnnl::engine. + dnnl::engine eng = make_engine_with_allocator(ekind, 0, alloc); + // Create dnnl::stream. + dnnl::stream strm(eng); + + // Incremental IDs used to create logical tensors and operations. + size_t id = 0; + + // dequantize for fc_gate weights + auto wei0_int4 = logical_tensor( + id++, data_type::u4, wei0_sz, layout_type::strided); + auto wei0_scales + = logical_tensor(id++, dt, wei0_scales_sz, layout_type::strided); + auto wei0_zps = logical_tensor( + id++, data_type::u8, wei0_scales_sz, layout_type::strided); + auto wei0_dt = logical_tensor(id++, dt, wei0_sz, layout_type::strided); + auto deq_gate = op(id++, op::kind::DynamicDequantize, "deq_gate"); + deq_gate.set_attr(op::attr::qtype, "per_group"); + deq_gate.set_attr(op::attr::group_shape, {1, p.gr}); + deq_gate.set_attr(op::attr::axis, -1); + deq_gate.add_inputs({wei0_int4, wei0_scales, wei0_zps}); + deq_gate.add_outputs({wei0_dt}); + + // fc_gate + auto src = logical_tensor(id++, dt, src_sz, layout_type::strided); + auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto fc_gate = op(id++, op::kind::MatMul, "fc_gate"); + fc_gate.add_inputs({src, wei0_dt}); + fc_gate.add_outputs({out0}); + + // dequantize for fc_up weights + auto wei1_int4 = logical_tensor( + id++, data_type::u4, wei0_sz, layout_type::strided); + auto wei1_scales + = logical_tensor(id++, dt, wei0_scales_sz, layout_type::strided); + auto wei1_zps = logical_tensor( + id++, data_type::u8, wei0_scales_sz, layout_type::strided); + auto wei1_dt = logical_tensor(id++, dt, wei0_sz, layout_type::strided); + auto deq_up = op(id++, op::kind::DynamicDequantize, "deq_up"); + deq_up.set_attr(op::attr::qtype, "per_group"); + deq_up.set_attr(op::attr::group_shape, {1, p.gr}); + deq_up.set_attr(op::attr::axis, -1); + deq_up.add_inputs({wei1_int4, wei1_scales, wei1_zps}); + deq_up.add_outputs({wei1_dt}); + + // fc_up + auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto fc_up = op(id++, op::kind::MatMul, "fc_up"); + fc_up.add_inputs({src, wei1_dt}); + fc_up.add_outputs({out1}); + + // activation swish: sigmoid + auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid"); + swi_sig.add_inputs({out0}); + swi_sig.add_outputs({out2}); + + // activation swish: multiply + auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply"); + swi_mul.add_inputs({out0, out2}); + swi_mul.add_outputs({out3}); + + // multiplication + auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto mul = op(id++, op::kind::Multiply, "mul"); + mul.add_inputs({out3, out1}); + mul.add_outputs({out4}); + + // dequantize for fc_down weights + auto wei2_int4 = logical_tensor( + id++, data_type::u4, wei2_sz, layout_type::strided); + auto wei2_scales + = logical_tensor(id++, dt, wei2_scales_sz, layout_type::strided); + auto wei2_zps = logical_tensor( + id++, data_type::u8, wei2_scales_sz, layout_type::strided); + auto wei2_dt = logical_tensor(id++, dt, wei2_sz, layout_type::strided); + auto deq_down = op(id++, op::kind::DynamicDequantize, "deq_down"); + deq_down.set_attr(op::attr::qtype, "per_group"); + deq_down.set_attr(op::attr::group_shape, {1, p.gr}); + deq_down.set_attr(op::attr::axis, -1); + deq_down.add_inputs({wei2_int4, wei2_scales, wei2_zps}); + deq_down.add_outputs({wei2_dt}); + + // fc_down + auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided); + auto fc_down = op(id++, op::kind::MatMul, "fc_down"); + fc_down.add_inputs({out4, wei2_dt}); + fc_down.add_outputs({dst}); + + // Construct a gated mlp graph with engine kind and operations. + dnnl::graph::graph mlp(ekind); + mlp.set_fpmath_mode(fpmath_mode::strict, true); + mlp.add_op(deq_gate); + mlp.add_op(deq_up); + mlp.add_op(fc_gate); + mlp.add_op(fc_up); + mlp.add_op(swi_sig); + mlp.add_op(swi_mul); + mlp.add_op(mul); + mlp.add_op(deq_down); + mlp.add_op(fc_down); + mlp.finalize(); + + // Get partitions from the mlp graph. + std::vector partitions = mlp.get_partitions(); + // This is just for oneDNN testing purpose. + if (partitions.size() != 1) { + std::cout << "unsupported mlp" << std::endl; + return; + } + + // Compile the partition with inputs, outputs, and an engine. + compiled_partition cp = partitions[0].compile( + {src, wei0_int4, wei0_scales, wei0_zps, wei1_int4, wei1_scales, + wei1_zps, wei2_int4, wei2_scales, wei2_zps}, + {dst}, eng); + + // Create tensor objects + auto ts_src = tensor(src, eng); + auto ts_wei0 = tensor(wei0_int4, eng); + auto ts_wei0_scales = tensor(wei0_scales, eng); + auto ts_wei0_zps = tensor(wei0_zps, eng); + auto ts_wei1 = tensor(wei1_int4, eng); + auto ts_wei1_scales = tensor(wei1_scales, eng); + auto ts_wei1_zps = tensor(wei1_zps, eng); + auto ts_wei2 = tensor(wei2_int4, eng); + auto ts_wei2_scales = tensor(wei2_scales, eng); + auto ts_wei2_zps = tensor(wei2_zps, eng); + auto ts_dst = tensor(dst, eng); + + // Allocate user data. + std::vector src_data(product(src_sz)); + std::vector wei0_data(product(wei0_sz)); + std::vector wei1_data(product(wei0_sz)); + std::vector wei2_data(product(wei2_sz)); + + fill_random(src_data); + fill_random(wei0_data); + fill_random(wei1_data); + fill_random(wei2_data); + + // Write data to tensor object's handle. + write_to_dnnl_tensor(src_data.data(), ts_src); + write_to_dnnl_tensor(wei0_data.data(), ts_wei0); + write_to_dnnl_tensor(wei1_data.data(), ts_wei1); + write_to_dnnl_tensor(wei2_data.data(), ts_wei2); + + // Warmup run. + // Execute the compiled partition of mlp. TODO: initialize the scales and zps. + cp.execute(strm, + {ts_src, ts_wei0, ts_wei0_scales, ts_wei0_zps, ts_wei1, + ts_wei1_scales, ts_wei1_zps, ts_wei2, ts_wei2_scales, + ts_wei2_zps}, + {ts_dst}); + + // Wait for the computation to finish. + strm.wait(); + + // First run. + auto start_first = std::chrono::steady_clock::now(); + cp.execute(strm, + {ts_src, ts_wei0, ts_wei0_scales, ts_wei0_zps, ts_wei1, + ts_wei1_scales, ts_wei1_zps, ts_wei2, ts_wei2_scales, + ts_wei2_zps}, + {ts_dst}); + strm.wait(); + auto end_first = std::chrono::steady_clock::now(); + std::chrono::duration dur_first + = end_first - start_first; + + if (quick_test) return; + + // Timing runs. + const int runs = std::max(min_runs, int(time_limit / dur_first.count())); + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i <= runs; i++) { + cp.execute(strm, + {ts_src, ts_wei0, ts_wei0_scales, ts_wei0_zps, ts_wei1, + ts_wei1_scales, ts_wei1_zps, ts_wei2, ts_wei2_scales, + ts_wei2_zps}, + {ts_dst}); + } + strm.wait(); + auto end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - start; + + // Display the results. + double avg_time = (duration.count() - dur_first.count()) / runs; + std::cout << "graph runs: " << runs + 1 << "; "; + std::cout << "avg_time: " << avg_time << " ms" << std::endl; +} + +void bad_args() { + std::cerr << "Usage: graph-gated-mlp-int4-cpp [cpu|gpu]\n" + " graph-gated-mlp-int4-cpp [cpu|gpu] " + "\n\n"; + throw std::invalid_argument("Incorrect input arguments."); +} + +void bench(engine::kind ekind, dnnl_data_type_t dt, const mlp_dims_t &p, + double time_limit = 0.) { + try { + bench_gated_mlp(ekind, static_cast(dt), p, + time_limit); + get_mem_pool().clear(); + } catch (dnnl::error &e) { + // Catch and report unimplemented cases. + if (e.status == dnnl_unimplemented) { + std::cout << "unsupported mlp" << std::endl; + } else + throw; + } +} + +void mlp_perf(engine::kind ekind, int argc, char **argv) { + // default testing parameters + mlp_dims_t params = {1, 4096, 14336, 128}; + + if (argc > 2) { + if (argc == 6) { + params.mb = std::atoi(argv[2]); + params.ic = std::atoi(argv[3]); + params.oc = std::atoi(argv[4]); + params.gr = std::atoi(argv[5]); + } else { + bad_args(); + } + + if (params.mb <= 0 || params.ic <= 0 || params.oc <= 0 + || params.gr <= 0) { + bad_args(); + } + + if (params.ic < params.gr || params.oc < params.gr + || params.ic % params.gr != 0 || params.oc % params.gr != 0) { + bad_args(); + } + } + + bench(ekind, dnnl_f32, params, 2000.0 /*ms*/); + bench(ekind, dnnl_bf16, params, 2000.0 /*ms*/); + bench(ekind, dnnl_f16, params, 2000.0 /*ms*/); +} + +int main(int argc, char **argv) { + return handle_example_errors( + mlp_perf, parse_engine_kind(argc, argv, 4), argc, argv); +} diff --git a/examples/graph/gated_mlp_wei_combined.cpp b/examples/graph/gated_mlp_wei_combined.cpp new file mode 100644 index 00000000000..2d3a6c88071 --- /dev/null +++ b/examples/graph/gated_mlp_wei_combined.cpp @@ -0,0 +1,300 @@ +/******************************************************************************* +* Copyright 2024-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_graph.hpp" + +#include "graph_example_utils.hpp" + +using namespace dnnl; + +using namespace dnnl::graph; +using layout_type = logical_tensor::layout_type; +using dim = logical_tensor::dim; +using dims = logical_tensor::dims; + +struct mlp_dims_t { + dim mb; + dim ic; + dim oc; +}; + +static const int min_runs = 4; + +// this is changed from the fill_random() function in matmul_perf.cpp. +void fill_random(std::vector &out) { + static std::vector random_data_f; + constexpr size_t nrand = 1037; + + if (random_data_f.empty()) { + std::mt19937 generator; + std::uniform_real_distribution dist_f(-1.0f, 1.0f); + + random_data_f.resize(nrand); + for (auto &d : random_data_f) + d = dist_f(generator); + } + + for (size_t i = 0; i < out.size(); i += nrand) { + size_t chunk = std::min(nrand, out.size() - i); + std::memcpy(&out[i], random_data_f.data(), chunk * sizeof(float)); + } +} + +const char *get_type_string(logical_tensor::data_type dt) { + const char *type_string = "unknown"; + +#define TYPE_CASE(T) \ + if (dt == logical_tensor::data_type::T) type_string = #T; + TYPE_CASE(f16); + TYPE_CASE(f32); + TYPE_CASE(bf16); +#undef TYPE_CASE + + return type_string; +} + +size_t size_of(logical_tensor::data_type dt) { + // This example only supports f32, bf16, and f16. + switch (dt) { + case logical_tensor::data_type::f32: return 4; + case logical_tensor::data_type::bf16: + case logical_tensor::data_type::f16: return 2; + default: assert(!"unknown data_type"); + } + + return (size_t)-1; /* not supposed to be reachable */ +} + +void print_test_case(logical_tensor::data_type dt, const mlp_dims_t &p) { + std::cout << '[' << std::setw(4) << get_type_string(dt); + std::cout << " mb = " << p.mb << ", ic = " << p.ic << ", oc = " << p.oc; + std::cout << "] " << std::flush; +} + +void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, + const mlp_dims_t &p, double time_limit = 0.) { + const bool quick_test = (time_limit == 0.); + print_test_case(dt, p); + + allocator alloc = create_allocator(ekind); + + // Create execution dnnl::engine. + dnnl::engine eng = make_engine_with_allocator(ekind, 0, alloc); + // Create dnnl::stream. + dnnl::stream strm(eng); + + // input shape + const dims src_sz = {p.mb, p.ic}; + // weight0/weight1 shape: fc_gate and fc_up + const dims wei0_sz = {p.ic, p.oc}; + // hidden shape + const dims hd_sz = {p.mb, p.oc}; + // weight2 shape: fc_down + const dims wei2_sz = {p.oc, p.ic}; + // output shape + const dims out_sz = {p.mb, p.ic}; + + // Combined wei0 and wei1 together into shape (ic, 2 * oc), assuming the + // first part is wei0 for fc_gate and the second part is wei1 for fc_up. + const dims combined_wei0_sz = {p.ic, 2 * p.oc}; + const dims combined_wei0_st = {2 * p.oc, 1}; + + // Incremental IDs used to create logical tensors and operations. + size_t id = 0; + + // This logical tensor is not part of the graph but is used to generate the + // big chunk of device memory which should be already there in real user + // application or framework. + auto combined_wei0 + = logical_tensor(id++, dt, combined_wei0_sz, layout_type::strided); + + // fc_gate: wei0 is non-contiguous now. + auto src = logical_tensor(id++, dt, src_sz, layout_type::strided); + auto wei0 = logical_tensor(id++, dt, wei0_sz, combined_wei0_st); + auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto fc_gate = op(id++, op::kind::MatMul, "fc_gate"); + fc_gate.add_inputs({src, wei0}); + fc_gate.add_outputs({out0}); + + // fc_up: wei1 is non-contiguous now. + auto wei1 = logical_tensor(id++, dt, wei0_sz, combined_wei0_st); + auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto fc_up = op(id++, op::kind::MatMul, "fc_up"); + fc_up.add_inputs({src, wei1}); + fc_up.add_outputs({out1}); + + // activation swish: sigmoid + auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid"); + swi_sig.add_inputs({out0}); + swi_sig.add_outputs({out2}); + + // activation swish: multiply + auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply"); + swi_mul.add_inputs({out0, out2}); + swi_mul.add_outputs({out3}); + + // multiplication + auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto mul = op(id++, op::kind::Multiply, "mul"); + mul.add_inputs({out3, out1}); + mul.add_outputs({out4}); + + // fc_down + auto wei2 = logical_tensor(id++, dt, wei2_sz, layout_type::strided); + auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided); + auto fc_down = op(id++, op::kind::MatMul, "fc_down"); + fc_down.add_inputs({out4, wei2}); + fc_down.add_outputs({dst}); + + // Construct a gated mlp graph with engine kind and operations. + dnnl::graph::graph mlp(ekind); + mlp.add_op(fc_gate); + mlp.add_op(fc_up); + mlp.add_op(swi_sig); + mlp.add_op(swi_mul); + mlp.add_op(mul); + mlp.add_op(fc_down); + mlp.finalize(); + + // Get partitions from the mlp graph. + std::vector partitions = mlp.get_partitions(); + // This is just for oneDNN testing purpose. + if (partitions.size() != 1) { + std::cout << "unsupported mlp" << std::endl; + return; + } + + // Compile the partition with inputs, outputs, and an engine. + compiled_partition cp + = partitions[0].compile({src, wei0, wei1, wei2}, {dst}, eng); + + // Create tensor objects + auto ts_src = tensor(src, eng); + auto ts_combined_wei0 = tensor(combined_wei0, eng); + auto ts_wei2 = tensor(wei2, eng); + auto ts_dst = tensor(dst, eng); + + // Allocate user data. + std::vector src_data(product(src_sz)); + std::vector combined_wei0_data(product(combined_wei0_sz)); + std::vector wei2_data(product(wei2_sz)); + + fill_random(src_data); + fill_random(combined_wei0_data); + fill_random(wei2_data); + + // Write data to tensor object's handle. + write_to_dnnl_tensor(src_data.data(), ts_src); + write_to_dnnl_tensor(combined_wei0_data.data(), ts_combined_wei0); + write_to_dnnl_tensor(wei2_data.data(), ts_wei2); + + // create ts_wei0, ts_wei1 from the data handle of combined_wei0 and offsets. + char *handle = reinterpret_cast(ts_combined_wei0.get_data_handle()); + auto ts_wei0 = tensor(wei0, eng, handle); + auto ts_wei1 = tensor(wei1, eng, handle + p.oc * size_of(dt)); + + // Warmup run. + // Execute the compiled partition of mqa. + cp.execute(strm, {ts_src, ts_wei0, ts_wei1, ts_wei2}, {ts_dst}); + + // Wait for the computation to finish. + strm.wait(); + + // First run. + auto start_first = std::chrono::steady_clock::now(); + cp.execute(strm, {ts_src, ts_wei0, ts_wei1, ts_wei2}, {ts_dst}); + strm.wait(); + auto end_first = std::chrono::steady_clock::now(); + std::chrono::duration dur_first + = end_first - start_first; + + if (quick_test) return; + + // Timing runs. + const int runs = std::max(min_runs, int(time_limit / dur_first.count())); + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i <= runs; i++) { + cp.execute(strm, {ts_src, ts_wei0, ts_wei1, ts_wei2}, {ts_dst}); + } + strm.wait(); + auto end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - start; + + // Display the results. + double avg_time = (duration.count() - dur_first.count()) / runs; + std::cout << "graph runs: " << runs + 1 << "; "; + std::cout << "avg_time: " << avg_time << " ms" << std::endl; +} + +void bad_args() { + std::cerr << "Usage: graph-gated-mlp-wei-combined-cpp [cpu|gpu]\n" + " graph-gated-mlp-wei-combined-cpp [cpu|gpu] " + "\n\n"; + throw std::invalid_argument("Incorrect input arguments."); +} + +void bench(engine::kind ekind, dnnl_data_type_t dt, const mlp_dims_t &p, + double time_limit = 0.) { + try { + bench_gated_mlp(ekind, static_cast(dt), p, + time_limit); + get_mem_pool().clear(); + } catch (dnnl::error &e) { + // Catch and report unimplemented cases. + if (e.status == dnnl_unimplemented) { + std::cout << "unsupported mlp" << std::endl; + } else + throw; + } +} + +void mlp_perf(engine::kind ekind, int argc, char **argv) { + // default testing parameters + mlp_dims_t params = {1, 4096, 14336}; + + if (argc > 2) { + if (argc == 5) { + params.mb = std::atoi(argv[2]); + params.ic = std::atoi(argv[3]); + params.oc = std::atoi(argv[4]); + } else { + bad_args(); + } + + if (params.mb <= 0 || params.ic <= 0 || params.oc <= 0) { bad_args(); } + } + + bench(ekind, dnnl_f32, params, 2000.0 /*ms*/); + bench(ekind, dnnl_bf16, params, 2000.0 /*ms*/); + bench(ekind, dnnl_f16, params, 2000.0 /*ms*/); +} + +int main(int argc, char **argv) { + return handle_example_errors( + mlp_perf, parse_engine_kind(argc, argv, 3), argc, argv); +} diff --git a/examples/graph/gqa.cpp b/examples/graph/gqa.cpp index 3c60bfd1d9d..0f0e244116d 100644 --- a/examples/graph/gqa.cpp +++ b/examples/graph/gqa.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,6 @@ #include "graph_example_utils.hpp" using namespace dnnl; -using tag = memory::format_tag; using namespace dnnl::graph; using layout_type = logical_tensor::layout_type; diff --git a/examples/graph/graph_example_utils.hpp b/examples/graph/graph_example_utils.hpp index 02a9a844812..4671aa36d9d 100644 --- a/examples/graph/graph_example_utils.hpp +++ b/examples/graph/graph_example_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,8 @@ /// @param partitions a list of partitions /// @param id_to_set_any_layout a set of ids of logical tensors with any layout /// type -void set_any_layout(const std::vector &partitions, +inline void set_any_layout( + const std::vector &partitions, std::unordered_set &id_to_set_any_layout) { // mapping from output tensor id to the all supported flags of // supported partitions, we may only need outputs' supported flags @@ -104,30 +105,30 @@ void set_any_layout(const std::vector &partitions, } } -struct cpu_deletor { - cpu_deletor() = default; +struct cpu_deletor_t { + cpu_deletor_t() = default; void operator()(void *ptr) { if (ptr) free(ptr); } }; #ifdef DNNL_WITH_SYCL -struct sycl_deletor { - sycl_deletor() = delete; +struct sycl_deletor_t { + sycl_deletor_t() = delete; ::sycl::context ctx_; - sycl_deletor(const ::sycl::context &ctx) : ctx_(ctx) {} + sycl_deletor_t(const ::sycl::context &ctx) : ctx_(ctx) {} void operator()(void *ptr) { if (ptr) ::sycl::free(ptr, ctx_); } }; -void *sycl_malloc_wrapper( +inline void *sycl_malloc_wrapper( size_t size, size_t alignment, const void *dev, const void *ctx) { return malloc_shared(size, *static_cast(dev), *static_cast(ctx)); } -void sycl_free_wrapper( +inline void sycl_free_wrapper( void *ptr, const void *device, const void *context, void *event) { // Device is not used in this example, but it may be useful for some users // application. @@ -142,7 +143,7 @@ void sycl_free_wrapper( } #endif -void allocate_graph_mem(std::vector &tensors, +inline void allocate_graph_mem(std::vector &tensors, const std::vector <s, std::vector> &data_buffer, const dnnl::engine &eng) { @@ -152,14 +153,14 @@ void allocate_graph_mem(std::vector &tensors, // memory allocation data_buffer.push_back({}); - data_buffer.back().reset(malloc(mem_size), cpu_deletor {}); + data_buffer.back().reset(malloc(mem_size), cpu_deletor_t {}); dnnl::graph::tensor new_ts {lt, eng, data_buffer.back().get()}; tensors.push_back(new_ts); } } -void allocate_graph_mem(std::vector &tensors, +inline void allocate_graph_mem(std::vector &tensors, const std::vector <s, std::vector> &data_buffer, std::unordered_map &global_outputs_ts_map, @@ -180,7 +181,7 @@ void allocate_graph_mem(std::vector &tensors, // memory allocation data_buffer.push_back({}); - data_buffer.back().reset(malloc(mem_size), cpu_deletor {}); + data_buffer.back().reset(malloc(mem_size), cpu_deletor_t {}); dnnl::graph::tensor new_ts {lt, eng, data_buffer.back().get()}; tensors.push_back(new_ts); @@ -191,7 +192,7 @@ void allocate_graph_mem(std::vector &tensors, } #ifdef DNNL_WITH_SYCL -void allocate_sycl_graph_mem(std::vector &tensors, +inline void allocate_sycl_graph_mem(std::vector &tensors, const std::vector <s, std::vector> &data_buffer, sycl::queue &q, const dnnl::engine &eng) { @@ -203,14 +204,14 @@ void allocate_sycl_graph_mem(std::vector &tensors, data_buffer.push_back({}); data_buffer.back().reset(::sycl::malloc_shared(mem_size, q.get_device(), q.get_context()), - sycl_deletor {q.get_context()}); + sycl_deletor_t {q.get_context()}); dnnl::graph::tensor new_ts {lt, eng, data_buffer.back().get()}; tensors.push_back(new_ts); } } -void allocate_sycl_graph_mem(std::vector &tensors, +inline void allocate_sycl_graph_mem(std::vector &tensors, const std::vector <s, std::vector> &data_buffer, std::unordered_map &global_outputs_ts_map, @@ -233,7 +234,7 @@ void allocate_sycl_graph_mem(std::vector &tensors, data_buffer.push_back({}); data_buffer.back().reset(::sycl::malloc_shared(mem_size, q.get_device(), q.get_context()), - sycl_deletor {q.get_context()}); + sycl_deletor_t {q.get_context()}); dnnl::graph::tensor new_ts {lt, eng, data_buffer.back().get()}; tensors.push_back(new_ts); @@ -292,7 +293,7 @@ static void *ocl_malloc_device( } static void ocl_free( - void *ptr, cl_device_id dev, const cl_context ctx, cl_event event) { + void *ptr, cl_device_id dev, cl_context ctx, cl_event event) { if (nullptr == ptr) return; using F = cl_int (*)(cl_context, void *); if (event) { OCL_CHECK(clWaitForEvents(1, &event)); } @@ -305,7 +306,7 @@ static void ocl_free( OCL_CHECK(f(ctx, ptr)); } -void allocate_ocl_graph_mem(std::vector &tensors, +inline void allocate_ocl_graph_mem(std::vector &tensors, const std::vector <s, std::vector> &data_buffer, std::unordered_map &global_outputs_ts_map, @@ -341,7 +342,8 @@ void allocate_ocl_graph_mem(std::vector &tensors, } } -void ocl_memcpy(dnnl::engine &eng, void *dst, const void *src, size_t size) { +inline void ocl_memcpy( + dnnl::engine &eng, void *dst, const void *src, size_t size) { using F = cl_int (*)(cl_command_queue, cl_bool, void *, const void *, size_t, cl_uint, const cl_event *, cl_event *); if (!src || !dst) return; @@ -370,8 +372,6 @@ void ocl_memcpy(dnnl::engine &eng, void *dst, const void *src, size_t size) { err = f(queue, CL_FALSE, dst, src, size, 0, nullptr, nullptr); if (err != CL_SUCCESS) throw std::runtime_error("clEnqueueMemcpyINTEL failed"); - - return; } #endif @@ -525,7 +525,7 @@ class simple_memory_pool_t { #ifdef DNNL_WITH_SYCL auto sh_ptr = std::shared_ptr { sycl_malloc_wrapper(size, alignment, dev, ctx), - sycl_deletor {*static_cast(ctx)}}; + sycl_deletor_t {*static_cast(ctx)}}; #endif #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL @@ -535,7 +535,7 @@ class simple_memory_pool_t { #endif ptr = sh_ptr.get(); // record the map of mm size and its ptr for reuse - map_size_ptr_.emplace(std::make_pair(size, sh_ptr)); + map_size_ptr_.emplace(size, sh_ptr); is_free_ptr_[ptr] = false; } return ptr; @@ -562,10 +562,11 @@ class simple_memory_pool_t { } } if (need_alloc_new_mm) { - auto sh_ptr = std::shared_ptr {malloc(size), cpu_deletor {}}; + auto sh_ptr + = std::shared_ptr {malloc(size), cpu_deletor_t {}}; ptr = sh_ptr.get(); // record the map of mm size and its ptr for reuse - map_size_ptr_.emplace(std::make_pair(size, sh_ptr)); + map_size_ptr_.emplace(size, sh_ptr); is_free_ptr_[ptr] = false; } return ptr; @@ -575,27 +576,26 @@ class simple_memory_pool_t { void deallocate( void *ptr, const void *device, const void *context, void *event) { std::lock_guard pool_guard(pool_lock); - if (event) { - auto sycl_deps_ptr = static_cast<::sycl::event *>(event); - sycl_deps_ptr->wait(); - } + // This example currently supports `in_order`. So the kernel are + // executed in the order in which they are submitted. Don't need to wait + // event. is_free_ptr_[ptr] = true; return; } #endif #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL void deallocate( - void *ptr, cl_device_id dev, const cl_context ctx, cl_event event) { + void *ptr, cl_device_id dev, cl_context ctx, cl_event event) { std::lock_guard pool_guard(pool_lock); - if (event) { OCL_CHECK(clWaitForEvents(1, &event)); } + // This example currently supports `In-order`. So the kernel are + // executed in the order in which they are submitted. Don't need to wait + // event. is_free_ptr_[ptr] = true; - return; } #endif void deallocate_host(void *ptr) { std::lock_guard pool_guard(pool_lock); is_free_ptr_[ptr] = true; - return; } void clear() { dnnl::graph::set_compiled_partition_cache_capacity(0); @@ -609,12 +609,12 @@ class simple_memory_pool_t { std::unordered_map is_free_ptr_; }; -simple_memory_pool_t &get_mem_pool() { +inline simple_memory_pool_t &get_mem_pool() { static simple_memory_pool_t mem_pool; return mem_pool; } -dnnl::graph::allocator create_allocator(dnnl::engine::kind ekind) { +inline dnnl::graph::allocator create_allocator(dnnl::engine::kind ekind) { if (ekind == dnnl::engine::kind::cpu) { #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL auto alloc_func = [](size_t size, size_t alignment, const void *dev, @@ -653,8 +653,8 @@ dnnl::graph::allocator create_allocator(dnnl::engine::kind ekind) { cl_context ctx) -> void * { return get_mem_pool().allocate(size, alignment, dev, ctx); }; - auto dealloc_func = [](void *ptr, cl_device_id dev, - const cl_context ctx, cl_event event) { + auto dealloc_func = [](void *ptr, cl_device_id dev, cl_context ctx, + cl_event event) { return get_mem_pool().deallocate(ptr, dev, ctx, event); }; return dnnl::graph::ocl_interop::make_allocator( diff --git a/examples/graph/mqa.cpp b/examples/graph/mqa.cpp index 3f35b4684de..4097d356a00 100644 --- a/examples/graph/mqa.cpp +++ b/examples/graph/mqa.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,6 @@ #include "graph_example_utils.hpp" using namespace dnnl; -using tag = memory::format_tag; using namespace dnnl::graph; using layout_type = logical_tensor::layout_type; diff --git a/examples/graph/sdpa.cpp b/examples/graph/sdpa.cpp index 3b86c51808c..99f3dc2c64b 100644 --- a/examples/graph/sdpa.cpp +++ b/examples/graph/sdpa.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,6 @@ #include "graph_example_utils.hpp" using namespace dnnl; -using tag = memory::format_tag; using namespace dnnl::graph; using layout_type = logical_tensor::layout_type; @@ -97,6 +96,9 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, // Create dnnl::stream. dnnl::stream strm(eng); + // Intermediate data type + const memory::data_type dt_inter = memory::data_type::f32; + // Prepare input and output shapes to construct the sdpa graph. const memory::dims q_sz = {p.mb, p.head_num, p.query_num, p.head_size}; const memory::dims k_sz = {p.mb, p.head_num, p.head_size, p.seq_len}; @@ -109,11 +111,12 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, // scaled_score = score / scale // masked_score = scaled_score + mask // All combined in a single matmul primitive. - auto query_md = memory::desc(q_sz, dt, tag::abcd); - auto key_md = memory::desc(k_sz, dt, tag::abdc); - auto score_md = memory::desc(score_sz, dt, tag::abcd); - auto scale_md = memory::desc(scale_sz, dt, tag::abcd); - auto mask_md = memory::desc(mask_sz, dt, tag::abcd); + auto query_md = memory::desc(q_sz, dt, memory::format_tag::abcd); + auto key_md = memory::desc(k_sz, dt, memory::format_tag::abdc); + auto score_md = memory::desc(score_sz, dt_inter, memory::format_tag::abcd); + auto scale_md = memory::desc(scale_sz, dt, memory::format_tag::abcd); + auto mask_md = memory::desc(mask_sz, dt, memory::format_tag::abcd); + auto probs_md = memory::desc(score_sz, dt, memory::format_tag::abcd); primitive_attr bmm1_attr; bmm1_attr.set_scratchpad_mode(scratchpad_mode::user); @@ -131,16 +134,16 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, softmax_attr.set_scratchpad_mode(scratchpad_mode::user); auto softmax_pd = softmax_forward::primitive_desc(eng, prop_kind::forward_inference, algorithm::softmax_accurate, score_md, - score_md, /* axis = */ score_md.get_ndims() - 1, softmax_attr); + probs_md, /* axis = */ score_md.get_ndims() - 1, softmax_attr); auto softmax_prim = softmax_forward(softmax_pd); // attention_output = attention_probs x value - auto value_md = memory::desc(v_sz, dt, tag::abcd); - auto output_md = memory::desc(q_sz, dt, tag::abcd); + auto value_md = memory::desc(v_sz, dt, memory::format_tag::abcd); + auto output_md = memory::desc(q_sz, dt, memory::format_tag::abcd); primitive_attr bmm2_attr; bmm2_attr.set_scratchpad_mode(scratchpad_mode::user); auto bmm2_pd = matmul::primitive_desc( - eng, score_md, value_md, output_md, bmm2_attr); + eng, probs_md, value_md, output_md, bmm2_attr); auto bmm2_prim = matmul(bmm2_pd); // Create memory objects @@ -180,10 +183,11 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, } auto scratchpad_md = memory::desc({static_cast(max_scratchpad_size)}, - memory::data_type::u8, tag::a); + memory::data_type::u8, memory::format_tag::a); // allocate intermediate memory auto m_score = memory(score_md, eng); + auto m_probs = memory(probs_md, eng); auto m_scratchpad = memory(scratchpad_md, eng); const auto loop = [&]() { @@ -198,11 +202,11 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, {DNNL_ARG_SCRATCHPAD, m_scratchpad}}); softmax_prim.execute(strm, - {{DNNL_ARG_SRC, m_score}, {DNNL_ARG_DST, m_score}, + {{DNNL_ARG_SRC, m_score}, {DNNL_ARG_DST, m_probs}, {DNNL_ARG_SCRATCHPAD, m_scratchpad}}); bmm2_prim.execute(strm, - {{DNNL_ARG_SRC, m_score}, {DNNL_ARG_WEIGHTS, m_value}, + {{DNNL_ARG_SRC, m_probs}, {DNNL_ARG_WEIGHTS, m_value}, {DNNL_ARG_DST, m_output}, {DNNL_ARG_SCRATCHPAD, m_scratchpad}}); }; @@ -283,10 +287,13 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // Incremental IDs used to create logical tensors and operations. size_t id = 0; + // Intermediate data type + const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32; + // score = query x key.T auto query = logical_tensor(id++, dt, qv_sz, layout_type::strided); auto key = logical_tensor(id++, dt, k_sz, layout_type::strided); - auto score = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto score = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto bmm1 = op(id++, op::kind::MatMul, "bmm1"); bmm1.set_attr(op::attr::transpose_b, true); bmm1.add_inputs({query, key}); @@ -295,7 +302,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // scaled_score = score / scale auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided); auto scaled_score - = logical_tensor(id++, dt, score_sz, layout_type::strided); + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto scale_div = op(id++, op::kind::Divide, "scale_div"); scale_div.add_inputs({score, scale}); scale_div.add_outputs({scaled_score}); @@ -303,7 +310,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // masked_score = scaled_score + mask auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided); auto masked_score - = logical_tensor(id++, dt, score_sz, layout_type::strided); + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto mask_add = op(id++, op::kind::Add, "mask_add"); mask_add.add_inputs({scaled_score, mask}); mask_add.add_outputs({masked_score}); diff --git a/examples/graph/sdpa_stacked_qkv.cpp b/examples/graph/sdpa_stacked_qkv.cpp index dacae9b4672..29920224192 100644 --- a/examples/graph/sdpa_stacked_qkv.cpp +++ b/examples/graph/sdpa_stacked_qkv.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,6 @@ #include "graph_example_utils.hpp" using namespace dnnl; -using tag = memory::format_tag; using namespace dnnl::graph; using layout_type = logical_tensor::layout_type; @@ -143,6 +142,9 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // Incremental IDs used to create logical tensors and operations. size_t id = 0; + // Intermediate data type + const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32; + // This logical tensor is not part of the graph but is used to generate the // big chunk of device memory which should be already there in real user // application or framework. @@ -153,7 +155,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, auto key = logical_tensor(id++, dt, qkv_sz, qkv_strides); // Though query and key are non-contiguous above, the output score is still // contiguous. - auto score = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto score = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto bmm1 = op(id++, op::kind::MatMul, "bmm1"); bmm1.set_attr(op::attr::transpose_b, true); bmm1.add_inputs({query, key}); @@ -162,7 +164,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // scaled_score = score / scale auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided); auto scaled_score - = logical_tensor(id++, dt, score_sz, layout_type::strided); + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto scale_div = op(id++, op::kind::Divide, "scale_div"); scale_div.add_inputs({score, scale}); scale_div.add_outputs({scaled_score}); @@ -170,7 +172,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // masked_score = scaled_score + mask auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided); auto masked_score - = logical_tensor(id++, dt, score_sz, layout_type::strided); + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto mask_add = op(id++, op::kind::Add, "mask_add"); mask_add.add_inputs({scaled_score, mask}); mask_add.add_outputs({masked_score}); diff --git a/examples/matmul_perf.cpp b/examples/matmul_perf.cpp index d1d29677a53..083d082d549 100644 --- a/examples/matmul_perf.cpp +++ b/examples/matmul_perf.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,20 +28,17 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - struct gemm_dims_t { memory::dim m, n, k; }; static const int min_runs = 4; -const char *get_type_string(dt type) { +const char *get_type_string(memory::data_type type) { const char *type_string = "unknown"; #define TYPE_CASE(T) \ - if (type == dt::T) type_string = #T; + if (type == memory::data_type::T) type_string = #T; TYPE_CASE(f16); TYPE_CASE(f32); TYPE_CASE(f64); @@ -53,7 +50,7 @@ const char *get_type_string(dt type) { return type_string; } -void print_test_case(dt type, gemm_dims_t dims) { +void print_test_case(memory::data_type type, gemm_dims_t dims) { std::cout << '[' << std::setw(4) << get_type_string(type); if (dims.m == dims.n && dims.m == dims.k) std::cout << " m = n = k = " << dims.m; @@ -89,9 +86,10 @@ void fill_random(std::vector &out, bool is_integer) { } } -double run_case(engine::kind engine_kind, dt type, gemm_dims_t dims, - double time_limit = 0.) { - bool is_integer = (type == dt::s8 || type == dt::u8); +double run_case(engine::kind engine_kind, memory::data_type type, + gemm_dims_t dims, double time_limit = 0.) { + bool is_integer + = (type == memory::data_type::s8 || type == memory::data_type::u8); bool quick_test = (time_limit == 0.); // Create execution dnnl::engine. @@ -115,12 +113,14 @@ double run_case(engine::kind engine_kind, dt type, gemm_dims_t dims, // Create memory descriptors and memory objects for src, weights, bias, and // dst. - auto a_md = memory::desc(a_dims, type, tag::any); - auto b_md = memory::desc(b_dims, type, tag::any); - auto c_md = memory::desc(c_dims, type, tag::any); + auto a_md = memory::desc(a_dims, type, memory::format_tag::any); + auto b_md = memory::desc(b_dims, type, memory::format_tag::any); + auto c_md = memory::desc(c_dims, type, memory::format_tag::any); - auto a_in_md = memory::desc(a_dims, dt::f32, tag::ab); - auto b_in_md = memory::desc(b_dims, dt::f32, tag::ab); + auto a_in_md = memory::desc( + a_dims, memory::data_type::f32, memory::format_tag::ab); + auto b_in_md = memory::desc( + b_dims, memory::data_type::f32, memory::format_tag::ab); auto a_in_mem = memory(a_in_md, engine); auto b_in_mem = memory(b_in_md, engine); @@ -197,7 +197,7 @@ double run_case(engine::kind engine_kind, dt type, gemm_dims_t dims, return avg_time; } -void run(engine::kind engine_kind, dt type, gemm_dims_t dims, +void run(engine::kind engine_kind, memory::data_type type, gemm_dims_t dims, double time_limit) { try { if (dims.m * dims.n != 0) { @@ -257,10 +257,10 @@ void matmul_perf(engine::kind engine_kind, int argc, char **argv) { if (dims.m <= 0 || dims.n <= 0 || dims.k <= 0) bad_args(); } - run(engine_kind, dt::f32, dims, 2.0); - run(engine_kind, dt::f16, dims, 2.0); - run(engine_kind, dt::bf16, dims, 2.0); - run(engine_kind, dt::s8, dims, 2.0); + run(engine_kind, memory::data_type::f32, dims, 2.0); + run(engine_kind, memory::data_type::f16, dims, 2.0); + run(engine_kind, memory::data_type::bf16, dims, 2.0); + run(engine_kind, memory::data_type::s8, dims, 2.0); } int main(int argc, char **argv) { diff --git a/examples/primitives/augru.cpp b/examples/primitives/augru.cpp index 28abc6557af..dd0901c6c62 100644 --- a/examples/primitives/augru.cpp +++ b/examples/primitives/augru.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,9 +42,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void augru_example(dnnl::engine::kind engine_kind) { if (engine_kind == engine::kind::gpu) @@ -100,10 +97,14 @@ void augru_example(dnnl::engine::kind engine_kind) { }); // Create memory descriptors and memory objects for src, bias, and dst. - auto src_layer_md = memory::desc(src_dims, dt::f32, tag::tnc); - auto attention_md = memory::desc(attention_dims, dt::f32, tag::tnc); - auto bias_md = memory::desc(bias_dims, dt::f32, tag::ldgo); - auto dst_layer_md = memory::desc(dst_dims, dt::f32, tag::tnc); + auto src_layer_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::tnc); + auto attention_md = memory::desc( + attention_dims, memory::data_type::f32, memory::format_tag::tnc); + auto bias_md = memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::ldgo); + auto dst_layer_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::tnc); auto src_layer_mem = memory(src_layer_md, engine); auto attention_mem = memory(attention_md, engine); @@ -112,10 +113,12 @@ void augru_example(dnnl::engine::kind engine_kind) { // Create memory objects for weights using user's memory layout. In this // example, LDIGO is assumed. - auto user_weights_layer_mem - = memory({weights_dims, dt::f32, tag::ldigo}, engine); - auto user_weights_iter_mem - = memory({weights_dims, dt::f32, tag::ldigo}, engine); + auto user_weights_layer_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::ldigo}, + engine); + auto user_weights_iter_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::ldigo}, + engine); // Write data to memory object's handle. write_to_dnnl_memory(src_layer_data.data(), src_layer_mem); @@ -126,8 +129,10 @@ void augru_example(dnnl::engine::kind engine_kind) { // Create memory descriptors for weights with format_tag::any. This enables // the AUGRU primitive to choose the optimized memory layout. - auto augru_weights_layer_md = memory::desc(weights_dims, dt::f32, tag::any); - auto augru_weights_iter_md = memory::desc(weights_dims, dt::f32, tag::any); + auto augru_weights_layer_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); + auto augru_weights_iter_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); // Optional memory descriptors for recurrent data. auto src_iter_md = memory::desc(); diff --git a/examples/primitives/batch_normalization.cpp b/examples/primitives/batch_normalization.cpp index 5c3e163f968..088f8470cfa 100644 --- a/examples/primitives/batch_normalization.cpp +++ b/examples/primitives/batch_normalization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,9 +44,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void batch_normalization_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -91,9 +88,12 @@ void batch_normalization_example(dnnl::engine::kind engine_kind) { }); // Create src and scale/shift memory descriptors and memory objects. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto dst_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto scaleshift_md = memory::desc(scaleshift_dims, dt::f32, tag::x); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto dst_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto scaleshift_md = memory::desc( + scaleshift_dims, memory::data_type::f32, memory::format_tag::x); auto src_mem = memory(src_md, engine); auto scale_mem = memory(scaleshift_md, engine); diff --git a/examples/primitives/binary.cpp b/examples/primitives/binary.cpp index 5b4707bc886..999650d0932 100644 --- a/examples/primitives/binary.cpp +++ b/examples/primitives/binary.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,9 +42,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void binary_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -78,9 +75,12 @@ void binary_example(dnnl::engine::kind engine_kind) { }); // Create src and dst memory descriptors. - auto src_0_md = memory::desc(src_0_dims, dt::f32, tag::nchw); - auto src_1_md = memory::desc(src_1_dims, dt::f32, tag::nchw); - auto dst_md = memory::desc(src_0_dims, dt::f32, tag::nchw); + auto src_0_md = memory::desc( + src_0_dims, memory::data_type::f32, memory::format_tag::nchw); + auto src_1_md = memory::desc( + src_1_dims, memory::data_type::f32, memory::format_tag::nchw); + auto dst_md = memory::desc( + src_0_dims, memory::data_type::f32, memory::format_tag::nchw); // Create src memory objects. auto src_0_mem = memory(src_0_md, engine); diff --git a/examples/primitives/concat.cpp b/examples/primitives/concat.cpp index e563722aeb8..50935810f98 100644 --- a/examples/primitives/concat.cpp +++ b/examples/primitives/concat.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,9 +43,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void concat_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -85,7 +82,8 @@ void concat_example(dnnl::engine::kind engine_kind) { std::vector src_mems; for (int n = 0; n < num_src; ++n) { - auto md = memory::desc(src_dims, dt::f32, tag::nchw); + auto md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); auto mem = memory(md, engine); // Write data to memory object's handle. diff --git a/examples/primitives/convolution.cpp b/examples/primitives/convolution.cpp index e3ecc0f9b94..cc8eb768363 100644 --- a/examples/primitives/convolution.cpp +++ b/examples/primitives/convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,9 +43,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void convolution_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -75,6 +72,7 @@ void convolution_example(dnnl::engine::kind engine_kind) { // dimensions. memory::dims src_dims = {N, IC, IH, IW}; memory::dims weights_dims = {OC, IC, KH, KW}; + // To simulate an empty bias use an empty initializer `{}`. memory::dims bias_dims = {OC}; memory::dims dst_dims = {N, OC, OH, OW}; @@ -86,7 +84,7 @@ void convolution_example(dnnl::engine::kind engine_kind) { // Allocate buffers. std::vector src_data(product(src_dims)); std::vector weights_data(product(weights_dims)); - std::vector bias_data(OC); + std::vector bias_data(product(bias_dims)); std::vector dst_data(product(dst_dims)); // Initialize src, weights, and dst tensors. @@ -105,26 +103,39 @@ void convolution_example(dnnl::engine::kind engine_kind) { // Create memory objects for tensor data (src, weights, dst). In this // example, NCHW layout is assumed for src and dst, and OIHW for weights. - auto user_src_mem = memory({src_dims, dt::f32, tag::nchw}, engine); - auto user_weights_mem = memory({weights_dims, dt::f32, tag::oihw}, engine); - auto user_dst_mem = memory({dst_dims, dt::f32, tag::nchw}, engine); + auto user_src_mem = memory( + {src_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); + auto user_weights_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::oihw}, + engine); + auto user_dst_mem = memory( + {dst_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); // Create memory descriptors with format_tag::any for the primitive. This // enables the convolution primitive to choose memory layouts for an // optimized primitive implementation, and these layouts may differ from the // ones provided by the user. - auto conv_src_md = memory::desc(src_dims, dt::f32, tag::any); - auto conv_weights_md = memory::desc(weights_dims, dt::f32, tag::any); - auto conv_dst_md = memory::desc(dst_dims, dt::f32, tag::any); + auto conv_src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::any); + auto conv_weights_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); + auto conv_dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::any); // Create memory descriptor and memory object for input bias. - auto user_bias_md = memory::desc(bias_dims, dt::f32, tag::a); + auto user_bias_md = bias_dims.empty() + ? memory::desc() + : memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::a); auto user_bias_mem = memory(user_bias_md, engine); // Write data to memory object's handle. write_to_dnnl_memory(src_data.data(), user_src_mem); write_to_dnnl_memory(weights_data.data(), user_weights_mem); - write_to_dnnl_memory(bias_data.data(), user_bias_mem); + if (!bias_dims.empty()) + write_to_dnnl_memory(bias_data.data(), user_bias_mem); // Create primitive post-ops (ReLU). const float alpha = 0.f; @@ -254,20 +265,30 @@ void depthwise_convolution_example(dnnl::engine::kind engine_kind) { // Create memory objects for tensor data (src, weights, dst). In this // example, NCHW layout is assumed for src and dst, and OIHW for weights. - auto user_src_mem = memory({src_dims, dt::f32, tag::nchw}, engine); - auto user_weights_mem = memory({weights_dims, dt::f32, tag::goihw}, engine); - auto user_dst_mem = memory({dst_dims, dt::f32, tag::nchw}, engine); + auto user_src_mem = memory( + {src_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); + auto user_weights_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::goihw}, + engine); + auto user_dst_mem = memory( + {dst_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); // Create memory descriptors with format_tag::any for the primitive. This // enables the convolution primitive to choose memory layouts for an // optimized primitive implementation, and these layouts may differ from the // ones provided by the user. - auto conv_src_md = memory::desc(src_dims, dt::f32, tag::any); - auto conv_weights_md = memory::desc(weights_dims, dt::f32, tag::any); - auto conv_dst_md = memory::desc(dst_dims, dt::f32, tag::any); + auto conv_src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::any); + auto conv_weights_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); + auto conv_dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::any); // Create memory descriptor and memory object for input bias. - auto user_bias_md = memory::desc(bias_dims, dt::f32, tag::a); + auto user_bias_md = memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::a); auto user_bias_mem = memory(user_bias_md, engine); // Write data to memory object's handle. diff --git a/examples/primitives/deconvolution.cpp b/examples/primitives/deconvolution.cpp index 841b7f7ba8d..f1efdc61386 100644 --- a/examples/primitives/deconvolution.cpp +++ b/examples/primitives/deconvolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,9 +43,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void deconvolution_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -111,20 +108,30 @@ void deconvolution_example(dnnl::engine::kind engine_kind) { // Create memory objects for tensor data (src, weights, dst). In this // example, NCHW layout is assumed for src and dst, and OIHW for weights. - auto user_src_mem = memory({src_dims, dt::f32, tag::nchw}, engine); - auto user_weights_mem = memory({weights_dims, dt::f32, tag::oihw}, engine); - auto user_dst_mem = memory({dst_dims, dt::f32, tag::nchw}, engine); + auto user_src_mem = memory( + {src_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); + auto user_weights_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::oihw}, + engine); + auto user_dst_mem = memory( + {dst_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); // Create memory descriptors with format_tag::any for the primitive. This // enables the deconvolution primitive to choose memory layouts for an // optimized primitive implementation, and these layouts may differ from the // ones provided by the user. - auto deconv_src_md = memory::desc(src_dims, dt::f32, tag::any); - auto deconv_weights_md = memory::desc(weights_dims, dt::f32, tag::any); - auto deconv_dst_md = memory::desc(dst_dims, dt::f32, tag::any); + auto deconv_src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::any); + auto deconv_weights_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); + auto deconv_dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::any); // Create memory descriptor and memory object for input bias. - auto user_bias_md = memory::desc(bias_dims, dt::f32, tag::a); + auto user_bias_md = memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::a); auto user_bias_mem = memory(user_bias_md, engine); // Write data to memory object's handle. diff --git a/examples/primitives/eltwise.cpp b/examples/primitives/eltwise.cpp index 2bea8dcbe08..acd59ed26a3 100644 --- a/examples/primitives/eltwise.cpp +++ b/examples/primitives/eltwise.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,9 +39,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void eltwise_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -73,8 +70,10 @@ void eltwise_example(dnnl::engine::kind engine_kind) { }); // Create src and dst memory descriptors and memory objects. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto dst_md = memory::desc(dst_dims, dt::f32, tag::nchw); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::nchw); auto src_mem = memory(src_md, engine); auto dst_mem = memory(dst_md, engine); diff --git a/examples/primitives/group_normalization.cpp b/examples/primitives/group_normalization.cpp index ce9ea87455f..84c67a41a57 100644 --- a/examples/primitives/group_normalization.cpp +++ b/examples/primitives/group_normalization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,9 +43,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void group_normalization_example(engine::kind engine_kind) { // Create execution dnnl::engine. dnnl::engine engine(engine_kind, 0); @@ -93,9 +90,12 @@ void group_normalization_example(engine::kind engine_kind) { }); // Create src and scale/shift memory descriptors and memory objects. - auto src_md = memory::desc(src_dims, dt::f32, tag::ncdhw); - auto dst_md = memory::desc(src_dims, dt::f32, tag::ncdhw); - auto scaleshift_md = memory::desc(scaleshift_dims, dt::f32, tag::x); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::ncdhw); + auto dst_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::ncdhw); + auto scaleshift_md = memory::desc( + scaleshift_dims, memory::data_type::f32, memory::format_tag::x); auto src_mem = memory(src_md, engine); auto scale_mem = memory(scaleshift_md, engine); diff --git a/examples/primitives/inner_product.cpp b/examples/primitives/inner_product.cpp index f987b88ca16..334c092151c 100644 --- a/examples/primitives/inner_product.cpp +++ b/examples/primitives/inner_product.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,9 +42,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void inner_product_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -89,9 +86,12 @@ void inner_product_example(dnnl::engine::kind engine_kind) { // Create memory descriptors and memory objects for src and dst. In this // example, NCHW layout is assumed. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto bias_md = memory::desc(bias_dims, dt::f32, tag::a); - auto dst_md = memory::desc(dst_dims, dt::f32, tag::nc); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto bias_md = memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::a); + auto dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::nc); auto src_mem = memory(src_md, engine); auto bias_mem = memory(bias_md, engine); @@ -99,7 +99,9 @@ void inner_product_example(dnnl::engine::kind engine_kind) { // Create memory object for user's layout for weights. In this example, OIHW // is assumed. - auto user_weights_mem = memory({weights_dims, dt::f32, tag::oihw}, engine); + auto user_weights_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::oihw}, + engine); // Write data to memory object's handles. write_to_dnnl_memory(src_data.data(), src_mem); @@ -110,8 +112,8 @@ void inner_product_example(dnnl::engine::kind engine_kind) { // the inner product primitive to choose the memory layout for an optimized // primitive implementation, and this format may differ from the one // provided by the user. - auto inner_product_weights_md - = memory::desc(weights_dims, dt::f32, tag::any); + auto inner_product_weights_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); // Create primitive post-ops (ReLU). const float alpha = 0.f; diff --git a/examples/primitives/layer_normalization.cpp b/examples/primitives/layer_normalization.cpp index 0079bc59b23..9bdc7dd2a8f 100644 --- a/examples/primitives/layer_normalization.cpp +++ b/examples/primitives/layer_normalization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,9 +43,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void layer_normalization_example(dnnl::engine::kind engine_kind) { /// Create execution dnnl::engine. @@ -89,9 +86,12 @@ void layer_normalization_example(dnnl::engine::kind engine_kind) { }); // Create src memory descriptor and memory objects. - auto src_md = memory::desc(src_dims, dt::f32, tag::tnc); - auto dst_md = memory::desc(src_dims, dt::f32, tag::tnc); - auto scaleshift_md = memory::desc(scaleshift_dims, dt::f32, tag::x); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::tnc); + auto dst_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::tnc); + auto scaleshift_md = memory::desc( + scaleshift_dims, memory::data_type::f32, memory::format_tag::x); auto src_mem = memory(src_md, engine); auto scale_mem = memory(scaleshift_md, engine); @@ -105,7 +105,8 @@ void layer_normalization_example(dnnl::engine::kind engine_kind) { // Create primitive descriptor. const float epsilon = 1.e-10f; auto lnorm_pd = layer_normalization_forward::primitive_desc(engine, - prop_kind::forward_training, src_md, dst_md, dt::f32, epsilon, + prop_kind::forward_training, src_md, dst_md, memory::data_type::f32, + epsilon, normalization_flags::use_scale | normalization_flags::use_shift); // Use the memory descriptors from the primitive to create memory objects diff --git a/examples/primitives/lbr_gru.cpp b/examples/primitives/lbr_gru.cpp index aeba8103c88..d343b44a33c 100644 --- a/examples/primitives/lbr_gru.cpp +++ b/examples/primitives/lbr_gru.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,9 +42,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void lbr_gru_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. dnnl::engine engine(engine_kind, 0); @@ -98,9 +95,12 @@ void lbr_gru_example(dnnl::engine::kind engine_kind) { }); // Create memory descriptors and memory objects for src, bias, and dst. - auto src_layer_md = memory::desc(src_dims, dt::f32, tag::tnc); - auto bias_md = memory::desc(bias_dims, dt::f32, tag::ldgo); - auto dst_layer_md = memory::desc(dst_layer_dims, dt::f32, tag::tnc); + auto src_layer_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::tnc); + auto bias_md = memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::ldgo); + auto dst_layer_md = memory::desc( + dst_layer_dims, memory::data_type::f32, memory::format_tag::tnc); auto src_layer_mem = memory(src_layer_md, engine); auto bias_mem = memory(bias_md, engine); @@ -110,9 +110,13 @@ void lbr_gru_example(dnnl::engine::kind engine_kind) { // example, LDIGO (num_layers, num_directions, input_channels, num_gates, // output_channels) is assumed. auto user_weights_layer_mem - = memory({weights_layer_dims, dt::f32, tag::ldigo}, engine); + = memory({weights_layer_dims, memory::data_type::f32, + memory::format_tag::ldigo}, + engine); auto user_weights_iter_mem - = memory({weights_iter_dims, dt::f32, tag::ldigo}, engine); + = memory({weights_iter_dims, memory::data_type::f32, + memory::format_tag::ldigo}, + engine); // Write data to memory object's handle. // For GRU cells, the gates order is update, reset and output @@ -125,8 +129,10 @@ void lbr_gru_example(dnnl::engine::kind engine_kind) { // Create memory descriptors for weights with format_tag::any. This enables // the lbr_gru primitive to choose the optimized memory layout. - auto weights_layer_md = memory::desc(weights_layer_dims, dt::f32, tag::any); - auto weights_iter_md = memory::desc(weights_iter_dims, dt::f32, tag::any); + auto weights_layer_md = memory::desc(weights_layer_dims, + memory::data_type::f32, memory::format_tag::any); + auto weights_iter_md = memory::desc( + weights_iter_dims, memory::data_type::f32, memory::format_tag::any); // Optional memory descriptors for recurrent data. // Default memory descriptor for initial hidden states of the GRU cells diff --git a/examples/primitives/lrn.cpp b/examples/primitives/lrn.cpp index 0d6df092d32..ce75c807910 100644 --- a/examples/primitives/lrn.cpp +++ b/examples/primitives/lrn.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,9 +39,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void lrn_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -69,8 +66,10 @@ void lrn_example(dnnl::engine::kind engine_kind) { }); // Create src and dst memory descriptors and memory objects. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto dst_md = memory::desc(src_dims, dt::f32, tag::nchw); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto dst_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); auto src_mem = memory(src_md, engine); auto dst_mem = memory(src_md, engine); diff --git a/examples/primitives/lstm.cpp b/examples/primitives/lstm.cpp index ba579944662..67514bb65c3 100644 --- a/examples/primitives/lstm.cpp +++ b/examples/primitives/lstm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,9 +42,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void lstm_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -90,9 +87,12 @@ void lstm_example(dnnl::engine::kind engine_kind) { }); // Create memory descriptors and memory objects for src, bias, and dst. - auto src_layer_md = memory::desc(src_dims, dt::f32, tag::tnc); - auto bias_md = memory::desc(bias_dims, dt::f32, tag::ldgo); - auto dst_layer_md = memory::desc(dst_dims, dt::f32, tag::tnc); + auto src_layer_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::tnc); + auto bias_md = memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::ldgo); + auto dst_layer_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::tnc); auto src_layer_mem = memory(src_layer_md, engine); auto bias_mem = memory(bias_md, engine); @@ -100,10 +100,12 @@ void lstm_example(dnnl::engine::kind engine_kind) { // Create memory objects for weights using user's memory layout. In this // example, LDIGO is assumed. - auto user_weights_layer_mem - = memory({weights_dims, dt::f32, tag::ldigo}, engine); - auto user_weights_iter_mem - = memory({weights_dims, dt::f32, tag::ldigo}, engine); + auto user_weights_layer_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::ldigo}, + engine); + auto user_weights_iter_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::ldigo}, + engine); // Write data to memory object's handle. write_to_dnnl_memory(src_layer_data.data(), src_layer_mem); @@ -113,8 +115,10 @@ void lstm_example(dnnl::engine::kind engine_kind) { // Create memory descriptors for weights with format_tag::any. This enables // the LSTM primitive to choose the optimized memory layout. - auto lstm_weights_layer_md = memory::desc(weights_dims, dt::f32, tag::any); - auto lstm_weights_iter_md = memory::desc(weights_dims, dt::f32, tag::any); + auto lstm_weights_layer_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); + auto lstm_weights_iter_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); // Optional memory descriptors for recurrent data. auto src_iter_md = memory::desc(); diff --git a/examples/primitives/matmul.cpp b/examples/primitives/matmul.cpp index 08d3faf11ae..fff7efeb0c2 100644 --- a/examples/primitives/matmul.cpp +++ b/examples/primitives/matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,9 +41,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void matmul_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -84,10 +81,14 @@ void matmul_example(dnnl::engine::kind engine_kind) { // Create memory descriptors and memory objects for src, weights, bias, and // dst. - auto src_md = memory::desc(src_dims, dt::f32, tag::abc); - auto weights_md = memory::desc(weights_dims, dt::f32, tag::abc); - auto bias_md = memory::desc(bias_dims, dt::f32, tag::abc); - auto dst_md = memory::desc(dst_dims, dt::f32, tag::abc); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::abc); + auto weights_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::abc); + auto bias_md = memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::abc); + auto dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::abc); auto src_mem = memory(src_md, engine); auto weights_mem = memory(weights_md, engine); diff --git a/examples/primitives/pooling.cpp b/examples/primitives/pooling.cpp index 92a2c877801..a84c37e6028 100644 --- a/examples/primitives/pooling.cpp +++ b/examples/primitives/pooling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,9 +39,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void pooling_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -66,6 +63,17 @@ void pooling_example(dnnl::engine::kind engine_kind) { DH = 1, // height-wise dilation DW = 1; // width-wise dilation + // oneDNN uses the following formula to calculate destination dimensions: + // dst = (src - ((weights - 1) * (dilation_onednn + 1) + 1)) / stride + 1 + // + // PyTorch and TensorFlow use a different formula: + // dst = (src - ((weights - 1) * dilation_torch + 1)) / stride + 1 + // + // As a result, the PyTorch and Tensorflow dilation parameters need to be + // adjusted by subtracting 1: + // dilation_onednn = dilation_torch - 1. + // + // Output tensor height and width. const memory::dim OH = (IH - ((KH - 1) * DH + KH) + PH_L + PH_R) / SH + 1; const memory::dim OW = (IW - ((KW - 1) * DW + KW) + PW_L + PW_R) / SW + 1; @@ -92,10 +100,12 @@ void pooling_example(dnnl::engine::kind engine_kind) { }); // Create memory descriptors and memory objects for src and dst. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); auto src_mem = memory(src_md, engine); - auto dst_md = memory::desc(dst_dims, dt::f32, tag::nchw); + auto dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::nchw); auto dst_mem = memory(dst_md, engine); // Write data to memory object's handle. diff --git a/examples/primitives/prelu.cpp b/examples/primitives/prelu.cpp index 9f46e61231f..986f90bc553 100644 --- a/examples/primitives/prelu.cpp +++ b/examples/primitives/prelu.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,9 +38,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void prelu_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -78,18 +75,27 @@ void prelu_example(dnnl::engine::kind engine_kind) { // Create memory objects for tensor data (src, weights, dst). In this // example, NCHW layout is assumed for src, weights and dst. - auto user_src_mem = memory({src_dims, dt::f32, tag::nchw}, engine); - auto user_weights_mem = memory({weights_dims, dt::f32, tag::nchw}, engine); - auto user_dst_mem = memory({dst_dims, dt::f32, tag::nchw}, engine); + auto user_src_mem = memory( + {src_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); + auto user_weights_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); + auto user_dst_mem = memory( + {dst_dims, memory::data_type::f32, memory::format_tag::nchw}, + engine); // Create memory descriptors for the primitive. Src tag is set // to match src memory object. Setting weights tag to format_tag::any // enables the PReLU primitive to choose memory layout for an optimized // primitive implementation, and that layout may differ from the one // provided by the user. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto weights_md = memory::desc(weights_dims, dt::f32, tag::any); - auto dst_md = memory::desc(src_dims, dt::f32, tag::any); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto weights_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); + auto dst_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::any); // Write data to memory object's handle. write_to_dnnl_memory(src_data.data(), user_src_mem); diff --git a/examples/primitives/reduction.cpp b/examples/primitives/reduction.cpp index 86dc1f70cbc..cde6abafdbc 100644 --- a/examples/primitives/reduction.cpp +++ b/examples/primitives/reduction.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,9 +34,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void reduction_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -66,8 +63,10 @@ void reduction_example(dnnl::engine::kind engine_kind) { }); // Create src and dst memory descriptors and memory objects. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto dst_md = memory::desc(dst_dims, dt::f32, tag::nchw); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::nchw); auto src_mem = memory(src_md, engine); auto dst_mem = memory(dst_md, engine); diff --git a/examples/primitives/reorder.cpp b/examples/primitives/reorder.cpp index 066b6eef97c..aa6d87f6c5e 100644 --- a/examples/primitives/reorder.cpp +++ b/examples/primitives/reorder.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,9 +41,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void reorder_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -72,8 +69,10 @@ void reorder_example(dnnl::engine::kind engine_kind) { }); // Create memory descriptors and memory objects for src and dst. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto dst_md = memory::desc(src_dims, dt::s8, tag::nhwc); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto dst_md = memory::desc( + src_dims, memory::data_type::s8, memory::format_tag::nhwc); auto src_mem = memory(src_md, engine); auto dst_mem = memory(dst_md, engine); @@ -94,7 +93,8 @@ void reorder_example(dnnl::engine::kind engine_kind) { // Create primitive post-ops (per-channel output scales) primitive_attr reorder_attr; reorder_attr.set_scales_mask(DNNL_ARG_DST, 1 << ic_dim); - auto dst_scales_mem = memory({{IC}, dt::f32, tag::x}, engine); + auto dst_scales_mem = memory( + {{IC}, memory::data_type::f32, memory::format_tag::x}, engine); write_to_dnnl_memory(scales.data(), dst_scales_mem); // Create primitive descriptor. diff --git a/examples/primitives/resampling.cpp b/examples/primitives/resampling.cpp index 59d6f948a6a..5bbefcd727a 100644 --- a/examples/primitives/resampling.cpp +++ b/examples/primitives/resampling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,9 +39,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void resampling_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -73,8 +70,10 @@ void resampling_example(dnnl::engine::kind engine_kind) { }); // Create memory descriptors and memory objects for src and dst. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto dst_md = memory::desc(dst_dims, dt::f32, tag::nchw); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto dst_md = memory::desc( + dst_dims, memory::data_type::f32, memory::format_tag::nchw); auto src_mem = memory(src_md, engine); auto dst_mem = memory(dst_md, engine); diff --git a/examples/primitives/shuffle.cpp b/examples/primitives/shuffle.cpp index 6a437a23d63..7f3041e8587 100644 --- a/examples/primitives/shuffle.cpp +++ b/examples/primitives/shuffle.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,9 +41,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void shuffle_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -76,11 +73,15 @@ void shuffle_example(dnnl::engine::kind engine_kind) { const int group_size = 4; // Create memory descriptor and memory objects for src and dst. - auto src_md = memory::desc(src_dims, dt::f32, tag::nchw); - auto dst_md = memory::desc(src_dims, dt::f32, tag::nchw); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); + auto dst_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); auto src_mem = memory(src_md, engine); - auto dst_mem = memory({src_dims, dt::f32, tag::abcd}, engine); + auto dst_mem = memory( + {src_dims, memory::data_type::f32, memory::format_tag::abcd}, + engine); // Write data to memory object's handle. write_to_dnnl_memory(src_data.data(), src_mem); diff --git a/examples/primitives/softmax.cpp b/examples/primitives/softmax.cpp index 58160edc245..183ac7cde76 100644 --- a/examples/primitives/softmax.cpp +++ b/examples/primitives/softmax.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,9 +43,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void softmax_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -70,8 +67,10 @@ void softmax_example(dnnl::engine::kind engine_kind) { }); // Create src memory descriptor and memory object. - auto src_md = memory::desc(src_dims, dt::f32, tag::nc); - auto dst_md = memory::desc(src_dims, dt::f32, tag::nc); + auto src_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nc); + auto dst_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nc); auto src_mem = memory(src_md, engine); // Write data to memory object's handle. diff --git a/examples/primitives/sum.cpp b/examples/primitives/sum.cpp index 19fc3e2e097..41149ff7146 100644 --- a/examples/primitives/sum.cpp +++ b/examples/primitives/sum.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,9 +41,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void sum_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. @@ -84,7 +81,8 @@ void sum_example(dnnl::engine::kind engine_kind) { std::vector src_mem; for (int n = 0; n < num_src; ++n) { - auto md = memory::desc(src_dims, dt::f32, tag::nchw); + auto md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::nchw); auto mem = memory(md, engine); // Write data to memory object's handle. diff --git a/examples/primitives/vanilla_rnn.cpp b/examples/primitives/vanilla_rnn.cpp index a288468fc09..e8f2c99ce17 100644 --- a/examples/primitives/vanilla_rnn.cpp +++ b/examples/primitives/vanilla_rnn.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,9 +42,6 @@ using namespace dnnl; -using tag = memory::format_tag; -using dt = memory::data_type; - void vanilla_rnn_example(dnnl::engine::kind engine_kind) { // Create execution dnnl::engine. dnnl::engine engine(engine_kind, 0); @@ -95,9 +92,12 @@ void vanilla_rnn_example(dnnl::engine::kind engine_kind) { }); // Create memory descriptors and memory objects for src, bias, and dst. - auto src_layer_md = memory::desc(src_dims, dt::f32, tag::tnc); - auto bias_md = memory::desc(bias_dims, dt::f32, tag::ldgo); - auto dst_layer_md = memory::desc(dst_layer_dims, dt::f32, tag::tnc); + auto src_layer_md = memory::desc( + src_dims, memory::data_type::f32, memory::format_tag::tnc); + auto bias_md = memory::desc( + bias_dims, memory::data_type::f32, memory::format_tag::ldgo); + auto dst_layer_md = memory::desc( + dst_layer_dims, memory::data_type::f32, memory::format_tag::tnc); auto src_layer_mem = memory(src_layer_md, engine); auto bias_mem = memory(bias_md, engine); @@ -106,10 +106,12 @@ void vanilla_rnn_example(dnnl::engine::kind engine_kind) { // Create memory objects for weights using user's memory layout. In this // example, LDIGO (num_layers, num_directions, input_channels, num_gates, // output_channels) is assumed. - auto user_weights_layer_mem - = memory({weights_dims, dt::f32, tag::ldigo}, engine); - auto user_weights_iter_mem - = memory({weights_dims, dt::f32, tag::ldigo}, engine); + auto user_weights_layer_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::ldigo}, + engine); + auto user_weights_iter_mem = memory( + {weights_dims, memory::data_type::f32, memory::format_tag::ldigo}, + engine); // Write data to memory object's handle. write_to_dnnl_memory(src_layer_data.data(), src_layer_mem); @@ -119,8 +121,10 @@ void vanilla_rnn_example(dnnl::engine::kind engine_kind) { // Create memory descriptors for weights with format_tag::any. This enables // the Vanilla primitive to choose the optimized memory layout. - auto weights_layer_md = memory::desc(weights_dims, dt::f32, tag::any); - auto weights_iter_md = memory::desc(weights_dims, dt::f32, tag::any); + auto weights_layer_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); + auto weights_iter_md = memory::desc( + weights_dims, memory::data_type::f32, memory::format_tag::any); // Optional memory descriptors for recurrent data. // Default memory descriptor for initial hidden states of the GRU cells diff --git a/examples/rnn_training_f32.cpp b/examples/rnn_training_f32.cpp index 42546f3adaa..fbde4aa0fa0 100644 --- a/examples/rnn_training_f32.cpp +++ b/examples/rnn_training_f32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,9 +68,6 @@ const int common_n_layers = 1; const int lstm_n_gates = 4; void simple_net(engine::kind engine_kind) { - using tag = memory::format_tag; - using dt = memory::data_type; - auto eng = engine(engine_kind, 0); stream s(eng); @@ -173,14 +170,14 @@ void simple_net(engine::kind engine_kind) { // Create auxiliary f32 memory descriptor // based on user- supplied dimensions and layout. - auto formatted_md - = [=](const memory::dims &dimensions, memory::format_tag layout) { - return memory::desc {{dimensions}, dt::f32, layout}; - }; + auto formatted_md = [=](const memory::dims &dimensions, + memory::format_tag layout) { + return memory::desc {{dimensions}, memory::data_type::f32, layout}; + }; // Create auxiliary generic f32 memory descriptor // based on supplied dimensions, with format_tag::any. auto generic_md = [=](const memory::dims &dimensions) { - return formatted_md(dimensions, tag::any); + return formatted_md(dimensions, memory::format_tag::any); }; // @@ -203,8 +200,9 @@ void simple_net(engine::kind engine_kind) { // Memory for the user allocated memory // Suppose user data is in tnc format. - auto net_src_memory - = dnnl::memory({{net_src_dims}, dt::f32, tag::tnc}, eng); + auto net_src_memory = dnnl::memory( + {{net_src_dims}, memory::data_type::f32, memory::format_tag::tnc}, + eng); write_to_dnnl_memory(net_src.data(), net_src_memory); // src_layer memory of the leftmost and rightmost RNN primitives // are accessed through the respective sub-memories in larger memory. @@ -222,34 +220,44 @@ void simple_net(engine::kind engine_kind) { // primitive prefers it in a different format. std::vector user_common_weights_layer( tz_volume(common_weights_layer_dims), 1.0f); - auto user_common_weights_layer_memory = dnnl::memory( - {common_weights_layer_dims, dt::f32, tag::ldigo}, eng); + auto user_common_weights_layer_memory + = dnnl::memory({common_weights_layer_dims, memory::data_type::f32, + memory::format_tag::ldigo}, + eng); write_to_dnnl_memory( user_common_weights_layer.data(), user_common_weights_layer_memory); std::vector user_common_weights_iter( tz_volume(common_weights_iter_dims), 1.0f); - auto user_common_weights_iter_memory = dnnl::memory( - {{common_weights_iter_dims}, dt::f32, tag::ldigo}, eng); + auto user_common_weights_iter_memory + = dnnl::memory({{common_weights_iter_dims}, memory::data_type::f32, + memory::format_tag::ldigo}, + eng); write_to_dnnl_memory( user_common_weights_layer.data(), user_common_weights_iter_memory); std::vector user_common_bias(tz_volume(common_bias_dims), 1.0f); auto user_common_bias_memory - = dnnl::memory({{common_bias_dims}, dt::f32, tag::ldgo}, eng); + = dnnl::memory({{common_bias_dims}, memory::data_type::f32, + memory::format_tag::ldgo}, + eng); write_to_dnnl_memory(user_common_bias.data(), user_common_bias_memory); std::vector user_leftmost_dst_layer( tz_volume(leftmost_dst_layer_dims), 1.0f); auto user_leftmost_dst_layer_memory - = dnnl::memory({{leftmost_dst_layer_dims}, dt::f32, tag::tnc}, eng); + = dnnl::memory({{leftmost_dst_layer_dims}, memory::data_type::f32, + memory::format_tag::tnc}, + eng); write_to_dnnl_memory( user_leftmost_dst_layer.data(), user_leftmost_dst_layer_memory); std::vector user_rightmost_dst_layer( tz_volume(rightmost_dst_layer_dims), 1.0f); - auto user_rightmost_dst_layer_memory = dnnl::memory( - {{rightmost_dst_layer_dims}, dt::f32, tag::tnc}, eng); + auto user_rightmost_dst_layer_memory + = dnnl::memory({{rightmost_dst_layer_dims}, memory::data_type::f32, + memory::format_tag::tnc}, + eng); write_to_dnnl_memory( user_rightmost_dst_layer.data(), user_rightmost_dst_layer_memory); @@ -265,7 +273,8 @@ void simple_net(engine::kind engine_kind) { generic_md(common_weights_layer_dims), // weights_layer_desc generic_md(common_weights_iter_dims), // weights_iter_desc generic_md(common_bias_dims), // bias_desc - formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc + formatted_md(leftmost_dst_layer_dims, + memory::format_tag::tnc), // dst_layer_desc generic_md(leftmost_dst_iter_dims), // dst_iter_desc generic_md(leftmost_dst_iter_c_dims) // dst_iter_c_desc ); @@ -304,7 +313,8 @@ void simple_net(engine::kind engine_kind) { generic_md(common_weights_layer_dims), // weights_layer_desc generic_md(common_weights_iter_dims), // weights_iter_desc generic_md(common_bias_dims), // bias_desc - formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc + formatted_md(rightmost_dst_layer_dims, + memory::format_tag::tnc), // dst_layer_desc memory::desc(), // dst_iter_desc memory::desc() // dst_iter_c_desc ); @@ -410,8 +420,8 @@ void simple_net(engine::kind engine_kind) { // User-provided memory for backward by data output std::vector net_diff_src(tz_volume(net_src_dims), 1.0f); - auto net_diff_src_memory - = dnnl::memory(formatted_md(net_src_dims, tag::tnc), eng); + auto net_diff_src_memory = dnnl::memory( + formatted_md(net_src_dims, memory::format_tag::tnc), eng); write_to_dnnl_memory(net_diff_src.data(), net_diff_src_memory); // diff_src follows the same layout we have for net_src @@ -429,13 +439,14 @@ void simple_net(engine::kind engine_kind) { std::vector user_common_diff_weights_layer( tz_volume(common_weights_layer_dims), 1.0f); auto user_common_diff_weights_layer_memory = dnnl::memory( - formatted_md(common_weights_layer_dims, tag::ldigo), eng); + formatted_md(common_weights_layer_dims, memory::format_tag::ldigo), + eng); write_to_dnnl_memory(user_common_diff_weights_layer.data(), user_common_diff_weights_layer_memory); std::vector user_common_diff_bias(tz_volume(common_bias_dims), 1.0f); - auto user_common_diff_bias_memory - = dnnl::memory(formatted_md(common_bias_dims, tag::ldgo), eng); + auto user_common_diff_bias_memory = dnnl::memory( + formatted_md(common_bias_dims, memory::format_tag::ldgo), eng); write_to_dnnl_memory( user_common_diff_bias.data(), user_common_diff_bias_memory); @@ -448,8 +459,8 @@ void simple_net(engine::kind engine_kind) { }; // Suppose user data is in tnc format. std::vector net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f); - auto net_diff_dst_memory - = dnnl::memory(formatted_md(net_diff_dst_dims, tag::tnc), eng); + auto net_diff_dst_memory = dnnl::memory( + formatted_md(net_diff_dst_dims, memory::format_tag::tnc), eng); write_to_dnnl_memory(net_diff_dst.data(), net_diff_dst_memory); // diff_dst_layer memory of the leftmost and rightmost RNN primitives // are accessed through the respective sub-memory in larger memory. @@ -474,7 +485,8 @@ void simple_net(engine::kind engine_kind) { generic_md(common_weights_layer_dims), // weights_layer_desc generic_md(common_weights_iter_dims), // weights_iter_desc generic_md(common_bias_dims), // bias_desc - formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc + formatted_md(leftmost_dst_layer_dims, + memory::format_tag::tnc), // dst_layer_desc generic_md(leftmost_dst_iter_dims), // dst_iter_desc generic_md(leftmost_dst_iter_c_dims), // dst_iter_c_desc user_leftmost_diff_src_layer_md, // diff_src_layer_desc @@ -519,7 +531,8 @@ void simple_net(engine::kind engine_kind) { generic_md(common_weights_layer_dims), // weights_layer_desc generic_md(common_weights_iter_dims), // weights_iter_desc generic_md(common_bias_dims), // bias_desc - formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc + formatted_md(rightmost_dst_layer_dims, + memory::format_tag::tnc), // dst_layer_desc memory::desc(), // dst_iter_desc memory::desc(), // dst_iter_c_desc user_rightmost_diff_src_layer_md, // diff_src_layer_desc diff --git a/examples/sycl_interop_buffer.cpp b/examples/sycl_interop_buffer.cpp index 7ef65de296a..59972ad6cbb 100644 --- a/examples/sycl_interop_buffer.cpp +++ b/examples/sycl_interop_buffer.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ /// @section sycl_interop_buffer_cpp_headers Public headers /// /// To start using oneDNN, we must first include the @ref dnnl.hpp -/// header file in the application. We also include CL/sycl.hpp from DPC++ for +/// header file in the application. We also include sycl/sycl.hpp from DPC++ for /// using SYCL APIs and @ref dnnl_debug.h, which contains some debugging /// facilities such as returning a string representation /// for common oneDNN C types. @@ -56,8 +56,6 @@ #if __has_include() #include -#elif __has_include() -#include #else #error "Unsupported compiler" #endif diff --git a/examples/sycl_interop_usm.cpp b/examples/sycl_interop_usm.cpp index a61d8bbf353..713c05b9ab5 100644 --- a/examples/sycl_interop_usm.cpp +++ b/examples/sycl_interop_usm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,8 +21,6 @@ #if __has_include() #include -#elif __has_include() -#include #else #error "Unsupported compiler" #endif diff --git a/examples/tutorials/matmul/cpu_matmul_quantization.cpp b/examples/tutorials/matmul/cpu_matmul_quantization.cpp index 5fee5ed17de..b7c0264b944 100644 --- a/examples/tutorials/matmul/cpu_matmul_quantization.cpp +++ b/examples/tutorials/matmul/cpu_matmul_quantization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2022 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -154,7 +154,7 @@ void compute_q10n_params(const char *message, const std::vector &v, #ifndef OMIT_WORKAROUND_FOR_SKX // Read more in CPU / Section 1 here: - // https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html + // https://uxlfoundation.github.io/oneDNN/dev_guide_int8_computations.html if (std::is_same::value) max_int /= 2; #endif @@ -203,7 +203,10 @@ int compare_vectors(const std::vector &v1, } // namespace -engine eng(engine::kind::cpu, 0); // We create a global engine for simplicity +const engine &eng() { + static const engine eng(engine::kind::cpu, 0); + return eng; +} // Quantize float data into X_int_m oneDNN memory using the q10n parameters // @@ -216,23 +219,23 @@ engine eng(engine::kind::cpu, 0); // We create a global engine for simplicity // - X_int_m -- prepared oneDNN memory that would hold quantized values void quantize(const std::vector &X_f32, float scale_X, int32_t zp_X, memory &X_int_m) { - using dt = memory::data_type; - - stream s(eng); + stream s(eng()); memory::desc x_int_md = X_int_m.get_desc(); const auto &dims = x_int_md.get_dims(); - memory::desc x_f32_md({dims[0], dims[1]}, dt::f32, {dims[1], 1}); - memory X_f32_m(x_f32_md, eng, (void *)X_f32.data()); + memory::desc x_f32_md( + {dims[0], dims[1]}, memory::data_type::f32, {dims[1], 1}); + memory X_f32_m(x_f32_md, eng(), (void *)X_f32.data()); primitive_attr q10n_attr; q10n_attr.set_scales_mask(DNNL_ARG_DST, /* mask */ 0); q10n_attr.set_zero_points_mask(DNNL_ARG_DST, /* mask */ 0); - reorder::primitive_desc q10n_pd(eng, x_f32_md, eng, x_int_md, q10n_attr); - memory dst_scale_X_m({{1}, dt::f32, {1}}, eng, &scale_X); - memory zp_X_m({{1}, dt::s32, {1}}, eng, &zp_X); + reorder::primitive_desc q10n_pd( + eng(), x_f32_md, eng(), x_int_md, q10n_attr); + memory dst_scale_X_m({{1}, memory::data_type::f32, {1}}, eng(), &scale_X); + memory zp_X_m({{1}, memory::data_type::s32, {1}}, eng(), &zp_X); reorder(q10n_pd).execute(s, {{DNNL_ARG_SRC, X_f32_m}, {DNNL_ARG_DST, X_int_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scale_X_m}, @@ -256,15 +259,15 @@ void f32_matmul_compute(int64_t M, int64_t N, int64_t K, memory::desc c_md({M, N}, memory::data_type::f32, {N, 1}); // Wrap raw pointers into oneDNN memory objects - memory A_f32_m(a_md, eng, (void *)A_f32.data()); - memory B_f32_m(b_md, eng, (void *)B_f32.data()); - memory C_f32_m(c_md, eng, (void *)C_f32.data()); + memory A_f32_m(a_md, eng(), (void *)A_f32.data()); + memory B_f32_m(b_md, eng(), (void *)B_f32.data()); + memory C_f32_m(c_md, eng(), (void *)C_f32.data()); // Create a MatMul primitive - matmul::primitive_desc matmul_pd(eng, a_md, b_md, c_md); + matmul::primitive_desc matmul_pd(eng(), a_md, b_md, c_md); matmul matmul_p(matmul_pd); - stream s(eng); + stream s(eng()); matmul_p.execute(s, {{DNNL_ARG_SRC, A_f32_m}, {DNNL_ARG_WEIGHTS, B_f32_m}, {DNNL_ARG_DST, C_f32_m}}); @@ -281,7 +284,7 @@ void f32_matmul_compute(int64_t M, int64_t N, int64_t K, void dynamic_q10n_matmul(int64_t M, int64_t N, int64_t K, const std::vector &A_f32, const std::vector &B_f32, std::vector &C_u8, float &scale_C, int32_t &zp_C) { - stream s(eng); + stream s(eng()); float scale_A, scale_B; int32_t zp_A, zp_B; @@ -295,13 +298,13 @@ void dynamic_q10n_matmul(int64_t M, int64_t N, int64_t K, // Quantize matrix A_u8 using reorder primitive std::vector A_u8(M * K, 0); memory::desc a_u8_md({M, K}, memory::data_type::u8, {K, 1}); - memory A_u8_m(a_u8_md, eng, (void *)A_u8.data()); + memory A_u8_m(a_u8_md, eng(), (void *)A_u8.data()); quantize(A_f32, scale_A, zp_A, A_u8_m); // Quantize matrix B_s8 using reorder primitive std::vector B_s8(K * N, 0); memory::desc b_s8_md({K, N}, memory::data_type::s8, {N, 1}); - memory B_s8_m(b_s8_md, eng, (void *)B_s8.data()); + memory B_s8_m(b_s8_md, eng(), (void *)B_s8.data()); quantize(B_f32, scale_B, 0, B_s8_m); // Compute C_f32. We cannot directly compute C_u8 since we don't know the @@ -319,7 +322,7 @@ void dynamic_q10n_matmul(int64_t M, int64_t N, int64_t K, std::vector C_f32(M * N, 0); memory::desc c_f32_md({M, N}, memory::data_type::f32, {N, 1}); - memory C_f32_m(c_f32_md, eng, (void *)C_f32.data()); + memory C_f32_m(c_f32_md, eng(), (void *)C_f32.data()); // Create and compute a reduced precision MatMul primitive { @@ -329,12 +332,12 @@ void dynamic_q10n_matmul(int64_t M, int64_t N, int64_t K, matmul_attr.set_zero_points_mask(DNNL_ARG_SRC, /* mask */ 0); matmul::primitive_desc matmul_pd( - eng, a_u8_md, b_s8_md, c_f32_md, matmul_attr); + eng(), a_u8_md, b_s8_md, c_f32_md, matmul_attr); matmul matmul_p(matmul_pd); - memory scales_A_m({{1}, memory::data_type::f32, {1}}, eng, &scale_A); - memory scales_B_m({{1}, memory::data_type::f32, {1}}, eng, &scale_B); - memory zp_A_m({{1}, memory::data_type::s32, {1}}, eng, &zp_A); + memory scales_A_m({{1}, memory::data_type::f32, {1}}, eng(), &scale_A); + memory scales_B_m({{1}, memory::data_type::f32, {1}}, eng(), &scale_B); + memory zp_A_m({{1}, memory::data_type::s32, {1}}, eng(), &zp_A); matmul_p.execute(s, {{DNNL_ARG_SRC, A_u8_m}, {DNNL_ARG_WEIGHTS, B_s8_m}, @@ -349,7 +352,7 @@ void dynamic_q10n_matmul(int64_t M, int64_t N, int64_t K, // Finally quantize the matrix C memory::desc c_u8_md({M, N}, memory::data_type::u8, {N, 1}); - memory C_u8_m(c_u8_md, eng, (void *)C_u8.data()); + memory C_u8_m(c_u8_md, eng(), (void *)C_u8.data()); quantize(C_f32, scale_C, zp_C, C_u8_m); } diff --git a/examples/tutorials/matmul/cpu_sgemm_and_matmul.cpp b/examples/tutorials/matmul/cpu_sgemm_and_matmul.cpp index 643b8a2f473..749a6911608 100644 --- a/examples/tutorials/matmul/cpu_sgemm_and_matmul.cpp +++ b/examples/tutorials/matmul/cpu_sgemm_and_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2022 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -111,7 +111,10 @@ int compare_vectors(const std::vector &v1, const std::vector &v2, int number_of_runs = 1; float fixed_beta = 0.f; -engine eng(engine::kind::cpu, 0); // We create a global engine for simplicity +const engine &eng() { + static const engine eng(engine::kind::cpu, 0); + return eng; +} // Create a _dynamic_ MatMul primitive that can work with arbitrary shapes // and alpha parameters. @@ -143,7 +146,7 @@ matmul dynamic_matmul_create() { } // Create a MatMul primitive - matmul::primitive_desc matmul_pd(eng, a_md, b_md, c_md, attr); + matmul::primitive_desc matmul_pd(eng(), a_md, b_md, c_md, attr); return matmul(matmul_pd); } @@ -164,15 +167,15 @@ void dynamic_matmul_execute(matmul &matmul_p, char transA, char transB, dims b_strides = tolower(transB) == 'n' ? dims {ldb, 1} : dims {1, ldb}; // Wrap raw pointers into oneDNN memories (with proper shapes) - memory A_m({{M, K}, memory::data_type::f32, a_strides}, eng, (void *)A); - memory B_m({{K, N}, memory::data_type::f32, b_strides}, eng, (void *)B); - memory C_m({{M, N}, memory::data_type::f32, {ldc, 1}}, eng, (void *)C); + memory A_m({{M, K}, memory::data_type::f32, a_strides}, eng(), (void *)A); + memory B_m({{K, N}, memory::data_type::f32, b_strides}, eng(), (void *)B); + memory C_m({{M, N}, memory::data_type::f32, {ldc, 1}}, eng(), (void *)C); // Prepare oneDNN memory for alpha - memory alpha_m({{1}, memory::data_type::f32, {1}}, eng, &alpha); + memory alpha_m({{1}, memory::data_type::f32, {1}}, eng(), &alpha); // Execute the MatMul primitive - stream s(eng); + stream s(eng()); matmul_p.execute(s, {{DNNL_ARG_SRC, A_m}, {DNNL_ARG_WEIGHTS, B_m}, {DNNL_ARG_DST, C_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, alpha_m}}); diff --git a/examples/tutorials/matmul/weights_decompression_matmul.cpp b/examples/tutorials/matmul/weights_decompression_matmul.cpp index ead80cfc451..b5d5d465f8a 100644 --- a/examples/tutorials/matmul/weights_decompression_matmul.cpp +++ b/examples/tutorials/matmul/weights_decompression_matmul.cpp @@ -160,10 +160,10 @@ void infer(const matmul &matmul_p, int64_t M, int64_t N, int64_t K, int64_t G, void weights_decompression_matmul(engine::kind engine_kind) { engine eng(engine_kind, 0); - const int64_t K = 96; + const int64_t K = 64; const int64_t N = 1000; const int64_t M = 100; - // Quantization Group size for scales + // Quantization Group size for scales. Must be divisible by 32. const int64_t G = K / 2; auto matmul_pd = matmul_pd_create(M, N, K, G, eng); diff --git a/examples/ukernels/cpu_brgemm.cpp b/examples/ukernels/cpu_brgemm.cpp index 5c119f5c527..2b2bc45c72e 100644 --- a/examples/ukernels/cpu_brgemm.cpp +++ b/examples/ukernels/cpu_brgemm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,9 +36,6 @@ using namespace dnnl; using namespace dnnl::ukernel; -using tag = memory::format_tag; -using dt = memory::data_type; - void brgemm_example() { // Create execution dnnl::engine. Needed for reorders to operate over input @@ -57,17 +54,37 @@ void brgemm_example() { } const memory::dim n_calls = K / K_k; + memory::data_type a_dt = memory::data_type::u8; + memory::data_type b_dt = memory::data_type::s8; + memory::data_type c_dt = memory::data_type::s32; // Accumulator data type. + memory::data_type d_dt = memory::data_type::f32; // Output data type. + + // Query the packing requirement from the ukernel. It's enough to query + // packing requirements once for multiple objects. + // Based on this information, specific `ldb` value can be used, since + // transform has a limited set of values supported. + bool need_pack = false; + try { + need_pack = brgemm::get_B_pack_type(a_dt, b_dt) == pack_type::pack32; + } catch (error &e) { + if (e.status == dnnl_unimplemented) + throw example_allows_unimplemented { + "Kernel is not supported on this platform.\n"}; + + // on any other error just re-throw + throw; + } + const memory::dim lda = K; + // `ldb` for `need_pack = true` must be one of 16, 32, 48, or 64. This + // example doesn't explore options for dividing N into blocks which would + // likely happen for N > 64. + // const memory::dim ldb = need_pack ? N_block : N; const memory::dim ldb = N; const memory::dim ldc = N; // Leading dimension for accumulator. const memory::dim ldd = N; // Leading dimension for an actual output. const memory::dim batch_size = n_calls - 1; - memory::data_type a_dt = dt::u8; - memory::data_type b_dt = dt::s8; - memory::data_type c_dt = dt::s32; // Accumulator data type. - memory::data_type d_dt = dt::f32; // Output data type. - // A, B, and C tensors dimensions. memory::dims A_dims = {M, K}; memory::dims B_dims = {K, N}; @@ -111,11 +128,16 @@ void brgemm_example() { // Create f32 memories. They are used as data holders and reorder into // memories passed to the ukernel. - auto A_f32_md = memory::desc(A_dims, dt::f32, tag::ab); - auto B_f32_md = memory::desc(B_dims, dt::f32, tag::ab); - auto binary_add_f32_md = memory::desc(binary_add_dims, dt::f32, tag::ab); - auto B_scales_f32_md = memory::desc(B_scales_dims, dt::f32, tag::ab); - auto D_f32_md = memory::desc(D_dims, dt::f32, tag::ab); + auto A_f32_md = memory::desc( + A_dims, memory::data_type::f32, memory::format_tag::ab); + auto B_f32_md = memory::desc( + B_dims, memory::data_type::f32, memory::format_tag::ab); + auto binary_add_f32_md = memory::desc( + binary_add_dims, memory::data_type::f32, memory::format_tag::ab); + auto B_scales_f32_md = memory::desc( + B_scales_dims, memory::data_type::f32, memory::format_tag::ab); + auto D_f32_md = memory::desc( + D_dims, memory::data_type::f32, memory::format_tag::ab); auto A_f32_mem = memory(A_f32_md, engine, A_user_data.data()); auto B_f32_mem = memory(B_f32_md, engine, B_user_data.data()); @@ -127,12 +149,14 @@ void brgemm_example() { // Create ukernel memories in requested data types. // Note that all formats are `ab`. - auto A_md = memory::desc(A_dims, a_dt, tag::ab); - auto B_md = memory::desc(B_dims, b_dt, tag::ab); - auto binary_add_md = memory::desc(binary_add_dims, dt::f32, tag::ab); - auto B_scales_md = memory::desc(B_scales_dims, dt::f32, tag::ab); - auto C_md = memory::desc(C_dims, c_dt, tag::ab); - auto D_md = memory::desc(D_dims, d_dt, tag::ab); + auto A_md = memory::desc(A_dims, a_dt, memory::format_tag::ab); + auto B_md = memory::desc(B_dims, b_dt, memory::format_tag::ab); + auto binary_add_md = memory::desc( + binary_add_dims, memory::data_type::f32, memory::format_tag::ab); + auto B_scales_md = memory::desc( + B_scales_dims, memory::data_type::f32, memory::format_tag::ab); + auto C_md = memory::desc(C_dims, c_dt, memory::format_tag::ab); + auto D_md = memory::desc(D_dims, d_dt, memory::format_tag::ab); auto A_mem = memory(A_md, engine); auto B_mem = memory(B_md, engine); @@ -213,7 +237,7 @@ void brgemm_example() { // Specify post-ops for the brgemm object. brg_po.set_post_ops(ldd, d_dt, brgemm_ops); // Specify quantization scales for B. - if (b_dt == dt::s8 || b_dt == dt::u8) { + if (b_dt == memory::data_type::s8 || b_dt == memory::data_type::u8) { brg_po.set_B_scales(/* mask = */ 2); } // Finalize the initialization. @@ -239,12 +263,6 @@ void brgemm_example() { void *B_base_ptr = B_ptr; size_t blocked_B_size = 0; - // Query the packing requirement from the kernel. It's enough to query - // packing requirements from a single object as long as only dimension - // settings change between objects. - // Note: example uses the one that always present regardless of dimensions. - const bool need_pack = brg_po.get_B_pack_type() == pack_type::pack32; - // If packing is needed, create a dedicated object for data transformation. if (need_pack) { // Packing B tensor routine. The BRGeMM ukernel expects B passed in a @@ -312,11 +330,21 @@ void brgemm_example() { params.set_post_ops_args(bin_po_ptrs.data()); params.set_B_scales(B_scales_mem.get_data_handle()); - // An execute call. The difference here is an additional D tensor pointer - // to store final output result after finishing accumulation and post-ops - // application. - brg_po.execute(A_ptr, B_base_ptr, A_B_po_offsets, C_ptr, - D_mem.get_data_handle(), scratchpad.data(), params); + // An execute call. The difference here is when post operations are + // requested, an additional D tensor pointer to store final output result + // after finishing accumulation and post-ops application is required. + // Additionally, a special `params` object with post operations handles + // is required. + // + // If post operations are not defined, the call is invalid, and a special + // API checks the state. + if (brg_po.is_execute_postops_valid()) { + brg_po.execute(A_ptr, B_base_ptr, A_B_po_offsets, C_ptr, + D_mem.get_data_handle(), scratchpad.data(), params); + } else { + brg_po.execute( + A_ptr, B_base_ptr, A_B_po_offsets, C_ptr, scratchpad.data()); + } // Once all computations are done, need to release HW context. brgemm::release_hw_context(); diff --git a/include/oneapi/dnnl/dnnl.h b/include/oneapi/dnnl/dnnl.h index fbc34a49ad5..ab5871f1c2e 100644 --- a/include/oneapi/dnnl/dnnl.h +++ b/include/oneapi/dnnl/dnnl.h @@ -24,6 +24,7 @@ #include "oneapi/dnnl/dnnl_config.h" #include "oneapi/dnnl/dnnl_types.h" #include "oneapi/dnnl/dnnl_version.h" +#include #ifdef __cplusplus extern "C" { @@ -420,6 +421,8 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode( /// otherwise. dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask( dnnl_primitive_attr_t attr, int arg, int mask); +dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_dims( + dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims, dnnl_data_type_t data_type); /// Sets primitive attributes scaling factors for primitive operations for a /// given memory argument. The scaling factors must be passed at execution time @@ -467,6 +470,8 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales( /// otherwise. dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask( dnnl_primitive_attr_t attr, int arg, int mask); +dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_dims( + dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims, dnnl_data_type_t data_type); /// Sets primitive attributes zero points for primitive operations for a given /// memory argument. The zero points must be passed at execution time @@ -499,7 +504,7 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points( /// /// @param attr Primitive attributes. /// @param arg Argument for which rounding mode should be set. -/// @params mode Rounding mode to apply to the argument. +/// @param mode Rounding mode to apply to the argument. /// @returns #dnnl_success on success and a status describing the error /// otherwise. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rounding( @@ -509,12 +514,24 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_rounding( /// /// @param attr Primitive attributes. /// @param arg Argument for which rounding mode query applies. -/// @params mode Output rounding mode. +/// @param mode Output rounding mode. /// @returns #dnnl_success on success and a status describing the error /// otherwise. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rounding( dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t *mode); +dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_compensations( + const_dnnl_primitive_attr_t attr, int *count, int *mask, const int32_t **compensations); + +dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_compensations( + dnnl_primitive_attr_t attr, int count, int mask); + +dnnl_status_t DNNL_API dnnl_primitive_attr_set_input_zero_points( + dnnl_primitive_attr_t attr, int count, int mask); + +dnnl_status_t DNNL_API dnnl_primitive_attr_set_weights_zero_points( + dnnl_primitive_attr_t attr, int count, int mask); + /// Returns primitive attributes post-ops. /// /// @warning @@ -716,6 +733,13 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw( dnnl_data_type_t *dst_data_type, dnnl_dim_t *kernel_size, dnnl_dim_t *stride_size, dnnl_dim_t *padding_l_size); +/// Appends DW convolution post operation to the @p post_ops with given parameters +/// @p weights and @p bias. +/// +/// The kind of this post operation is #dnnl_convolution. +dnnl_status_t DNNL_API dnnl_post_ops_append_dw_conv( + dnnl_post_ops_t post_ops, int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt); + /// Appends a binary post-op. /// /// The kind of this post operation is #dnnl_binary. @@ -795,6 +819,18 @@ dnnl_status_t DNNL_API dnnl_post_ops_append_prelu( dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu( const_dnnl_post_ops_t post_ops, int index, int *mask); +dnnl_status_t DNNL_API dnnl_post_ops_append_depthwise( + dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, size_t offset_size, const size_t* offset); + +dnnl_status_t DNNL_API dnnl_post_ops_append_quantization( + dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, + size_t per_channel_size, const bool* per_channel, + size_t all_default_size, const bool* all_default, + size_t offset_size, const size_t* offset); + +dnnl_status_t DNNL_API dnnl_post_ops_append_binarization( + dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, const float* weights_data, const float* output_mask); + /// @} dnnl_api_attributes /// @} dnnl_api_primitives @@ -898,6 +934,29 @@ dnnl_status_t DNNL_API dnnl_memory_desc_create_with_csr_encoding( dnnl_data_type_t data_type, dnnl_dim_t nnz, dnnl_data_type_t indices_dt, dnnl_data_type_t pointers_dt); +/// Creates a memory descriptor for COO encoding. +/// +/// The created memory descriptor will describe a memory object that +/// contains n+1 buffers for an n-dimensional tensor. +/// The buffers have the following meaning and assigned numbers (index): +/// - 0: values +/// - 1: indices for dimension 0 +/// - 2: indices for dimension 1 ... +/// - n: indices for dimension n-1 +/// +/// @param memory_desc Output memory descriptor. +/// @param ndims Number of dimensions. +/// @param dims Array of dimensions. +/// @param data_type Elements data type. +/// @param nnz Number of non-zero entries. +/// @param indices_dt Data type of indices. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_with_coo_encoding( + dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, + dnnl_data_type_t data_type, dnnl_dim_t nnz, + dnnl_data_type_t indices_dt); + /// Creates a memory descriptor for packed sparse encoding. /// /// The created memory descriptor cannot be used to create a memory @@ -921,6 +980,19 @@ dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding( dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, dnnl_dim_t nnz); #endif +/// Initializes a sparse descriptor. +/// +/// @param memory_desc Output memory descriptor. +/// @param encoding Encoding. +/// @param ndims Number of dimensions. +/// @param dims Array of dimensions. +/// @param data_type Elements data type. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_sparse( + dnnl_memory_desc_t *memory_desc, + dnnl_sparse_encoding_t encoding, int ndims, + const dnnl_dims_t dims, dnnl_data_type_t data_type); /// Creates a memory descriptor for a region inside an area /// described by an existing memory descriptor. @@ -1175,7 +1247,6 @@ size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc); size_t DNNL_API dnnl_memory_desc_get_size_v2( const_dnnl_memory_desc_t memory_desc, int index); #endif - /// Returns the size of data type. /// /// @param data_type Data type. @@ -1228,7 +1299,6 @@ dnnl_status_t DNNL_API dnnl_memory_create_v2(dnnl_memory_t *memory, const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, int nhandles, void **handles); #endif - /// Returns the memory descriptor for a memory object. /// /// @param memory Memory object. @@ -1340,7 +1410,6 @@ dnnl_status_t DNNL_API dnnl_memory_unmap_data( dnnl_status_t DNNL_API dnnl_memory_unmap_data_v2( const_dnnl_memory_t memory, void *mapped_ptr, int index); #endif - /// Returns memory object's data handle. /// /// @param memory Memory object. @@ -1477,7 +1546,7 @@ dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create( /// Creates a primitive descriptor for a binary primitive. /// /// @note -/// Memory descriptors @p src1_desc and @p dst_desc are alloweded to be +/// Memory descriptors @p src1_desc and @p dst_desc are allowed to be /// initialized with #dnnl_format_tag_any or with format_kind set to /// #dnnl_format_kind_any. /// @@ -1504,6 +1573,37 @@ dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create( const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr); +/// Creates a primitive descriptor for a binary primitive with support of +/// ternary operators. +/// +/// @note +/// Memory descriptors @p src1_desc, @p src2_desc and @p dst_desc are +/// allowed to be initialized with #dnnl_format_tag_any or with format_kind +/// set to #dnnl_format_kind_any. +/// +/// @note +/// All memory descriptors must have the same number of dimensions. +/// Element broadcasting is supported for memory descriptor @p src1_desc +/// and is applied to @p src1_desc dimensions that have a size equal to 1. +/// There is no broadcasting support for @p src2_desc. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Algorithm kind. +/// @param src0_desc Source 0 memory descriptor. +/// @param src1_desc Source 1 memory descriptor. +/// @param src2_desc Source memory descriptor for ternary operations. Might +/// be empty. +/// @param dst_desc Destination memory descriptor. +/// @param attr Primitive attributes. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create_v2( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc, + const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t src2_desc, + const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr); + /// @} dnnl_api_binary /// @addtogroup dnnl_api_convolution @@ -2638,6 +2738,10 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams( const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales); +dnnl_status_t DNNL_API dnnl_primitive_attr_set_src_dyn_quant_params( + dnnl_primitive_attr_t attr, uint64_t group_size); +dnnl_status_t DNNL_API dnnl_primitive_attr_get_src_dyn_quant_params( + dnnl_primitive_attr_t attr, uint64_t* group_size); /// @} dnnl_api_attributes /// @addtogroup dnnl_api_rnn diff --git a/include/oneapi/dnnl/dnnl.hpp b/include/oneapi/dnnl/dnnl.hpp index 1dc369eaa05..8e55eabfcd0 100644 --- a/include/oneapi/dnnl/dnnl.hpp +++ b/include/oneapi/dnnl/dnnl.hpp @@ -1,5 +1,6 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation +* Copyright 2024-2025 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +20,7 @@ #ifndef ONEAPI_DNNL_DNNL_HPP #define ONEAPI_DNNL_DNNL_HPP +// NOLINTBEGIN(readability-identifier-naming) #include "oneapi/dnnl/dnnl_config.h" @@ -29,6 +31,7 @@ #include #include #include +#include #include #include "oneapi/dnnl/dnnl.h" @@ -148,6 +151,10 @@ struct primitive : public handle { layer_normalization = dnnl_layer_normalization, /// A group normalization primitive group_normalization = dnnl_group_normalization, + + depthwise = dnnl_depthwise, + quantization = dnnl_quantization, + binarization = dnnl_binarization, }; using handle::handle; @@ -168,7 +175,7 @@ struct primitive : public handle { const std::vector &cache_blob); /// Constructs a primitive from a primitive descriptor. - /// + ///src/common/deconvolution_pd.hpp /// @param pd Primitive descriptor. primitive(const primitive_desc &pd); @@ -406,6 +413,12 @@ enum class algorithm { eltwise_hardswish = dnnl_eltwise_hardswish, /// Elementwise: hardsigmoid eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid, + /// Elementwise: hsigmoid + eltwise_hsigmoid = dnnl_eltwise_hsigmoid, + /// Elementwise: round_half_to_even + eltwise_round_half_to_even = dnnl_eltwise_round_half_to_even, + /// Elementwise: round_half_away_from_zero + eltwise_round_half_away_from_zero = dnnl_eltwise_round_half_away_from_zero, /// Elementwise: rectified linar unit (ReLU) (dst for backward) eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd, /// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward) @@ -470,6 +483,10 @@ enum class algorithm { binary_eq = dnnl_binary_eq, /// Binary not equal binary_ne = dnnl_binary_ne, + /// Binary select + binary_select = dnnl_binary_select, + /// Binary prelu + binary_prelu = dnnl_binary_prelu, /// Nearest Neighbor resampling method resampling_nearest = dnnl_resampling_nearest, /// Linear (Bilinear, Trilinear) resampling method @@ -496,6 +513,13 @@ enum class algorithm { softmax_accurate = dnnl_softmax_accurate, /// LogSoftmax, numerically stable softmax_log = dnnl_softmax_log, + + depthwise_scale_shift = dnnl_depthwise_scale_shift, + depthwise_prelu = dnnl_depthwise_prelu, + + quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize, + quantization_quantize = dnnl_quantization_quantize, + binarization_depthwise = dnnl_binarization_depthwise, }; /// Converts algorithm kind enum value from C++ API to C API type. @@ -831,10 +855,10 @@ struct memory : public handle { using handle::handle; /// Integer type for representing dimension sizes and indices. - typedef dnnl_dim_t dim; + using dim = dnnl_dim_t; /// Vector of dimensions. Implementations are free to force a limit on the /// vector's length. - typedef std::vector dims; + using dims = std::vector; /// Helper function that validates that an `std::vector` of dimensions can /// be safely converted to the C API array ::dnnl_dims_t. Throws if @@ -852,6 +876,10 @@ struct memory : public handle { enum class data_type { /// Undefined data type (used for empty memory descriptors). undef = dnnl_data_type_undef, + /// 4-bit float data type with 3-bit exponent and 0 bit mantissa. + f4_e3m0 = dnnl_f4_e3m0, + /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa. + f4_e2m1 = dnnl_f4_e2m1, /// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent. e8m0 = dnnl_e8m0, /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) @@ -879,6 +907,10 @@ struct memory : public handle { s4 = dnnl_s4, /// 4-bit unsigned integer. u4 = dnnl_u4, + /// 1-bit integer + bin = dnnl_bin, + /// 4-bit normalized float. + nf4 = dnnl_nf4, }; /// Returns size of data type in bytes. @@ -901,6 +933,7 @@ struct memory : public handle { /// Format kind for sparse tensors. sparse = dnnl_format_kind_sparse, #endif + sparsed = dnnl_format_sparse, /// A special format kind that indicates that tensor format is opaque. opaque = dnnl_format_kind_opaque, }; @@ -918,6 +951,8 @@ struct memory : public handle { /// only be used to create a primitive descriptor to query the /// actual memory descriptor (similar to the format tag `any`). packed = dnnl_packed, + /// Coordinate Sparse (COO) encoding. + coo = dnnl_coo, }; #endif @@ -1213,6 +1248,7 @@ struct memory : public handle { AB16b64a2b = dnnl_AB16b64a2b, Ab4a = dnnl_Ab4a, Ab8a = dnnl_Ab8a, + Ab32a = dnnl_Ab32a, Abc16a = dnnl_Abc16a, ABc16a16b = dnnl_ABc16a16b, ABc4a4b = dnnl_ABc4a4b, @@ -1302,6 +1338,7 @@ struct memory : public handle { aBCd4b4c = dnnl_aBCd4b4c, ABcd8a16b2a = dnnl_ABcd8a16b2a, ABcd8a8b = dnnl_ABcd8a8b, + ABcd8a32b = dnnl_ABcd8a32b, ABcd8a4b = dnnl_ABcd8a4b, ABcd8a2b = dnnl_ABcd8a2b, /// 4D tensor blocked by 2nd dimension with block size 8 @@ -1407,6 +1444,7 @@ struct memory : public handle { aBdeC8b4c = dnnl_aBdeC8b4c, aBdefc16b = dnnl_aBdefc16b, aCBdef16c16b = dnnl_aCBdef16c16b, + aCBdef8b8c = dnnl_aCBdef8b8c, aCBdef16b16c = dnnl_aCBdef16b16c, aBdefc4b = dnnl_aBdefc4b, aBdefc8b = dnnl_aBdefc8b, @@ -1417,8 +1455,10 @@ struct memory : public handle { Acb8a = dnnl_Acb8a, AcB8a2b = dnnl_AcB8a2b, AcB8a4b = dnnl_AcB8a4b, + aCBd8b8c = dnnl_aCBd8b8c, aCBd16b16c = dnnl_aCBd16b16c, aCBd16c16b = dnnl_aCBd16c16b, + aCBde8b8c = dnnl_aCBde8b8c, aCBde16b16c = dnnl_aCBde16b16c, aCBde16c16b = dnnl_aCBde16c16b, Acdb16a = dnnl_Acdb16a, @@ -1431,14 +1471,19 @@ struct memory : public handle { Acdeb8a = dnnl_Acdeb8a, AcdeB8a2b = dnnl_AcdeB8a2b, AcdeB8a4b = dnnl_AcdeB8a4b, + BAc8a8b = dnnl_BAc8a8b, BAc16a16b = dnnl_BAc16a16b, BAc16b16a = dnnl_BAc16b16a, + BAcd8a8b = dnnl_BAcd8a8b, BAcd16a16b = dnnl_BAcd16a16b, BAcd16b16a = dnnl_BAcd16b16a, ABcd32a32b = dnnl_ABcd32a32b, BAcde16b16a = dnnl_BAcde16b16a, + BAcde8a8b = dnnl_BAcde8a8b, BAcde16a16b = dnnl_BAcde16a16b, aBdec32b = dnnl_aBdec32b, + Abcdef4a = dnnl_Abcdef4a, + Abcdef8a = dnnl_Abcdef8a, Abcdef16a = dnnl_Abcdef16a, Abcdef32a = dnnl_Abcdef32a, Acdb32a = dnnl_Acdb32a, @@ -1460,10 +1505,12 @@ struct memory : public handle { AB8a2b = dnnl_AB8a2b, abDc16d = dnnl_abDc16d, abDc32d = dnnl_abDc32d, + abDC16d4c = dnnl_abDC16d4c, abDC32d4c = dnnl_abDC32d4c, abCd32c = dnnl_abCd32c, abdEc16e = dnnl_abdEc16e, abdEc32e = dnnl_abdEc32e, + abdEC16e4c = dnnl_abdEC16e4c, abdEC32e2c = dnnl_abdEC32e2c, abdEC32e4c = dnnl_abdEC32e4c, abdCe16c = dnnl_abdCe16c, @@ -1596,6 +1643,9 @@ struct memory : public handle { BA16a32b4a = dnnl_BA16a32b4a, BA16a48b4a = dnnl_BA16a48b4a, BA16a64b4a = dnnl_BA16a64b4a, + BA24b8a = dnnl_BA24b8a, + aCB24c8b = dnnl_aCB24c8b, + abDC24d8c = dnnl_abDC24d8c, decbA16a = dnnl_decbA16a, decbA8a = dnnl_decbA8a, defcbA16a = dnnl_defcbA16a, @@ -1686,7 +1736,10 @@ struct memory : public handle { IOdhw16i16o = dnnl_IOdhw16i16o, gIOhw16i16o = dnnl_gIOhw16i16o, gOhwi32o = dnnl_gOhwi32o, + Goidhw4g = dnnl_Goidhw4g, + Goidhw8g = dnnl_Goidhw8g, Goidhw16g = dnnl_Goidhw16g, + IOw8o8i = dnnl_IOw8o8i, IOw16o16i = dnnl_IOw16o16i, OIw16i16o = dnnl_OIw16i16o, OwI16i16o = dnnl_OwI16i16o, @@ -1743,6 +1796,7 @@ struct memory : public handle { Owi8o = dnnl_Owi8o, OwI8o2i = dnnl_OwI8o2i, OwI8o4i = dnnl_OwI8o4i, + IOhw8o8i = dnnl_IOhw8o8i, IOhw16o16i = dnnl_IOhw16o16i, Ohwi16o = dnnl_Ohwi16o, OhwI16o2i = dnnl_OhwI16o2i, @@ -1786,8 +1840,11 @@ struct memory : public handle { OhwI8i8o = dnnl_OhwI8i8o, OIhw8o16i2o = dnnl_OIhw8o16i2o, OIhw8o8i = dnnl_OIhw8o8i, + OIhw8o32i = dnnl_OIhw8o32i, + OIhw16o32i = dnnl_OIhw16o32i, OIhw8o4i = dnnl_OIhw8o4i, OIhw2i8o4i = dnnl_OIhw2i8o4i, + IOdhw8o8i = dnnl_IOdhw8o8i, IOdhw16o16i = dnnl_IOdhw16o16i, Odhwi16o = dnnl_Odhwi16o, OdhwI16o2i = dnnl_OdhwI16o2i, @@ -1841,6 +1898,7 @@ struct memory : public handle { OdhwI8i8o = dnnl_OdhwI8i8o, OIdhw8o8i = dnnl_OIdhw8o8i, OIdhw8o4i = dnnl_OIdhw8o4i, + gIOw8o8i = dnnl_gIOw8o8i, gIOw16o16i = dnnl_gIOw16o16i, gOIw16i16o = dnnl_gOIw16i16o, gOIw16o16i = dnnl_gOIw16o16i, @@ -1869,6 +1927,7 @@ struct memory : public handle { gOwI8o4i = dnnl_gOwI8o4i, Goiw8g = dnnl_Goiw8g, Goiw16g = dnnl_Goiw16g, + gIOhw8o8i = dnnl_gIOhw8o8i, gIOhw16o16i = dnnl_gIOhw16o16i, gOhwi16o = dnnl_gOhwi16o, gOhwI16o2i = dnnl_gOhwI16o2i, @@ -1915,6 +1974,7 @@ struct memory : public handle { gOIhw8o8i = dnnl_gOIhw8o8i, gOIhw8o4i = dnnl_gOIhw8o4i, gIOdhw16i16o = dnnl_gIOdhw16i16o, + gIOdhw8o8i = dnnl_gIOdhw8o8i, gIOdhw16o16i = dnnl_gIOdhw16o16i, gOdhwi16o = dnnl_gOdhwi16o, gOdhwI16o2i = dnnl_gOdhwI16o2i, @@ -1955,8 +2015,10 @@ struct memory : public handle { ldOi16o = abDc16d, ldOi32o = abDc32d, + ldOI16o4i = abDC16d4c, ldOI32o4i = abDC32d4c, ldgOi16o = abdEc16e, + ldgOI16o4i = abdEC16e4c, ldgOi32o = abdEc32e, ldgOI32o2i = abdEC32e2c, ldgOI32o4i = abdEC32e4c, @@ -2721,7 +2783,6 @@ struct memory : public handle { /// A memory descriptor. struct desc : public handle { using handle::handle; - friend struct memory; /// Constructs a zero (empty) memory descriptor. Such a memory @@ -2828,6 +2889,38 @@ struct memory : public handle { return desc {md}; } + /// Function for creating a memory descriptor for COO sparse encodings. + /// + /// The created memory descriptor will describe a memory object that + /// contains n+1 buffers for an n-dimensional tensor. + /// The buffers have the following meaning and assigned numbers (index): + /// - 0: values + /// - 1: indices for dimension 0 + /// - 2: indices for dimension 1 ... + /// - n: indices for dimension n-1 + /// + /// @param adims Tensor dimensions. + /// @param adata_type Data precision/type. + /// @param nnz Number of non-zero entries. + /// @param index_dt Data type of indices. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be constructed. This flag is + /// optional and defaults to false. + static desc coo(const dims &adims, data_type adata_type, dim nnz, + data_type index_dt, bool allow_empty = false) { + validate_dims(adims); + dnnl_memory_desc_t md = nullptr; + dnnl_status_t status = dnnl_memory_desc_create_with_coo_encoding( + &md, (int)adims.size(), adims.data(), + convert_to_c(adata_type), nnz, convert_to_c(index_dt)); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a memory descriptor for COO sparse " + "encoding"); + return desc {md}; + } + /// Function for creating a memory descriptor for packed sparse /// encoding. /// @@ -2880,6 +2973,31 @@ struct memory : public handle { reset(md); } + /// @fork + /// Copy constructor for memory::desc + /// Ensures deep copy (underlying C structure is copied as well) + /// To preserve behavior of 2.x oneDNN versions + /// + /// @param desc memory descriptor to copy. + desc(const memory::desc& adesc) { + auto cdesc = adesc.get(); + dnnl_memory_desc_t cloned_md = nullptr; + dnnl_memory_desc_clone(&cloned_md, cdesc); + + reset(cloned_md); + } + + desc sparse_desc(const dims &adims, data_type adata_type, + bool allow_empty = false) { + dnnl_memory_desc_t md = nullptr; + dnnl_status_t status = dnnl_memory_desc_create_sparse(&md, dnnl_sparse_encoding_packed, + (int)adims.size(), adims.data(), convert_to_c(adata_type)); + + if (!allow_empty) + error::wrap_c_api(status, + "could not construct a memory descriptor with sparse format"); + return desc(md); + } /// Constructs a memory descriptor for a region inside an area /// described by this memory descriptor. // @@ -3128,9 +3246,9 @@ struct memory : public handle { /// Returns the data type of the memory descriptor. /// /// @returns The data type. - memory::data_type get_data_type() const { - return query_data_type(query::data_type); - } + // memory::data_type get_data_type() const { + // return query_data_type(query::data_type); + // } #endif /// Returns the format kind of the memory descriptor. @@ -3145,6 +3263,30 @@ struct memory : public handle { : dnnl::memory::format_kind::undef; } + /// Returns the format kind of the memory descriptor. + /// + /// @returns the format kind. + dnnl_sparse_encoding_t get_sparse_encoding() const { + dnnl_sparse_encoding_t sparse_encoding; + dnnl_status_t status = dnnl_memory_desc_query( + get(), dnnl_query_sparse_encoding, &sparse_encoding); + return status == dnnl_success + ? sparse_encoding + : dnnl_sparse_encoding_undef; + } + + /// Returns the data type of the memory descriptor. + /// + /// @returns The data type. + memory::data_type get_data_type() const { + dnnl_data_type_t data_type; + dnnl_status_t status = dnnl_memory_desc_query( + get(), dnnl_query_data_type, &data_type); + return status == dnnl_success + ? static_cast(data_type) + : dnnl::memory::data_type::undef; + } + /// Returns dimensions of the memory descriptor. /// /// Potentially expensive due to the data copy involved. @@ -3322,6 +3464,44 @@ struct memory : public handle { reset(result); } #else + /// Constructs a memory object. + /// + /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory + /// object will have the underlying buffer set. In this case, the buffer + /// will be initialized as if #dnnl::memory::set_data_handle() had been + /// called. + /// + /// @sa memory::set_data_handle() + /// + /// @param md Memory descriptor. + /// @param aengine Engine to store the data on. + /// @param handle Handle of the memory buffer to use. + /// - A pointer to the user-allocated buffer. In this case the library + /// doesn't own the buffer. + /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to + /// allocate the buffer for the memory object. In this case the + /// library owns the buffer. + /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying + /// buffer. + // memory(const desc &md, const engine &aengine, void *handle) { + // dnnl_memory_t result; + // error::wrap_c_api( + // dnnl_memory_create(&result, md.get(), aengine.get(), handle), + // "could not create a memory object"); + // reset(result); + // } + + /// Constructs a memory object. + /// + /// The underlying buffer(s) for the memory will be allocated by the + /// library. + /// + /// @param md Memory descriptor. + /// @param aengine Engine to store the data on. + // memory(const desc &md, const engine &aengine) + // : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {} +#endif + /// Constructs a memory object. /// /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory @@ -3349,15 +3529,11 @@ struct memory : public handle { reset(result); } - /// Constructs a memory object. - /// /// The underlying buffer for the memory will be allocated by the library. - /// /// @param md Memory descriptor. /// @param aengine Engine to store the data on. memory(const desc &md, const engine &aengine) : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {} -#endif /// Returns the associated memory descriptor. desc get_desc() const { @@ -3805,6 +3981,12 @@ struct post_ops : public handle { "could not append a binary post-op"); } + void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt) { + error::wrap_c_api(dnnl_post_ops_append_dw_conv(get(), + in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt), + "could not append dw conv"); + } + /// Returns the parameters of a binary post-op. /// /// @param index Index of the binary post-op. @@ -3889,6 +4071,23 @@ struct post_ops : public handle { error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask), "could not get parameters of a binary post-op"); } + + void append_depthwise(algorithm alg, const std::array& offset) { + error::wrap_c_api(dnnl_post_ops_append_depthwise(get(), convert_to_c(alg), offset.size(), offset.data()), + "could not append depthwise"); + } + + void append_quantization(algorithm alg, const std::array& per_channel, const std::array& all_default, + const std::array& offset) { + error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), per_channel.size(), per_channel.data(), + all_default.size(), all_default.data(), offset.size(), offset.data()), + "could not append quantization"); + } + + void append_binarization(algorithm alg, const float* weights_data, const float* output_mask) { + error::wrap_c_api(dnnl_post_ops_append_binarization(get(), convert_to_c(alg), weights_data, output_mask), + "could not append binarization"); + } }; /// @cond DO_NOT_DOCUMENT_THIS @@ -4067,6 +4266,10 @@ struct primitive_attr : public handle { error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask), "could not set scales primitive attribute"); } + void set_scales_dims(int arg, const memory::dims& dims, memory::data_type data_type = memory::data_type::f32) { + error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size(), memory::convert_to_c(data_type)), + "could not set scales primitive attribute"); + } /// Sets scaling factors for primitive operations for a given memory /// argument. The scaling factors must be passed at execution time @@ -4112,6 +4315,11 @@ struct primitive_attr : public handle { dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask), "could not set zero points primitive attribute"); } + void set_zero_points_dims(int arg, const memory::dims& dims, memory::data_type dt) { + error::wrap_c_api( + dnnl_primitive_attr_set_zero_points_dims(get(), arg, dims.data(), dims.size(), memory::convert_to_c(dt)), + "could not set zero points primitive attribute"); + } /// Sets zero points for primitive operations for a given memory argument. /// The zero points must be passed at execution time as an argument with @@ -4139,10 +4347,28 @@ struct primitive_attr : public handle { "could not set zero points primitive attribute"); } + void set_output_compensations(dnnl_dim_t count, int mask) + { + error::wrap_c_api(dnnl_primitive_attr_set_output_compensations(get(), count, mask), + "could not set int output compensations"); + } + + void set_input_zero_points(dnnl_dim_t count, int mask) + { + error::wrap_c_api(dnnl_primitive_attr_set_input_zero_points(get(), count, mask), + "could not set int input zero_points"); + } + + void set_weights_zero_points(dnnl_dim_t count, int mask) + { + error::wrap_c_api(dnnl_primitive_attr_set_weights_zero_points(get(), count, mask), + "could not set int weights zero_points"); + } + /// Returns post-ops previously set via set_post_ops(). /// /// @returns Post-ops. - const post_ops get_post_ops() const { + post_ops get_post_ops() const { const_dnnl_post_ops_t const_c_post_ops; error::wrap_c_api( dnnl_primitive_attr_get_post_ops(get(), &const_c_post_ops), @@ -4161,7 +4387,7 @@ struct primitive_attr : public handle { /// by the respective primitive descriptor constructor. /// /// @param ops Post-ops object to copy post-ops from. - void set_post_ops(const post_ops ops) { + void set_post_ops(const post_ops &ops) { error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()), "could not set post-ops primitive attribute"); } @@ -4362,6 +4588,16 @@ struct primitive_attr : public handle { for (dnnl_dim_t c = 0; c < count; c++) scales[c] = c_scales[c]; } + + void set_src_dyn_quant_params(uint64_t group_size) { + error::wrap_c_api(dnnl_primitive_attr_set_src_dyn_quant_params(get(), group_size), + "could not set src dynamic quantization parameters primitive attribute"); + } + + void get_src_dyn_quant_params(uint64_t& group_size) const { + error::wrap_c_api(dnnl_primitive_attr_get_src_dyn_quant_params(get(), &group_size), + "could not get src dynamic quantization parameters primitive attribute"); + } }; /// @} dnnl_api_attributes @@ -5002,8 +5238,10 @@ struct reorder : public primitive { dst_engine.get(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a reorder " - "primitive"); + "could not create a primitive descriptor for " + "the reorder primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); } @@ -5030,8 +5268,10 @@ struct reorder : public primitive { dst.get_engine().get(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a reorder " - "primitive"); + "could not create a primitive descriptor for " + "the reorder primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); } @@ -5155,8 +5395,10 @@ struct concat : public primitive { concat_dimension, c_srcs.data(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a concat " - "primitive"); + "could not create a primitive descriptor for " + "the concat primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); } @@ -5189,8 +5431,10 @@ struct concat : public primitive { concat_dimension, c_api_srcs.data(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a concat " - "primitive"); + "could not create a primitive descriptor for " + "the concat primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); } @@ -5271,8 +5515,10 @@ struct sum : public primitive { scales.data(), c_api_srcs.data(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a sum " - "primitive"); + "could not create a primitive descriptor for " + "the sum primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); } @@ -5306,8 +5552,10 @@ struct sum : public primitive { scales.data(), c_api_srcs.data(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a sum " - "primitive"); + "could not create a primitive descriptor for " + "the sum primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); } @@ -5647,8 +5895,10 @@ struct convolution_forward : public primitive { &padding_l[0], &padding_r[0], attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "convolution forward propagation primitive"); + "could not create a primitive descriptor for " + "the convolution forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -5839,8 +6089,10 @@ struct convolution_backward_data : public primitive { hint_fwd_pd.get(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "convolution backward propagation primitive"); + "could not create a primitive descriptor for " + "the convolution backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -6145,8 +6397,10 @@ struct convolution_backward_weights : public primitive { &padding_r[0], hint_fwd_pd.get(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "convolution weights update primitive"); + "could not create a primitive descriptor for " + "the convolution weights update primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -6441,8 +6695,10 @@ struct deconvolution_forward : public primitive { &padding_l[0], &padding_r[0], attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "deconvolution forward propagation primitive"); + "could not create a primitive descriptor for " + "the deconvolution forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -6631,8 +6887,10 @@ struct deconvolution_backward_data : public primitive { hint_fwd_pd.get(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "deconvolution backward propagation primitive"); + "could not create a primitive descriptor for " + "the deconvolution backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -6930,8 +7188,10 @@ struct deconvolution_backward_weights : public primitive { &padding_r[0], hint_fwd_pd.get(), attr.get()); if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "deconvolution weights update primitive"); + "could not create a primitive descriptor for " + "the deconvolution weights update primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -7009,8 +7269,10 @@ struct lrn_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a lrn " - "forward propagation primitive"); + "could not create a primitive descriptor for " + "the lrn forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(pd); } @@ -7116,8 +7378,10 @@ struct lrn_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a lrn " - "backward propagation primitive"); + "could not create a primitive descriptor for " + "the lrn backward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(pd); } @@ -7329,8 +7593,10 @@ struct eltwise_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for an " - "eltwise forward propagation primitive"); + "could not create a primitive descriptor for " + "the eltwise forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -7506,8 +7772,10 @@ struct eltwise_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for an " - "eltwise backward propagation primitive"); + "could not create a primitive descriptor for " + "the eltwise backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -7579,8 +7847,10 @@ struct softmax_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a softmax " - "forward propagation primitive"); + "could not create a primitive descriptor for " + "the softmax forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } @@ -7670,8 +7940,10 @@ struct softmax_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a softmax " - "backward propagation primitive"); + "could not create a primitive descriptor for " + "the softmax backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } @@ -7788,8 +8060,11 @@ struct batch_normalization_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a batch " - "normalization forward propagation primitive"); + "could not create a primitive descriptor for " + "the batch normalization forward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); reset(pd); } @@ -7916,8 +8191,11 @@ struct batch_normalization_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a batch " - "normalization backward propagation primitive"); + "could not create a primitive descriptor for " + "the batch normalization backward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); reset(pd); } @@ -8061,8 +8339,11 @@ struct group_normalization_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a group " - "normalization forward propagation primitive"); + "could not create a primitive descriptor for " + "the group normalization forward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); reset(pd); } @@ -8193,8 +8474,11 @@ struct group_normalization_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a group " - "normalization backward propagation primitive"); + "could not create a primitive descriptor for " + "the group normalization backward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); reset(pd); } @@ -8499,8 +8783,11 @@ struct layer_normalization_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a layer " - "normalization forward propagation primitive"); + "could not create a primitive descriptor for " + "the layer normalization forward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); reset(pd); } }; @@ -8768,8 +9055,11 @@ struct layer_normalization_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a layer " - "normalization backward propagation primitive"); + "could not create a primitive descriptor for " + "the layer normalization backward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); reset(pd); } }; @@ -8908,8 +9198,10 @@ struct inner_product_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for an inner " - "product forward propagation primitive"); + "could not create a primitive descriptor for " + "the inner product forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -8975,8 +9267,10 @@ struct inner_product_backward_data : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for an inner " - "product backward propagation primitive"); + "could not create a primitive descriptor for " + "the inner product backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } @@ -9136,8 +9430,10 @@ struct inner_product_backward_weights : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for an inner " - "product weights gradient primitive"); + "could not create a primitive descriptor for " + "the inner product weights gradient primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -9438,8 +9734,10 @@ struct rnn_primitive_desc_base : public primitive_desc { weights_iter_desc.get(), bias_desc.get(), dst_layer_desc.get(), dst_iter_desc.get(), convert_to_c(flags), alpha, beta, attr.get()); - msg = "could not create a primitive descriptor for a vanilla " - "RNN forward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the vanilla RNN forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; break; case algorithm::vanilla_lstm: status = dnnl_lstm_forward_primitive_desc_create(&pd, @@ -9452,8 +9750,10 @@ struct rnn_primitive_desc_base : public primitive_desc { dst_layer_desc.get(), dst_iter_desc.get(), optional_arg(dst_iter_c_desc), convert_to_c(flags), attr.get()); - msg = "could not create a primitive descriptor for an LSTM " - "forward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the LSTM forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; break; case algorithm::vanilla_gru: status = dnnl_gru_forward_primitive_desc_create(&pd, @@ -9463,8 +9763,10 @@ struct rnn_primitive_desc_base : public primitive_desc { weights_iter_desc.get(), bias_desc.get(), dst_layer_desc.get(), dst_iter_desc.get(), convert_to_c(flags), attr.get()); - msg = "could not create a primitive descriptor for a GRU " - "forward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the GRU forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; break; case algorithm::lbr_gru: status = dnnl_lbr_gru_forward_primitive_desc_create(&pd, @@ -9474,8 +9776,10 @@ struct rnn_primitive_desc_base : public primitive_desc { weights_iter_desc.get(), bias_desc.get(), dst_layer_desc.get(), dst_iter_desc.get(), convert_to_c(flags), attr.get()); - msg = "could not create a primitive descriptor for an LBR GRU " - "forward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the LBR GRU forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; break; case algorithm::vanilla_augru: status = dnnl_augru_forward_primitive_desc_create(&pd, @@ -9485,8 +9789,10 @@ struct rnn_primitive_desc_base : public primitive_desc { weights_layer_desc.get(), weights_iter_desc.get(), bias_desc.get(), dst_layer_desc.get(), dst_iter_desc.get(), convert_to_c(flags), attr.get()); - msg = "could not create a primitive descriptor for an AUGRU " - "forward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the AUGRU forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; break; case algorithm::lbr_augru: status = dnnl_lbr_augru_forward_primitive_desc_create(&pd, @@ -9496,8 +9802,10 @@ struct rnn_primitive_desc_base : public primitive_desc { weights_layer_desc.get(), weights_iter_desc.get(), bias_desc.get(), dst_layer_desc.get(), dst_iter_desc.get(), convert_to_c(flags), attr.get()); - msg = "could not create a primitive descriptor for an LBR " - "AUGRU forward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the LBR AUGRU forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; break; default: status = dnnl_unimplemented; } @@ -9555,8 +9863,10 @@ struct rnn_primitive_desc_base : public primitive_desc { diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), convert_to_c(flags), alpha, beta, hint_fwd_pd.get(), attr.get()); - msg = "could not create a primitive descriptor for a vanilla " - "RNN backward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the vanilla RNN backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; break; case algorithm::vanilla_lstm: status = dnnl_lstm_backward_primitive_desc_create(&pd, @@ -9578,8 +9888,10 @@ struct rnn_primitive_desc_base : public primitive_desc { diff_dst_iter_desc.get(), optional_arg(diff_dst_iter_c_desc), convert_to_c(flags), hint_fwd_pd.get(), attr.get()); - msg = "could not create a primitive descriptor for an LSTM " - "backward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the LSTM backward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; break; case algorithm::vanilla_gru: status = dnnl_gru_backward_primitive_desc_create(&pd, @@ -9593,8 +9905,10 @@ struct rnn_primitive_desc_base : public primitive_desc { diff_weights_iter_desc.get(), diff_bias_desc.get(), diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), convert_to_c(flags), hint_fwd_pd.get(), attr.get()); - msg = "could not create a primitive descriptor for a GRU " - "backward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the GRU backward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; break; case algorithm::lbr_gru: status = dnnl_lbr_gru_backward_primitive_desc_create(&pd, @@ -9608,8 +9922,10 @@ struct rnn_primitive_desc_base : public primitive_desc { diff_weights_iter_desc.get(), diff_bias_desc.get(), diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), convert_to_c(flags), hint_fwd_pd.get(), attr.get()); - msg = "could not create a primitive descriptor for an LBR GRU " - "backward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the LBR GRU backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; break; case algorithm::vanilla_augru: status = dnnl_augru_backward_primitive_desc_create(&pd, @@ -9625,8 +9941,10 @@ struct rnn_primitive_desc_base : public primitive_desc { diff_weights_iter_desc.get(), diff_bias_desc.get(), diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), convert_to_c(flags), hint_fwd_pd.get(), attr.get()); - msg = "could not create a primitive descriptor for an AUGRU " - "backward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the AUGRU backward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; break; case algorithm::lbr_augru: status = dnnl_lbr_augru_backward_primitive_desc_create(&pd, @@ -9642,8 +9960,10 @@ struct rnn_primitive_desc_base : public primitive_desc { diff_weights_iter_desc.get(), diff_bias_desc.get(), diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), convert_to_c(flags), hint_fwd_pd.get(), attr.get()); - msg = "could not create a primitive descriptor for an LBR " - "AUGRU backward propagation primitive"; + msg = "could not create a primitive descriptor for " + "the LBR AUGRU backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; break; default: status = dnnl_unimplemented; } @@ -12381,8 +12701,10 @@ struct shuffle_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a shuffle " - "forward propagation primitive"); + "could not create a primitive descriptor for " + "the shuffle forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } @@ -12468,8 +12790,10 @@ struct shuffle_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a shuffle " - "backward propagation primitive"); + "could not create a primitive descriptor for " + "the shuffle backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } @@ -12560,8 +12884,46 @@ struct binary : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a binary " - "operation primitive"); + "could not create a primitive descriptor for " + "the binary operation primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for an elementwise binary operator + /// primitive with support of ternary operators. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Elementwise binary algorithm. + /// @param src0 Memory descriptor for source tensor #0. + /// @param src1 Memory descriptor for source tensor #1. + /// @param src2 Memory descriptor for source tensor #2 for ternary + /// operations. Might be empty. + /// @param dst Memory descriptor for destination tensor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src0, const memory::desc &src1, + const memory::desc &src2, const memory::desc &dst, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_binary_primitive_desc_create_v2(&pd, + aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(), + src1.get(), src2.get(), dst.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the binary v2 operation primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(pd); } @@ -12581,6 +12943,9 @@ struct binary : public primitive { /// Returns the memory descriptor for source #1. memory::desc src1_desc() const { return base::src_desc(1); } + /// Returns the memory descriptor for source #2. + memory::desc src2_desc() const { return base::src_desc(2); } + /// @copydoc dnnl::primitive_desc_base::dst_desc()const memory::desc dst_desc() const { return base::dst_desc(0); } @@ -12700,8 +13065,10 @@ struct matmul : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a matmul " - "primitive"); + "could not create a primitive descriptor for " + "the matmul primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(pd); } }; @@ -12863,8 +13230,10 @@ struct resampling_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "resampling forward propagation primitive"); + "could not create a primitive descriptor for " + "the resampling forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -12987,8 +13356,10 @@ struct resampling_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "resampling backward propagation primitive"); + "could not create a primitive descriptor for " + "the resampling backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } }; @@ -13324,8 +13695,10 @@ struct prelu_forward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a prelu " - "forward propagation primitive"); + "could not create a primitive descriptor for " + "the prelu forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(pd); } @@ -13409,8 +13782,10 @@ struct prelu_backward : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a prelu " - "backward propagation primitive"); + "could not create a primitive descriptor for " + "the prelu backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); reset(pd); } @@ -13509,8 +13884,10 @@ struct reduction : public primitive { if (!allow_empty) error::wrap_c_api(status, - "could not create a primitive descriptor for a " - "reduction primitive descriptor"); + "could not create a primitive descriptor for " + "the reduction primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); reset(pd); } @@ -13864,4 +14241,5 @@ namespace dnnl = ::dnnl; /// @} dnnl_api +// NOLINTEND(readability-identifier-naming) #endif /* ONEAPI_DNNL_DNNL_HPP */ diff --git a/include/oneapi/dnnl/dnnl_common.hpp b/include/oneapi/dnnl/dnnl_common.hpp index 562f2d4aaa3..1112f863c32 100644 --- a/include/oneapi/dnnl/dnnl_common.hpp +++ b/include/oneapi/dnnl/dnnl_common.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2024 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #ifndef ONEAPI_DNNL_DNNL_COMMON_HPP #define ONEAPI_DNNL_DNNL_COMMON_HPP +// NOLINTBEGIN(readability-identifier-naming) /// @cond DO_NOT_DOCUMENT_THIS #include @@ -127,7 +128,7 @@ template > struct handle { private: static dnnl_status_t dummy_destructor(T) { return dnnl_success; } - std::shared_ptr::type> data_ {0}; + std::shared_ptr::type> data_ {nullptr}; protected: bool operator==(const T other) const { return other == data_.get(); } @@ -370,6 +371,7 @@ struct stream : public handle { } }; +//NOLINTBEGIN(bugprone-macro-parentheses) #define DNNL_DEFINE_BITMASK_OPS(enum_name) \ inline enum_name operator|(enum_name lhs, enum_name rhs) { \ return static_cast( \ @@ -407,6 +409,7 @@ struct stream : public handle { inline enum_name operator~(enum_name rhs) { \ return static_cast(~static_cast(rhs)); \ } +//NOLINTEND(bugprone-macro-parentheses) DNNL_DEFINE_BITMASK_OPS(stream::flags) @@ -476,4 +479,5 @@ inline dnnl_accumulation_mode_t convert_to_c(accumulation_mode mode) { /// @} dnnl_api -#endif +// NOLINTEND(readability-identifier-naming) +#endif /* ONEAPI_DNNL_DNNL_COMMON_HPP */ diff --git a/include/oneapi/dnnl/dnnl_common_types.h b/include/oneapi/dnnl/dnnl_common_types.h index 5b6348ebae7..56ac2a8ecf3 100644 --- a/include/oneapi/dnnl/dnnl_common_types.h +++ b/include/oneapi/dnnl/dnnl_common_types.h @@ -104,6 +104,14 @@ typedef enum { dnnl_u4 = 12, /// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent. dnnl_e8m0 = 13, + /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa. + dnnl_f4_e2m1 = 14, + /// 4-bit float data type with 3-bit exponent and 0 bit mantissa. + dnnl_f4_e3m0 = 15, + /// 4-bit normalized float. + dnnl_nf4 = 16, + /// 1-bit integer. + dnnl_bin = 17, /// Parameter to allow internal only data_types without undefined behavior. /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. diff --git a/include/oneapi/dnnl/dnnl_config.h.in b/include/oneapi/dnnl/dnnl_config.h.in index f2ba61b6511..af74c13a072 100644 --- a/include/oneapi/dnnl/dnnl_config.h.in +++ b/include/oneapi/dnnl/dnnl_config.h.in @@ -70,6 +70,9 @@ /// TBB runtime (CPU only) #define DNNL_RUNTIME_TBB 4u +/// TBB runtime with auto partitioning (CPU only) +#define DNNL_RUNTIME_TBB_AUTO 5u + /// Threadpool runtime (CPU only) #define DNNL_RUNTIME_THREADPOOL 8u @@ -222,6 +225,7 @@ #cmakedefine01 BUILD_XEHPG #cmakedefine01 BUILD_XEHPC #cmakedefine01 BUILD_XE2 +#cmakedefine01 BUILD_XE3 // GeMM kernels ISA controls #cmakedefine01 BUILD_GEMM_KERNELS_ALL #cmakedefine01 BUILD_GEMM_KERNELS_NONE diff --git a/include/oneapi/dnnl/dnnl_debug.h b/include/oneapi/dnnl/dnnl_debug.h index 9efa63dd61e..14b7fb596e4 100644 --- a/include/oneapi/dnnl/dnnl_debug.h +++ b/include/oneapi/dnnl/dnnl_debug.h @@ -44,6 +44,7 @@ const char DNNL_API *dnnl_fmt_tag2str(dnnl_format_tag_t v); const char DNNL_API *dnnl_prop_kind2str(dnnl_prop_kind_t v); const char DNNL_API *dnnl_prim_kind2str(dnnl_primitive_kind_t v); const char DNNL_API *dnnl_alg_kind2str(dnnl_alg_kind_t v); +const char DNNL_API *dnnl_sparse_encoding2str(dnnl_sparse_encoding_t v); const char DNNL_API *dnnl_rnn_flags2str(dnnl_rnn_flags_t v); const char DNNL_API *dnnl_rnn_direction2str(dnnl_rnn_direction_t v); const char DNNL_API *dnnl_scratchpad_mode2str(dnnl_scratchpad_mode_t v); diff --git a/include/oneapi/dnnl/dnnl_graph.h b/include/oneapi/dnnl/dnnl_graph.h index a0d465982ca..77f7b46b48f 100644 --- a/include/oneapi/dnnl/dnnl_graph.h +++ b/include/oneapi/dnnl/dnnl_graph.h @@ -590,6 +590,28 @@ dnnl_status_t DNNL_API dnnl_graph_graph_create_with_fpmath_mode( /// otherwise. dnnl_status_t DNNL_API dnnl_graph_graph_destroy(dnnl_graph_graph_t graph); +/// Set the floating point math mode for a graph. +/// +/// @param graph The target graph. +/// @param mode The floating-point math mode. +/// @param apply_to_int The flag that controls whether to use floating-point +/// arithmetic for integral operations. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_set_fpmath_mode( + dnnl_graph_graph_t graph, dnnl_fpmath_mode_t mode, int apply_to_int); + +/// Get the floating point math mode for a graph. +/// +/// @param graph The target graph. +/// @param mode The floating-point math mode. +/// @param apply_to_int The flag that controls whether to use floating-point +/// arithmetic for integral operations. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_get_fpmath_mode( + dnnl_graph_graph_t graph, dnnl_fpmath_mode_t *mode, int *apply_to_int); + /// Adds an operation into a graph. The API will return failure if the operator /// has already been added to the graph or the operation cannot pass the schema /// check in the library (eg. input and output numbers and data types, the diff --git a/include/oneapi/dnnl/dnnl_graph.hpp b/include/oneapi/dnnl/dnnl_graph.hpp index 1d178e07973..288105aa08b 100644 --- a/include/oneapi/dnnl/dnnl_graph.hpp +++ b/include/oneapi/dnnl/dnnl_graph.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #ifndef ONEAPI_DNNL_DNNL_GRAPH_HPP #define ONEAPI_DNNL_DNNL_GRAPH_HPP +// NOLINTBEGIN(readability-identifier-naming) #include "oneapi/dnnl/dnnl_common.hpp" #include "oneapi/dnnl/dnnl_graph.h" @@ -270,6 +271,10 @@ class logical_tensor { /// floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) /// with a 4-bit exponent and a 3-bit mantissa. f8_e4m3 = dnnl_f8_e4m3, + /// 4-bit signed integer. + s4 = dnnl_s4, + /// 4-bit unsigned integer. + u4 = dnnl_u4, }; /// Layout type @@ -360,7 +365,7 @@ class logical_tensor { layout_type ltype, property_type ptype = property_type::undef) { dnnl_graph_logical_tensor_t val; // if dimension size equals to 0, it's a scalar - if (adims.size() == 0) + if (adims.empty()) error::wrap_c_api(dnnl_graph_logical_tensor_init(&val, tid, convert_to_c(dtype), 0, convert_to_c(ltype), convert_to_c(ptype)), @@ -415,7 +420,7 @@ class logical_tensor { property_type ptype = property_type::undef) { dnnl_graph_logical_tensor_t val; - if (adims.size() == 0) { + if (adims.empty()) { error::wrap_c_api(dnnl_graph_logical_tensor_init(&val, tid, convert_to_c(dtype), 0, convert_to_c(layout_type::opaque), @@ -832,6 +837,8 @@ class op : public op_handle { TanhBackward = dnnl_graph_op_tanh_backward, TypeCast = dnnl_graph_op_type_cast, Wildcard = dnnl_graph_op_wildcard, + GenIndex = dnnl_graph_op_gen_index, + GreaterEqual = dnnl_graph_op_greater_equal, // Sentinel LastSymbol = dnnl_graph_op_last_symbol, }; @@ -908,6 +915,12 @@ class op : public op_handle { weights_shape = dnnl_graph_op_attr_weights_shape, /// Specifies a zps attribute to an op. zps = dnnl_graph_op_attr_zps, + /// Specifies the group shape of an op. The size of the vector should + /// match that of the input. For the dimensions where the grouped + /// quantization occurs, the values should correspond to the group + /// size, which indicates the number of elements that will share the + /// same scaling factor. + group_shape = dnnl_graph_op_attr_group_shape, // bool attributes. The value of these attributes can be any single bool // value. @@ -1373,6 +1386,10 @@ class graph : public graph_handle { /// mode. All partitions returned from the graph will inherit the engine /// kind and floating-point math mode. /// + /// Setting the floating-point math mode enables automatic down-conversion + /// of inputs for the given graph, promoting speedup by using + /// lower-precision data types when available. + /// /// @param engine_kind Engine kind. /// @param mode Floating-point math mode. graph(engine::kind engine_kind, fpmath_mode mode) { @@ -1384,6 +1401,37 @@ class graph : public graph_handle { reset(g); } + /// Set the floating point math mode for a graph. Users can enforce the + /// graph to comply with the mode by specifying a boolean flag with the + /// setter function. + /// + /// @param mode The floating-point math mode. + /// @param apply_to_int The flag that controls whether to use + /// floating-point arithmetic for integral operations. + void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) { + error::wrap_c_api(dnnl_graph_graph_set_fpmath_mode( + get(), convert_to_c(mode), apply_to_int), + "could not set fpmath mode graph attribute"); + } + + /// Get the floating point math mode and the boolean flag that specifies + /// whether the graph will be enforced to comply the mode. + /// + /// @param mode The floating-point math mode. + /// @param apply_to_int The flag that controls whether to use + /// floating-point arithmetic for integral operations. + void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const { + dnnl_fpmath_mode_t c_mode; + int c_apply_to_int; + + error::wrap_c_api(dnnl_graph_graph_get_fpmath_mode( + get(), &c_mode, &c_apply_to_int), + "could not get fpmath mode graph attribute"); + + mode = fpmath_mode(c_mode); + apply_to_int = static_cast(c_apply_to_int); + } + /// Adds an op into the graph to construct a computational DAG. The API will /// return failure if the operator has already been added to the graph or /// the operation cannot pass the schema check in the library (eg. input and @@ -1584,4 +1632,5 @@ namespace dnnl = ::dnnl; /// @} dnnl_api -#endif +// NOLINTEND(readability-identifier-naming) +#endif /* ONEAPI_DNNL_DNNL_GRAPH_HPP */ diff --git a/include/oneapi/dnnl/dnnl_graph_ocl.hpp b/include/oneapi/dnnl/dnnl_graph_ocl.hpp index 636dc0d1c47..18ff36bd686 100644 --- a/include/oneapi/dnnl/dnnl_graph_ocl.hpp +++ b/include/oneapi/dnnl/dnnl_graph_ocl.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/include/oneapi/dnnl/dnnl_graph_sycl.hpp b/include/oneapi/dnnl/dnnl_graph_sycl.hpp index 8f694f4b36b..2507842cb38 100644 --- a/include/oneapi/dnnl/dnnl_graph_sycl.hpp +++ b/include/oneapi/dnnl/dnnl_graph_sycl.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,8 +25,6 @@ #if __has_include() #include -#elif __has_include() -#include #else #error "Unsupported compiler" #endif diff --git a/include/oneapi/dnnl/dnnl_graph_types.h b/include/oneapi/dnnl/dnnl_graph_types.h index 4ec65da25cd..4aeb4d6bd87 100644 --- a/include/oneapi/dnnl/dnnl_graph_types.h +++ b/include/oneapi/dnnl/dnnl_graph_types.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2020-2024 Intel Corporation + * Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -256,6 +256,8 @@ typedef enum { dnnl_graph_op_select, dnnl_graph_op_pow, dnnl_graph_op_group_norm, + dnnl_graph_op_gen_index, + dnnl_graph_op_greater_equal, dnnl_graph_op_last_symbol, } dnnl_graph_op_kind_t; @@ -327,6 +329,8 @@ typedef enum { dnnl_graph_op_attr_weights_shape, /// Specifies a zps attribute to an op. dnnl_graph_op_attr_zps, + /// Specifies a group shape attribute to an op. + dnnl_graph_op_attr_group_shape, // bool attributes. The value of these attributes can be any single bool // value. diff --git a/include/oneapi/dnnl/dnnl_ocl.h b/include/oneapi/dnnl/dnnl_ocl.h index 6300bb7459f..70d0c5460a0 100644 --- a/include/oneapi/dnnl/dnnl_ocl.h +++ b/include/oneapi/dnnl/dnnl_ocl.h @@ -75,6 +75,35 @@ dnnl_status_t DNNL_API dnnl_ocl_interop_memory_create(dnnl_memory_t *memory, const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, dnnl_ocl_interop_memory_kind_t memory_kind, void *handle); +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory Output memory object. +/// @param memory_desc Memory descriptor. +/// @param engine Engine to use. +/// @param memory_kind Memory allocation kind to specify the type of handles. +/// @param nhandles Number of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl_ocl_interop_usm. +/// - An OpenCL buffer. In this case the library doesn't own the buffer. +/// Requires @p memory_kind be equal to be equal to dnnl_ocl_interop_buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_memory_create_v2(dnnl_memory_t *memory, + const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, + dnnl_ocl_interop_memory_kind_t memory_kind, int nhandles, + void **handles); +#endif + /// Returns the memory allocation kind associated with a memory object. /// /// @param memory Memory to query. diff --git a/include/oneapi/dnnl/dnnl_ocl.hpp b/include/oneapi/dnnl/dnnl_ocl.hpp index c2466bc8276..de3b4150b8a 100644 --- a/include/oneapi/dnnl/dnnl_ocl.hpp +++ b/include/oneapi/dnnl/dnnl_ocl.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -236,6 +236,112 @@ inline memory_kind get_memory_kind(const memory &amemory) { return static_cast(ckind); } +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl_ocl_interop_usm. +/// - An OpenCL buffer. In this case the library doesn't own the buffer. +/// Requires @p memory_kind be equal to be equal to dnnl_ocl_interop_buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// If the @p handles vector is not provided the library will allocate all +/// buffers as if all handles have the special value DNNL_MEMORY_ALLOCATE. +/// +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, + std::vector handles = {}) { + if (handles.empty()) { + const int nhandles = memory_desc.get_num_handles(); + handles.resize(nhandles, DNNL_MEMORY_ALLOCATE); + } + + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_ocl_interop_memory_create_v2(&c_memory, memory_desc.get(), + aengine.get(), convert_to_c(kind), (int)handles.size(), + handles.data()), + "could not create a memory"); + return memory(c_memory); +} + +/// Constructs a memory object with multiple OpenCL buffers. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param mem_objects A vector of OpenCL buffers to use. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, std::vector mem_objects) { + const int nhandles = memory_desc.get_num_handles(); + std::vector handles(nhandles, DNNL_MEMORY_NONE); + memory amemory(memory_desc, aengine, handles); + for (int i = 0; i < nhandles; i++) + amemory.set_data_handle(mem_objects[i], i); + return amemory; +} + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the +/// constructed memory object will have the underlying buffer set. In this +/// case, the buffer will be initialized as if: +/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is +/// equal to dnnl::ocl_interop::memory_kind::usm, or +/// - dnnl::ocl_interop::set_mem_object() has been called, if @p memory_kind is +/// equal to dnnl::ocl_interop::memory_kind::buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handle. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl::ocl_interop::memory_kind::usm. +/// - An OpenCL buffer. In this case the library doesn't own the buffer. +/// Requires @p memory_kind be equal to be equal to +/// dnnl::ocl_interop::memory_kind::buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, void *handle) { + return make_memory( + memory_desc, aengine, kind, std::vector {handle}); +} + +/// Constructs a memory object from an OpenCL buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param mem_object An OpenCL buffer to use. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, cl_mem mem_object) { + return make_memory(memory_desc, aengine, std::vector {mem_object}); +} +#else + /// Creates a memory object. /// /// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the @@ -288,6 +394,7 @@ inline memory make_memory(const memory::desc &memory_desc, set_mem_object(amemory, mem_object); return amemory; } +#endif /// Executes computations specified by the primitive in a specified stream and /// returns a SYCL event. diff --git a/include/oneapi/dnnl/dnnl_sycl.h b/include/oneapi/dnnl/dnnl_sycl.h index ed61d92435b..a4abe851836 100644 --- a/include/oneapi/dnnl/dnnl_sycl.h +++ b/include/oneapi/dnnl/dnnl_sycl.h @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -95,6 +95,36 @@ dnnl_status_t DNNL_API dnnl_sycl_interop_memory_create(dnnl_memory_t *memory, const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, dnnl_sycl_interop_memory_kind_t memory_kind, void *handle); +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory Output memory object. +/// @param memory_desc Memory descriptor. +/// @param engine Engine to use. +/// @param memory_kind Memory allocation kind to specify the type of handles. +/// @param nhandles Number of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl_sycl_interop_usm. +/// - A pointer to SYCL buffer. In this case the library doesn't own the +/// buffer. Requires @p memory_kind be equal to be equal to +/// dnnl_sycl_interop_buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_memory_create_v2(dnnl_memory_t *memory, + const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, + dnnl_sycl_interop_memory_kind_t memory_kind, int nhandles, + void **handles); +#endif + /// Returns the memory allocation kind associated with a memory object. /// /// @param memory Memory to query. diff --git a/include/oneapi/dnnl/dnnl_sycl.hpp b/include/oneapi/dnnl/dnnl_sycl.hpp index b9ddc876ed8..1f7d8f559c1 100644 --- a/include/oneapi/dnnl/dnnl_sycl.hpp +++ b/include/oneapi/dnnl/dnnl_sycl.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,8 +28,6 @@ #if __has_include() #include -#elif __has_include() -#include #else #error "Unsupported compiler" #endif @@ -208,6 +206,83 @@ inline memory_kind get_memory_kind(const memory &amemory) { return static_cast(ckind); } +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl::sycl_interop::memory_kind::usm. +/// - A pointer to SYCL buffer. In this case the library doesn't own the +/// buffer. Requires @p memory_kind be equal to be equal to +/// dnnl::sycl_interop::memory_kind::buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// If the @p handles vector is not provided the library will allocate all +/// buffers as if all handles have the special value DNNL_MEMORY_ALLOCATE. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, + std::vector handles = {}) { + if (handles.empty()) { + const int nhandles = memory_desc.get_num_handles(); + handles.resize(nhandles, DNNL_MEMORY_ALLOCATE); + } + + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_sycl_interop_memory_create_v2(&c_memory, memory_desc.get(), + aengine.get(), convert_to_c(kind), (int)handles.size(), + handles.data()), + "could not create a memory"); + return memory(c_memory); +} + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the +/// constructed memory object will have the underlying buffer set. In this +/// case, the buffer will be initialized as if: +/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is +/// equal to dnnl::sycl_interop::memory_kind::usm, or +/// - dnnl::sycl_interop::set_buffer() has been called, if @p memory_kind is +/// equal to dnnl::sycl_interop::memory_kind::buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handle. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl::sycl_interop::memory_kind::usm. +/// - A pointer to SYCL buffer. In this case the library doesn't own the +/// buffer. Requires @p memory_kind be equal to be equal to +/// dnnl::sycl_interop::memory_kind::buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, void *handle) { + return make_memory( + memory_desc, aengine, kind, std::vector {handle}); +} +#else + /// Creates a memory object. /// /// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the @@ -246,6 +321,7 @@ inline memory make_memory(const memory::desc &memory_desc, "could not create a memory"); return memory(c_memory); } +#endif /// Constructs a memory object from a SYCL buffer. /// diff --git a/include/oneapi/dnnl/dnnl_threadpool.hpp b/include/oneapi/dnnl/dnnl_threadpool.hpp index e3ebd0ff251..849465a540c 100644 --- a/include/oneapi/dnnl/dnnl_threadpool.hpp +++ b/include/oneapi/dnnl/dnnl_threadpool.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/include/oneapi/dnnl/dnnl_threadpool_iface.hpp b/include/oneapi/dnnl/dnnl_threadpool_iface.hpp index 271d4db7f22..c3127c1d474 100644 --- a/include/oneapi/dnnl/dnnl_threadpool_iface.hpp +++ b/include/oneapi/dnnl/dnnl_threadpool_iface.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,12 @@ * limitations under the License. *******************************************************************************/ +/// @file +/// Threadpool Interoperability C++ Types + #ifndef ONEAPI_DNNL_DNNL_THREADPOOL_IFACE_HPP #define ONEAPI_DNNL_DNNL_THREADPOOL_IFACE_HPP +// NOLINTBEGIN(readability-identifier-naming) #include #include @@ -57,7 +61,7 @@ struct threadpool_iface { /// waiting for the submitted closures to finish execution on its own. static constexpr uint64_t ASYNCHRONOUS = 1; - virtual ~threadpool_iface() {} + virtual ~threadpool_iface() = default; }; } // namespace threadpool_interop @@ -70,4 +74,5 @@ struct threadpool_iface { /// @} dnnl_api -#endif +// NOLINTEND(readability-identifier-naming) +#endif /* ONEAPI_DNNL_DNNL_THREADPOOL_IFACE_HPP */ diff --git a/include/oneapi/dnnl/dnnl_types.h b/include/oneapi/dnnl/dnnl_types.h index bb385ee2737..8821401352b 100644 --- a/include/oneapi/dnnl/dnnl_types.h +++ b/include/oneapi/dnnl/dnnl_types.h @@ -1,5 +1,6 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation +* Copyright 2024-2025 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,6 +56,8 @@ typedef enum { /// Format kind for sparse tensors. dnnl_format_kind_sparse, #endif + /// Format for sparse data. + dnnl_format_sparse, /// Parameter to allow internal only format kinds without undefined /// behavior. This parameter is chosen to be valid for so long as /// sizeof(int) >= 2. @@ -74,6 +77,8 @@ typedef enum { /// only be used to create a primitive descriptor to query the /// actual memory descriptor (similar to the format tag `any`). dnnl_packed, + /// Coordinate Sparse Encoding (COO). + dnnl_coo, } dnnl_sparse_encoding_t; #endif @@ -284,6 +289,7 @@ typedef enum { dnnl_ABcd8a16b2a, dnnl_ABcd2b8a4b, dnnl_ABcd8a8b, + dnnl_ABcd8a32b, dnnl_ABcd8a4b, /// 4D tensor blocked by 2nd dimension with block size 8 dnnl_aBcd8b, @@ -396,6 +402,8 @@ typedef enum { dnnl_aCBdef16c16b, dnnl_aBdefc4b, dnnl_aBdefc8b, + dnnl_Abcdef4a, + dnnl_Abcdef8a, dnnl_Abcdef16a, dnnl_Abcdef32a, dnnl_aBedc16b, @@ -1035,6 +1043,18 @@ typedef enum { dnnl_bcad, dnnl_cabd, dnnl_dabc, + dnnl_Ab32a, + dnnl_aCBd8b8c, + dnnl_aCBde8b8c, + dnnl_BAc8a8b, + dnnl_BAcd8a8b, + dnnl_BAcde8a8b, + dnnl_aCBdef8b8c, + dnnl_abdEC16e4c, + dnnl_abDC16d4c, + dnnl_BA24b8a, + dnnl_aCB24c8b, + dnnl_abDC24d8c, /// Just a sentinel, not real memory format tag. Must be changed after new /// format tag is added. @@ -1170,10 +1190,12 @@ typedef enum { /// 5D LSTM projection tensor dnnl_ldOi16o = dnnl_abDc16d, dnnl_ldOi32o = dnnl_abDc32d, + dnnl_ldOI16o4i = dnnl_abDC16d4c, dnnl_ldOI32o4i = dnnl_abDC32d4c, dnnl_ldIo32i = dnnl_abCd32c, /// 6D RNN weights tensor dnnl_ldgOi16o = dnnl_abdEc16e, + dnnl_ldgOI16o4i = dnnl_abdEC16e4c, dnnl_ldgOi32o = dnnl_abdEc32e, dnnl_ldgOI32o2i = dnnl_abdEC32e2c, dnnl_ldgOI32o4i = dnnl_abdEC32e4c, @@ -1255,6 +1277,7 @@ typedef enum { dnnl_OI8i8o = dnnl_AB8b8a, // weights, 3D + dnnl_IOw8o8i = dnnl_BAc8a8b, dnnl_IOw16o16i = dnnl_BAc16a16b, dnnl_IOw16i16o = dnnl_BAc16b16a, dnnl_OIw16i16o = dnnl_ABc16b16a, @@ -1325,6 +1348,7 @@ typedef enum { // weights, 4D dnnl_IOhw16i16o = dnnl_BAcd16b16a, + dnnl_IOhw8o8i = dnnl_BAcd8a8b, dnnl_IOhw16o16i = dnnl_BAcd16a16b, dnnl_Ohwi16o = dnnl_Acdb16a, dnnl_OhwI16o2i = dnnl_AcdB16a2b, @@ -1386,6 +1410,8 @@ typedef enum { dnnl_OIhw2i8o4i = dnnl_ABcd2b8a4b, dnnl_IOhw8o16i2o = dnnl_BAcd8a16b2a, dnnl_OIhw8o8i = dnnl_ABcd8a8b, + dnnl_OIhw8o32i = dnnl_ABcd8a32b, + dnnl_OIhw16o32i = dnnl_ABcd16a32b, dnnl_OIhw8o4i = dnnl_ABcd8a4b, dnnl_Owhi16o = dnnl_Adcb16a, dnnl_OIhw8i32o = dnnl_ABcd8b32a, @@ -1457,6 +1483,7 @@ typedef enum { dnnl_OIdhw8o4i = dnnl_ABcde8a4b, dnnl_IOdhw16i16o = dnnl_BAcde16b16a, dnnl_OIdhw4o8i8o4i = dnnl_ABcde4a8b8a4b, + dnnl_IOdhw8o8i = dnnl_BAcde8a8b, dnnl_IOdhw16o16i = dnnl_BAcde16a16b, dnnl_OIdhw16o16i2o = dnnl_ABcde16a16b2a, dnnl_OIdhw8i32o = dnnl_ABcde8b32a, @@ -1470,6 +1497,7 @@ typedef enum { dnnl_Goiw16g = dnnl_Abcd16a, dnnl_Goiw8g = dnnl_Abcd8a, dnnl_Goiw4g = dnnl_Abcd4a, + dnnl_gIOw8o8i = dnnl_aCBd8b8c, dnnl_gIOw16o16i = dnnl_aCBd16b16c, dnnl_gIOw16i16o = dnnl_aCBd16c16b, dnnl_gOIw16i16o = dnnl_aBCd16c16b, @@ -1515,6 +1543,7 @@ typedef enum { // weights w/ groups, 4D dnnl_gIOhw16i16o = dnnl_aCBde16c16b, + dnnl_gIOhw8o8i = dnnl_aCBde8b8c, dnnl_gIOhw16o16i = dnnl_aCBde16b16c, dnnl_gOhwi16o = dnnl_aBdec16b, dnnl_gOhwI16o2i = dnnl_aBdeC16b2c, @@ -1582,6 +1611,7 @@ typedef enum { // weights w/ groups, 6D dnnl_gIOdhw16i16o = dnnl_aCBdef16c16b, + dnnl_gIOdhw8o8i = dnnl_aCBdef8b8c, dnnl_gIOdhw16o16i = dnnl_aCBdef16b16c, dnnl_gOdhwi16o = dnnl_aBdefc16b, dnnl_gOdhwI16o2i = dnnl_aBdefC16b2c, @@ -1617,6 +1647,8 @@ typedef enum { dnnl_gIOdhw8o16i2o = dnnl_aCBdef8b16c2b, dnnl_gOIdhw8o8i = dnnl_aBCdef8b8c, dnnl_gOIdhw8o4i = dnnl_aBCdef8b4c, + dnnl_Goidhw4g = dnnl_Abcdef4a, + dnnl_Goidhw8g = dnnl_Abcdef8a, dnnl_Goidhw16g = dnnl_Abcdef16a, dnnl_Goidhw32g = dnnl_Abcdef32a, dnnl_gOIdhw2i4o2i = dnnl_aBCdef2c4b2c, @@ -1989,6 +2021,12 @@ typedef enum { dnnl_deconvolution, /// An element-wise primitive. dnnl_eltwise, + /// An depthwise-wise primitive. + dnnl_depthwise, + /// A quantization primitive. + dnnl_quantization, + /// A binatization primitive. + dnnl_binarization, /// An LRN primitive. dnnl_lrn, /// A batch normalization primitive. @@ -2081,6 +2119,12 @@ typedef enum { dnnl_eltwise_mish, /// Eltwise: hardswish dnnl_eltwise_hardswish, + /// Eltwise: hsigmoid + dnnl_eltwise_hsigmoid, + /// Eltwise: round_half_to_even + dnnl_eltwise_round_half_to_even, + /// Eltwise: round_half_away_from_zero + dnnl_eltwise_round_half_away_from_zero, /// Eltwise: ReLU (dst for backward) dnnl_eltwise_relu_use_dst_for_bwd = 0x100, /// Eltwise: hyperbolic tangent non-linearity (tanh) (dst for backward) @@ -2147,6 +2191,10 @@ typedef enum { dnnl_binary_eq = 0x1fffa, /// Binary not equal dnnl_binary_ne = 0x1fffb, + /// Binary select + dnnl_binary_select = 0x1fffc, + /// Binary prelu + dnnl_binary_prelu = 0x1fffd, /// Nearest Neighbor Resampling Method dnnl_resampling_nearest = 0x2fff0, /// Linear Resampling Method @@ -2173,6 +2221,13 @@ typedef enum { dnnl_softmax_accurate = 0x30000, /// Logsoftmax dnnl_softmax_log, + + dnnl_depthwise_scale_shift = 0x3fff0, + dnnl_depthwise_prelu = 0x3fff1, + + dnnl_quantization_quantize_dequantize = 0x4fff0, + dnnl_quantization_quantize = 0x4fff1, + dnnl_binarization_depthwise = 0x4fff2, } dnnl_alg_kind_t; /// Flags for normalization primitives. @@ -2259,7 +2314,12 @@ typedef enum { /// A `size_t` counterpart of the DNNL_RUNTIME_DIM_VAL. /// For instance, this value is returned by dnnl_memory_desc_get_size() if /// either of the dimensions or strides equal to #DNNL_RUNTIME_DIM_VAL. + +#if INTPTR_MAX == INT64_MAX #define DNNL_RUNTIME_SIZE_VAL ((size_t)DNNL_RUNTIME_DIM_VAL) +#else +#define DNNL_RUNTIME_SIZE_VAL ((size_t)INT32_MIN) +#endif /// @cond DO_NOT_DOCUMENT_THIS /// Hex representation for a **special** quiet NAN (!= NAN from math.h) @@ -2291,6 +2351,52 @@ typedef struct dnnl_memory_desc *dnnl_memory_desc_t; /// A memory descriptor handle. typedef const struct dnnl_memory_desc *const_dnnl_memory_desc_t; +/// Sparse encodings. +typedef enum { + dnnl_sparse_encoding_undef = 0, + dnnl_sparse_encoding_any, + dnnl_sparse_encoding_packed, + dnnl_sparse_encoding_csr, + dnnl_sparse_encoding_coo, +} dnnl_sparse_encoding_t; + +/* typedef struct dnnl_sparse_desc *dnnl_sparse_desc_t; */ +/* typedef const struct dnnl_sparse_desc *const_dnnl_sparse_desc_t; */ + +/// Flags for memory special features +typedef enum { + dnnl_memory_extra_flag_none = 0x0U, + /// Indicates the weights have an additional buffer, that depends on the + /// @p compensation_mask. + /// + /// For instance, in 4D case with the compensation mask equals (1 << 0) + /// the additional buffer would consist of OC values: + /// O[oc : 0,OC] = + /// -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) } + dnnl_memory_extra_flag_compensation_conv_s8s8 = 0x1U, + dnnl_memory_extra_flag_scale_adjust = 0x2U, + dnnl_memory_extra_flag_rnn_u8s8_compensation = 0x4U, + dnnl_memory_extra_flag_gpu_rnn_u8s8_compensation + = dnnl_memory_extra_flag_rnn_u8s8_compensation, + dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 0x8U, + dnnl_memory_extra_flag_rnn_s8s8_compensation = 0x16U, +} dnnl_memory_extra_flags_t; + +/// Description of extra information stored in memory +typedef struct { + /// The flags contain arbitrary extra information, such as compensation. + /// @sa dnnl_memory_extra_flags_t + uint64_t flags; + /// Compensation mask + int compensation_mask; + /// Scale applied to the data + float scale_adjust; + /// Compensation mask for asymmetric quantization + int asymm_compensation_mask; + /// For future backwards compatibility + char reserved[60]; +} dnnl_memory_extra_desc_t; + /// @struct dnnl_memory /// An opaque structure to describe a memory. struct dnnl_memory; @@ -2383,6 +2489,7 @@ typedef enum { dnnl_scratchpad_mode_user, } dnnl_scratchpad_mode_t; +/// Rounding mode typedef enum { /// rounding mode dictated by the floating-point environment dnnl_rounding_mode_environment, @@ -2529,6 +2636,12 @@ typedef const struct dnnl_primitive *const_dnnl_primitive_t; /// Bias tensor argument. #define DNNL_ARG_BIAS 41 +/// Reduce tensor argument. +#define DNNL_ARG_REDUCE 42 + +/// Note: when adding a new macro after `DNNL_ARG_REDUCE` please reserve a +/// space for potential indices for `DNNL_ARG_REDUCE`. + /// Mean values tensor argument. #define DNNL_ARG_MEAN 49 /// Variance values tensor argument. @@ -2642,6 +2755,7 @@ typedef const struct dnnl_primitive *const_dnnl_primitive_t; #define DNNL_ARG_ATTR_DROPOUT_SEED 511 /// Output scaling factors provided at execution time. +/// Deprecated value. #define DNNL_ARG_ATTR_OUTPUT_SCALES 513 /// Starting index for source arguments for primitives that take a variable @@ -2804,6 +2918,8 @@ typedef enum { dnnl_query_num_handles_s32, ///< Number of buffers required for a memory /// descriptor #endif + dnnl_query_sparse_encoding, + // Max value to prevent UB for internal use only dnnl_query_t dnnl_query_max = 0x7fff, } dnnl_query_t; @@ -2891,6 +3007,7 @@ typedef enum { dnnl_cpu_isa_avx10_1_512_amx_fp16 = 0x1fef, /// @copydoc dnnl_cpu_isa_avx10_1_512_amx_fp16 dnnl_cpu_isa_avx512_core_amx_fp16 = dnnl_cpu_isa_avx10_1_512_amx_fp16, + dnnl_cpu_isa_avx512_vpopcnt = 0x3fef, } dnnl_cpu_isa_t; /// CPU ISA hints flags diff --git a/include/oneapi/dnnl/dnnl_ukernel.h b/include/oneapi/dnnl/dnnl_ukernel.h index 102b2765373..50cdfb71c0c 100644 --- a/include/oneapi/dnnl/dnnl_ukernel.h +++ b/include/oneapi/dnnl/dnnl_ukernel.h @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -185,13 +185,14 @@ dnnl_status_t DNNL_API dnnl_brgemm_finalize(dnnl_brgemm_t brgemm); /// Returns the packing type expected by a tensor B of a BRGeMM ukernel object. /// -/// @param brgemm BRGeMM ukernel object. -/// @param pack_type Output packing type. Can be `dnnl_brgemm_no_pack` if -/// packing is not expected, and `dnnl_brgemm_pack_32`, otherwise. +/// @param pack_type Output packing type. Can be `dnnl_brgemm_no_trans` if +/// packing is not expected, and `dnnl_pack_type_pack32`, otherwise. +/// @param dt_a Data type of tensor A. +/// @param dt_b Data type of tensor B. /// @returns #dnnl_success on success and a status describing the error /// otherwise. -dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type( - const_dnnl_brgemm_t brgemm, dnnl_pack_type_t *pack_type); +dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type(dnnl_pack_type_t *pack_type, + dnnl_data_type_t dt_a, dnnl_data_type_t dt_b); /// Returns the size of a scratchpad memory needed for the BRGeMM ukernel /// object. @@ -203,6 +204,17 @@ dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type( dnnl_status_t DNNL_API dnnl_brgemm_get_scratchpad_size( const_dnnl_brgemm_t brgemm, size_t *size); +/// Returns the flag indicating when the call to `dnnl_brgemm_execute_postops` +/// is valid. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param valid The flag indicating if `dnnl_brgemm_execute_postops` is valid +/// for a given ukernel object. `1` is for valid and `0`, otherwise. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_is_execute_postops_valid( + const_dnnl_brgemm_t brgemm, int *valid); + /// Initializes the hardware-specific context. If no initialization required, /// returns the success status. /// diff --git a/include/oneapi/dnnl/dnnl_ukernel.hpp b/include/oneapi/dnnl/dnnl_ukernel.hpp index 642123842de..e42895973e3 100644 --- a/include/oneapi/dnnl/dnnl_ukernel.hpp +++ b/include/oneapi/dnnl/dnnl_ukernel.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #ifndef ONEAPI_DNNL_DNNL_UKERNEL_HPP #define ONEAPI_DNNL_DNNL_UKERNEL_HPP +// NOLINTBEGIN(readability-identifier-naming) #include "oneapi/dnnl/dnnl.hpp" #include "oneapi/dnnl/dnnl_ukernel.h" @@ -29,6 +30,8 @@ /// oneDNN namespace namespace dnnl { +#ifdef DNNL_EXPERIMENTAL_UKERNEL + /// @addtogroup dnnl_api_utils /// @{ @@ -59,6 +62,8 @@ struct handle_traits { /// @} dnnl_api_utils +#endif + /// @addtogroup dnnl_api_ukernel Ukernels /// Collection of ukernels /// @{ @@ -68,6 +73,10 @@ namespace ukernel { #ifdef DNNL_EXPERIMENTAL_UKERNEL +/// @addtogroup dnnl_api_ukernel_utils ukernel utils +/// ukernel utility functions +/// @{ + /// Packing specification enum class pack_type { /// Undefined pack type. A guard value. @@ -115,8 +124,8 @@ struct attr_params : public handle { /// Sets tensor B scales arguments to a storage. /// - /// If @ref brgemm::set_B_scales used mask of 2, then at least N values of - /// selected data type are expected. + /// If @ref attr_params::set_B_scales used mask of 2, then at + /// least N values of selected data type are expected. /// /// @param b_scales Pointer to scales storage. void set_B_scales(const void *b_scales) { @@ -136,11 +145,13 @@ struct attr_params : public handle { error::wrap_c_api(status, "could not set D scales argument"); } }; +/// @} dnnl_api_ukernel_utils /// @addtogroup dnnl_api_ukernel_brgemm BRGeMM ukernel /// BRGeMM ukernel routines /// @{ +/// BRGeMM ukernel struct brgemm : public handle { /// Default constructor. Produces an empty object. brgemm() = default; @@ -200,7 +211,7 @@ struct brgemm : public handle { /// /// @param ldd Leading dimension of tensor D. /// @param d_dt Data type of tensor D. - /// @param post_ops Primitive post-operation attributes to extend the kernel + /// @param po Primitive post-operation attributes to extend the kernel /// operations. void set_post_ops(memory::dim ldd, memory::data_type d_dt, const post_ops &po = default_post_ops()) { @@ -258,9 +269,14 @@ struct brgemm : public handle { /// Returns the packing type expected by a tensor B of a BRGeMM ukernel /// object. - pack_type get_B_pack_type() const { + /// + /// @param a_dt Data type of tensor A. + /// @param b_dt Data type of tensor B. + static pack_type get_B_pack_type( + memory::data_type a_dt, memory::data_type b_dt) { dnnl_pack_type_t c_pack_type; - dnnl_status_t status = dnnl_brgemm_get_B_pack_type(get(), &c_pack_type); + dnnl_status_t status = dnnl_brgemm_get_B_pack_type(&c_pack_type, + memory::convert_to_c(a_dt), memory::convert_to_c(b_dt)); if (status != dnnl_success) error::wrap_c_api(status, "could not query B pack type"); @@ -279,6 +295,21 @@ struct brgemm : public handle { return size; } + /// Returns the flag indicating when the call to execute with post + /// operations is valid. + /// + /// `True` is for a valid call, `false`, otherwise. + bool is_execute_postops_valid() const { + int valid; + dnnl_status_t status + = dnnl_brgemm_is_execute_postops_valid(get(), &valid); + if (status != dnnl_success) + error::wrap_c_api(status, + "could not query a flag for execute postops from a BRGeMM " + "ukernel object"); + return static_cast(valid); + } + /// Initializes the hardware-specific context. Affects the global state for /// all BRGeMM ukernel objects. If no initialization required, returns. void set_hw_context() const { @@ -334,11 +365,11 @@ struct brgemm : public handle { /// @param C Pointer to a tensor C (accumulation buffer). /// @param D Pointer to a tensor D (output buffer). /// @param scratchpad Pointer to a scratchpad buffer. - /// @param binary_po Binary post-op memory buffer. Must be passed If binary - /// post-op was specified at construction call. + /// @param params Post-op memory arguments. Must be passed If binary + /// post-op or scales were set. void execute(const void *A, const void *B, const std::vector> &A_B_offsets, - void *C, void *D, void *scratchpad, + const void *C, void *D, void *scratchpad, const attr_params ¶ms = default_attr_params()) const { // TODO: export batch_element to C API later for user to fill it and // pass directly to the call. @@ -364,7 +395,13 @@ struct brgemm : public handle { return ap; } }; +/// @} dnnl_api_ukernel_brgemm + +/// @addtogroup dnnl_api_ukernel_transform Transform ukernel +/// Transform routines +/// @{ +/// Transform ukernel struct transform : public handle { /// Default constructor. Produces an empty object. transform() = default; @@ -419,7 +456,7 @@ struct transform : public handle { } }; -/// @} dnnl_api_ukernel_brgemm +/// @} dnnl_api_ukernel_transform #endif @@ -431,4 +468,5 @@ struct transform : public handle { /// @} dnnl_api +// NOLINTEND(readability-identifier-naming) #endif /* ONEAPI_DNNL_DNNL_UKERNEL_HPP */ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000000..c29a82348f6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.black] +line-length = 80 +include = 'scripts\/.*\.pyi?$' diff --git a/scripts/generate_dnnl_debug.py b/scripts/generate_dnnl_debug.py index 84c5b086aad..5c197152c99 100755 --- a/scripts/generate_dnnl_debug.py +++ b/scripts/generate_dnnl_debug.py @@ -1,6 +1,6 @@ #!/usr/bin/env python ################################################################################ -# Copyright 2018-2024 Intel Corporation +# Copyright 2018-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,43 +24,20 @@ import xml.etree.ElementTree as ET -def banner(year_from): - year_now = str(datetime.datetime.now().year) - banner_year = ( - year_from if year_now == year_from else "%s-%s" % (year_from, year_now) - ) +def template(body, banner): return """\ -/******************************************************************************* -* Copyright %s Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - +%s // DO NOT EDIT, AUTO-GENERATED // Use this script to update the file: scripts/%s // clang-format off -""" % ( - banner_year, +%s""" % ( + banner, os.path.basename(__file__), + body ) - -def template(body, year_from): - return "%s%s" % (banner(year_from), body) - - def header(body): return ( """\ @@ -129,6 +106,7 @@ def header_benchdnn(body): #ifdef DNNL_EXPERIMENTAL_SPARSE const char *sparse_encoding2str(dnnl_sparse_encoding_t encoding); #endif +const char *sparse_encoding2str(dnnl_sparse_encoding_t encoding); /* engine kind */ const char *engine_kind2str(dnnl_engine_kind_t kind); @@ -183,6 +161,9 @@ def source_benchdnn(body): return dnnl_sparse_encoding2str(encoding); } #endif +const char *sparse_encoding2str(dnnl_sparse_encoding_t encoding) { + return dnnl_sparse_encoding2str(encoding); +} const char *engine_kind2str(dnnl_engine_kind_t kind) { return dnnl_engine_kind2str(kind); @@ -320,7 +301,7 @@ def str_to_func(enum, values, is_dnnl=True): return func -def generate(ifile, banner_years): +def generate(ifile, banners): h_body, s_body = "", "" h_benchdnn_body, s_benchdnn_body = "", "" root = ET.parse(ifile).getroot() @@ -361,7 +342,7 @@ def generate(ifile, banner_years): header_benchdnn(h_benchdnn_body), source_benchdnn(s_benchdnn_body), ] - return [template(b, y) for b, y in zip(bodies, banner_years)] + return [template(b, y) for b, y in zip(bodies, banners)] def usage(): @@ -380,7 +361,6 @@ def usage(): ) sys.exit(1) - for arg in sys.argv: if "-help" in arg: usage() @@ -396,12 +376,12 @@ def usage(): "%s/../tests/benchdnn/dnnl_debug_autogenerated.cpp" % script_root, ) -banner_years = [] +banners = [] for file_path in file_paths: with open(file_path, "r") as f: - m = re.search(r"Copyright (.*) Intel", f.read()) - banner_years.append(m.group(1).split("-")[0]) + m = re.match(r'^/\*+\n(\*.*\n)+\*+/\n', f.read()) + banners.append('' if m == None else m.group(0)) -for file_path, file_body in zip(file_paths, generate(ifile, banner_years)): +for file_path, file_body in zip(file_paths, generate(ifile, banners)): with open(file_path, "w") as f: f.write(file_body) diff --git a/scripts/synthdnn/README.md b/scripts/synthdnn/README.md new file mode 100644 index 00000000000..6fd3ac003f6 --- /dev/null +++ b/scripts/synthdnn/README.md @@ -0,0 +1,27 @@ +# Synthdnn + +Synthdnn is a suite of scripts for collecting and analyzing oneDNN performance +across a randomly generated data. The general architecture is intended to follow +a data pipeline composed of synthetic problem generation, data collection, and +data analysis. The `synthdnn.py` script provides a command line interface to +these tools. Sample Usage: + + +Problem Generation: +``` +python3 synthdnn.py [sampling controls] -b +``` +Performance Data Collection: +``` +python3 synthdnn.py collect --engine= --collect -b +``` + +Problem Generation and Performance Data Collection: +``` +python3 synthdnn.py [sampling controls] --engine= --collect +``` + +Report Generation: Not yet implemented. +``` + +See `synthdnn.py -h` for additional details. diff --git a/scripts/synthdnn/matmul/primitive.py b/scripts/synthdnn/matmul/primitive.py new file mode 100644 index 00000000000..0a30fcd6781 --- /dev/null +++ b/scripts/synthdnn/matmul/primitive.py @@ -0,0 +1,216 @@ +################################################################################ +# Copyright 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import itertools + + +class Dims: + def __init__(self, b, m, n, k): + # b is a list due to variable size + self.b = b + self.m = m + self.n = n + self.k = k + + def __str__(self): + a_dims = self.b + [self.m, self.k] + b_dims = self.b + [self.k, self.n] + a_str = "x".join([str(x) for x in a_dims]) + b_str = "x".join([str(x) for x in b_dims]) + return f"{a_str}:{b_str}" + + def __eq__(self, other): + return (self.b, self.m, self.n, self.k) == ( + other.b, + other.m, + other.n, + other.k, + ) + + def __hash__(self): + return hash((self.b, self.m, self.n, self.k)) + + +class Layouts: + class Layout: + def __init__(self, layout): + self.A, self.B, self.C = layout.split(":") + + def benchdnn_str(self): + return f"--stag={self.A} --wtag={self.B} --dtag={self.C}" + + def __init__(self, layouts, ndims): + if layouts == "all": + self.values = self.supported(ndims) + else: + self.values = [self.Layout(x) for x in layouts.split(",")] + + def __iter__(self): + return iter(self.values) + + @staticmethod + def supported(ndims): + if ndims < 2 or ndims > 6: + raise RuntimeError(f"No support for ndims={ndims}") + dims_base = "abcdef" + gemm_kn = dims_base[ndims - 1] + gemm_mk = dims_base[ndims - 2] + perms = [ + "".join(p) + for p in itertools.permutations(dims_base[:ndims]) + if p[-1] == gemm_kn or p[-1] == gemm_mk + ] + perms.insert(0, "any") + return [ + Layouts.Layout(f"{a}:{b}:{c}") + for a, b, c in itertools.product(perms, perms, perms) + if c == "any" or c[-1] == gemm_kn + ] + + +class Types: + class Type: + def __init__(self, type_str): + s = type_str.split("(") + self.A, self.B, self.C = s[0].split(":") + self.A, self.B, self.C = self.wildcard_match(self.A, self.B, self.C) + if len(s) < 2: + self.mode = None + else: + self.mode = s[1].strip(")") + + @staticmethod + def wildcard_match(A, B, C): + wildcard_match = A + B = B.replace("*", wildcard_match) + C = C.replace("*", wildcard_match) + return [A, B, C] + + def __str__(self): + mode_str = "" + if self.mode: + mode_str = f"({self.mode})" + return f"{self.A}:{self.B}:{self.C}{mode_str}" + + def benchdnn_str(self): + mode_str = "" + if not self.mode is None: + mode_str = f"--attr-fpmath={self.mode}" + return f"--dt={self.A}:{self.B}:{self.C} {mode_str}" + + def __eq__(self, other): + return (self.A, self.B, self.C, self.mode) == ( + other.A, + other.B, + other.C, + other.mode, + ) + + def __init__(self, types): + if types == "all": + self.values = self.supported() + else: + self.values = [self.Type(x) for x in types.split(",")] + + def __str__(self): + return ",".join([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + @staticmethod + def supported(): + support_matrix = [ + [["f64"], ["f64"], ["f64"]], + [["f32"], ["f32"], ["f32"]], + [["f32"], ["u8", "s8"], ["f32", "f16", "bf16"]], + [ + ["f16", "bf16"], + ["*", "u8", "s8", "u4", "s4"], + ["f32", "*", "u8", "s8"], + ], + [["u8", "s8"], ["u8"], ["f32", "bf16", "f16", "s32", "u8", "s8"]], + [ + ["f8_e5m2", "f8_e4m3"], + ["f8_e5m2", "f8_e4m3"], + ["f32", "bf16", "f16", "f8_e5m2", "f8_e4m3"], + ], + ] + + def is_int_type(t): + return t in ["u4", "s4", "u8", "s8", "s32"] + + def get_accumulator(wei): + if is_int_type(wei): + return "s32" + if wei == "f64": + return "f64" + return "f32" + + def get_fpmath_modes(src, wei, dst): + src, wei, dst = Types.Type.wildcard_match(src, wei, dst) + if get_accumulator(wei) == "f32": + ret = [""] + if "f32" in [src, wei]: + ret.append("(tf32)") + if "f32" in [src, wei] and not "f16" in [src, wei]: + ret.append("(bf16)") + if "f32" in [src, wei] and not "bf16" in [src, wei]: + ret.append("(f16)") + return ret + if ( + get_accumulator(wei) == "s32" + and not is_int_type(dst) + and not is_int_type(src) + ): + ret = [] + if "f32" in [src, wei]: + ret.append("(strict:true)") + ret.append("(tf32:true)") + if "f16" not in [src, wei]: + ret.append("(bf16:true)") + if "bf16" not in [src, wei]: + ret.append("(f16:true)") + return ret + return [""] + + out = [] + for c in support_matrix: + for src, wei, dst in itertools.product(c[0], c[1], c[2]): + for math in get_fpmath_modes(src, wei, dst): + out.append(Types.Type(f"{src}:{wei}:{dst}{math}")) + return out + + +# Kind represents problem parameters that do not make sense to consider +# in aggregate for optimization purposes as these features require significant +# changes within generated implementations or the implementation dispatching. +class Kind: + def __init__(self, layout, type): + self.layout = layout + self.type = type + + def benchdnn_str(self): + return f"{self.layout.benchdnn_str()} {self.type.benchdnn_str()}" + + +class Primitive: + def __init__(self, kind, dims): + self.kind: Kind = kind + self.dims = dims + + def benchdnn_str(self): + return f"{self.kind.benchdnn_str()} {self.dims}" diff --git a/scripts/synthdnn/matmul/sampler.py b/scripts/synthdnn/matmul/sampler.py new file mode 100755 index 00000000000..77a50737559 --- /dev/null +++ b/scripts/synthdnn/matmul/sampler.py @@ -0,0 +1,157 @@ +################################################################################ +# Copyright 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import itertools +import random +import math + +from matmul.primitive import * + + +class Region: + def __init__(self, line): + restrictions = [] + for x in line.split(":"): + if len(x) <= 0 or x[0] != "(" or x[-1] != ")": + raise RuntimeError( + f"Unable to parse restrictions: {x} in {line}" + ) + restrictions.append(x[1:-1]) + if len(restrictions) != 3: + raise RuntimeError(f"Invalid number of restrictions in {line}") + + self.min = [int(x) for x in restrictions[0].split(",")] + self.max = [int(x) for x in restrictions[1].split(",")] + self.alignment = [int(x) for x in restrictions[2].split(",")] + + if len(self.min) != len(self.max) or len(self.min) != len( + self.alignment + ): + raise RuntimeError( + f"Inconsistent number of dimensions between restrictions in {line}" + ) + + self.ndims = len(self.min) + + def __str__(self): + str_min = ",".join([str(x) for x in self.min]) + str_max = ",".join([str(x) for x in self.max]) + str_alignment = ",".join([str(x) for x in self.alignment]) + return f"({str_min}):({str_max}):({str_alignment})" + + +class Sampler: + def __init__(self, samples, mode, types, layouts, region): + self.layouts = layouts + self.mode = mode + self.types = types + self.region = region + self.samples = samples + + random.seed("oneDNN Matmul") + self.kinds = [Kind(x, y) for x, y in itertools.product(layouts, types)] + random.shuffle(self.kinds) + self.dim_sampler = self.DimSampler(region) + + def __str__(self): + return f"-s {self.samples} -m {self.mode} -l {self.layouts} -r {self.region} -t {self.types}" + + def __iter__(self): + if self.mode == "zip": + return self.ZipIter(self.samples, self.kinds, self.dim_sampler) + elif self.mode == "product": + return self.ProductIter(self.samples, self.kinds, self.dim_sampler) + else: + raise RuntimeError(f"Unknown iteration mode {self.mode}") + + # Itertools.product seems to break on an infinite sampler + class ProductIter: + def __init__(self, samples, kinds, dim_sampler): + self.dim_sampler = dim_sampler + self.kinds = kinds + self.kinds_iter = iter(self.kinds) + self.rem_samples = samples + + def __next__(self): + if self.rem_samples == 0: + raise StopIteration + + try: + self.k = next(self.kinds_iter) + self.s = next(self.dim_sampler) + except StopIteration: + self.kinds_iter = iter(self.kinds) + self.k = next(self.kinds_iter) + self.s = next(self.dim_sampler) + self.rem_samples = self.rem_samples - 1 + + return Primitive(self.k, self.s) + + class ZipIter: + def __init__(self, samples, kinds, dim_sampler): + self.dim_sampler = dim_sampler + self.kinds_iter = itertools.cycle(kinds) + self.rem_samples = samples + + def __next__(self): + if self.rem_samples == 0: + raise StopIteration + + self.rem_samples = self.rem_samples - 1 + k = next(self.kinds_iter) + s = next(self.dim_sampler) + + return Primitive(k, s) + + class DimSampler: + def __init__(self, region): + self.region = region + self.seen = set() + if region.ndims < 3: + raise RuntimeError( + f"Insufficient dimensions for matmul operation, expected at least 3, but got {region.ndims}" + ) + + def __next__(self): + + # Sample from a power distribution as most problem features occur + # when some dimension is small. In addition, small problems often + # require less time to run enabling faster data collection + def get_sample(minval, maxval, align): + assert minval <= maxval, "Sample bounds are out of order" + if minval == maxval: + return minval + x = round( + pow(2, random.uniform(math.log2(minval), math.log2(maxval))) + ) + return (x // align) * align + + for _ in range(1000): + dims = [0] * self.region.ndims + for i in range(self.region.ndims): + dims[i] = get_sample( + self.region.min[i], + self.region.max[i], + self.region.alignment[i], + ) + dims_tuple = tuple(dims) + if dims_tuple not in self.seen: + self.seen.add(dims_tuple) + return Dims(dims[:-3], dims[-3], dims[-2], dims[-1]) + + raise RuntimeError( + f"Cannot sample >{len(self.seen)} problems in region {self.region}" + ) diff --git a/scripts/synthdnn/synthdnn.py b/scripts/synthdnn/synthdnn.py new file mode 100755 index 00000000000..7c0f170e3e5 --- /dev/null +++ b/scripts/synthdnn/synthdnn.py @@ -0,0 +1,205 @@ +#! /bin/python3 +################################################################################ +# Copyright 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import os +import sys +from tempfile import NamedTemporaryFile + +from matmul import sampler as matmul_sampler +from matmul import primitive as matmul + + +def log(output): + print("synthdnn: " + output) + + +def error(output): + print("synthdnn: error: " + output) + exit(1) + + +def write_batch_file(batch_file, samples, optional_args): + batch_file.write("#### Auto-generated by synthdnn\n") + batch_file.write(f"#### python3 synthdnn.py {' '.join(sys.argv[1:])}\n\n") + for s in samples: + batch_file.write(f"--reset {optional_args}{s.benchdnn_str()}\n") + batch_file.flush() + + +def setup_collect_args(parser, req): + parser.add_argument( + "-b", + "--batch-file", + required=req, + default=None, + help="batch file used for the operation", + ) + + # Interface with benchdnn + nargs = 1 + if not req: + nargs = "?" + parser.add_argument( + "benchdnn", nargs=nargs, help="path to benchdnn executable" + ) + parser.add_argument( + "--engine", default="cpu", help="engine used for benchdnn execution" + ) + parser.add_argument( + "--impl", + default=None, + help="implementation to use in benchdnn execution", + ) + parser.add_argument( + "--skip-impl", + default=None, + help="implementation to skip in benchdnn execution", + ) + parser.add_argument( + "--collect", + default="corr", + help="benchdnn collection type, can be one of [corr, perf]", + ) + parser.add_argument("-n", "--name", default="", help="sample name") + + +def setup_collect_subparser(subparsers): + collect_parser = subparsers.add_parser( + "collect", help="call with -h for information" + ) + collect_parser.add_argument( + "--subprogram_main", default=collect_main, help=argparse.SUPPRESS + ) + + setup_collect_args(collect_parser, True) + + +def get_optional_args(args): + optional_args = [] + if args.impl: + optional_args.append(f"--impl={args.impl}") + if args.skip_impl: + optional_args.append(f"--skip-impl={args.skip_impl}") + + if len(optional_args) > 0: + return " ".join(optional_args) + " " + + return "" + + +def collect_main(args): + # args.benchdnn may be a list depending on command line setup + benchdnn = args.benchdnn + if type(benchdnn) is list: + benchdnn = benchdnn[0] + + if not os.path.exists(benchdnn): + error(f"cannot execute {benchdnn}, no such file exists") + + if args.collect == "corr": + benchdnn_args = f"--engine={args.engine} --matmul --mode-modifier=P {get_optional_args(args)}" + elif args.collect == "perf": + benchdnn_args = f"--engine={args.engine} --matmul --mode=F --cold-cache=all --perf-template=sample,{args.name},%prb%,%0Gflops%,%0Gbw% --memory-kind=usm_device --attr-scratchpad=user {get_optional_args(args)}" + if args.name.find(",") != -1: + error(f"sample name {args.name} contains invalid character: ,") + else: + error(f"unknown collection method {args.collect}") + + cmd = f"{benchdnn} {benchdnn_args} --batch={args.batch_file}" + log(f"executing: {cmd}") + ret = os.system(cmd) + log("execution complete") + if ret != 0: + error(f"execution of {cmd} failed with return code {ret}") + + +def setup_matmul_subparser(subparsers): + matmul_parser = subparsers.add_parser( + "matmul", help="call with -h for information" + ) + matmul_parser.add_argument( + "--subprogram_main", default=matmul_main, help=argparse.SUPPRESS + ) + + # Data Collection shortcut + setup_collect_args(matmul_parser, False) + + # Sampler Arguments + matmul_parser.add_argument( + "-l", + "--layouts", + default="all", + help='stag:wtag:dtag, comma separated list of layouts or "all" for every supported layout', + ) + matmul_parser.add_argument( + "-m", + "--iter-mode", + default="zip", + help="iteration mode, must be one of zip or product", + ) + matmul_parser.add_argument( + "-r", + "--region", + default="(1,1,1,1):(8,8192,8192,8192):(1,1,1,1)", + help="([b_min,]m_min,n_min,k_min):([b_max,]m_max,n_max,k_max):([b_align,]m_align,n_align,k_align)", + ) + matmul_parser.add_argument( + "-s", "--samples", default=1000, help="number of samples to collect" + ) + matmul_parser.add_argument( + "-t", + "--types", + default="all", + help='dt:dt:dt(optional fpmath-mode), comma separated list of type configurations or "all" for every supported type', + ) + + +def matmul_main(args): + batch_file = ( + open(args.batch_file, "w+t") if args.batch_file is not None else None + ) + if args.benchdnn is not None and batch_file == None: + batch_file = NamedTemporaryFile("w+t") + + region = matmul_sampler.Region(args.region) + types = matmul.Types(args.types) + layouts = matmul.Layouts(args.layouts, region.ndims - 1) + samples = matmul_sampler.Sampler( + int(args.samples), args.iter_mode, types, layouts, region + ) + if batch_file: + log(f"generating batch file: {args.batch_file}") + write_batch_file(batch_file, samples, get_optional_args(args)) + log(f"generation complete") + else: + write_batch_file(sys.stdout, samples, get_optional_args(args)) + + if args.benchdnn: + args.batch_file = batch_file.name + collect_main(args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers( + help="primitive targeted for data collection" + ) + setup_collect_subparser(subparsers) + setup_matmul_subparser(subparsers) + args = parser.parse_args() + args.subprogram_main(args) diff --git a/scripts/verbose_converter/README.md b/scripts/verbose_converter/README.md index 3a7b3af3e26..983a5ecc26a 100644 --- a/scripts/verbose_converter/README.md +++ b/scripts/verbose_converter/README.md @@ -1,7 +1,7 @@ # Verbose log converter Verbose log converter is a tool that allows to convert [oneDNN -verbose](https://oneapi-src.github.io/oneDNN/dev_guide_verbose.html) +verbose](https://uxlfoundation.github.io/oneDNN/dev_guide_verbose.html) output to various outputs (input files for benchdnn and execution statistics breakdown at this time). The tool can be extended to produce other types of output by adding generators. diff --git a/scripts/verbose_converter/src/__init__.py b/scripts/verbose_converter/src/__init__.py new file mode 100644 index 00000000000..3acfa281439 --- /dev/null +++ b/scripts/verbose_converter/src/__init__.py @@ -0,0 +1,18 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# This file marks this directory as a package and is needed to allow relative +# imports. See https://docs.python.org/3/tutorial/modules.html#packages. diff --git a/scripts/verbose_converter/src/benchdnn_generator.py b/scripts/verbose_converter/src/benchdnn_generator.py index 9ed56f199cf..6131c9d70d5 100644 --- a/scripts/verbose_converter/src/benchdnn_generator.py +++ b/scripts/verbose_converter/src/benchdnn_generator.py @@ -1,5 +1,5 @@ ################################################################################ -# Copyright 2020-2024 Intel Corporation +# Copyright 2020-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,831 +14,848 @@ # limitations under the License. ################################################################################ +import logging +from collections import defaultdict +from typing import Dict, List, Mapping, Optional, Set, cast -def everyone_is(list, value="None"): - if [value == "None"]: - value = list[0] - return [e for e in list if e != value] == [] - - -primitives_with_algs = ( - "binary", - "convolution", - "deconvolution", - "eltwise", - "lrn", - "pooling", - "reduction", - "resampling", - "rnn", -) - - -def alg_remove_primitive(alg): - for p in primitives_with_algs: - if alg.find(p) != -1: - alg = alg[(alg.find(p) + len(p) + 1) :] - return alg - - -def convert_driver(prop_kind): - driver = { - "batch_normalization": "bnorm", - "binary": "binary", - "brgemm": "brgemm", - "concat": "concat", - "convolution": "conv", - "deconvolution": "deconv", - "eltwise": "eltwise", - "group_normalization": "gnorm", - "inner_product": "ip", - "layer_normalization": "lnorm", - "lrn": "lrn", - "matmul": "matmul", - "pooling": "pool", - "prelu": "prelu", - "reduction": "reduction", - "reorder": "reorder", - "resampling": "resampling", - "rnn": "rnn", - "shuffle": "shuffle", - "softmax": "softmax", - "sum": "sum", - }.get(prop_kind) - return driver - - -def convert_engine(engine): - return f"--engine={engine}" - - -def convert_dir(entry): - # get base direction - dir = { - "forward_training": "FWD_D", - "forward_inference": "FWD_I", - "backward_data": "BWD_D", - "backward_weights": "BWD_W", - "backward": "BWD_DW", - }.get(entry["prop_kind"]) - - if not dir: - return "" +from . import ir - found_bias = [ - e for e in entry["mds"] if "bia" == e["arg"] and e["data_type"] != "undef" - ] - dir = "FWD_B" if "FWD" in dir and found_bias else dir - dir = "BWD_WB" if dir == "BWD_W" and found_bias else dir - if entry["prim_kind"] == "rnn": - return f"--prop={dir}" - else: - return f"--dir={dir}" - - -def convert_aux(entry): - if entry.get("aux") != None: - alg = entry["aux"]["alg"] if entry["aux"].get("alg") != None else "" - pk = entry["prim_kind"] - if pk == "convolution": - str = "" - alg = alg_remove_primitive(alg) - algs = {"winograd": "WINO", "direct": "direct"} - alg = algs.get(alg) - if alg != None: - str = f"--alg={alg}" - return str - if pk == "eltwise": - alpha = entry["aux"]["alpha"] - beta = entry["aux"]["beta"] - alg += f" --alpha={alpha} --beta={beta}" - return f"--alg={alg}" - elif pk == "concat": - axis = entry["aux"]["axis"] - return f"--axis={axis}" - elif pk in [ - "batch_normalization", - "layer_normalization", - "group_normalization", - ]: - flags = entry["aux"]["flags"] - return f"--flags={flags}" - elif pk == "lrn": - str = "" - alg = alg_remove_primitive(alg) - algs = {"across_channels": "ACROSS", "within_channel": "WITHIN"} - alg = algs.get(alg) - if alg != None: - str = f"--alg={alg}" - return str - elif pk == "reduction": - p = entry["aux"]["p"] - eps = entry["aux"]["eps"] - alg += f" --p={p} --eps={eps}" - return f"--alg={alg}" - elif pk == "rnn": - str = "" - algs = { - "vanilla_rnn": "VANILLA_RNN", - "vanilla_lstm": "VANILLA_LSTM", - "vanilla_gru": "VANILLA_GRU", - "vanilla_augru": "VANILLA_AUGRU", - "lbr_gru": "LBR_GRU", - "lbr_augru": "LBR_AUGRU", - } - alg = algs.get(alg) - if alg != None: - str += f"--alg={alg}" - ir_dir = entry["aux"]["direction"] - dirs = { - "unidirectional_left2right": "left2right", - "unidirectional_right2left": "right2left", - "bidirectional_sum": "sum", - "bidirectional_concat": "concat", - } - dir = dirs.get(ir_dir) - if dir is not None: - str += f" --direction={dir}" - ir_act = entry["aux"]["activation"] - acts = { - "eltwise_relu": "RELU", - "eltwise_logistic": "LOGISTIC", - "eltwise_tanh": "TANH", - } - act = acts.get(ir_act) - if act is not None: - str += f" --activation={act}" - flags = entry["aux"]["flags"] - if flags is not None: - str += f" --flags={flags}" - return str - elif pk == "shuffle": - axis = entry["aux"]["axis"] - group = entry["aux"]["group"] - return f"--axis={axis} --group={group}" - elif pk == "softmax": - axis = entry["aux"]["axis"] - return f"--alg={alg} --axis={axis}" - elif pk == "pooling": - return f"--alg={alg}" - elif pk == "matmul": - runtime_dims_masks = ( - entry["aux"]["runtime_dims_masks"] - if entry["aux"].get("runtime_dims_masks") != None - else "" - ) - return f"--runtime_dims_masks={runtime_dims_masks}" - elif pk == "reorder": - runtime_dim_mask = ( - entry["aux"]["runtime-dim-mask"] - if entry["aux"].get("runtime-dim-mask") != None - else "" - ) - return f"--runtime-dim-mask={runtime_dim_mask}" - elif pk == "brgemm": - bs = entry["aux"]["bs"] if entry["aux"].get("bs") != None else "" - beta = entry["aux"]["beta"] if entry["aux"].get("beta") != None else "" - return f"--bs={bs} --beta={beta}" - else: - alg = alg_remove_primitive(alg) - if alg != "": - return f"--alg={alg}" - return "" - - -def convert_bias_mask(mds): - bia_mds = [md for md in mds if md["arg"] == "bia"] - if len(bia_mds) != 0: - bia_md = bia_mds[0] - flags = bia_md["flags"]["value"].split("_") - if len(flags) > 1: - mask = flags[1][4:] - return f"--bia_mask={mask}" - return "" - - -def convert_dts(mds, prim_kind): - def convert_dts_common(mds): - dts = [md["data_type"] for md in mds if md["data_type"] != "undef"] - dt = dts[0] - return f"--dt={dt}" - - def convert_dts_cfg_rnn(mds): - cfg = "--cfg=" - args = ["src_iter", "src_iter_c", "src_layer", "dst_iter", "dst_layer", "bias"] - mds_strip = [md for md in mds if md["arg"] in args] - # ws is not part of cfg - mds_strip = [md for md in mds_strip if "ws" not in md["arg"]] - # bias is not part of cfg - mds_strip = [md for md in mds_strip if "bia" not in md["arg"]] - common_dt = everyone_is([md["data_type"] for md in mds_strip]) - if common_dt and mds_strip[0]["data_type"] in ["f32", "f16"]: - cfg += mds_strip[0]["data_type"] - elif common_dt and mds_strip[0]["data_type"] == "bf16": - cfg += mds_strip[0]["data_type"] - # bias is part of cfg for bf16 - bias_md = [md for md in mds if md["arg"] == "bias"][0] - bias_dt = bias_md["data_type"] - if bias_dt != mds_strip[0]["data_type"]: - cfg += bias_dt - else: - for arg in args: - for md in mds_strip: - if md["arg"] == arg: - # src iter is skipped if it is f32 - if arg == "src_iter_c" and md["data_type"] == "f16": - continue - cfg += md["data_type"] - return cfg - - def convert_dts_all(mds): - dts = "" - md_args = "" - for md in mds: - md_arg = md["arg"][0] - if md_args.find(md_arg) == -1: - md_dt = md["data_type"] - dts += f" --{md_arg}dt={md_dt}" - md_args += md_arg - return dts - def convert_dts_prelu(mds): - data_md = [md for md in mds if "data" in md["arg"]][0] - weights_md = [md for md in mds if "wei" in md["arg"]][0] - - data_dt = data_md["data_type"] - weights_dt = weights_md["data_type"] - - return f" --sdt={data_dt}:{weights_dt}" - - # --dt=SRC_DT[:WEI_DT][:DST_DT] - def convert_dts_multiple(mds): - dts = "--dt=" - has_fused_dw = 0 - for md in mds: - md_dt = md["data_type"] - md_arg = md["arg"] - if md_arg == "src_fused": - has_fused_dw = 1 - # Fused dw defines dst_dt by src_fused argument - # Note: assumes the order in mds is 'src_fused', then 'dst'. - if has_fused_dw == 1 and md_arg == "dst": - continue +def maybe_make_any_tag(md: ir.MemoryDescriptor): + return "any" if "a" in md.properties else md.tag - if md_arg == "src": - dts += f"{md_dt}" - elif md_arg == "wei": - dts += f":{md_dt}" - elif md_arg == "dst" or md_arg == "src_fused": - dts += f":{md_dt}" - else: - dts += f"" - return dts - def convert_dts_multiple_src(mds): - src_dts = "" - dts = "" - first_src = True - for md in mds: - md_dt = md["data_type"] - md_arg = md["arg"] - if md_arg == "src": - if not first_src: - src_dts += f":{md_dt}" - else: - src_dts += f" --{md_arg[0]}dt={md_dt}" - first_src = False - else: - if md_dt != "undef": - dts += f" --{md_arg[0]}dt={md_dt}" - return src_dts + dts - - def convert_dts_with_bias(mds): - dt = convert_dts_multiple(mds) - mds_bias = [md for md in mds if "bia" in md["arg"]] - if len(mds_bias) != 0: - md_bias = mds_bias[0] - bias_dt = md_bias["data_type"] - dt += " " + f"--bia_dt={bias_dt}" - return dt - - def convert_dts_with_ss(mds): - dt = convert_dts_multiple(mds) - mds_scale = [md for md in mds if "scale" in md["arg"]] - mds_shift = [md for md in mds if "shift" in md["arg"]] - - if len(mds_scale) != 0: - md_scale = mds_scale[0] - scale_dt = md_scale["data_type"] - dt += " " + f"--ss_dt={scale_dt}" - elif len(mds_shift) != 0: - md_shift = mds_shift[0] - shift_dt = md_shift["data_type"] - dt += " " + f"--ss_dt={shift_dt}" - - return dt - - convert_dts = { - "batch_normalization": convert_dts_common, - "binary": convert_dts_multiple_src, - "brgemm": convert_dts_multiple, - "concat": convert_dts_all, - "convolution": convert_dts_multiple, - "deconvolution": convert_dts_multiple, - "eltwise": convert_dts_common, - "inner_product": convert_dts_multiple, - "group_normalization": convert_dts_multiple, - "layer_normalization": convert_dts_with_ss, - "lrn": convert_dts_common, - "matmul": convert_dts_with_bias, - "pooling": convert_dts_multiple, - "prelu": convert_dts_prelu, - "reduction": convert_dts_all, - "reorder": convert_dts_all, - "resampling": convert_dts_all, - "rnn": convert_dts_cfg_rnn, - "shuffle": convert_dts_common, - "softmax": convert_dts_all, - "sum": convert_dts_multiple_src, - } +def attribute_flag(name: str): + def wrapper(converter: "Converter"): + attr = getattr(converter.entry.exts, name) + flag_name = name.replace("_", "-") + if attr is None: + return "" + return f"--attr-{flag_name}={attr!s}" + + return property(wrapper) + + +class ConverterMeta(type): + driver: str + + +class Converter(metaclass=ConverterMeta): + def __init__(self, entry: ir.Entry): + self.entry = entry + + def _get_dir(self): + dirs = { + "forward_training": "FWD_D", + "forward_inference": "FWD_I", + "backward_data": "BWD_D", + "backward_weights": "BWD_W", + "backward": "BWD_DW", + } + + if self.entry.prop_kind not in dirs: + return "" + + return dirs[self.entry.prop_kind] + + def _get_alg(self): + return self.entry.aux.get("alg") + + @staticmethod + def _get_policies(): + return "common", "per_oc" + + @staticmethod + def _get_policy_map(): + return 0, 1, 1, 1 + + def policy(self, mask: int): + policies = self._get_policies() + policy_map = self._get_policy_map() - convert = convert_dts.get(prim_kind) - if convert != None: - return convert(mds) - # FIXME: Error handling. Throw an error if get() is used, but None returned - return "" - - -def convert_tags(mds, prim_kind): - def convert_tags_common(mds): - tags = [md["tag"] for md in mds if md["tag"] != ""] - tag = tags[0] - return f"--tag={tag}" if tag else "" - - def convert_tags_all(mds): - tags = "" - has_fused_dw = 0 - for md in mds: - md_arg = md["arg"] - md_arg_abbr = md["arg"][0] - if md_arg == "src_fused": - has_fused_dw = 1 - md_arg_abbr = "d" - - # Fused dw defines dst_dt by src_fused argument - # Note: assumes the order in mds is 'src_fused', then 'dst'. - if has_fused_dw == 1 and md_arg == "dst": + if mask >= len(policy_map) or policy_map[mask] >= len(policies): + return "per_tensor" + return policies[policy_map[mask]] + + @property + def engine(self): + return f"--engine={self.entry.engine}" + + @property + def dir(self): + if self._get_dir(): + return f"--dir={self._get_dir()}" + return "" + + @property + def bias_mask(self): + return "" + + @property + def dts(self): + for md in self.entry.mds: + if md.data_type == "undef": continue - # skip bias and dw_fused weights - if md_arg_abbr == "b" or md_arg == "wei_fused": + return f"--dt={md.data_type}" + return "" + + @property + def tags(self): + for md in self.entry.mds: + if not md.tag: continue + return f"--tag={md.tag}" # XXX: Don't use maybe_make_any_tag + return "" + + @property + def flags(self): + return "" + + def _get_nondefault_args(self, values, defaults): + parts: List[str] = [] + pairs = list(zip(values, defaults)) + seen_nondefault = False + for value, default in reversed(pairs): + if value != default: + seen_nondefault = True + if seen_nondefault: + parts.append(str(value)) + return list(reversed(parts)) + + def _convert_dw_post_op(self, po: ir.DepthwisePostOp): + return f"dw:{po.ksp}:{po.dst_dt}" + + def _convert_sum_post_op(self, po: ir.SumPostOp): + values = po.scale, po.zp, po.dt + args = self._get_nondefault_args(values, defaults=(1.0, 0, "")) + return ":".join(["sum"] + args) + + def _convert_prelu_post_op(self, po: ir.PreLUPostOp): + if po.mask != 0: + return f"prelu:{self.policy(po.mask)}" + return "prelu" + + def _convert_eltwise_post_op(self, po: ir.EltwisePostOp): + values = po.alpha, po.beta, po.scale + args = self._get_nondefault_args(values, defaults=(0.0, 0.0, 1.0)) + return ":".join([po.alg] + args) + + def _convert_binary_post_op(self, po: ir.BinaryPostOp): + if po.tag != "any": + return f"{po.alg}:{po.dt}:{po.mask}:{po.tag}" + return f"{po.alg}:{po.dt}:{po.mask}" + + @property + def post_ops(self): + post_ops = self.entry.exts.post_ops + if post_ops is None: + return "" + results = [] + for post_op in post_ops: + if post_op.alg == "dw": + dw_po = cast(ir.DepthwisePostOp, post_op) + results.append(self._convert_dw_post_op(dw_po)) + elif post_op.alg == "sum": + sum_po = cast(ir.SumPostOp, post_op) + results.append(self._convert_sum_post_op(sum_po)) + elif post_op.alg == "prelu": + prelu_po = cast(ir.PreLUPostOp, post_op) + results.append(self._convert_prelu_post_op(prelu_po)) + elif post_op.alg.startswith("binary_"): + binary_po = cast(ir.BinaryPostOp, post_op) + results.append(self._convert_binary_post_op(binary_po)) + elif post_op.alg.startswith("eltwise_"): + eltwise_po = cast(ir.EltwisePostOp, post_op) + results.append(self._convert_eltwise_post_op(eltwise_po)) + return "--attr-post-ops=" + "+".join(results) + + def _get_quantization( + self, + params: Optional[Mapping[str, ir.QuantizationParam]], + def_value: float, + def_type: str, + ): + if params is None: + return "" + results = [] + for arg, param in params.items(): + policy = self.policy(param.mask) + result = f"{arg}:{policy}" + if policy == "common": + result += f":{def_value}" + dt = param.data_type + groups = param.groups + if dt != def_type or groups: + result += f":{dt}" + if groups: + result += f":{groups}" + results.append(result) + return "+".join(results) + + @property + def scales(self): + params = self._get_quantization(self.entry.exts.scales, 0.5, "f32") + return f"--attr-scales={params}" + + @property + def zero_points(self): + params = self._get_quantization(self.entry.exts.zero_points, 1, "s32") + return f"--attr-zero-points={params}" + + @property + def rounding_mode(self): + rounding_modes = self.entry.exts.rounding_mode + if rounding_modes is None: + return "" + results = [] + for arg, mode in rounding_modes.items(): + results.append(f"{arg}:{mode!s}") + return "--attr-rounding-mode=" + "+".join(results) + + scratchpad_mode = attribute_flag("scratchpad") + fpmath_mode = attribute_flag("fpmath") + acc_mode = attribute_flag("acc_mode") + + @property + def dropout(self): + dropout = self.entry.exts.dropout + if dropout is None: + return "" + # Use default p=0.5 and seed=12345 since those values are user data and + # can't be obtained properly. + result = "0.5:12345" + if dropout.tag: + result += f":{dropout.tag}" + return f"--attr-dropout={result}" + + deterministic = attribute_flag("deterministic") + + @property + def attrs(self): + attrs = ( + self.post_ops, + self.scales, + self.zero_points, + self.scratchpad_mode, + self.fpmath_mode, + self.acc_mode, + self.rounding_mode, + self.dropout, + self.deterministic, + ) + return " ".join(attr for attr in attrs if attr) + + @property + def aux(self): + alg = self._get_alg() + if alg is not None: + return f"--alg={alg}" + return "" - if "a" in md["properties"]: - tags += f" --{md_arg_abbr}tag=any" + @property + def shapes(self): + return self.entry.shapes + + +class AlgorithmMixin: + entry: ir.Entry + + def _get_alg(self): + alg = self.entry.aux.get("alg") + if alg is None: + return None + return alg.split(self.entry.prim_kind, 1)[1][1:] + + +class MultiSourceMixin: + entry: ir.Entry + + @property + def dts(self): + src_dts: List[str] = [] + other_dts: Dict[str, str] = {} + for md in self.entry.mds: + dt = md.data_type + if md.arg == "src": + src_dts.append(dt) + elif dt != "undef": + other_dts[md.arg[0]] = dt + sdt_flags = "--sdt=" + ":".join(src_dts) + other_dt_flags = " ".join(f"--{k}dt={v}" for k, v in other_dts.items()) + return f"{sdt_flags} {other_dt_flags}".strip() + + @property + def tags(self): + src_tags: List[str] = [] + other_tags: Dict[str, str] = {} + for md in self.entry.mds: + if md.arg == "src": + src_tags.append(maybe_make_any_tag(md)) + elif md.tag: + other_tags[md.arg[0]] = maybe_make_any_tag(md) + stag_flags = "--stag=" + ":".join(src_tags) + other_tag_flags = " ".join( + f"--{k}tag={v}" for k, v in other_tags.items() + ) + return f"{stag_flags} {other_tag_flags}".strip() + + +class CommonDataTypeMixin: + entry: ir.Entry + + @property + def dts(self): + dts: Dict[str, str] = {} + for md in self.entry.mds: + c = md.arg[0] + if c in dts: + continue + dts[c] = md.data_type + return " ".join(f"--{k}dt={v}" for k, v in dts.items()) + + +class TagTripletMixin: + entry: ir.Entry + + @property + def tags(self): + md_map = {md.arg: md for md in self.entry.mds} + has_fused_dw = "src_fused" in md_map + # Fused dw defines dst tag by src_fused argument + dst_name = "src_fused" if has_fused_dw else "dst" + tags = [] + if "src" in md_map: + md = md_map["src"] + tag = maybe_make_any_tag(md) + tags.append(f"--stag={tag}") + if "wei" in md_map: + md = md_map["wei"] + tag = maybe_make_any_tag(md) # pass wtag any for cases with compensation - elif md_arg_abbr == "w" and md["flags"]["value"] != "f0": - tags += f" --{md_arg_abbr}tag=any" + if str(md.flags.value) != "f0": + tag = "any" + tags.append(f"--wtag={tag}") + if dst_name in md_map: + md = md_map[dst_name] + tag = maybe_make_any_tag(md) + tags.append(f"--dtag={tag}") + return " ".join(tags) + + +class StridesMixin(TagTripletMixin): + @property + def tags(self): + tags = [] + strides = [] + + def add_strides_or_tag(arg, md): + tag = maybe_make_any_tag(md) + if arg == "wei" and str(md.flags.value) != "f0": + tag = "any" + if tag != "any" and tag.lower() == tag and md.strides: + strides.append(md.strides) else: - md_tag = md["tag"] - tags += f" --{md_arg_abbr}tag={md_tag}" - return tags - - def convert_tags_and_strides(mds): - tags = "" - strides = f" --strides=" - for md in mds: - md_arg = md["arg"][0] - # skip bias - if md_arg == "b": + tags.append(f"--{arg[0]}tag={tag}") + strides.append("") + + md_map = {md.arg: md for md in self.entry.mds} + args = "src", "wei", "dst" + for arg in args: + if arg not in md_map: continue + md = md_map[arg] + add_strides_or_tag(arg, md) + stride_flag = "--strides=" + ":".join(strides) + return " ".join(tags + [stride_flag]) + + +class MultiDataTypeMixin: + entry: ir.Entry + + @property + def dts(self): + dt_map = {md.arg: md.data_type for md in self.entry.mds} + # Fused dw defines dst_dt by src_fused argument + has_fused_dw = "src_fused" in dt_map + dst_name = "src_fused" if has_fused_dw else "dst" + dts = [ + dt_map.get("src", ""), + dt_map.get("wei", ""), + dt_map.get(dst_name, ""), + ] + return "--dt=" + ":".join(dt for dt in dts if dt) + + +class MultiDataTypeWithBiasMixin(MultiDataTypeMixin): + @property + def dts(self): + dts = super().dts + for md in self.entry.mds: + if md.arg != "bia": + continue + return f"{dts} --bia-dt={md.data_type}".strip() + return dts - if "a" in md["properties"]: - tags += f" --{md_arg}tag=any" - # pass wtag any for cases with compensation - elif md_arg == "w" and md["flags"]["value"] != "f0": - tags += f" --{md_arg}tag=any" - else: - md_strides = md["strides"] - - def tag_has_blocks(string): - for l in string: - if l.isupper(): - return True - return False - - md_tag_has_blocks = tag_has_blocks(md["tag"]) - if md_strides != "" and not md_tag_has_blocks: - strides += f"{md_strides}" - else: - md_tag = md["tag"] - tags += f" --{md_arg}tag={md_tag}" - if md_arg != "d": - strides += f":" - - tags += strides - return tags - # --tag=SRC_TAG[:WEI_TAG][:DST_TAG] - def convert_tags_multiple(mds): - tags = "--tag=" - for md in mds: - md_tag = md["tag"] - md_arg = md["arg"] - if md_arg == "src" or md_arg == "wei" or md_arg == "dst": - if md_arg != "src": - tags += f":" - if "a" in md["properties"]: - tags += f"any" - else: - tags += f"{md_tag}" - else: - tags += f"" - return tags - - def convert_tags_multiple_src(mds): - src_tags = "" - tags = "" - first_src = False - for md in mds: - md_tag = md["tag"] - md_arg = md["arg"] - if md_arg == "src": - if first_src: - if "a" in md["properties"]: - src_tags += f":any" - else: - src_tags += f":{md_tag}" - else: - if "a" in md["properties"]: - src_tags += f" --{md_arg[0]}tag=any" - else: - src_tags += f" --{md_arg[0]}tag={md_tag}" - first_src = True - else: - if md_tag != "": - if "a" in md["properties"]: - tags += f" --{md_arg[0]}tag=any" - else: - tags += f" --{md_arg[0]}tag={md_tag}" - return src_tags + tags - - def convert_tags_prelu(mds): - # FIXME: fix benchdnn input template - data_md = [md for md in mds if "data" in md["arg"]][0] - weights_md = [md for md in mds if "wei" in md["arg"]][0] - - data_tag = data_md["tag"] - if "a" in data_md["properties"]: - data_tag = "any" - weights_tag = weights_md["tag"] - if "a" in weights_md["properties"]: - weights_tag = "any" - - return f" --stag={data_tag}:{weights_tag}" - - def convert_tags_rnn(mds): - tags = "--tag=" - with_proj = "" - with_peep = "" - skip_colon = True +class NormalizationMixin: + entry: ir.Entry - # Tags for backward are driven by diff tensors, query them instead of - # forward tensors. Latter will always have `any` format. - has_diff_tensors = False - for md in mds: - if md["arg"].find("diff") != -1: - has_diff_tensors = True + @property + def aux(self): + flags = self.entry.aux.get("flags") + if flags is not None: + return f"--flags={flags}" + return "" - for md in mds: - md_arg = md["arg"] - md_tag = md["tag"] - if has_diff_tensors == True: - if md_arg in ["diff_src_layer", "diff_wei_layer", "diff_dst_layer"]: - if not skip_colon: - tags += f":" - if "a" in md["properties"]: - tags += f"any" - else: - tags += f"{md_tag}" - skip_colon = False - else: - if md_arg in ["src_layer", "wei_layer", "dst_layer"]: - if not skip_colon: - tags += f":" - if "a" in md["properties"]: - tags += f"any" - else: - tags += f"{md_tag}" - skip_colon = False - - if md_arg == "wei_proj" and md_tag != "undef": - with_proj = " --with-projection=true" - if md_arg == "wei_peephole" and md_tag != "undef": - with_peep = " --with-peephole=true" - - return tags + with_proj + with_peep - - def convert_tags_lnorm(mds): - tag = convert_tags_multiple(mds) - stat_md = "" - for md in mds: - if md["arg"] == "stats": - stat_tag = md["tag"] - - return f"{tag} --stat_tag={stat_tag}" - - cvt_tags = { - "batch_normalization": convert_tags_common, - "binary": convert_tags_multiple_src, - "concat": convert_tags_multiple_src, - "convolution": convert_tags_all, - "deconvolution": convert_tags_all, - "eltwise": convert_tags_common, - "inner_product": convert_tags_all, - "group_normalization": convert_tags_multiple, - "layer_normalization": convert_tags_lnorm, - "lrn": convert_tags_common, - "matmul": convert_tags_and_strides, - "pooling": convert_tags_common, - "prelu": convert_tags_prelu, - "reduction": convert_tags_all, - "reorder": convert_tags_and_strides, - "resampling": convert_tags_common, - "rnn": convert_tags_rnn, - "shuffle": convert_tags_common, - "softmax": convert_tags_all, - "sum": convert_tags_multiple_src, - } - convert = cvt_tags.get(prim_kind) - if convert: - return convert(mds) - return "" - - -def convert_flags(mds, prim_kind): - def convert_flags_reorder(mds): - def convert_flag(prefix, md): - flag = "" - flag_fields = md.get("flags") - if flag_fields != None: - cvt = {"s8_comp_mask": "s8s8_comp", "zp_comp_mask": "zp_comp"} - for f in cvt.keys(): - value = flag_fields.get(f) - if value != None: - benchdnn_flag = cvt[f] + ":" + value - if flag == "": - flag = benchdnn_flag - else: - flag += "+" + benchdnn_flag - if flag != "": - return f"--{prefix}flag={flag}" - else: - return "" +class BatchNormalizationConverter(NormalizationMixin, Converter): + driver: str = "bnorm" - flags = "" - # FIXME: fix benchdnn input template - input_md = [md for md in mds if "src" in md["arg"]][0] - output_md = [md for md in mds if "dst" in md["arg"]][0] - iflag = convert_flag("i", input_md) - oflag = convert_flag("o", output_md) +class BinaryConverter(AlgorithmMixin, MultiSourceMixin, Converter): + driver: str = "binary" - if iflag != "": - flags += iflag - if oflag != "": - flags += " " + oflag - return flags + @property + def shapes(self): + return self.entry.shapes.split(" ", 1)[0] - def convert_flags_rnn(mds): - for md in mds: - md_arg = md["arg"] - if md_arg == "src_iter" or md_arg == "src_layer": - md_strides = md["strides"] - if md_strides != "": - return f"--trivial-strides=false" - return f"--trivial-strides=true" +class BRGEMMConverter(MultiDataTypeMixin, Converter): + driver: str = "brgemm" - cvt_flags = { - "rnn": convert_flags_rnn, - "reorder": convert_flags_reorder, - } + @property + def aux(self): + bs = self.entry.aux.get("bs", "") + beta = self.entry.aux.get("beta", "") + return f"--bs={bs} --beta={beta}" + + +class ConcatConverter(CommonDataTypeMixin, MultiSourceMixin, Converter): + driver: str = "concat" - convert = cvt_flags.get(prim_kind) - if convert: - return convert(mds) - return "" + @property + def aux(self): + axis = self.entry.aux.get("axis") + if axis is None: + return "" + return f"--axis={axis}" -def extract_attr(attrs, type): - start_idx = attrs.find(type) - if start_idx == -1: +class ConvolutionConverter( + AlgorithmMixin, + TagTripletMixin, + MultiDataTypeWithBiasMixin, + Converter, +): + driver: str = "conv" + + @property + def aux(self): + alg = self._get_alg() + if alg is not None: + return f"--alg={alg}" return "" - start_idx += len(type) + 1 - end_symbol = ";" - if type == "post_ops": - start_idx += 1 - end_symbol = "'" - end_idx = attrs.find(end_symbol, start_idx) - if type == "post_ops": - start_idx -= 1 - end_idx += 1 - return attrs[start_idx:end_idx] - - -def convert_scale_policy(value, prim_kind): - if prim_kind == "reorder": - masks = {0: "common", 1: "per_dim_0", 2: "per_dim_1", 3: "per_dim_01"} - elif prim_kind == "matmul": - masks = { - 0: "common", - 1: "per_oc", - 2: "per_oc", - 3: "per_ocic", - 4: "per_oc", - 6: "per_ocic", - 8: "per_oc", - 12: "per_ocic", - } - else: - masks = {0: "common", 1: "per_oc", 2: "per_oc", 3: "per_oc"} - - mask = masks.get(int(value)) - if mask: - return mask - # this is a workaround for tensors with mask more than 4 - return "per_tensor" - - -def convert_zp_policy(value, prim_kind): - if prim_kind == "matmul": - masks = { - 0: "common", - 2: "per_oc", - 3: "per_ocic", - 4: "per_oc", - 6: "per_ocic", - 12: "per_ocic", - } - else: - masks = {0: "common", 2: "per_dim_1"} - mask = masks.get(int(value)) - if mask: - return mask - # this is a workaround for tensors with mask more than 4 - return "per_tensor" - - -def convert_post_ops(post_ops, prim_kind): - def convert_binary_post_op(post_op): - po = post_op["alg"] + ":" + post_op["dt"] + ":" + post_op["mask"] - if post_op["tag"] != None: - po += ":" + post_op["tag"] - return po - - def convert_dw_post_op(post_op): - po = post_op["alg"] + ":" + post_op["ksp"] + ":" + post_op["dst_dt"] - return po - - def convert_eltwise_post_op(post_op): - benchdnn_p_op = post_op["alg"] - alpha = post_op["alpha"] - beta = post_op["beta"] - scale = post_op["scale"] - if alpha != "1.0": - benchdnn_p_op += ":" + alpha - if beta != "0.0": - benchdnn_p_op += ":" + beta - if alpha != "1.0": - benchdnn_p_op += ":" + scale - return benchdnn_p_op - - def convert_sum_post_op(post_op): - benchdnn_p_op = post_op["alg"] - if post_op["scale"] != 1.0: - benchdnn_p_op += ":" + post_op["scale"] - if post_op["zp"] != 0: - benchdnn_p_op += ":" + post_op["zp"] - if post_op["dt"] != "": - benchdnn_p_op += ":" + post_op["dt"] - return benchdnn_p_op - - def convert_prelu_post_op(post_op): - benchdnn_p_op = post_op["alg"] - if post_op["mask"] != 0: - policy = convert_scale_policy(post_op["mask"], prim_kind) - benchdnn_p_op += ":" + policy - return benchdnn_p_op - - convert = { - "binary": convert_binary_post_op, - "dw": convert_dw_post_op, - "eltwise": convert_eltwise_post_op, - "sum": convert_sum_post_op, - "prelu": convert_prelu_post_op, - } - benchdnn_postops = "" - for e in post_ops: - for k in convert.keys(): - if k in e["alg"]: - cvt = convert.get(k) - if benchdnn_postops != "": - benchdnn_postops += "+" - benchdnn_postops += cvt(e) - break - return benchdnn_postops +class DeconvolutionConverter(ConvolutionConverter): + driver: str = "deconv" -def convert_quantization(q_param, prim_kind, def_value, def_type): - res = [] - for arg in q_param.keys(): - p = q_param[arg] - policy = convert_scale_policy(p["mask"], prim_kind) - benchdnn_p = arg + ":" + policy - if policy == "common": - benchdnn_p += ":" + def_value - dt = p["data_type"] - groups = p["groups"] - if dt != def_type or groups != "": - benchdnn_p += ":" + dt - if groups != "": - benchdnn_p += ":" + groups - res.append(benchdnn_p) - return "+".join(res) +class EltwiseConverter(Converter): + driver: str = "eltwise" + @property + def aux(self): + alpha = self.entry.aux.get("alpha") + beta = self.entry.aux.get("beta") + flags = [f"--alg={self._get_alg()}"] + if alpha is not None: + flags.append(f"--alpha={alpha}") + if beta is not None: + flags.append(f"--beta={beta}") + return " ".join(flags) -def convert_scales(scales, prim_kind): - return convert_quantization( - q_param=scales, prim_kind=prim_kind, def_value="0.5", def_type="f32" - ) +class GroupNormalizationConverter( + MultiDataTypeMixin, + BatchNormalizationConverter, +): + driver: str = "gnorm" -def convert_zero_points(zero_points, prim_kind): - return convert_quantization( - q_param=zero_points, prim_kind=prim_kind, def_value="1", def_type="s32" - ) + # --tag=SRC_TAG[:WEI_TAG][:DST_TAG] + @property + def tags(self): + tag_map = {md.arg: maybe_make_any_tag(md) for md in self.entry.mds} + args = "src", "wei", "dst" + tags = [tag_map[arg] for arg in args if arg in tag_map] + return "--tag=" + ":".join(tags) + + +class InnerProductConverter( + TagTripletMixin, MultiDataTypeWithBiasMixin, Converter +): + driver: str = "ip" + + +class LayerNormalizationConverter(GroupNormalizationConverter): + driver: str = "lnorm" + + @property + def dts(self): + dts = super().dts + shift_flag = None + for md in self.entry.mds: + if "scale" in md.arg: + return f"{dts} --ss_dt={md.data_type}".strip() + if "shift" in md.arg and shift_flag is None: + shift_flag = f"--ss_dt={md.data_type}" + if shift_flag is not None: + return f"{dts} {shift_flag}".strip() + return dts -def convert_rounding_mode(rounding_modes, prim_kind): - res = [] - for arg in rounding_modes.keys(): - res.append(arg + ":" + rounding_modes[arg]) - return "+".join(res) + @property + def tags(self): + tags = super().tags + for md in self.entry.mds: + if md.arg == "stats": + tags = f"{tags} --stat_tag={maybe_make_any_tag(md)}" + return tags.strip() -def convert_scratchpad_mode(scratchpad_mode, prim_kind): - return scratchpad_mode +class LRNConverter(AlgorithmMixin, Converter): + driver: str = "lrn" -def convert_fpmath_mode(fpmath_mode, prim_kind): - return fpmath_mode + @property + def aux(self): + alg = self._get_alg() + algs = {"across_channels": "ACROSS", "within_channel": "WITHIN"} + if alg not in algs: + return "" + return f"--alg={algs[alg]}" -def convert_acc_mode(acc_mode, prim_kind): - return acc_mode +class MatmulConverter(StridesMixin, MultiDataTypeWithBiasMixin, Converter): + driver: str = "matmul" + @staticmethod + def _get_policies(): + return "common", "per_oc", "per_ocic" -def convert_dropout(dropout, prim_kind): - ret = dropout["p"] - if dropout["seed"] != None: - ret += ":" + dropout["seed"] - if dropout["tag"] != None: - ret += ":" + dropout["tag"] - return ret + @staticmethod + def _get_policy_map(): + return 0, 1, 1, 2, 1, 3, 2, 3, 1, 3, 3, 3, 2 + @property + def bias_mask(self): + for md in self.entry.mds: + if md.arg != "bia": + continue + if "_" in md.flags.value: + mask = md.flags.value.split("_")[1][4:] + return f"--bia_mask={mask}" + return "" -def convert_deterministic(deterministic, prim_kind): - return deterministic + @property + def aux(self): + rt_dim_masks = self.entry.aux.get("runtime_dims_masks", "") + return f"--runtime_dims_masks={rt_dim_masks}" -def convert_attrs(exts, prim_kind): - converters = { - "attr-post-ops": convert_post_ops, - "attr-scales": convert_scales, - "attr-zero-points": convert_zero_points, - "attr-scratchpad": convert_scratchpad_mode, - "attr-fpmath": convert_fpmath_mode, - "attr-acc": convert_acc_mode, - "attr-rounding-mode": convert_rounding_mode, - "attr-dropout": convert_dropout, - "attr-deterministic": convert_deterministic, - } +class PoolingConverter(MultiDataTypeMixin, Converter): + driver: str = "pool" + + @property + def aux(self): + return f"--alg={self._get_alg()}" + + +class PreLUConverter(Converter): + driver: str = "prelu" - benchdnn_attrs = "" - for e in converters.keys(): - attr = exts.get(e) - if attr != None: - if benchdnn_attrs != "": - benchdnn_attrs += " " - benchdnn_attrs += f"--{e}=" + converters[e](attr, prim_kind) - return benchdnn_attrs + @property + def dts(self): + data_dt, wei_dt = "", "" + for md in self.entry.mds: + if "data" in md.arg and not data_dt: + data_dt = md.data_type + if "wei" in md.arg and not wei_dt: + wei_dt = md.data_type + if data_dt and wei_dt: + break + return f"--sdt={data_dt}:{wei_dt}" + + @property + def tags(self): + data_tag, wei_tag = "", "" + for md in self.entry.mds: + if "data" in md.arg and not data_tag: + data_tag = maybe_make_any_tag(md) + if "wei" in md.arg and not wei_tag: + wei_tag = maybe_make_any_tag(md) + if data_tag and wei_tag: + break + return f"--stag={data_tag}:{wei_tag}" + + +class ReductionConverter( + AlgorithmMixin, + TagTripletMixin, + CommonDataTypeMixin, + Converter, +): + driver: str = "reduction" + + @property + def aux(self): + p = self.entry.aux.get("p") + eps = self.entry.aux.get("eps") + args = [f"--alg={self._get_alg()}"] + if p is not None: + args.append(f"--p={p}") + if eps is not None: + args.append(f"--eps={eps}") + return " ".join(args) + + +class ReorderConverter(StridesMixin, CommonDataTypeMixin, Converter): + driver: str = "reorder" + + def _convert_flag(self, prefix, md: ir.MemoryDescriptor): + flags = [] + fields = md.flags + if fields.s8_comp_mask is not None: + flags.append(f"s8s8_comp:{fields.s8_comp_mask}") + if fields.zp_comp_mask is not None: + flags.append(f"zp_comp:{fields.zp_comp_mask}") + if flags: + return f"--{prefix}flag=" + "+".join(flags) + return "" + + @staticmethod + def _get_policies(): + return "common", "per_dim_0", "per_dim_1", "per_dim_01" + @staticmethod + def _get_policy_map(): + return 0, 1, 2, 3 -def convert_shapes(shapes, prim_kind): - if prim_kind == "binary": - shapes = shapes.split(" ")[0] - return f"{shapes}" + @property + def flags(self): + flags = {} + for md in self.entry.mds: + if "src" in md.arg and "src" not in flags: + flags["src"] = self._convert_flag("i", md) + elif "dst" in md.arg and "dst" not in flags: + flags["dst"] = self._convert_flag("o", md) + + if "src" in flags and "dst" in flags: + break + iflag = flags.get("src", "") + oflag = flags.get("dst", "") + return f"{iflag} {oflag}".strip() + + @property + def aux(self): + mask = self.entry.aux.get("runtime-dim-mask") + if mask: + return f"--runtime-dim-mask={mask}" + return "" + + +class ResamplingConverter(AlgorithmMixin, CommonDataTypeMixin, Converter): + driver: str = "resampling" + + +class RNNConverter(AlgorithmMixin, Converter): + driver: str = "rnn" + + @property + def flags(self): + for md in self.entry.mds: + if md.arg not in ("src_iter", "src_layer"): + continue + if md.strides == "": + continue + return "--trivial-strides=false" + return "--trivial-strides=true" + + def _get_flag_from(self, flag_name, flag_values): + flag = self.entry.aux.get(flag_name) + if flag is None or flag not in flag_values: + return "" + return f"--{flag_name}={flag_values[flag]}" + + @property + def aux(self): + algs = { + "vanilla_rnn": "VANILLA_RNN", + "vanilla_lstm": "VANILLA_LSTM", + "vanilla_gru": "VANILLA_GRU", + "vanilla_augru": "VANILLA_AUGRU", + "lbr_gru": "LBR_GRU", + "lbr_augru": "LBR_AUGRU", + } + dirs = { + "unidirectional_left2right": "left2right", + "unidirectional_right2left": "right2left", + "bidirectional_sum": "sum", + "bidirectional_concat": "concat", + } + acts = { + "eltwise_relu": "RELU", + "eltwise_logistic": "LOGISTIC", + "eltwise_tanh": "TANH", + } + all_flags = [ + self._get_flag_from("alg", algs), + self._get_flag_from("direction", dirs), + self._get_flag_from("activation", acts), + ] + flags = self.entry.aux.get("flags") + if flags is not None: + all_flags.append(f"--flags={flags}") + return " ".join(flag for flag in all_flags if flag) + + @property + def dir(self): + dir = self._get_dir() + return f"--prop={dir}" + + @property + def dts(self): + args = ["src_iter", "src_iter_c", "src_layer", "dst_iter", "dst_layer"] + cfg_dts: str + common_dt = True + shared_dt = None + bias_dt = None + md_map: Dict[Optional[str], ir.MemoryDescriptor] = {} + for md in self.entry.mds: + md_map[md.arg] = md + if md.arg == "bias": + bias_dt = md.data_type + elif md.arg in args: + if shared_dt is None: + shared_dt = md.data_type + elif md.data_type != shared_dt: + common_dt = False + if common_dt and shared_dt in ["f32", "f16"]: + cfg_dts = shared_dt + elif common_dt and shared_dt == "bf16": + cfg_dts = shared_dt + # bias is part of cfg for bf16 + if bias_dt is not None and bias_dt != shared_dt: + cfg_dts += bias_dt + else: + cfg_dts = "" + for arg in args: + if arg not in md_map: + continue + md = md_map[arg] + # src iter is skipped if it is f16 + if arg == "src_iter_c" and md.data_type == "f16": + continue + cfg_dts += md.data_type + return f"--cfg={cfg_dts}" + + @property + def tags(self): + # Tags for backward are driven by diff tensors, query them instead of + # forward tensors. Latter will always have `any` format. + has_diff_tensors = False + for md in self.entry.mds: + if "diff" in md.arg: + has_diff_tensors = True + break + + layer_names = ["src_layer", "wei_layer", "dst_layer"] + if has_diff_tensors: + layer_names = [f"diff_{name}" for name in layer_names] + tags = [] + other_flags = [] + for md in self.entry.mds: + arg = md.arg + tag = maybe_make_any_tag(md) + if arg in layer_names: + tags.append(tag) + elif md.tag == "undef": + continue + elif arg == "wei_proj": + other_flags.append("--with-projection=true") + elif arg == "wei_peephole": + other_flags.append("--with-peephole=true") + tag_flag = "--tag=" + ":".join(tags) + return " ".join([tag_flag] + other_flags) + + +class ShuffleConverter(Converter): + driver: str = "shuffle" + + @property + def aux(self): + axis = self.entry.aux.get("axis") + group = self.entry.aux.get("group") + args = [] + if axis is not None: + args.append(f"--axis={axis}") + if group is not None: + args.append(f"--group={group}") + return " ".join(args) + + +class SoftmaxConverter(TagTripletMixin, CommonDataTypeMixin, Converter): + driver: str = "softmax" + + @property + def aux(self): + axis = self.entry.aux.get("axis") + flags = f"--alg={self._get_alg()}" + if axis is not None: + flags += f" --axis={axis}" + return flags + + +class SumConverter(MultiSourceMixin, Converter): + driver: str = "sum" + + +class ZeroPadConverter(Converter): + driver: str = "zeropad" + + @property + def dts(self): + return f"--dt={self.entry.mds[0].data_type}" + + @property + def tags(self): + return f"--tag={maybe_make_any_tag(self.entry.mds[0])}" + + +def get_converter(primitive: str) -> ConverterMeta: + converters: Dict[str, ConverterMeta] = { + "batch_normalization": BatchNormalizationConverter, + "binary": BinaryConverter, + "brgemm": BRGEMMConverter, + "concat": ConcatConverter, + "convolution": ConvolutionConverter, + "deconvolution": DeconvolutionConverter, + "eltwise": EltwiseConverter, + "group_normalization": GroupNormalizationConverter, + "inner_product": InnerProductConverter, + "layer_normalization": LayerNormalizationConverter, + "lrn": LRNConverter, + "matmul": MatmulConverter, + "pooling": PoolingConverter, + "prelu": PreLUConverter, + "reduction": ReductionConverter, + "reorder": ReorderConverter, + "resampling": ResamplingConverter, + "rnn": RNNConverter, + "shuffle": ShuffleConverter, + "softmax": SoftmaxConverter, + "sum": SumConverter, + "zero_pad": ZeroPadConverter, + } + return converters[primitive] class InputGenerator: @@ -846,49 +863,39 @@ class InputGenerator: Generates an input for benchdnn from internal representation. """ - def __init__(self, writer): - self.__writer = writer + def __init__(self, logger: Optional[logging.Logger] = None): + self.logger = logger + + def _generate_case(self, entry: ir.Entry): + Converter = get_converter(entry.prim_kind) + converter = Converter(entry) + args = [ + "--reset", + "--allow-enum-tags-only=0", + converter.engine, + converter.dir, + converter.aux, + converter.bias_mask, + converter.dts, + converter.tags, + converter.flags, + converter.attrs, + converter.shapes, + ] + return converter.driver, " ".join(arg for arg in args if arg) def generate(self, input, split_by_driver=False): - data = {} - - def generate_case(entry, add_driver=True): - case = "" - if add_driver: - case += "--" + convert_driver(entry["prim_kind"]) - # reset everything, because benchdnn is a state machine and options - # affect all following test cases - case += " --reset" - # allow extended set of tags - case += " --allow-enum-tags-only=0" - - case += " " + convert_engine(entry["engine"]) - # XXX: direction depends on mds (FWD_B is forward + defined bias md) - case += " " + convert_dir(entry) - case += " " + convert_aux(entry) - if entry["prim_kind"] == "matmul": - case += " " + convert_bias_mask(entry["mds"]) - # XXX: data types configuration is not unified across drivers - case += " " + convert_dts(entry["mds"], entry["prim_kind"]) - case += " " + convert_tags(entry["mds"], entry["prim_kind"]) - case += " " + convert_flags(entry["mds"], entry["prim_kind"]) - case += " " + convert_attrs(entry["exts"], entry["prim_kind"]) - case += " " + convert_shapes(entry["shapes"], entry["prim_kind"]) - return case - - if split_by_driver: - for key, value in input.items(): - case = generate_case(value, False) + "\n" - driver_cases = data.get(convert_driver(value["prim_kind"])) - if driver_cases: - data[convert_driver(value["prim_kind"])] += case - else: - data[convert_driver(value["prim_kind"])] = case - else: - for key, value in input.items(): - case = generate_case(value, True) + "\n" - if data.get("all"): - data["all"] += case - else: - data["all"] = case - return data + missing: Set[str] = set() + data: Dict[str, List[str]] = defaultdict(list) + for value in input.values(): + try: + driver, args = self._generate_case(value) + except KeyError as e: + if self.logger is not None and str(e) not in missing: + missing.add(str(e)) + self.logger.warning(f"Missing converter: {e!s}") + continue + if not split_by_driver: + driver, args = "all", f"--{driver} {args}" + data[driver].append(args) + return {k: "\n".join(v) for k, v in data.items()} diff --git a/scripts/verbose_converter/src/breakdown_generator.py b/scripts/verbose_converter/src/breakdown_generator.py index 23ed3ada8d7..772f4c0dcf1 100644 --- a/scripts/verbose_converter/src/breakdown_generator.py +++ b/scripts/verbose_converter/src/breakdown_generator.py @@ -1,5 +1,5 @@ ################################################################################ -# Copyright 2022-2023 Intel Corporation +# Copyright 2022-2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,35 +14,44 @@ # limitations under the License. ################################################################################ +from collections import defaultdict +from typing import Any, Dict, List + +from . import ir + + +class Aggregate: + def __init__(self): + self.occurrences = 0 + self.time = 0.0 + + def add(self, occurrence: float): + self.occurrences += 1 + self.time += occurrence + + def __iter__(self): + yield self.occurrences + yield self.time + class BreakdownGenerator: """ Generates an input for benchdnn from internal representation. """ - def __init__(self, writer): - self.__writer = writer + def __init__(self, _: Any = None): # Maintain old interface + pass - def generate(self, input, agg_keys): - data = {} - output = {} + def generate(self, input: Dict[int, ir.Entry], agg_keys: List[str]): + data: Dict[str, Aggregate] = defaultdict(Aggregate) ofs = "," + if not input: + return {} + def key2str(key, value): def mds2str(mds): - md_fields = [ - "arg", - "data_type", - "properties", - "format_kind", - "tag", - "strides", - ] - ffs = ":" - mdfs = " " - return mdfs.join( - [ffs.join([arg[field] for field in md_fields]) for arg in mds] - ) + return " ".join(map(str, mds)) def aux2str(aux): auxfs = " " @@ -56,66 +65,62 @@ def aux2str(aux): return str(value) # Gather occurences and aggregate time statistics - total_time = 0 - for key, value in input.items(): - item_key = ofs.join([key2str(k, value[k]) for k in agg_keys]) - occ, time = data.get(item_key, (0, 0.0)) - data[item_key] = (occ + 1, time + float(value["time"])) - total_time += float(value["time"]) + total_time: float = 0 + for value in input.values(): + item_key = ofs.join(key2str(k, getattr(value, k)) for k in agg_keys) + data[item_key].add(value.time) + total_time += value.time # sort keys by increasing total time - sorted_item_keys = sorted( - data, key=lambda t: data.__getitem__(t)[1], reverse=True - ) - - cum_entry = 0 - cum_time = 0 - avg_call = 0 - sorted_avg_call = {} - sorted_cum_time = {} - for key in sorted_item_keys: - cum_entry = cum_entry + 1 - cum_time = cum_time + data[key][1] - avg_call = avg_call + (data[key][0] - avg_call) / cum_entry + sorted_keys = sorted(data, key=lambda t: data[t].time, reverse=True) + + cum_entry: int = 0 + cum_time: float = 0 + avg_call: float = 0 + sorted_avg_call: Dict[str, float] = {} + sorted_cum_time: Dict[str, float] = {} + for key in sorted_keys: + item = data[key] + cum_entry += 1 + cum_time = cum_time + item.time + avg_call = avg_call + (item.occurrences - avg_call) / cum_entry sorted_avg_call[key] = avg_call sorted_cum_time[key] = cum_time - output["all"] = ( - ofs.join( - agg_keys - + [ - "ncalls", - "time(ms)", - "overall%", - "agg_ncalls(avg)", - "agg_time(ms)", - "agg_overall%", - ] - ) - + "\n" - ) + fixed_keys = [ + "ncalls", + "time(ms)", + "overall%", + "agg_ncalls(avg)", + "agg_time(ms)", + "agg_overall%", + ] + + output = ofs.join(agg_keys + fixed_keys) def str_num(s): - return "{val:.2f}".format(val=s) + return f"{s:.2f}" def str_pct(s): - return "{val:.2f}".format(val=s * 100) - - ors = "\n" - output["all"] += ors.join( - [ - ofs.join( - [ - str(item_key), - str(data[item_key][0]), - str_num(data[item_key][1]), - str_pct(data[item_key][1] / total_time), - str_num(sorted_avg_call[item_key]), - str_num(sorted_cum_time[item_key]), - str_pct(sorted_cum_time[item_key] / total_time), - ] - ) - for item_key in sorted_item_keys + return f"{s * 100:.2f}" + + def safe_div(n, d): + # Assumption: 0 <= n <= d + # If the assumption is broken, we can still raise ZeroDivisionError + return 1 if n == d == 0 else n / d + + for key in sorted_keys: + item = data[key] + avg_call = sorted_avg_call[key] + cum_time = sorted_cum_time[key] + fields = [ + str(key), + str(item.occurrences), + str_num(item.time), + str_pct(safe_div(item.time, total_time)), + str_num(avg_call), + str_num(cum_time), + str_pct(safe_div(cum_time, total_time)), ] - ) - return output + output += "\n" + ofs.join(fields) + return {"all": output} diff --git a/scripts/verbose_converter/src/dnnl_parser.py b/scripts/verbose_converter/src/dnnl_parser.py index 88bfcba3405..78d24da59bd 100644 --- a/scripts/verbose_converter/src/dnnl_parser.py +++ b/scripts/verbose_converter/src/dnnl_parser.py @@ -13,6 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ +from typing import Iterable, List, Tuple + +from . import ir, parse + + +class LoggingContext: + def __init__(self, logger): + self.logger = logger + + def __enter__(self): + return self + + def __exit__(self, type, value, _): + if type is not None and issubclass(type, parse.ParseError): + self.logger.warning(str(value)) + return True class LogParser: @@ -21,32 +37,10 @@ class LogParser: representation. """ - def __init__(self, writer, input=""): - # each data entry is a dictionary that consists of: - # engine(str), - # primitive(str), - # implementation(str), - # prop_kind(str), - # aux({field(str) : value(str)}), - # mds( - # { - # arg(str): { - # data_type(str), - # properties(str), - # format_kind(str), - # tag(str), - # strides(str), - # flags(str), - # } - # } - # ) - # shapes(str) - # extensions(str) - # time(float) - self.__raw_data = [] - self.__data = {} - self.__writer = writer - self.__input = input + def __init__(self, logger, input: Iterable[str] = ()): + self.input = input + self.error_handler = LoggingContext(logger) + self.data: List[Tuple[str, ir.Entry]] = [] def process(self, filter_events): """ @@ -61,436 +55,8 @@ def process(self, filter_events): None """ - def convert_primitive(log_entry, template, version): - """ - Converts oneDNN verbose primitive entry into the internal - representation. - """ - - def split_arg_dt(arg_dt): - def buffer(dt): - return {"data": dt, "offset": 0} - - def eof(buf): - return buf["offset"] >= len(buf["data"]) - - def get_data(buf): - if eof(buf): - return None - return buf["data"][buf["offset"] :] - - def read_int(buf): - data = get_data(buf) - if not data: - return None - if data[0] not in "123456789": - return None - for n, c in enumerate(data): - if c not in "0123456789": - buf["offset"] += n - return int(data[:n]) - buf["offset"] += len(data) - return int(data) - - def read_literal(buf, literal): - data = get_data(buf) - if not data: - return None - if not data.startswith(literal): - return None - buf["offset"] += len(literal) - return True - - def parse_int_type(dt): - buf = buffer(dt) - if not (read_literal(buf, "u") or read_literal(buf, "s")): - return False - if not read_int(buf): - return False - return eof(buf) - - def parse_float_type(dt): - buf = buffer(dt) - read_literal(buf, "b") # ignore b in bf16 - if not read_literal(buf, "f"): - return False - if not read_int(buf): - return False - if eof(buf): - return True # f16, f32, f64 - if not read_literal(buf, "_e"): - return False - if not read_int(buf): - return False - if not read_literal(buf, "m"): - return False - if not read_int(buf): - return False - return eof(buf) # f8_eXmY - - parts = arg_dt.split("_") - for split in range(1, len(parts)): - input_parts = parts[:split] - dt_parts = parts[split:] - dt = "_".join(dt_parts) - if dt == "undef": - return "_".join(input_parts), dt - if parse_int_type(dt) or parse_float_type(dt): - return "_".join(input_parts), dt - - def convert_mds(log_mds, version): - mds = [] - for md in log_mds.split(" "): - fields = md.split(":") - idx = 0 - - # if version >= 1: - # arg:dt:properties:format_kind:tag:strides:flags - ## ^ - # else: - # arg_dt:properties:format_kind:tag:strides:flags - # (note) Legacy way could have collisions with `arg` and - # `dt` since `_` used as a delimiter and as a part of the - # name. - arg = None - data_type = None - if int(version) >= 1: - arg = fields[idx] - idx += 1 - data_type = fields[idx] - idx += 1 - else: - arg_dt = fields[idx] - idx += 1 - arg, data_type = split_arg_dt(arg_dt) - - properties = fields[idx] - idx += 1 - format_kind = fields[idx] - idx += 1 - tag = fields[idx] - idx += 1 - - # Add compatibility for v3.1 verbose and below, - # when strides delimeter is absent. - # TODO: remove eventually. - strides = "" - if "f" not in fields[idx] and format_kind != "undef": - strides = fields[idx] - idx += 1 - - flags = {} - flags["value"] = fields[idx] - idx += 1 - if len(fields) > idx: - flag_fields = fields[idx:] - for f in flag_fields: - if f[:3] == "s8m": - flags["s8_comp_mask"] = f[3:] - if f[:3] == "zpm": - flags["zp_comp_mask"] = f[3:] - - mds.append( - { - "arg": arg, - "data_type": data_type, - "properties": properties, - "format_kind": format_kind, - "tag": tag, - "strides": strides, - "flags": flags, - } - ) - return mds - - def convert_aux(log_aux, version): - aux = {} - if log_aux == "": - return aux - for log_aux_l in log_aux.split(" "): - # Handle strings like NAME:VAL1[:VAL2[:VAL3...]] - res = log_aux_l.split(":") - field = res[0] - value = "" - last_idx = len(res) - 1 - for i in range(1, last_idx): - val_i = res[i] - value += f"{val_i}:" - val_n = res[last_idx] - value += f"{val_n}" - aux[field] = value - return aux - - def convert_prim_kind(prim_kind, version): - return prim_kind - - def convert_exts(exts, version): - def extract_attr(attrs, type): - start_idx = attrs.find(type) - if start_idx == -1: - return "" - - start_idx += len(type) + 1 - end_symbol = " " - end_idx = attrs.find(end_symbol, start_idx) - if end_idx == -1: - end_idx = None - return attrs[start_idx:end_idx] - - def convert_structure_to_ir_seq(ir, value): - params = value.split(":") - fields = list(ir.keys()) - ir.update( - (fields[i], params[i]) - for i in range(0, min(len(params), len(fields))) - ) - return ir - - def convert_post_ops(value): - def convert_binary_post_op(value): - p_op = {"alg": "", "dt": "f32", "mask": "0", "tag": None} - p_op = convert_structure_to_ir_seq(p_op, value) - p_op["prim_kind"] = "binary" - return p_op - - def convert_dw_post_op(value): - p_op = { - "alg": "", - "ksp": "", - "dst_dt": "f32", - "wei_dt": "f32", - "scales": {"mask": "0", "value": None}, - } - params = value.split(":") - len_params = len(params) - p_op["alg"] = params[0] - p_op["ksp"] = params[1] - if len_params > 2: - p_op["dst_dt"] = params[2] - if len_params > 3: - p_op["wei_dt"] = "s8" - p_op["scales"]["mask"] = params[3] - if len_params > 4: - p_op["scales"]["value"] = params[4] - return p_op - - def convert_eltwise_post_op(value): - p_op = { - "alg": "", - "alpha": "1.0", - "beta": "0.0", - "scale": "1.0", - } - return convert_structure_to_ir_seq(p_op, value) - - def convert_sum_post_op(value): - p_op = {"alg": "", "scale": "1.0", "zp": "0", "dt": ""} - return convert_structure_to_ir_seq(p_op, value) - - def convert_prelu_post_op(value): - p_op = {"alg": "", "mask": "0"} - return convert_structure_to_ir_seq(p_op, value) - - convert = { - "binary": convert_binary_post_op, - "dw": convert_dw_post_op, - "eltwise": convert_eltwise_post_op, - "sum": convert_sum_post_op, - "prelu": convert_prelu_post_op, - } - - entries = value.split("+") - postops = [] - for e in entries: - for k in convert.keys(): - if k in e: - cvt = convert.get(k) - postops.append(cvt(e)) - break - return postops - - def convert_scales(value): - res = {} - scales = value.split("+") - for s in scales: - arg = s[: s.find(":")] - s_wo_arg = s[s.find(":") + 1 :] - scale_dict = {"mask": "0", "data_type": "f32", "groups": ""} - res[arg] = convert_structure_to_ir_seq(scale_dict, s_wo_arg) - return res - - def convert_zero_points(value): - res = {} - zp_value = value.split("+") - for zp in zp_value: - arg = zp[: zp.find(":")] - zp_value_wo_arg = zp[zp.find(":") + 1 :] - zp_dict = {"mask": "0", "data_type": "s32", "groups": ""} - res[arg] = convert_structure_to_ir_seq(zp_dict, zp_value_wo_arg) - return res - - def convert_rounding_mode(value): - res = {} - rounding_modes = value.split("+") - for r in rounding_modes: - arg = r[: r.find(":")] - res[arg] = r[r.find(":") + 1 :] - return res - - def convert_scratchpad_mode(value): - return value - - def convert_fpmath_mode(value): - return value - - def convert_acc_mode(value): - return value - - def convert_dropout(value): - res = {"p": 0} - elems = value.split(":") - res["p"] = elems[0] - if len(elems) > 1: - res["seed"] = elems[1] - if len(elems) > 2: - res["tag"] = elems[2] - return res - - def convert_deterministic(value): - return value - - converters = { - "attr-post-ops": convert_post_ops, - "attr-scales": convert_scales, - "attr-zero-points": convert_zero_points, - "attr-scratchpad": convert_scratchpad_mode, - "attr-fpmath": convert_fpmath_mode, - "attr-acc": convert_acc_mode, - "attr-rounding-mode": convert_rounding_mode, - "attr-dropout": convert_dropout, - "attr-deterministic": convert_deterministic, - } - attrs = {} - for e in converters.keys(): - attr = extract_attr(exts, e) - if attr != "": - attrs[e] = converters[e](attr) - return attrs - - def convert_pass(v, version): - return v - - convert = { - "prim_kind": convert_prim_kind, - "mds": convert_mds, - "aux": convert_aux, - "exts": convert_exts, - } - - dnnl_to_ir = { - "engine": "engine", - "prim_kind": "primitive", - "impl": "implementation", - "prop_kind": "prop_kind", - "mds": "memory_descriptors", - "exts": "attributes", - "aux": "auxiliary", - "shapes": "problem_desc", - "time": "exec_time", - "timestamp": "timestamp", - } - - ir_req = [ - "engine", - "prim_kind", - "impl", - "prop_kind", - "mds", - "exts", - "aux", - "shapes", - ] - - entry = {} - - t = template.split(",") - for key, value in dnnl_to_ir.items(): - notification_level = "WARN" if key in ir_req else "INFO" - try: - idx = t.index(value) - if idx != -1: - cvt = convert.get(key) - if cvt is None: - cvt = convert_pass - field = log_entry[idx] - try: - entry[key] = cvt(field, version) - except: - self.__writer.print( - f"Parser: parsing entry error: {field}: {value}", - notification_level, - ) - else: - self.__writer.print( - f"Parser: Unknown entry: {value}", notification_level - ) - except: - self.__writer.print( - f"Parser: skipping empty entry: {key}", notification_level - ) - return entry - - # `verbose_template` should have `component` field as second entry, but - # since it gets discarded for compatibility with previous verbose - # outputs, it's not in the final version of the string. - # Restore `component` when the least compatible library version's - # verbose output will contain it. - verbose_template = ( - "onednn_verbose,operation,engine,primitive," - + "implementation,prop_kind,memory_descriptors,attributes," - + "auxiliary,problem_desc" - ) - - i = len(self.__data) - for line in self.__input: - self.__raw_data.append(line.rstrip()) - l_raw = line.split(",") - marker = l_raw[0] - if marker != "onednn_verbose": - continue - - verbose_version = 0 - # Check for version presence, discard 'v' from numerical version, - # and discard version entry for compatibility reasons. - # Note: to compare against `version`, one must use int() function - # call as arg is passed as `str` object! - if l_raw[1][0] == "v" and l_raw[1][1].isdigit(): - verbose_version = l_raw[1].lstrip("v") - l_raw.pop(1) - - # Discard a timestamp when it's supplied in a standalone line. - # TODO: update verbose_template instead. - if l_raw[1].split(".")[0].isdigit(): - l_raw.pop(1) - # Skip Graph component as not supported - if l_raw[1] == "graph": - continue - # Remove a component from the line if presented (see a comment above) - if l_raw[1] == "primitive" or l_raw[1] == "ukernel": - l_raw.pop(1) - - event = l_raw[1].split(":")[0] - if event == "info": - opt = l_raw[2] - if opt.split(":")[0] == "template": - verbose_template = "onednn_verbose," + line.split(":")[1] - if event in filter_events: - l_converted = convert_primitive( - l_raw, verbose_template + ",exec_time", verbose_version - ) - if l_converted: - self.__data[i] = l_converted - i = i + 1 + parser = parse.Parser(self.input, filter_events, self.error_handler) + self.data = list(parser) def get_data(self): """ @@ -505,7 +71,7 @@ def get_data(self): data """ - return self.__data + return {i: entry for i, (_, entry) in enumerate(self.data)} def dump(self, converted=False): """ @@ -513,18 +79,16 @@ def dump(self, converted=False): Parameters ---------- - converted (default: False) -- If True dump() prints data in internal - represenataion, otherwise prints data in the original form. + converted (default: False) -- If truthy, prints data in internal + representation, otherwise prints data in the original form. Returns ------- None """ - if converted: - [ - self.__writer.print(f"{key}, {value}", "STDIO") - for key, value in self.__data.items() - ] - else: - [self.__writer.print(d, "STDIO") for d in self.__raw_data] + for i, (line, entry) in enumerate(self.data): + if converted: + print(f"{i}, {entry!r}") + else: + print(line) diff --git a/scripts/verbose_converter/src/ir.py b/scripts/verbose_converter/src/ir.py new file mode 100644 index 00000000000..8c399ce084a --- /dev/null +++ b/scripts/verbose_converter/src/ir.py @@ -0,0 +1,436 @@ +################################################################################ +# Copyright 2024-2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import enum +import string +from abc import abstractmethod +from collections.abc import MutableMapping +from dataclasses import MISSING, dataclass, fields +from typing import Dict, List, Optional, Union + + +def alias(attr): + def getter(self): + return getattr(self, attr) + + def setter(self, value): + return setattr(self, attr, value) + + def deleter(self): + return delattr(self, attr) + + return property(getter, setter, deleter, attr) + + +def hash_str(obj): + return getattr(obj.__class__, "__hash_str__", str)(obj) + + +@dataclass(eq=False) +class Mapping(MutableMapping): + def __getitem__(self, item): + try: + value = getattr(self, item) + if isinstance(value, int): + value = str(value) + elif isinstance(value, float): + value = str(value) + # The verbose converter assumes defaults are 1.0, whereas + # oneDNN assumes defaults are 0.0. This is a workaround so that + # we don't accidentally drop these values, instead setting as 0 + # or 1 which will always be sent through to the benchdnn + # reproducer + if value[-2:] == ".0": + value = value[:-2] + return value + except AttributeError: + raise KeyError(item) + + def __setitem__(self, item, value): + setattr(self, item, value) + + def __delitem__(self, item): + delattr(self, item) + + def __len__(self): + return len(fields(self)) + + def __iter__(self): + for field in fields(self): + yield field.name + + def __hash__(self): + return hash(hash_str(self)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return hash_str(self) == hash_str(other) + + def __str__(self): + raise NotImplementedError + + def __hash_str__(self): + return str(self) + + def __repr__(self): + child_reprs = [] + for key, value in self.items(): + child_reprs.append(f"{key!r}: {value!r}") + return "{" + ", ".join(child_reprs) + "}" + + +@dataclass(eq=False) +class MemoryDescriptor(Mapping): + @dataclass(eq=False) + class Flags(Mapping): + value: str + s8_comp_mask: Optional[str] = None + zp_comp_mask: Optional[str] = None + scale_adjust: float = 1.0 + + def __str__(self): + my_str = self.value + if self.s8_comp_mask is not None: + my_str += f":s8m{self.s8_comp_mask}" + if self.zp_comp_mask is not None: + my_str += f":s8m{self.zp_comp_mask}" + if self.scale_adjust != 1.0: + my_str += f":sa{self.scale_adjust}" + return my_str + + arg: str + data_type: str + properties: str + format_kind: str + tag: str + flags: Flags + strides: str = "" # Pre-v3.1 does not have strides + + padding = alias("properties") + + def __len__(self): + return 1 + super().__len__() + + def __iter__(self): + yield from super().__iter__() + yield "padding" + + def _format(self, tag: str, convert) -> str: + header = f"{self.arg}:{self.data_type}" + return ":".join( + [ + header, + self.properties, + self.format_kind, + tag, + self.strides, + convert(self.flags), + ] + ) + + def __str__(self): + return self._format(self.tag, str) + + def __hash_str__(self): + tag = self.tag + if "a" not in self.properties: + return self._format(tag, hash_str) + for i, c in enumerate(tag): + if not c.isalpha(): + return self._format(string.ascii_lowercase[:i], hash_str) + return self._format(string.ascii_lowercase[: len(tag)], hash_str) + + +@dataclass(eq=False) +class Dropout(Mapping): + tag: Optional[str] = None + + def __str__(self): + return self.tag or "" + + +class FormattedMapping(Mapping): + @abstractmethod + def _format(self, _) -> str: + raise NotImplementedError + + def __str__(self): + return self._format(str) + + def __hash_str__(self): + return self._format(hash_str) + + +@dataclass(eq=False) +class PostOp(FormattedMapping): + alg: str + + def _format(self, convert): + required_args = [] + optional_args = [] + seen_non_default = False + for field in reversed(fields(self)): + if field.name == "alg": + continue + value = getattr(self, field.name) + if field.default is MISSING: + required_args.append(value) + continue + if not seen_non_default and value == field.default: + continue + seen_non_default = True + optional_args.append(value) + args = [self.alg] + required_args[::-1] + optional_args[::-1] + return ":".join(map(convert, args)) + + +@dataclass(eq=False) +class SumPostOp(PostOp): + alg: str = "sum" + scale: float = 1.0 + zp: int = 0 + dt: str = "" + + +@dataclass(eq=False) +class DepthwiseScales(Mapping): + mask: int = 0 + value: Optional[str] = None + + def __str__(self): + if self.value is not None: + return f"{self.mask}:{self.value}" + if self.mask != 0: + return str(self.mask) + return "" + + +@dataclass(eq=False) +class KSPMixin: + ksp: str + + +@dataclass(eq=False) +class DepthwisePostOp(PostOp, KSPMixin): + alg: str = "dw" + dst_dt: str = "f32" + wei_dt: str = "f32" + scales: DepthwiseScales = DepthwiseScales() + + def __len__(self): + return 1 + super().__len__() + + def __iter__(self): + yield "alg" + yield from super().__iter__() + + +@dataclass(eq=False) +class PreLUPostOp(PostOp): + alg: str = "prelu" + mask: int = 0 + has_scaleshift: bool = False + + def __getitem__(self, item): + if item == "has_scaleshift": + return "true" if self.has_scaleshift else "" + return super().__getitem__(item) + + def __str__(self): + if self.has_scaleshift: + return f"{self.alg}:{self.mask}:true" + return f"{self.alg}:{self.mask}" + + +@dataclass(eq=False) +class EltwisePostOp(PostOp): + alpha: float = 0.0 + beta: float = 0.0 + scale: float = 1.0 + + +@dataclass(eq=False) +class BinaryPostOp(PostOp): + dt: str + mask: int = 0 + tag: str = "any" + + +@dataclass(eq=False) +class QuantizationParam(Mapping): + value: float + data_type: str + mask: int = 0 + groups: str = "" + + def __str__(self): + if self.groups: + return f"{self.mask}:{self.data_type}:{self.groups}" + return f"{self.mask}:{self.data_type}" + + +@dataclass(eq=False) +class Scale(QuantizationParam): + value: float = 1.0 + data_type: str = "f32" + + +@dataclass(eq=False) +class ZeroPoint(QuantizationParam): + value: int = 0 + data_type: str = "s32" + + +class CompositeAttribute: + def __str__(self): + raise NotImplementedError + + +@dataclass(eq=False) +class FPMathMode(CompositeAttribute): + mode: str + apply_to_int: bool = False + + def __str__(self): + a2i_str = ":true" if self.apply_to_int else "" + return self.mode + a2i_str + + +class RoundingMode(CompositeAttribute, enum.Enum): + ENVIRONMENT = "environment" + STOCHASTIC = "stochastic" + + def __str__(self): + return self.value + + +Attribute = Union[ + str, # acc-mode, etc + FPMathMode, + Dropout, + List[PostOp], + Dict[str, Scale], + Dict[str, ZeroPoint], + Dict[str, RoundingMode], + Scale, # oscale +] + + +@dataclass(eq=False) +class Attributes(FormattedMapping): + acc_mode: Optional[str] = None + deterministic: Optional[str] = None + dropout: Optional[Dropout] = None + fpmath: Optional[FPMathMode] = None + oscale: Optional[Scale] = None + post_ops: Optional[List[PostOp]] = None + rounding_mode: Optional[Dict[str, RoundingMode]] = None + scales: Optional[Dict[str, Scale]] = None + scratchpad: Optional[str] = None + zero_points: Optional[Dict[str, ZeroPoint]] = None + + acc = alias("acc_mode") + + @staticmethod + def _field_name_to_attr_name(field_name: str): + return "attr-" + field_name.replace("_", "-") + + def _attr_name_to_field_name(self, item: str): + original_item = item + for field in fields(self): + if item == self._field_name_to_attr_name(field.name): + return field.name + raise KeyError(original_item) + + def __getitem__(self, item: str): + value = getattr(self, self._attr_name_to_field_name(item)) + if value is None: + raise KeyError(item) + return value + + def __setitem__(self, item: str, value: Attribute): + return setattr(self, self._attr_name_to_field_name(item), value) + + def __delitem__(self, item: str): + setattr(self, self._attr_name_to_field_name(item), None) + + def __iter__(self): + for field in fields(self): + if getattr(self, field.name) is not None: + yield self._field_name_to_attr_name(field.name) + + def __len__(self): + return len(list(iter(self))) + + def _format(self, convert): + parts = [] + for key, attr in self.items(): + if isinstance(attr, list): + sub_parts = "+".join(map(convert, attr)) + parts.append(f"{key}:{sub_parts}") + elif isinstance(attr, dict): + converted = (f"{k}:{convert(v)}" for k, v in attr.items()) + combined = "+".join(converted) + parts.append(f"{key}:{combined}") + else: + parts.append(f"{key}:{convert(attr)}") + return " ".join(parts) + + +@dataclass(eq=False) +class HashableEntry(FormattedMapping): + operation: str + engine: str + prim_kind: str + impl: str + prop_kind: str + aux: Dict[str, str] + mds: List[MemoryDescriptor] + shapes: str + exts: Attributes + + def _format(self, convert): + parts = [ + self.operation, + self.engine, + self.prim_kind, + self.impl, + self.prop_kind, + " ".join(map(convert, self.mds)), + convert(self.exts), + " ".join(f"{k}:{convert(v)}" for k, v in self.aux.items()), + self.shapes, + ] + return ",".join(parts) + + def __str__(self): + return f"onednn_verbose,v1,primitive,{super().__str__()},0" + + +class Entry(HashableEntry): + def __init__( + self, + *, + time=0.0, + timestamp: Optional[float] = None, + version: int = 0, + **kwargs, + ): + self.time = time + self.timestamp = timestamp + self.version = version + super().__init__(**kwargs) diff --git a/scripts/verbose_converter/src/parse.py b/scripts/verbose_converter/src/parse.py new file mode 100644 index 00000000000..0b20d98bb8e --- /dev/null +++ b/scripts/verbose_converter/src/parse.py @@ -0,0 +1,659 @@ +################################################################################ +# Copyright 2024-2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import string +from contextlib import nullcontext +from typing import ( + ContextManager, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, +) + +from . import ir + +__all__ = ["Parser"] + + +class ParseSpec: + digits = list(string.digits) + + def __init__(self, buf: str): + self._buf = buf + self.offset = 0 + + def __str__(self): + return self.buf + + @property + def buf(self): + return self._buf[self.offset :] + + @property + def eof(self): + return self.offset >= len(self._buf) + + def peek(self, n=1): + return self.buf[:n] + + def seek(self, n=1): + self._read(n) + + def _read(self, n: int) -> str: + token = self._buf[self.offset : self.offset + n] + self.offset += n + return token + + def _find_str(self) -> int: + buf = ParseSpec(self.buf) + while not buf.eof and buf.peek() not in ("+", ":"): + buf.seek() + return buf.offset + + def _find_uint(self) -> int: + buf = ParseSpec(self.buf) + if buf.eof or buf.peek() not in self.digits: + return 0 + + if not buf.read_literal("0"): + while buf.read_one_of(*self.digits): + pass + return buf.offset + + def _find_int(self) -> int: + buf = ParseSpec(self.buf) + buf.read_one_of("-", "+") + return buf.offset + buf._find_uint() + + def _find_float(self) -> int: + buf = ParseSpec(self.buf) + buf.read_one_of("-", "+") + if buf.eof or buf.peek() not in ["."] + self.digits: + return 0 # ignore [+/-][e...] + if not buf.read_literal("0"): + while buf.read_one_of(*self.digits): + pass + # else: we already read a 0. + if buf.read_literal("."): + while buf.read_one_of(*self.digits): + pass + if buf.read_literal("e"): + buf.read_one_of("-", "+") + if not buf.read_one_of(*self.digits): + return 0 # ignore [+/-][X][.Y]e[+/-] + while buf.read_one_of(*self.digits): + pass + return buf.offset + + def _find_literal(self, literal): + if self.buf.startswith(literal): + return len(literal) + return 0 + + def read_str(self) -> str: + return self._read(self._find_str()) + + def read_literal(self, literal: str) -> Optional[str]: + offset = self._find_literal(literal) + if offset == len(literal): + return self._read(offset) + return None + + def read_one_of(self, *literals: str) -> Optional[str]: + for literal in literals: + if self.read_literal(literal) is not None: + return literal + return None + + def read_uint(self) -> Optional[int]: + offset = self._find_uint() + if offset: + return int(self._read(offset)) + return None + + def read_int(self) -> Optional[int]: + offset = self._find_int() + if offset: + return int(self._read(offset)) + return None + + def read_float(self) -> Optional[float]: + offset = self._find_float() + if offset: + return float(self._read(offset)) + return None + + +class ParseError(ValueError): + pass + + +class InvalidEntryError(ParseError): + pass + + +class ParserImpl: + default_template = ( + "operation,engine,primitive,implementation,prop_kind," + + "memory_descriptors,attributes,auxiliary,problem_desc,exec_time" + ) + _version_map: Dict[int, type] = {} + + @staticmethod + def parse_aux(aux: str): + parsed: Dict[str, str] = {} + if aux == "": + return parsed + for aux_l in aux.split(): + # Handle strings like NAME:VAL1[:VAL2[:VAL3...]] + field, *values = aux_l.split(":", 1) + parsed[field] = values[0] if values else "" + return parsed + + def parse_mds(self, descriptors): + try: + return list(map(self.parse_md, descriptors.split())) + except ValueError: + raise ValueError(f"Could not parse mds {descriptors}") + + @staticmethod + def is_bit_layout(dt): + buf = ParseSpec(dt) + if not buf.read_literal("e"): + return False + if buf.read_uint() is None: + return False + if not buf.read_literal("m"): + return False + if buf.read_uint() is None: + return False + return buf.eof # eXmY + + def is_float_type(self, dt): + buf = ParseSpec(dt) + buf.read_literal("b") # ignore b in bf16 + if not buf.read_literal("f"): + return False + if buf.read_uint() is None: + return False + if buf.eof: + return True # bf16, f16, f32, f64 + if not buf.read_literal("_"): + return False + return self.is_bit_layout(buf.buf) # fZ_eXmY + + @staticmethod + def is_int_type(dt): + buf = ParseSpec(dt) + if not buf.read_one_of("u", "s"): + return False + if buf.read_uint() is None: + return False + return buf.eof + + def is_data_type(self, dt): + return ( + dt == "undef" + or self.is_int_type(dt) + or self.is_float_type(dt) + or self.is_bit_layout(dt) + ) + + @staticmethod + def parse_md_flags(flags, fields): + flags = ir.MemoryDescriptor.Flags(value=flags or "f0") + for field in fields: + if field[:3] == "s8m": + flags.s8_comp_mask = field[3:] + elif field[:3] == "zpm": + flags.zp_comp_mask = field[3:] + elif field[:2] == "sa": + flags.scale_adjust = float(field[2:]) + return flags + + def parse_md(self, descriptor): + fields = descriptor.split(":") + arg_dt, properties, format_kind, tag = fields[:4] + arg_dt_parts = arg_dt.split("_") + for i in range(1, len(arg_dt_parts)): + arg = "_".join(arg_dt_parts[:i]) + dt = "_".join(arg_dt_parts[i:]) + if self.is_data_type(dt): + break + else: + if len(arg_dt_parts) != 1 or not self.is_data_type(arg_dt): + raise ParseError( + f"Could not parse memory descriptor {descriptor}" + ) + arg, dt = "data", arg_dt + + strides = "" + if "f" not in fields[4] and format_kind != "undef": + strides = fields[4] + flags = self.parse_md_flags(fields[5], fields[6:]) + else: + flags = self.parse_md_flags(fields[4], fields[5:]) + return ir.MemoryDescriptor( + arg=arg, + data_type=dt, + properties=properties, + format_kind=format_kind, + tag=tag, + strides=strides, + flags=flags, + ) + + def parse_attrs(self, attrs): + exts = ir.Attributes() + for attr in attrs.split(): + spec = ParseSpec(attr) + name, args = spec.read_str(), "" + if spec.read_literal(":"): + args = spec.buf + if name in ("attr-acc-mode", "attr-acc"): + exts.acc_mode = self.parse_acc_mode(args) + elif name == "attr-deterministic": + exts.deterministic = self.parse_deterministic(args) + elif name == "attr-dropout": + exts.dropout = self.parse_dropout(args) + elif name == "attr-fpmath": + exts.fpmath = self.parse_fpmath_mode(args) + # Kept for compatibility with v2.7 and below. + elif name == "attr-oscale": + exts.oscale = self.parse_oscale(args) + elif name == "attr-post-ops": + exts.post_ops = self.parse_post_ops(args) + elif name == "attr-rounding-mode": + exts.rounding_mode = self.parse_rounding_modes(args) + elif name == "attr-scales": + exts.scales = self.parse_scales(args) + elif name == "attr-scratchpad": + exts.scratchpad = self.parse_scratchpad_mode(args) + elif name == "attr-zero-points": + exts.zero_points = self.parse_zero_points(args) + return exts + + def parse_post_ops(self, post_ops: str): + spec = ParseSpec(post_ops) + parsed: List[ir.PostOp] = [] + while True: + alg = spec.read_str() + if alg == "sum": + parsed.append(self.parse_sum_post_op(spec)) + elif alg == "dw": + parsed.append(self.parse_dw_post_op(spec)) + elif alg == "prelu": + parsed.append(self.parse_prelu_post_op(spec)) + elif alg.startswith("eltwise_"): + parsed.append(self.parse_eltwise_post_op(spec, alg)) + elif alg.startswith("binary_"): + parsed.append(self.parse_binary_post_op(spec, alg)) + else: + raise ParseError(f"Unexpected post-op: {alg}") + if not spec.read_literal("+"): + break + return parsed + + @staticmethod + def parse_sum_post_op(spec) -> ir.SumPostOp: + post_op = ir.SumPostOp() + if spec.read_literal(":"): + post_op.scale = spec.read_float() + if spec.read_literal(":"): + post_op.zp = spec.read_int() + if spec.read_literal(":"): + post_op.dt = spec.read_str() + return post_op + + @staticmethod + def parse_dw_post_op(spec) -> ir.DepthwisePostOp: + if not spec.read_literal(":"): + raise ParseError("Expected argument for depthwise post-op") + ksp = spec.read_str() + post_op = ir.DepthwisePostOp(ksp=ksp) + if spec.read_literal(":"): + post_op.dst_dt = spec.read_str() + if spec.read_literal(":"): + post_op.wei_dt = "s8" + post_op.scales.mask = spec.read_uint() + if spec.read_literal(":"): + post_op.scales.value = spec.read_str() + return post_op + + @staticmethod + def parse_prelu_post_op(spec) -> ir.PreLUPostOp: + post_op = ir.PreLUPostOp() + if spec.read_literal(":"): + post_op.mask = spec.read_uint() + if spec.read_literal(":"): + post_op.has_scaleshift = spec.read_str() == "true" + return post_op + + @staticmethod + def parse_eltwise_post_op(spec, alg) -> ir.EltwisePostOp: + post_op = ir.EltwisePostOp(alg=alg) + if spec.read_literal(":"): + post_op.alpha = spec.read_float() + if spec.read_literal(":"): + post_op.beta = spec.read_float() + if spec.read_literal(":"): + post_op.scale = spec.read_float() + return post_op + + @staticmethod + def parse_binary_post_op(spec, alg) -> ir.BinaryPostOp: + if not spec.read_literal(":"): + raise ParseError("Expected data type for binary post-op") + dt = spec.read_str() + post_op = ir.BinaryPostOp(alg=alg, dt=dt) + if spec.read_literal(":"): + post_op.mask = spec.read_uint() + if spec.read_literal(":"): + post_op.tag = spec.read_str() + return post_op + + @staticmethod + def parse_dropout(args: str) -> ir.Dropout: + return ir.Dropout(tag=args if args else None) + + @staticmethod + def parse_per_argument(attr, name, parse): + spec = ParseSpec(attr) + parsed = {} + while True: + arg = spec.read_str() + if not spec.read_literal(":"): + raise ParseError(f"Expected mask for {arg} {name}") + parsed[arg] = parse(spec) + if not spec.read_literal("+"): + break + return parsed + + def parse_scales(self, scales: str): + return self.parse_per_argument(scales, "scale", self.parse_scale) + + @staticmethod + def parse_quantization_param(spec, read_value, param_type): + # Old style: mask[:[value[*]|*]] + # New style: mask[:data_type[:groups]] + param = param_type() + param.mask = spec.read_uint() + if spec.read_literal(":"): + value = read_value() + if value is not None: + param.value = value + spec.read_literal("*") + elif spec.read_literal("*"): + pass + elif not spec.eof: # new style + param.data_type = spec.read_str() + if spec.read_literal(":"): + param.groups = spec.read_str() + return param + + # v2.7 and below + def parse_oscale(self, oscale: str): + spec = ParseSpec(oscale) + return self.parse_scale(spec) + + def parse_scale(self, spec) -> ir.Scale: + return self.parse_quantization_param(spec, spec.read_float, ir.Scale) + + def parse_zero_points(self, zps: str): + return self.parse_per_argument(zps, "zero point", self.parse_zero_point) + + def parse_zero_point(self, spec) -> ir.ZeroPoint: + return self.parse_quantization_param(spec, spec.read_int, ir.ZeroPoint) + + @staticmethod + def parse_fpmath_mode(mathmode: str) -> ir.FPMathMode: + spec = ParseSpec(mathmode) + mode = spec.read_str() + apply_to_int = False + if spec.read_literal(":"): + apply_to_int = spec.read_str() == "true" + return ir.FPMathMode(mode=mode, apply_to_int=apply_to_int) + + @staticmethod + def parse_rounding_mode(rounding_mode: str) -> ir.RoundingMode: + rm = rounding_mode.lower() + for member in ir.RoundingMode.__members__.values(): + if str(member) == rm: + return member + else: + raise ValueError(f"Invalid rounding mode {rounding_mode}") + + def parse_rounding_modes(self, rounding_modes: str): + spec = ParseSpec(rounding_modes) + modes: Dict[str, ir.RoundingMode] = {} + while True: + arg = spec.read_str() + if not spec.read_literal(":"): + raise ValueError("Expected rounding mode") + mode = self.parse_rounding_mode(spec.read_str()) + modes[arg] = mode + if not spec.read_literal("+"): + break + return modes + + identity = staticmethod(lambda x: x) + + # Additional attributes + parse_acc_mode = identity + parse_deterministic = identity + parse_scratchpad_mode = identity + + # Additional template components + parse_operation = identity + parse_prim_kind = identity + parse_prop_kind = identity + parse_engine = identity + parse_impl = identity + parse_shapes = identity + parse_time = staticmethod(float) + parse_timestamp = staticmethod(float) + + def dnnl_to_ir(self): + return { + "operation": ("operation", self.parse_operation, True), + "engine": ("engine", self.parse_engine, True), + "primitive": ("prim_kind", self.parse_prim_kind, True), + "implementation": ("impl", self.parse_impl, True), + "prop_kind": ("prop_kind", self.parse_prop_kind, True), + "memory_descriptors": ("mds", self.parse_mds, True), + "attributes": ("exts", self.parse_attrs, True), + "auxiliary": ("aux", self.parse_aux, True), + "problem_desc": ("shapes", self.parse_shapes, True), + "exec_time": ("time", self.parse_time, False), + "timestamp": ("timestamp", self.parse_timestamp, False), + } + + def parse(self, line: str, template: Optional[str]): + if template is None: + template = self.default_template + entry = {} + fields = template.rstrip().split(",") + values = line.rstrip().split(",") + mapping = self.dnnl_to_ir() + min_fields = sum((mapping[field][2] for field in fields)) + max_fields = len(fields) + if len(values) < min_fields: + raise InvalidEntryError("parse error: too few fields to parse") + if len(values) > max_fields: + raise InvalidEntryError("parse error: too many fields to parse") + mapped = dict(zip(fields, values)) + for field, (key, parse, reqd) in mapping.items(): + if field not in mapped: + if not reqd: + continue + raise InvalidEntryError(f"parse error: missing {field} field") + value = mapped[field] + try: + entry[key] = parse(value) + except (ParseError, ValueError) as e: + raise ParseError(f"parse error: {field}: {value} ({e!s})") + return entry + + +def register(*, version: int): + def registrar(impl: type): + ParserImpl._version_map[version] = impl + return impl + + return registrar + + +@register(version=0) +class LegacyParserImpl(ParserImpl): + pass + + +@register(version=1) +class V1ParserImpl(ParserImpl): + def parse_md(self, descriptor): + fields = descriptor.split(":") + return ir.MemoryDescriptor( + arg=fields[0], + data_type=fields[1], + properties=fields[2], + format_kind=fields[3], + tag=fields[4], + strides=fields[5], + flags=self.parse_md_flags(fields[6], fields[7:]), + ) + + +class Parser: + _parser_impls: Dict[int, ParserImpl] = {} + _default_events = "exec", "create", "create_nested" + + def __init__( + self, + input: Iterable[str], + events: Iterable[str] = _default_events, + error_handler: ContextManager = nullcontext(), + ): + self.input = input + self.events = set(events) + self.error_handler = error_handler + + def _fix_template(self, template) -> Optional[str]: + return template + + @staticmethod + def _parse_leading_fields(input: Iterable[str]): + MARKER = "onednn_verbose" + for line in map(str.rstrip, input): + if not line.startswith(f"{MARKER},"): + continue + try: + _, operation, args = line.split(",", 2) + except ValueError: + continue + version = 0 + if operation.startswith("v"): + try: + version = int(operation[1:]) + except ValueError: + pass + else: + operation, args = args.split(",", 1) + timestamp = None + try: + timestamp = float(operation) + except ValueError: + pass + else: + operation, args = args.split(",", 1) + component = "primitive" + if operation in ("graph", "primitive", "ukernel"): + component = operation + operation, args = args.split(",", 1) + yield line, version, timestamp, component, operation, args + + def __iter__(self) -> Iterator[Tuple[str, ir.Entry]]: + template = None + cache: Dict[str, dict] = {} + errors: Set[str] = set() + parsed = self._parse_leading_fields(self.input) + for line, version, timestamp, component, operation, args in parsed: + if component == "graph": + continue + event = operation.split(":", 1)[0] + if event == "info": + for marker in ("template", "prim_template"): + if not args.startswith(f"{marker}:"): + continue + fixed_template = self._fix_template(args[len(marker) + 1 :]) + if fixed_template is not None: + break + else: + continue + first_component, rest = fixed_template.split(",", 1) + # Timestamp is usually out of order with respect to the + # template because of missing component for "graph", + # "primitive", "ukernel", etc. + if first_component == "timestamp": + fixed_template = rest + if template != fixed_template: + template = fixed_template + cache.clear() + continue + if event not in self.events: + continue + leading_args, last_arg = args.rsplit(",", 1) + try: + time = float(last_arg) + except ValueError: + time = 0.0 + leading_args = args + key = f"v{version},{component},{operation},{leading_args}" + if key in errors: + continue + success = False + with self.error_handler: + if key in cache: + params = dict(cache[key]) + params.update(time=time, timestamp=timestamp) + else: + new_line = f"{operation},{args}" + params = self.parse(new_line, template, version) + cache[key] = dict(params) + if timestamp is not None: + params.update(timestamp=timestamp) + yield line, ir.Entry(version=version, **params) + success = True + if not success: + errors.add(key) + + def items(self) -> Iterable[Tuple[int, Tuple[str, ir.Entry]]]: + yield from enumerate(self) + + @staticmethod + def _get_impl(version: int = 0) -> ParserImpl: + if version not in Parser._parser_impls: + if version not in ParserImpl._version_map: + raise ParseError(f"No parsers registered for version {version}") + Parser._parser_impls[version] = ParserImpl._version_map[version]() + return Parser._parser_impls[version] + + def parse(self, line: str, template: Optional[str], version: int = 0): + impl = self._get_impl(version) + return impl.parse(line, template) diff --git a/scripts/verbose_converter/src/utils.py b/scripts/verbose_converter/src/utils.py index 69cdcf7161c..39cc525a3d9 100644 --- a/scripts/verbose_converter/src/utils.py +++ b/scripts/verbose_converter/src/utils.py @@ -1,5 +1,5 @@ ################################################################################ -# Copyright 2021-2023 Intel Corporation +# Copyright 2021-2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,19 +14,39 @@ # limitations under the License. ################################################################################ +import functools import sys -status = {"SUCCESS": 0, "FAILED": 1} + +@functools.total_ordering +class Version: + def __init__(self, major: int, minor: int, patch: int): + self.major = major + self.minor = minor + self.patch = patch + + @property + def _as_tuple(self): + return self.major, self.minor, self.patch + + def __lt__(self, other): + return self._as_tuple < other._as_tuple + + def __eq__(self, other): + return self._as_tuple == other._as_tuple def get_version(): - version = sys.version.split(" ")[0].split(".") - return {"major": int(version[0]), "minor": int(version[1]), "fix": int(version[2])} + return Version(*map(int, sys.version.split(" ")[0].split("."))) def check_version(): - v = get_version() - if not (v["major"] >= 3 and v["minor"] >= 6): - print("ERROR: unsupported python version") - return status.get("FAILED") - return status.get("SUCCESS") + return get_version() >= Version(3, 7, 0) + + +def dedent(multiline): + lines = multiline.split("\n") + if len(lines) == 1: + return lines[0].strip() + indent = min(len(line) - len(line.lstrip()) for line in lines[1:]) + return (lines[0] + "\n".join(line[indent:] for line in lines[1:])).strip() diff --git a/scripts/verbose_converter/src/writer.py b/scripts/verbose_converter/src/writer.py deleted file mode 100644 index 5ec54d7918a..00000000000 --- a/scripts/verbose_converter/src/writer.py +++ /dev/null @@ -1,30 +0,0 @@ -################################################################################ -# Copyright 2021-2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -################################################################################ - - -class Writer: - def __init__(self, verbose_level=0): - self.__verbose_level = int(verbose_level) - self.__file = None - - def print(self, string, type): - if type == "WARN": - print(f"{type}: {string}") - if type == "INFO": - if self.__verbose_level > 0: - print(string) - if type == "STDIO": - print(string) diff --git a/scripts/verbose_converter/tests/benchdnn_test.py b/scripts/verbose_converter/tests/benchdnn_test.py index db0d108c8d0..0bbe0a9fd83 100755 --- a/scripts/verbose_converter/tests/benchdnn_test.py +++ b/scripts/verbose_converter/tests/benchdnn_test.py @@ -15,139 +15,157 @@ # limitations under the License. ################################################################################ -import sys, os, subprocess - import argparse +import os +import subprocess +import sys from argparse import RawTextHelpFormatter +from collections import defaultdict +from typing import Dict, List -# add parent dir to sys.path to make verbose_converter visible for test -current_dir = os.path.dirname(os.path.realpath(__file__)) -parent_dir = os.path.dirname(current_dir) -sys.path.append(parent_dir) -import verbose_converter -from src import benchdnn_generator as benchdnn_gen +class TestingException(RuntimeError): + def __init__(self, msg): + from src.utils import dedent # type: ignore[import-not-found] + + super().__init__(dedent(msg)) + -status = {"SUCCESS": 0, "FAILED": 1} +class FailedCase(TestingException): + def __init__(self, status: str, repro: str): + super().__init__(f"Failed case: {status}: {repro}") def convert_dir_benchdnn2verbose(dir): - return { + mapping = { "FWD_D": "forward_training", "FWD_B": "forward_training", "FWD_I": "forward_inference", "BWD_D": "backward_data", "BWD_W": "backward_weights", "BWD_DW": "backward", - }.get(dir) - - -def filter_verbose(benchdnn_verbose, driver): - v = "" - benchdnn_prop_kind = None - - for test_case in benchdnn_verbose.split("__REPRO"): - verbose_lines = test_case.split("\n") - # `start` with `1` as there's a leftover from previous REPRO line. - for idx, l in enumerate(verbose_lines, start=1): - # Parse header - if l.find("create: ") != -1: - # detect prop kind in benchdnn log - dir = "--prop=" if driver == "rnn" else "--dir=" - dir_start = l.find(dir) - if dir_start != -1: - dir_end = l.find(" ", dir_start) - benchdnn_prop_kind = convert_dir_benchdnn2verbose( - l[dir_start + len(dir) : dir_end] - ) - else: - benchdnn_prop_kind = None - else: - # detect driver - l_s = l.split(",") - primitive_idx = 5 - d = ( - benchdnn_gen.convert_driver(l_s[primitive_idx]) - if len(l_s) > primitive_idx - else "" - ) - if ( - len(l_s) > primitive_idx - and l_s[0] == "onednn_verbose" - and d == driver - ): - # filter out additional forward calls, it's located in two - # positions after primitive_kind. - verbose_prop_kind = l_s[primitive_idx + 2] - if ( - benchdnn_prop_kind != None - and verbose_prop_kind != benchdnn_prop_kind - ): - continue - # Filter out fill reorders. Only the last one is actual. - # `len - 1` due to status piece left in `verbose_lines` as - # a product of split by `__REPRO`. - if d == "reorder" and idx != len(verbose_lines) - 1: - continue - # Filter out transform routine till it's properly supported. - # Use impl name for that due to it's the only difference - # between two ukernel calls. - impl_name = l_s[5] - if d == "brgemm" and impl_name == "pack_B": - continue - - # found primitive creation for the test case - # remove time - l_wo_time = "".join(f + "," for f in l.split(",")[0:-1])[0:-1] - v += l_wo_time + "\n" + } + return mapping.get(dir, "undef") + + +def filter_verbose(verbose: str, driver: str, filter_event: str): + found_cases: List[str] = [] + tentative_cases: Dict[str, List[str]] = defaultdict(list) + for line in verbose.split("\n"): + if "__REPRO" in line: + # n: STATUS (Status message) __REPRO: repro + _, status_info, repro = map(str.strip, line.split(":", 2)) + status_and_message = status_info.rsplit(None, 1)[0] + status = status_and_message.split("(", 1)[0].strip() + # workaround for nvim-treesitter indent bug: ) + argname = "prop" if driver == "rnn" else "dir" + known_prop_kind: str = "undef" + for part in repro.split(): + if part.startswith(f"--{argname}="): + value = part[len(argname) + 3 :] + known_prop_kind = convert_dir_benchdnn2verbose(value) break - return [status.get("SUCCESS"), ""], v - -def generate_verbose(path_to_benchdnn, driver, batch): + cases = tentative_cases[known_prop_kind] + tentative_cases.clear() + if status == "SKIPPED": + continue + elif "FAILED" in status: + raise FailedCase(status, repro) + elif not cases: + continue + found_cases.append(cases[-1]) + elif line.startswith("onednn_verbose,"): + # Detect driver + parts = line.split(",") + try: + float(parts[2]) # check for timestamp + except ValueError: + pass + else: + parts.pop(2) + try: + component = parts[2] + event, *_ = parts[3].split(":", 1) + primitive = parts[5] + impl_name = parts[6] + prop_kind = parts[7] + except IndexError: + continue + if component != "primitive" or event not in filter_event: + continue + if get_driver(primitive) != driver: + continue + # Filter out transform routine till it's properly supported. Use + # impl name for that due to it's the only difference between two + # ukernel calls. + if driver == "brgemm" and impl_name == "pack_B": + continue + # Remove primitive creation/run time + try: + float(parts[-1]) + except ValueError: + continue + without_time = ",".join(parts[:-1]) + # Filter out fill reorders. Only the last one is real. + tentative_cases[prop_kind].append(without_time) + if prop_kind != "undef": + # In case the reproducer uses the default prop kind + tentative_cases["undef"].append(without_time) + return "\n".join(found_cases) + + +def generate_verbose(path_to_benchdnn, engine, driver, batch): benchdnn_exe = path_to_benchdnn + "/benchdnn" sub_env = os.environ.copy() sub_env["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "0" # Runtime dimension require execution verbose output. # BRGEMM driver through ukernel API supports verbose only at execution. - sub_env["ONEDNN_VERBOSE"] = "2" + profile_mode = "create" benchdnn_mode = "I" - if driver == "matmul" or driver == "reorder" or driver == "brgemm": - sub_env["ONEDNN_VERBOSE"] = "1" + if driver in ("matmul", "reorder", "brgemm"): + profile_mode = "exec" benchdnn_mode = "R" + # Add extra noise (dispatch, etc.) to ensure it gets filtered out + sub_env["ONEDNN_VERBOSE"] = f"dispatch,error,check,profile_{profile_mode}" sub_args = [ benchdnn_exe, + f"--engine={engine}", f"--{driver}", f"--mode={benchdnn_mode}", - f"-v1", f"--batch={batch}", ] try: - sub = subprocess.run(sub_args, capture_output=True, text=True, env=sub_env) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError) as e: - return [ - status.get("FAILED"), - f"subprocess.run() raised exception: " + f"{e.stdout}", - ], "" - except BaseException as e: - return [ - status.get("FAILED"), - f"subprocess.run() raised exception: " + f"{e.args}\n{e.stdout}", - ], "" + sub = subprocess.run( + sub_args, + capture_output=True, + text=True, + env=sub_env, + ) + except Exception as e: + raise TestingException( + f"subprocess.run() raised exception: {e!s}" + ) from None + if sub.returncode != 0: # most likely converter generated incorrect batch file - return [ - status.get("FAILED"), - f"subprocess.run() returned {sub.returncode},\n" - + f"args: {sub_args}\nstderr: {sub.stderr}", - ], "" + raise TestingException( + f""" + subprocess.run() returned {sub.returncode}, + args: {sub_args} + stderr: {sub.stderr} + """ + ) - return filter_verbose(sub.stdout, driver=driver) + filter_event = "exec" if benchdnn_mode == "R" else "create" + return filter_verbose(sub.stdout, driver, filter_event) def generate_batch(verbose, driver): + import verbose_converter # type: ignore[import-not-found] + verbose = verbose.splitlines() aggregate_opts = [ "engine", @@ -159,8 +177,7 @@ def generate_batch(verbose, driver): "alg_kind", "shapes", ] - s, data = verbose_converter.convert( - verbose_level=0, + data = verbose_converter.convert( parser="oneDNN", input=verbose, action="generate", @@ -168,68 +185,117 @@ def generate_batch(verbose, driver): split_output=True, agg_keys=aggregate_opts, ) - if s != status.get("SUCCESS"): - return [s, f"verbose_converter.convert() returned {s}"], "" - filename = "test.generated" - for key, value in data.items(): - # remove -- from driver name - driver_filename = key + "." + filename - of = open(driver_filename, "w") - print(value, file=of) - return [s, ""], driver + "." + filename + filename = f"{driver}.test.generated" + output = data.get(driver, "") + with open(filename, "w") as fd: + fd.write(f"{output}\n") + return filename def compare(driver, ref_v, comp_v): - ref_lines = ref_v.splitlines() - ref_lines = [l for l in ref_lines if driver in l] - comp_lines = comp_v.splitlines() - len(comp_lines) - comp_lines = [l for l in comp_lines if driver in l] - len(comp_lines) - - for r, c in zip(ref_lines, comp_lines): - if r != c: - ref_log_filename = f"{driver}.reference.log" - com_log_filename = f"{driver}.computed.log" - ref_log = open(ref_log_filename, "w") - com_log = open(com_log_filename, "w") - print(ref_v, file=ref_log) - print(comp_v, file=com_log) - return status.get("FAILED"), f"verboses do not match,\nref: {r}\ncom: {c}" - - return status.get("SUCCESS"), "" - - -def test(path_to_benchdnn, driver, batch): - s, ref_verbose = generate_verbose(path_to_benchdnn, driver, batch) - if s[0] != status.get("SUCCESS"): - return s - # XXX: Maybe generate batch and run becndhnn for each verbose line + def filter_lines(lines): + for line in lines.splitlines(): + if driver in line: + yield line + + def without_impl(verbose_line): + parts = verbose_line.split(",") + return ",".join(parts[:6] + parts[7:]) + + def find_named_entry(name, entries): + for entry in entries: + entry_name, *entry_args = entry.split(":") + if entry_name == name: + return entry_args + return None + + def accept_results(r, c): + if r == c: + return True + + # TODO: Handle cases with non-unique md tags + # * multiple size-1 dimensions with the same stride + # * multiple dimensions with 0 stride + if driver == "matmul": + # In matmul cases with runtime dims that resolve to ones, the bias + # memory descriptor will potentially have the wrong mask printed in + # the verbose line. We do not maintain enough information to always + # print the correct mask, but the reference and computed verbose + # lines will match, up to implementation name. + parts = r.split(",") + mds = parts[8].split() + aux = parts[10].split() + shapes = parts[11].split(":", 1) + wei, act = list(map(lambda x: list(map(int, x.split("x"))), shapes)) + if find_named_entry("bia", mds) is None: + return False + rt_dim_mask = find_named_entry("runtime_dims_masks", aux) + if rt_dim_mask is None: + return False + wei_mask, act_mask = list(map(int, rt_dim_mask)) + if wei[-2] == 1 and wei_mask & (1 << (len(wei) - 2)): + return without_impl(r) == without_impl(c) + if act[-1] == 1 and act_mask & (1 << (len(act) - 1)): + return without_impl(r) == without_impl(c) + elif driver == "sum": + # There is no information in a sum verbose line about scales, so if + # dispatch depends on particular scale values, the implementation + # may change with default scales. In this case, we check that the + # rest of the verbose line is the same. + return without_impl(r) == without_impl(c) + return False + + file_map = {"reference": ref_v, "computed": comp_v} + for r, c in zip(filter_lines(ref_v), filter_lines(comp_v)): + if accept_results(r, c): + continue + for log_type, content in file_map.items(): + with open(f"{driver}.{log_type}.log", "w") as fd: + fd.write(content) + raise TestingException( + f""" + verboses do not match + ref: {r} + com: {c} + """ + ) + + +def test(path_to_benchdnn, engine, driver, batch): + ref_verbose = generate_verbose(path_to_benchdnn, engine, driver, batch) + # XXX: Maybe generate batch and run benchdnn for each verbose line # separately to detect error on case level and not on batch level? # The reason behind testing on batch level is that ref_verbose generator # might introduce multiple verbose lines for single line in batch file - s, gen_batch = generate_batch(ref_verbose, driver) - if s[0] != status.get("SUCCESS"): - return s - s, verbose = generate_verbose(path_to_benchdnn, driver, gen_batch) - if s[0] != status.get("SUCCESS"): - return s - - return compare(driver, ref_verbose, verbose) + com_batch = generate_batch(ref_verbose, driver) + com_verbose = generate_verbose(path_to_benchdnn, engine, driver, com_batch) + compare(driver, ref_verbose, com_verbose) + # XXX: Maybe run an additional loop + # ref -> ref verbose -> com 1 -> com 1 verbose -> com 2 -> com 2 verbose + # Comparing com 1 and com 2 verbose instead would address the special cases + # in accept_results. We can even compare just the cases where ref and com 1 + # don't match. def main(): + relpath = "../../../build/tests/benchdnn" realpath = os.path.dirname(os.path.realpath(__file__)) - print(realpath) - realpath_benchdnn = realpath + "/../../../build/tests/benchdnn" + realpath_benchdnn = os.path.realpath(f"{realpath}/{relpath}") args_parser = argparse.ArgumentParser( description="benchdnn test", formatter_class=RawTextHelpFormatter ) + args_parser.add_argument( + "-e", + "--engine", + default="cpu", + choices=("cpu", "gpu"), + help="Engine to use to run tests", + ) args_parser.add_argument( "-d", "--dataset", - default=realpath + "/" + "dataset_simple", + default=f"{realpath}/dataset_simple", help="input with benchdnn batch files", ) args_parser.add_argument( @@ -241,23 +307,49 @@ def main(): args_parser.add_argument( "-i", "--inputs_path", - default=realpath_benchdnn + "/" + "inputs", + default=f"{realpath_benchdnn}/inputs", help="Path to benchdnn batch files", ) args = args_parser.parse_args() + failed = False with open(args.dataset, "r") as dataset: for case in dataset.readlines(): - if case[0] != "#" and case[0] != "\n": - [driver, batch] = case.split(",") - batch = batch.split("\n")[0] - batch_file_path = args.inputs_path + "/" + driver + "/" + batch - s = test(args.benchdnn_path, driver, batch_file_path) - s_str = "PASSED" if s[0] == status.get("SUCCESS") else "FAILED" - print(f"BENCHDNN TEST: {driver}, {batch}: {s_str} " + s[1]) + case = case.split("#", 1)[0].strip() + if not case: + continue + driver, batch = case.split(",") + batch = batch.split("\n", 1)[0] + batch_file_path = f"{args.inputs_path}/{driver}/{batch}" + test_info = f"BENCHDNN TEST: {args.engine}, {driver}, {batch}" + try: + test(args.benchdnn_path, args.engine, driver, batch_file_path) + except Exception as e: + print(f"{test_info}: FAILED {e!s}") + failed = True + else: + print(f"{test_info}: PASSED") + return failed - return status.get("SUCCESS") + +def get_driver(primitive: str): + import src.benchdnn_generator as bg # type: ignore[import-not-found] + + try: + converter = bg.get_converter(primitive) + except KeyError: + return None + else: + return converter.driver +# Add parent dir to sys.path to make verbose_converter visible for test +current_dir = os.path.dirname(os.path.realpath(__file__)) +parent_dir = os.path.dirname(current_dir) +sys.path.append(parent_dir) + if __name__ == "__main__": - main() + try: + sys.exit(main()) + except KeyboardInterrupt: + sys.exit(0) diff --git a/scripts/verbose_converter/tests/dataset_simple b/scripts/verbose_converter/tests/dataset_simple index 90776b29557..586cd6a5bc6 100644 --- a/scripts/verbose_converter/tests/dataset_simple +++ b/scripts/verbose_converter/tests/dataset_simple @@ -1,5 +1,5 @@ ################################################################################ -# Copyright 2021-2023 Intel Corporation +# Copyright 2021-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ ip,shapes_ci lnorm,shapes_ci lrn,shapes_ci matmul,shapes_2d_ci -pooling,shapes_basic +pool,shapes_basic prelu,shapes_ci reduction,shapes_ci resampling,shapes_ci diff --git a/scripts/verbose_converter/verbose_converter.py b/scripts/verbose_converter/verbose_converter.py index 680cd2252a8..5479330dadb 100755 --- a/scripts/verbose_converter/verbose_converter.py +++ b/scripts/verbose_converter/verbose_converter.py @@ -15,82 +15,99 @@ # limitations under the License. ################################################################################ -import sys - import argparse +import logging +import sys from argparse import RawTextHelpFormatter +from typing import IO, Dict, Iterable, List + +from src.benchdnn_generator import InputGenerator # type: ignore +from src.breakdown_generator import BreakdownGenerator # type: ignore +from src.dnnl_parser import LogParser # type: ignore +from src.utils import check_version # type: ignore + +default_events = "exec", "create" +stream_handler = logging.StreamHandler(sys.stderr) +fmt = logging.Formatter(fmt="{levelname}: {name}: {message}", style="{") +# workaround for nvim-treesitter indent bug: } +stream_handler.setFormatter(fmt) +logger = logging.getLogger("verbose_converter") +logger.setLevel(logging.CRITICAL + 10) # off +logger.addHandler(stream_handler) + + +def one_line(multiline: str): + return " ".join(map(str.strip, multiline.split("\n"))).strip() -from src import utils -from src import writer + +class ConverterError(RuntimeError): + pass + + +def generate(generator, parser: LogParser, *args): + return generator.generate(parser.get_data(), *args) def convert( - verbose_level, - parser, - input, - action, - generator, - split_output, - agg_keys, - events=["create", "exec"], -): - status = utils.check_version() - if status != utils.status.get("SUCCESS"): - return status - - logger = writer.Writer(verbose_level=verbose_level) - log_parser = None - if parser == "oneDNN": - from src import dnnl_parser + parser: str, + input: Iterable[str], + action: str, + generator: str, + split_output: bool, + agg_keys: List[str], + events: Iterable[str] = default_events, +) -> Dict[str, str]: + if not check_version(): + raise ConverterError("Unsupported Python version") - log_parser = dnnl_parser.LogParser(logger, input) + log_parser: LogParser + if parser == "oneDNN": + log_parser = LogParser(logger, input) else: - logger.print("Error: unsupported parser", "STDIO") - return utils.status.get("FAILED") + raise ConverterError("Unsupported parser") - logger.print(f"Processing input ...", "INFO") + logger.info("Processing input ...") log_parser.process(events) - output = None if action == "dumpIR": - logger.print(f"Dumping data from input...", "INFO") + logger.info("Dumping data from input...") log_parser.dump(True) - - if action == "generate": - logger.print(f"Generating output ...", "INFO") + return {} + elif action == "generate": + logger.info("Generating output ...") if generator == "benchdnn": - from src import benchdnn_generator - - gen = benchdnn_generator.InputGenerator(logger) - output = gen.generate(log_parser.get_data(), split_output) + if "create_nested" in events: + logger.warning( + one_line( + """ + Benchdnn arguments generated from create_nested events + may not work! + """ + ) + ) + return generate(InputGenerator(logger), log_parser, split_output) elif generator == "breakdown": - from src import breakdown_generator - - gen = breakdown_generator.BreakdownGenerator(logger) - output = gen.generate(log_parser.get_data(), agg_keys) + return generate(BreakdownGenerator(logger), log_parser, agg_keys) else: - logger.print("Error: unsupported generator", "STDIO") - return utils.status.get("FAILED") - - return utils.status.get("SUCCESS"), output + raise ConverterError("Unsupported generator") + else: + raise ConverterError("Unsupported action") -def validate_option(value, supported_values, str): - if not value in supported_values: - print(f"ERROR: {str}") - return utils.status.get("FAILED") - return utils.status.get("SUCCESS") +def validate_option(value, supported_values, message): + if value not in supported_values: + raise ConverterError(message) -def main(): - status = utils.check_version() - if status != utils.status.get("SUCCESS"): - return status +def main() -> int: + if not check_version(): + logger.error("Unsupported Python version") + return 1 action_opts = ["generate", "dumpIR"] generator_opts = ["benchdnn", "breakdown"] parser_opts = ["oneDNN"] - verbose_opts = ["0", "1"] + verbose_opts = [0, 1] aggregate_opts = [ "engine", "prim_kind", @@ -101,7 +118,7 @@ def main(): "aux", "shapes", ] - event_opts = ["exec", "create"] + event_opts = list(default_events) + ["create_nested"] args_parser = argparse.ArgumentParser( description="oneDNN log converter", formatter_class=RawTextHelpFormatter ) @@ -132,12 +149,18 @@ def main(): "--aggregate", nargs="+", default=aggregate_opts, - help=f"aggregates statistics on the specified keys (default: all keys but time).\nValues: {aggregate_opts}", + help=one_line( + f""" + aggregates statistics on the specified keys (default: all keys but + time). Values: {aggregate_opts} + """ + ), ) args_parser.add_argument( "-v", "--verbose_level", - default="0", + default=0, + type=int, help=f"verbose level (default: 0). Values: {verbose_opts}.", ) args_parser.add_argument( @@ -153,26 +176,31 @@ def main(): "-e", "--events", nargs="+", - default=event_opts, - help=f"events to parse (default: create and exec).\nValues: {event_opts}.", + default=list(default_events), + help=one_line( + f""" + events to parse (default: create and exec). Values: {event_opts}. + """ + ), ) args = args_parser.parse_args() # validate options - status = validate_option(args.action, action_opts, "Unknown action value") - if status != utils.status.get("SUCCESS"): - return status - status = validate_option( - args.verbose_level, verbose_opts, "Unknown verbose_level value" - ) - if status != utils.status.get("SUCCESS"): - return status - status = validate_option(args.parser, parser_opts, "Unknown parser value") - if status != utils.status.get("SUCCESS"): - return status - status = validate_option(args.generator, generator_opts, "Unknown generator value") - if status != utils.status.get("SUCCESS"): - return status + logger.setLevel(logging.ERROR) + try: + validate_option(args.action, action_opts, "Unknown action value") + validate_option( + args.verbose_level, verbose_opts, "Unknown verbose level" + ) + validate_option(args.parser, parser_opts, "Unknown parser value") + validate_option( + args.generator, generator_opts, "Unknown generator value" + ) + for event in args.events: + validate_option(event, event_opts, "Unknown event") + except ConverterError as e: + logger.error(str(e)) + return 1 input_data = [] if args.input == "stdin": @@ -181,53 +209,61 @@ def main(): for line in sys.stdin: input_data.append(line) else: - print("WARN: no input was provided to the script") + logger.warning("No input was provided to the script") args_parser.print_help() else: try: input_data = open(args.input, "r").readlines() except BaseException as e: - print(f"Error while reading input: {e}") - - output = None + logger.error(f"While reading input: {e!s}") + return 1 - event_sets = args.events if args.generator == 'breakdown' else [args.events] + event_sets = ( + [[e] for e in args.events] + if args.generator == "breakdown" + else [args.events] + ) + verbosity_levels = [logging.WARNING, logging.INFO] + logger.setLevel(verbosity_levels[args.verbose_level]) for events in event_sets: - status, output = convert( - verbose_level=args.verbose_level, - parser=args.parser, - input=input_data, - action=args.action, - generator=args.generator, - split_output=args.split, - agg_keys=args.aggregate, - events=events - ) - - if status != utils.status.get("SUCCESS"): - return status + try: + output = convert( + parser=args.parser, + input=input_data, + action=args.action, + generator=args.generator, + split_output=args.split, + agg_keys=args.aggregate, + events=events, + ) + except ConverterError as e: + logger.error(str(e)) + return 1 - if output != None: + for key, value in output.items(): + fd: IO + filename = args.output + if args.split: + filename += f".{key}" if args.output != "stdout": - if output != None: - for key, value in output.items(): - filename = args.output - if args.split == True: - filename += "." + key - of = open(filename, "w") - if args.generator == "breakdown": - print(f"Event: {events}", file=of) - print(value, end="", file=of) + fd = open(filename, "w") + else: + fd = sys.stdout + if args.generator == "breakdown": + fd.write(f"Event: {events[0]}\n") + fd.write(f"{value}\n") else: - if args.generator == "breakdown": - print(f"Event: {events}") - for key, value in output.items(): - if args.split == False: - print(f"{value}") - else: - print(f"--{key}\n{value}") + if args.split: + fd.write(f"--{key}\n") + fd.write(f"{value}\n") + if args.output != "stdout": + fd.close() + return 0 if __name__ == "__main__": - main() + try: + sys.exit(main()) + except KeyboardInterrupt: + sys.exit(0) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b3111b2c3cd..9aa535aadc4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2016-2024 Intel Corporation +# Copyright 2016-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ file(GLOB HEADERS_SUBDIR ${CMAKE_CURRENT_SOURCE_DIR}/../include/oneapi/dnnl/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../include/oneapi/dnnl/*.hpp ) +include_directories_with_host_compiler(${PROJECT_SOURCE_DIR}/third_party) include_directories_with_host_compiler(${CMAKE_CURRENT_SOURCE_DIR}) if (DNNL_LIBRARY_TYPE STREQUAL "SHARED") @@ -80,10 +81,11 @@ if(DNNL_EXPERIMENTAL_SPARSE) endif() if(DNNL_EXPERIMENTAL_UKERNEL) - if(DNNL_TARGET_ARCH STREQUAL "ARCH_GENERIC") - message(FATAL_ERROR "ukernel API does not support generic architecture.") + if(DNNL_TARGET_ARCH STREQUAL "X64" OR DNNL_TARGET_ARCH STREQUAL "AARCH64") + message(STATUS "Experimental functionality for ukernels is enabled") + else() + message(FATAL_ERROR "ukernel API isn't supported for ${DNNL_TARGET_ARCH}.") endif() - message(STATUS "Experimental functionality for ukernels is enabled") endif() if(DNNL_EXPERIMENTAL_PROFILING) @@ -123,6 +125,23 @@ if(UNIX) endif() endif() +if(DNNL_XBYAK_NO_EXCEPTION) + add_definitions_with_host_compiler(-DDNNL_XBYAK_NO_EXCEPTION) +endif() + +macro(enable_conditional_compilation4 target) + if(COMMAND ov_mark_target_as_cc) + ov_mark_target_as_cc(${target}) + if(SELECTIVE_BUILD STREQUAL "ON") + # After disabling a block of code, some variables might be unused. + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" + OR CMAKE_CXX_COMPILER_ID MATCHES "^(Apple)?Clang$") + target_compile_options(${target} PRIVATE -Wno-unused-variable) + endif() + endif() + endif() +endmacro() + add_subdirectory(common) if(NOT DNNL_CPU_RUNTIME STREQUAL "NONE") @@ -152,10 +171,10 @@ endif() if(ONEDNN_BUILD_GRAPH) message(STATUS "Graph component is enabled") - if (NOT DNNL_GPU_RUNTIME STREQUAL "NONE" AND NOT DNNL_GPU_VENDOR STREQUAL "INTEL") + if (NOT DNNL_GPU_RUNTIME STREQUAL "NONE" AND NOT DNNL_GPU_VENDOR STREQUAL "INTEL" AND NOT DNNL_GPU_VENDOR STREQUAL "NVIDIA") message(FATAL_ERROR "Graph API does not support ${DNNL_GPU_VENDOR} GPU. " "Either disable Graph API with ONEDNN_BUILD_GRAPH=OFF or change GPU " - "vendor to INTEL with ONEDNN_GPU_VENDOR=INTEL.") + "vendor to INTEL or NVIDIA.") endif() if (NOT DNNL_ENABLE_PRIMITIVE STREQUAL "ALL") @@ -164,22 +183,12 @@ if(ONEDNN_BUILD_GRAPH) "primitive selection with ONEDNN_ENABLE_PRIMITIVE=ALL.") endif() - if(ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND) - add_definitions_with_host_compiler(-DDNNL_ENABLE_COMPILER_BACKEND) - endif() if(ONEDNN_ENABLE_GRAPH_DUMP) message(STATUS "Graph artifacts dump is enabled") add_definitions_with_host_compiler(-DDNNL_ENABLE_GRAPH_DUMP) endif() add_subdirectory(graph) - if(ONEDNN_EXPERIMENTAL_GRAPH_COMPILER_BACKEND AND TARGET dnnl_graphcompiler_llvm_lib_exclude_string) - get_property(GC_EXCLUDE_LIBS TARGET dnnl_graphcompiler_llvm_lib_exclude_string PROPERTY INTERFACE_LINK_LIBRARIES) - if(DEFINED GC_EXCLUDE_LIBS AND NOT MSVC AND NOT APPLE) - # set the LLVM symbols as hidden, or all LLVM symbols will be exported - append(CMAKE_SHARED_LINKER_FLAGS "-Wl,--exclude-libs=${GC_EXCLUDE_LIBS}") - endif() - endif() else() # If graph component is not built, remove the headers from build and installation. list(REMOVE_ITEM HEADERS_SUBDIR @@ -196,8 +205,7 @@ get_property(SHARED_LIB_DEPS GLOBAL PROPERTY DNNL_SUBDIR_EXTRA_SHARED_LIBS) add_library(${LIB_PACKAGE_NAME} ${DNNL_LIBRARY_TYPE} ${VERSION_RESOURCE_FILE} ${HEADERS_ROOT} ${HEADERS_SUBDIR} ${LIB_DEPS}) -# LINK_PRIVATE for cmake 2.8.11 compatibility -target_link_libraries(${LIB_PACKAGE_NAME} LINK_PRIVATE ${STATIC_LIB_DEPS} ${SHARED_LIB_DEPS}) +target_link_libraries(${LIB_PACKAGE_NAME} PRIVATE ${STATIC_LIB_DEPS} ${SHARED_LIB_DEPS}) set_property(TARGET ${LIB_PACKAGE_NAME} PROPERTY OUTPUT_NAME ${DNNL_LIBRARY_NAME}) set_property(TARGET ${LIB_PACKAGE_NAME} PROPERTY VERSION "${DNNL_VERSION_MAJOR}.${DNNL_VERSION_MINOR}") @@ -209,13 +217,20 @@ target_include_directories(${LIB_PACKAGE_NAME} PUBLIC $ ) -target_link_libraries_build(${LIB_PACKAGE_NAME} - "${EXTRA_SHARED_LIBS};${EXTRA_STATIC_LIBS}") +target_link_libraries(${LIB_PACKAGE_NAME} PUBLIC "$") target_link_libraries_install(${LIB_PACKAGE_NAME} "${EXTRA_SHARED_LIBS}") if(DNNL_LIBRARY_TYPE STREQUAL "STATIC") target_link_libraries_install(${LIB_PACKAGE_NAME} "${EXTRA_STATIC_LIBS}") endif() +foreach(object_library IN LISTS LIB_DEPS) + string(REPLACE "$" "" object_library "${object_library}") + + # explicitly set compile PDB name as with Ninja, all targets have the same pdb name like vc.pdb + set_target_properties(${object_library} PROPERTIES COMPILE_PDB_NAME ${object_library}) +endforeach() + set(LIB_EXPORT_NAME "${LIB_PACKAGE_NAME}-targets") install(TARGETS ${LIB_PACKAGE_NAME} EXPORT "${LIB_EXPORT_NAME}" @@ -232,7 +247,7 @@ foreach(header ${HEADERS_SUBDIR}) endforeach() string(TOUPPER "${LIB_PACKAGE_NAME}::" LIB_NAMESPACE) -if(DNNL_INSTALL_MODE STREQUAL "BUNDLE_V2" AND WIN32) +if(DNNL_INSTALL_MODE STREQUAL "BUNDLE" AND WIN32) # Config file for binary distribution needs to define a mapping # DEBUG -> RELWITHMDD so that proper library (dnnld) is picked up for the # DEBUG configuration. @@ -263,7 +278,7 @@ install(EXPORT ${LIB_EXPORT_NAME} # Apply a workaround to CMake config file to make it work with symlinks. # The patched config file is only used in oneAPI binary distribution. -if(UNIX AND DNNL_INSTALL_MODE STREQUAL "BUNDLE_V2") +if(UNIX AND DNNL_INSTALL_MODE STREQUAL "BUNDLE") install(CODE "file(READ \"${CMAKE_INSTALL_PREFIX}/${LIB_CONFIG_INSTALL_DIR}/${LIB_PACKAGE_NAME}-targets.cmake\" TARGETS_CONTENT)") install(CODE "string(REPLACE \"get_filename_component(_IMPORT_PREFIX \\\"\\\${CMAKE_CURRENT_LIST_FILE}\\\" PATH)\" @@ -273,7 +288,7 @@ if(UNIX AND DNNL_INSTALL_MODE STREQUAL "BUNDLE_V2") endif() # Install custom find modules for transitive dependencies -if(DNNL_CPU_THREADING_RUNTIME STREQUAL "TBB") +if("${DNNL_CPU_THREADING_RUNTIME}" MATCHES "^(TBB|TBB_AUTO)$") if(WIN32) install(FILES "../cmake/win/TBBConfig.cmake" RENAME "FindTBB.cmake" DESTINATION ${LIB_CONFIG_INSTALL_DIR}) @@ -298,6 +313,14 @@ if(DNNL_BLAS_VENDOR STREQUAL "ACCELERATE") DESTINATION ${LIB_CONFIG_INSTALL_DIR}) endif() +if(DNNL_SYCL_CUDA) + install(FILES + "../cmake/FindcuBLAS.cmake" + "../cmake/FindcublasLt.cmake" + "../cmake/FindcuDNN.cmake" + DESTINATION ${LIB_CONFIG_INSTALL_DIR}) +endif() + # On Windows we need to add dnnl.dll path to CTESTCONFIG_PATH which is later # passed to ctest and Visual Studio solutions if(WIN32) diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index ffb7d2c3831..5d698d8d0e1 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2019-2024 Intel Corporation +# Copyright 2019-2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ file(GLOB SOURCES if(DNNL_ENABLE_JIT_PROFILING OR DNNL_ENABLE_ITT_TASKS) if(DNNL_TARGET_ARCH STREQUAL "AARCH64" OR DNNL_TARGET_ARCH STREQUAL "X64") file(GLOB ITT_SOURCES - ${CMAKE_CURRENT_SOURCE_DIR}/ittnotify/*.[ch] + ${PROJECT_SOURCE_DIR}/third_party/ittnotify/*.c ) list(APPEND SOURCES ${ITT_SOURCES}) @@ -33,10 +33,10 @@ if(DNNL_ENABLE_JIT_PROFILING OR DNNL_ENABLE_ITT_TASKS) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DITT_API_IPT_SUPPORT") if(UNIX OR MINGW) enable_language(ASM) - set(ITT_PT ${CMAKE_CURRENT_SOURCE_DIR}/ittnotify/ittptmark64.S) + set(ITT_PT ${PROJECT_SOURCE_DIR}/third_party/ittnotify/ittptmark64.S) else() enable_language(ASM_MASM) - set(ITT_PT ${CMAKE_CURRENT_SOURCE_DIR}/ittnotify/ittptmark64.asm) + set(ITT_PT ${PROJECT_SOURCE_DIR}/third_party/ittnotify/ittptmark64.asm) endif() list(APPEND SOURCES ${ITT_PT}) endif() @@ -49,13 +49,11 @@ if(NOT DNNL_CPU_RUNTIME STREQUAL "THREADPOOL") endif() if(NOT DNNL_EXPERIMENTAL_LOGGING) - # avoid building and linking spdlog if logging support is not enabled - list(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/spdlog/*") list(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/logging.cpp") - list(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/logging.hpp") endif() set(OBJ_LIB ${LIB_PACKAGE_NAME}_common) add_library(${OBJ_LIB} OBJECT ${SOURCES}) set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS $) +enable_conditional_compilation4(${OBJ_LIB}) diff --git a/src/common/batch_normalization_pd.hpp b/src/common/batch_normalization_pd.hpp index 577f3bb75a5..6cd6ce47d96 100644 --- a/src/common/batch_normalization_pd.hpp +++ b/src/common/batch_normalization_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -97,7 +97,7 @@ struct batch_normalization_pd_t : public primitive_desc_t { float alpha() const { const auto &p = attr()->post_ops_; - const bool entry_size_ok = p.entry_.size() > 0; + const bool entry_size_ok = !p.entry_.empty(); assert(entry_size_ok || fuse_norm_relu() || fuse_norm_add_relu()); if (entry_size_ok) return p.entry_[0].eltwise.alpha; return 0.f; @@ -126,16 +126,15 @@ struct batch_normalization_pd_t : public primitive_desc_t { memory_desc_t ws_md_; - batch_normalization_pd_t(const batch_normalization_desc_t *adesc, + batch_normalization_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const batch_normalization_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) , src_md_(desc_.src_desc) , stat_md_(desc_.stat_desc) - , scaleshift_md_(desc_.scaleshift_desc) - , ws_md_() {} + , scaleshift_md_(desc_.scaleshift_desc) {} virtual status_t init_default_ws(size_t bits_per_element) { const auto src_mdw = memory_desc_wrapper(src_md_); @@ -149,14 +148,16 @@ struct batch_normalization_pd_t : public primitive_desc_t { } }; +// NOLINTBEGIN(google-default-arguments) struct batch_normalization_fwd_pd_t : public batch_normalization_pd_t { - typedef batch_normalization_fwd_pd_t base_class; - typedef batch_normalization_fwd_pd_t hint_class; + using base_class = batch_normalization_fwd_pd_t; + using hint_class = batch_normalization_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; - if (arg == DNNL_ARG_SRC_1 && fuse_norm_add_relu()) - return arg_usage_t::input; + if (arg == DNNL_ARG_SRC_1) + return fuse_norm_add_relu() ? arg_usage_t::input + : arg_usage_t::unused; if (arg == DNNL_ARG_DST) return arg_usage_t::output; if (utils::one_of(arg, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE)) { @@ -165,11 +166,14 @@ struct batch_normalization_fwd_pd_t : public batch_normalization_pd_t { return arg_usage_t::unused; } - if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input; - if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input; + if (arg == DNNL_ARG_SCALE) + return use_scale() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_SHIFT) + return use_shift() ? arg_usage_t::input : arg_usage_t::unused; - if (arg == DNNL_ARG_WORKSPACE && !types::is_zero_md(workspace_md())) - return arg_usage_t::output; + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::output + : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -230,7 +234,7 @@ struct batch_normalization_fwd_pd_t : public batch_normalization_pd_t { protected: memory_desc_t dst_md_; - batch_normalization_fwd_pd_t(const batch_normalization_desc_t *adesc, + batch_normalization_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const batch_normalization_fwd_pd_t *hint_fwd_pd) : batch_normalization_pd_t(adesc, attr, hint_fwd_pd) @@ -247,30 +251,36 @@ struct batch_normalization_fwd_pd_t : public batch_normalization_pd_t { weights_md()->data_type == data_type::f32); } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct batch_normalization_bwd_pd_t : public batch_normalization_pd_t { - typedef batch_normalization_bwd_pd_t base_class; - typedef batch_normalization_fwd_pd_t hint_class; + using base_class = batch_normalization_bwd_pd_t; + using hint_class = batch_normalization_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE, DNNL_ARG_DIFF_DST)) return arg_usage_t::input; - if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input; - if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input; + if (arg == DNNL_ARG_SCALE) + return use_scale() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_SHIFT) + return use_shift() ? arg_usage_t::input : arg_usage_t::unused; - if (arg == DNNL_ARG_WORKSPACE && !types::is_zero_md(workspace_md())) - return arg_usage_t::input; + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::input + : arg_usage_t::unused; if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_SRC_1 && fuse_norm_add_relu()) - return arg_usage_t::output; - - if (arg == DNNL_ARG_DIFF_SCALE && use_scale()) - return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_SHIFT && use_shift()) - return arg_usage_t::output; + if (arg == DNNL_ARG_DIFF_SRC_1) + return fuse_norm_add_relu() ? arg_usage_t::output + : arg_usage_t::unused; + + if (arg == DNNL_ARG_DIFF_SCALE) + return use_scale() ? arg_usage_t::output : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_SHIFT) + return use_shift() ? arg_usage_t::output : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -341,7 +351,7 @@ struct batch_normalization_bwd_pd_t : public batch_normalization_pd_t { memory_desc_t diff_dst_md_; memory_desc_t diff_scaleshift_md_; - batch_normalization_bwd_pd_t(const batch_normalization_desc_t *adesc, + batch_normalization_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const batch_normalization_fwd_pd_t *hint_fwd_pd) : batch_normalization_pd_t(adesc, attr, hint_fwd_pd) @@ -366,6 +376,7 @@ struct batch_normalization_bwd_pd_t : public batch_normalization_pd_t { diff_weights_md()->data_type)); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/binary.cpp b/src/common/binary.cpp index 2948b5b6beb..570e6eddc3e 100644 --- a/src/common/binary.cpp +++ b/src/common/binary.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,23 +48,24 @@ status_t binary_attr_check(const binary_desc_t &desc, const engine_t *engine, // Check attributes const data_type_t dst_dt = desc.dst_desc.data_type; - auto attr_mask = smask_t::post_ops | smask_t::scales_runtime; + auto attr_mask = smask_t::post_ops | smask_t::scales; VCHECK_BINARY_UNIMPL(attr->has_default_values(attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); // Check scales if (!attr->scales_.has_default_values()) { - VCHECK_BINARY_UNIMPL(attr->scales_.has_default_values( - {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}), + static const std::vector supported_args { + DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}; + VCHECK_BINARY_UNIMPL(attr->scales_.has_default_values(supported_args), VERBOSE_UNSUPPORTED_SCALES_CFG); - const auto &sc = attr->scales_; - const int mask_src_0 = sc.get(DNNL_ARG_SRC_0).mask_; - const int mask_src_1 = sc.get(DNNL_ARG_SRC_1).mask_; + for (int arg : supported_args) { + if (attr->scales_.has_default_values(arg)) continue; - VCHECK_BINARY_UNIMPL(utils::everyone_is(0, mask_src_0, mask_src_1), - VERBOSE_UNSUPPORTED_SCALES_CFG); + const int mask = attr->scales_.get_mask(arg); + VCHECK_BINARY_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } } // Check post-ops @@ -77,30 +78,24 @@ status_t binary_attr_check(const binary_desc_t &desc, const engine_t *engine, // Check sum VCHECK_BINARY_UNIMPL(po.check_sum_consistency(dst_dt, false, true), VERBOSE_UNSUPPORTED_POSTOP); - } + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); + } return status::success; } -status_t dnnl_binary_primitive_desc_create( - primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, - alg_kind_t alg_kind, const memory_desc_t *src0_md, - const memory_desc_t *src1_md, const memory_desc_t *dst_md, - const primitive_attr_t *attr) { +status_t binary_md_check(const engine_t *engine, alg_kind_t alg_kind, + const memory_desc_t *src0_md, const memory_desc_t *src1_md, + const memory_desc_t *src2_md, const memory_desc_t *dst_md) { VCHECK_BINARY(!any_null(src0_md, src1_md, dst_md), VERBOSE_NULL_ARG); - VCHECK_BINARY( - one_of(alg_kind, binary_add, binary_mul, binary_max, binary_min, - binary_div, binary_sub, binary_ge, binary_gt, binary_le, - binary_lt, binary_eq, binary_ne), - VERBOSE_BAD_ALGORITHM); + VCHECK_BINARY(IMPLICATION(alg_kind == binary_select, src2_md != nullptr), + VERBOSE_NULL_ARG); + // TODO - Add support for mutual or bi-directional broadcasts VCHECK_BINARY(!memory_desc_wrapper(src0_md).format_any(), VERBOSE_UNSUPPORTED_TAG_S, "src0"); - auto bod = binary_desc_t(); - bod.primitive_kind = primitive_kind::binary; - bod.alg_kind = alg_kind; - VCONDCHECK(primitive, create, check, binary, !memory_desc_wrapper(src0_md).has_runtime_dims_or_strides(), status::unimplemented, VERBOSE_RUNTIMEDIM_UNSUPPORTED); @@ -111,10 +106,6 @@ status_t dnnl_binary_primitive_desc_create( !memory_desc_wrapper(dst_md).has_runtime_dims_or_strides(), status::unimplemented, VERBOSE_RUNTIMEDIM_UNSUPPORTED); - bod.src_desc[0] = *src0_md; - bod.src_desc[1] = *src1_md; - bod.dst_desc = *dst_md; - const int ndims = dst_md->ndims; const dims_t &dims = dst_md->dims; @@ -122,8 +113,19 @@ status_t dnnl_binary_primitive_desc_create( src0_md->ndims == ndims, VERBOSE_INCONSISTENT_NDIMS, "src0", "dst"); VCHECK_BINARY( src1_md->ndims == ndims, VERBOSE_INCONSISTENT_NDIMS, "src1", "dst"); + + if (src2_md != nullptr) { + VCONDCHECK(primitive, create, check, binary, + !memory_desc_wrapper(src2_md).has_runtime_dims_or_strides(), + status::unimplemented, VERBOSE_RUNTIMEDIM_UNSUPPORTED); + VCHECK_BINARY(src2_md->ndims == ndims, VERBOSE_INCONSISTENT_NDIMS, + "src2", "dst"); + VCHECK_BINARY( + src2_md->data_type == data_type::s8, VERBOSE_UNSUPPORTED_DT); + } + for (int d = 0; d < ndims; ++d) { - //dims must equal eachother or equal 1 (broadcast) + //dims must equal each other or equal 1 (broadcast) VCHECK_BINARY(utils::one_of(src0_md->dims[d], 1, dims[d]), VERBOSE_BAD_DIM, "src0", d); VCHECK_BINARY(utils::one_of(src1_md->dims[d], 1, dims[d]), @@ -131,7 +133,49 @@ status_t dnnl_binary_primitive_desc_create( VCHECK_BINARY(IMPLICATION(src0_md->dims[d] != dims[d], src1_md->dims[d] == dims[d]), VERBOSE_INCONSISTENT_DIM, "src1", d, "dst", d); + + if (src2_md != nullptr) { + VCHECK_BINARY(utils::one_of(src2_md->dims[d], 1, dims[d]), + VERBOSE_BAD_DIM, "src2", d); + VCHECK_BINARY(IMPLICATION(src0_md->dims[d] != dims[d], + src2_md->dims[d] == src0_md->dims[d]), + VERBOSE_INCONSISTENT_DIM, "src0", d, "src2", d); + } } + return status::success; +} + +status_t dnnl_binary_primitive_desc_create( + primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, + alg_kind_t alg_kind, const memory_desc_t *src0_md, + const memory_desc_t *src1_md, const memory_desc_t *dst_md, + const primitive_attr_t *attr) { + + return dnnl_binary_primitive_desc_create_v2(primitive_desc_iface, engine, + alg_kind, src0_md, src1_md, nullptr, dst_md, attr); +} + +status_t dnnl_binary_primitive_desc_create_v2( + primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, + alg_kind_t alg_kind, const memory_desc_t *src0_md, + const memory_desc_t *src1_md, const memory_desc_t *src2_md, + const memory_desc_t *dst_md, const primitive_attr_t *attr) { + VCHECK_BINARY( + one_of(alg_kind, binary_add, binary_mul, binary_max, binary_min, + binary_div, binary_sub, binary_ge, binary_gt, binary_le, + binary_lt, binary_eq, binary_ne, binary_select, binary_prelu), + VERBOSE_BAD_ALGORITHM); + + CHECK(binary_md_check(engine, alg_kind, src0_md, src1_md, src2_md, dst_md)); + + auto bod = binary_desc_t(); + bod.primitive_kind = primitive_kind::binary; + bod.alg_kind = alg_kind; + + bod.src_desc[0] = *src0_md; + bod.src_desc[1] = *src1_md; + if (alg_kind == binary_select) bod.src_desc[2] = *src2_md; + bod.dst_desc = *dst_md; CHECK(binary_attr_check(bod, engine, attr)); return primitive_desc_create(primitive_desc_iface, engine, diff --git a/src/common/binary_pd.hpp b/src/common/binary_pd.hpp index aa0d2cb23cf..2ec1c31bb5b 100644 --- a/src/common/binary_pd.hpp +++ b/src/common/binary_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,11 +37,12 @@ namespace dnnl { namespace impl { +// NOLINTBEGIN(google-default-arguments) struct binary_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::binary; - typedef binary_pd_t base_class; - typedef binary_pd_t hint_class; + using base_class = binary_pd_t; + using hint_class = binary_pd_t; const binary_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { @@ -59,7 +60,8 @@ struct binary_pd_t : public primitive_desc_t { } arg_usage_t arg_usage(int arg) const override { - if (arg == DNNL_ARG_SRC_0 || arg == DNNL_ARG_SRC_1) + if (arg == DNNL_ARG_SRC_0 || arg == DNNL_ARG_SRC_1 + || arg == DNNL_ARG_SRC_2) return arg_usage_t::input; if (arg == DNNL_ARG_DST) return arg_usage_t::output; @@ -72,6 +74,7 @@ struct binary_pd_t : public primitive_desc_t { switch (arg) { case DNNL_ARG_SRC_0: return src_md(0); case DNNL_ARG_SRC_1: return src_md(1); + case DNNL_ARG_SRC_2: return src_md(2); case DNNL_ARG_DST: return dst_md(0, user_input); default: return primitive_desc_t::arg_md(arg); } @@ -81,6 +84,7 @@ struct binary_pd_t : public primitive_desc_t { int index = 0, bool user_input = false) const override { if (index == 0) return user_input ? &desc()->src_desc[0] : &src0_md_; if (index == 1) return user_input ? &desc()->src_desc[1] : &src1_md_; + if (index == 2) return user_input ? &desc()->src_desc[2] : &src2_md_; return &glob_zero_md; } const memory_desc_t *dst_md( @@ -89,7 +93,9 @@ struct binary_pd_t : public primitive_desc_t { return &glob_zero_md; } - int n_inputs() const override { return 2 + n_binary_po_inputs(); } + int n_inputs() const override { + return 2 + n_binary_po_inputs() + static_cast(is_ternary_op()); + } int n_outputs() const override { return 1; } const dims_t &broadcast_dims() const { return broadcast_dims_; } @@ -106,21 +112,29 @@ struct binary_pd_t : public primitive_desc_t { return src0_d.consistent_with(src1_d); } + bool is_ternary_op() const { + const memory_desc_wrapper src2_d(src_md(2)); + return !src2_d.is_zero() + && (desc()->alg_kind == alg_kind::binary_select); + } + protected: binary_desc_t desc_; memory_desc_t src0_md_; memory_desc_t src1_md_; + memory_desc_t src2_md_; memory_desc_t dst_md_; dims_t broadcast_dims_; - binary_pd_t(const binary_desc_t *adesc, const primitive_attr_t *attr, + binary_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const binary_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , src0_md_(desc_.src_desc[0]) , src1_md_(desc_.src_desc[1]) + , src2_md_(desc_.src_desc[2]) , dst_md_(desc_.dst_desc) { init_broadcast_dims(); } @@ -134,6 +148,14 @@ struct binary_pd_t : public primitive_desc_t { } } + if (is_ternary_op() && src2_md_.format_kind == format_kind::any) { + const memory_desc_wrapper src_d(src_md(0)); + if (src_d.is_blocking_desc()) { + CHECK(memory_desc_init_by_blocking_desc( + src2_md_, src_d.blocking_desc())); + } + } + if (dst_md_.format_kind == format_kind::any) { const memory_desc_wrapper src_d(src_md(0)); if (src_d.is_blocking_desc()) { @@ -158,10 +180,13 @@ struct binary_pd_t : public primitive_desc_t { bool attr_scales_ok(const std::vector &supported_args = {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, DNNL_ARG_DST}) const { - bool ok = attr()->scales_.has_default_values(supported_args); - for (int arg : supported_args) { - const auto &mask = attr()->scales_.get(arg).mask_; - ok = ok && (mask == 0); + const auto &scales = attr()->scales_; + bool ok = scales.has_default_values(supported_args); + + for (const auto &arg : supported_args) { + if (scales.has_default_values(arg)) continue; + + ok = ok && scales.get_mask(arg) == 0; } return ok; } @@ -176,6 +201,7 @@ struct binary_pd_t : public primitive_desc_t { = (dims_A[d] == dims_B[d] && dims_A[d] != 1) ? 0 : 1; } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/broadcast_strategy.cpp b/src/common/broadcast_strategy.cpp index 6123f917cac..8ab33aa1d54 100644 --- a/src/common/broadcast_strategy.cpp +++ b/src/common/broadcast_strategy.cpp @@ -34,6 +34,7 @@ broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( static const bcast_set_t all_bcast_strategies { broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, broadcasting_strategy_t::per_oc_spatial, + broadcasting_strategy_t::per_oc_d, broadcasting_strategy_t::shared_axes, broadcasting_strategy_t::per_mb, broadcasting_strategy_t::per_mb_spatial, @@ -164,6 +165,17 @@ bool is_spatial_bcast(const std::bitset mask, return spatial_bcast; } +// Check if mask corresponds to per oc_d +// true if dim == 4 and mask = [1, 0, 0, 1] +bool is_per_oc_d_bcast(const std::bitset mask, + const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d) { + const dims_t &rdims = rhs_arg_md.dims; + const dims_t &ddims = dst_d.dims(); + if (rhs_arg_md.ndims != 4) return false; + if (!mask.test(0) || !mask.test(3)) return false; + if (rdims[1] != ddims[1] || rdims[2] != ddims[2]) return false; + return true; +} bool bcast_strategy_enabled(const bcast_set_t &supported_strategy_set, const broadcasting_strategy_t &bcast) { @@ -254,7 +266,10 @@ broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( else if (is_spatial_bcast(mask, dst_d) && is_enabled(broadcasting_strategy_t::spatial)) bcast = broadcasting_strategy_t::spatial; - else if (is_enabled(broadcasting_strategy_t::shared_axes)) + else if (is_per_oc_d_bcast(mask, rhs_arg_md, dst_d) + && is_enabled(broadcasting_strategy_t::per_oc_d)) { + bcast = broadcasting_strategy_t::per_oc_d; + } else if (is_enabled(broadcasting_strategy_t::shared_axes)) bcast = broadcasting_strategy_t::shared_axes; return bcast; diff --git a/src/common/broadcast_strategy.hpp b/src/common/broadcast_strategy.hpp index 8b10e205fff..1ec98e52363 100644 --- a/src/common/broadcast_strategy.hpp +++ b/src/common/broadcast_strategy.hpp @@ -34,6 +34,7 @@ enum class broadcasting_strategy_t { per_oc, // [1, c, 1, 1, 1] // Channel-wise per_oc_spatial, // [1, c, 1, 1, 1] specific case for binary kernel nchw format per_mb, // [n, 1, 1, 1, 1] // broadcast per batch + per_oc_d, // [a, b, c, d] -> [1, b, c, 1]; [n, g, oc/g, sp] --> [1, g, oc/g, 1] specific case for ncsp matmul reduction. per_mb_spatial, // [n, 1, d, h, w] // Broadcast only channel per_mb_w, // [n, 1, 1, 1, w] // Broadcast per batch and width per_w, // [1, 1, 1, 1, w] // Broadcast per width diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index a299936466c..64abe75b796 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -1,5 +1,6 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation +* Copyright 2024-2025 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,6 +91,9 @@ const alg_kind_t eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh; const alg_kind_t eltwise_gelu_erf = dnnl_eltwise_gelu_erf; const alg_kind_t eltwise_hardswish = dnnl_eltwise_hardswish; const alg_kind_t eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid; +const alg_kind_t eltwise_hsigmoid = dnnl_eltwise_hsigmoid; +const alg_kind_t eltwise_round_half_to_even = dnnl_eltwise_round_half_to_even; +const alg_kind_t eltwise_round_half_away_from_zero = dnnl_eltwise_round_half_away_from_zero; const alg_kind_t eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd; const alg_kind_t eltwise_tanh_use_dst_for_bwd @@ -126,6 +130,8 @@ const alg_kind_t binary_le = dnnl_binary_le; const alg_kind_t binary_lt = dnnl_binary_lt; const alg_kind_t binary_eq = dnnl_binary_eq; const alg_kind_t binary_ne = dnnl_binary_ne; +const alg_kind_t binary_select = dnnl_binary_select; +const alg_kind_t binary_prelu = dnnl_binary_prelu; const alg_kind_t resampling_nearest = dnnl_resampling_nearest; const alg_kind_t resampling_linear = dnnl_resampling_linear; const alg_kind_t reduction_max = dnnl_reduction_max; @@ -141,11 +147,23 @@ const alg_kind_t reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum; const alg_kind_t softmax_accurate = dnnl_softmax_accurate; const alg_kind_t softmax_log = dnnl_softmax_log; +const alg_kind_t depthwise_scale_shift = dnnl_depthwise_scale_shift; +const alg_kind_t depthwise_prelu = dnnl_depthwise_prelu; +const alg_kind_t quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize; +const alg_kind_t quantization_quantize = dnnl_quantization_quantize; +const alg_kind_t binarization_depthwise = dnnl_binarization_depthwise; +// Internal only alg kinds. +const alg_kind_t internal_only_start = (alg_kind_t)(1 << 12); +// GPU only via jit_eltwise injector. +const alg_kind_t eltwise_stochastic_round + = (alg_kind_t)(internal_only_start + 1); } // namespace alg_kind using data_type_t = dnnl_data_type_t; namespace data_type { const data_type_t undef = dnnl_data_type_undef; +const data_type_t f4_e3m0 = dnnl_f4_e3m0; +const data_type_t f4_e2m1 = dnnl_f4_e2m1; const data_type_t e8m0 = dnnl_e8m0; const data_type_t f8_e5m2 = dnnl_f8_e5m2; const data_type_t f8_e4m3 = dnnl_f8_e4m3; @@ -161,9 +179,11 @@ const data_type_t u4 = dnnl_u4; const data_type_t boolean = dnnl_boolean; const data_type_t data_type_max = dnnl_data_type_max; +const data_type_t bin = dnnl_bin; +const data_type_t nf4 = dnnl_nf4; + // Not exposed through API as all current uses are internal only const data_type_t tf32 = static_cast(1 << 8); - } // namespace data_type using fpmath_mode_t = dnnl_fpmath_mode_t; @@ -202,16 +222,18 @@ using sparse_encoding_t = dnnl_sparse_encoding_t; namespace sparse_encoding { const sparse_encoding_t undef = dnnl_sparse_encoding_undef; const sparse_encoding_t csr = dnnl_csr; +const sparse_encoding_t coo = dnnl_coo; const sparse_encoding_t packed = dnnl_packed; } // namespace sparse_encoding #else // Declare dummy values to avoid guarding internal implementation. -using sparse_encoding_t = int; -namespace sparse_encoding { -const sparse_encoding_t undef = 0; -const sparse_encoding_t csr = 1; -const sparse_encoding_t packed = 2; -} // namespace sparse_encoding +// using sparse_encoding_t = int; +// namespace sparse_encoding { +// const sparse_encoding_t undef = 0; +// const sparse_encoding_t csr = 1; +// const sparse_encoding_t packed = 2; +// const sparse_encoding_t coo = 3; +// } // namespace sparse_encoding #endif using format_kind_t = dnnl_format_kind_t; @@ -223,13 +245,15 @@ const format_kind_t opaque = dnnl_format_kind_opaque; #ifdef DNNL_EXPERIMENTAL_SPARSE const format_kind_t sparse = dnnl_format_kind_sparse; #else -const format_kind_t sparse = static_cast(4); +// const format_kind_t sparse = static_cast(4); #endif // Internal only format kinds. const format_kind_t internal_only_start = (format_kind_t)(1 << 8); const format_kind_t wino = internal_only_start; const format_kind_t rnn_packed = (format_kind_t)(internal_only_start + 1); +const format_kind_t cublaslt_blocked = (format_kind_t)(internal_only_start + 2); +const format_kind_t sparse = dnnl_format_sparse; } // namespace format_kind #ifdef DNNL_EXPERIMENTAL_PROFILING @@ -248,6 +272,8 @@ const profiling_data_kind_t internal_only_start = (profiling_data_kind_t)(1 << 8); const profiling_data_kind_t cycles = (profiling_data_kind_t)(internal_only_start + 1); +const profiling_data_kind_t time_per_kernel + = (profiling_data_kind_t)(internal_only_start + 2); } // namespace profiling_data_kind using format_tag_t = dnnl_format_tag_t; @@ -358,6 +384,9 @@ const format_tag_t aCB16b16c = dnnl_aCB16b16c; const format_tag_t aCB16b32c = dnnl_aCB16b32c; const format_tag_t aCB16b48c = dnnl_aCB16b48c; const format_tag_t aCB16b64c = dnnl_aCB16b64c; +const format_tag_t BA24b8a = dnnl_BA24b8a; +const format_tag_t aCB24c8b = dnnl_aCB24c8b; +const format_tag_t abDC24d8c = dnnl_abDC24d8c; const format_tag_t aCB16b16c2b = dnnl_aCB16b16c2b; const format_tag_t aCB16b32c2b = dnnl_aCB16b32c2b; const format_tag_t aCB16b48c2b = dnnl_aCB16b48c2b; @@ -369,6 +398,7 @@ const format_tag_t aCB16b64c4b = dnnl_aCB16b64c4b; const format_tag_t Ab4a = dnnl_Ab4a; const format_tag_t Ab8a = dnnl_Ab8a; +const format_tag_t Ab32a = dnnl_Ab32a; const format_tag_t Abc16a = dnnl_Abc16a; const format_tag_t ABc16a16b = dnnl_ABc16a16b; const format_tag_t ABc4a2b = dnnl_ABc4a2b; @@ -471,6 +501,7 @@ const format_tag_t aBCd4b4c = dnnl_aBCd4b4c; const format_tag_t ABcd8a16b2a = dnnl_ABcd8a16b2a; const format_tag_t BAcd8a16b2a = dnnl_BAcd8a16b2a; const format_tag_t ABcd8a8b = dnnl_ABcd8a8b; +const format_tag_t ABcd8a32b = dnnl_ABcd8a32b; const format_tag_t ABcd8a4b = dnnl_ABcd8a4b; const format_tag_t ABcd8a2b = dnnl_ABcd8a2b; const format_tag_t aBcd8b = dnnl_aBcd8b; @@ -615,6 +646,7 @@ const format_tag_t aBdefc16b = dnnl_aBdefc16b; const format_tag_t aBdefC16b2c = dnnl_aBdefC16b2c; const format_tag_t aBdefC16b4c = dnnl_aBdefC16b4c; const format_tag_t aCBdef16c16b = dnnl_aCBdef16c16b; +const format_tag_t aCBdef8b8c = dnnl_aCBdef8b8c; const format_tag_t aCBdef16b16c = dnnl_aCBdef16b16c; const format_tag_t aBdefc4b = dnnl_aBdefc4b; const format_tag_t aBdefc8b = dnnl_aBdefc8b; @@ -629,8 +661,10 @@ const format_tag_t Acb4a = dnnl_Acb4a; const format_tag_t Acb8a = dnnl_Acb8a; const format_tag_t AcB8a2b = dnnl_AcB8a2b; const format_tag_t AcB8a4b = dnnl_AcB8a4b; +const format_tag_t aCBd8b8c = dnnl_aCBd8b8c; const format_tag_t aCBd16b16c = dnnl_aCBd16b16c; const format_tag_t aCBd16c16b = dnnl_aCBd16c16b; +const format_tag_t aCBde8b8c = dnnl_aCBde8b8c; const format_tag_t aCBde16b16c = dnnl_aCBde16b16c; const format_tag_t aCBde16c16b = dnnl_aCBde16c16b; const format_tag_t Acdb16a = dnnl_Acdb16a; @@ -649,7 +683,9 @@ const format_tag_t AcdeB8a2b = dnnl_AcdeB8a2b; const format_tag_t AcdeB8a4b = dnnl_AcdeB8a4b; const format_tag_t Acedb16a = dnnl_Acedb16a; const format_tag_t Adcb16a = dnnl_Adcb16a; +const format_tag_t BAc8a8b = dnnl_BAc8a8b; const format_tag_t BAc16a16b = dnnl_BAc16a16b; +const format_tag_t BAcd8a8b = dnnl_BAcd8a8b; const format_tag_t BAcd16a16b = dnnl_BAcd16a16b; const format_tag_t ABc32a16b = dnnl_ABc32a16b; const format_tag_t ABcd32a16b = dnnl_ABcd32a16b; @@ -658,6 +694,7 @@ const format_tag_t ABc40a16b = dnnl_ABc40a16b; const format_tag_t ABcd40a16b = dnnl_ABcd40a16b; const format_tag_t ABcde40a16b = dnnl_ABcde40a16b; const format_tag_t ABc32a32b = dnnl_ABc32a32b; +const format_tag_t BAcde8a8b = dnnl_BAcde8a8b; const format_tag_t BAcde16a16b = dnnl_BAcde16a16b; const format_tag_t ABcd32a32b = dnnl_ABcd32a32b; const format_tag_t ABcde32a32b = dnnl_ABcde32a32b; @@ -666,6 +703,8 @@ const format_tag_t ABcd40a32b = dnnl_ABcd40a32b; const format_tag_t ABcde40a32b = dnnl_ABcde40a32b; const format_tag_t BAcde16b16a = dnnl_BAcde16b16a; const format_tag_t aBdec32b = dnnl_aBdec32b; +const format_tag_t Abcdef4a = dnnl_Abcdef4a; +const format_tag_t Abcdef8a = dnnl_Abcdef8a; const format_tag_t Abcdef16a = dnnl_Abcdef16a; const format_tag_t Abcdef32a = dnnl_Abcdef32a; const format_tag_t Acdb32a = dnnl_Acdb32a; @@ -689,6 +728,7 @@ const format_tag_t AB32a32b8a2b = dnnl_AB32a32b8a2b; const format_tag_t AB8a2b = dnnl_AB8a2b; const format_tag_t abDc16d = dnnl_abDc16d; const format_tag_t abDc32d = dnnl_abDc32d; +const format_tag_t abDC16d4c = dnnl_abDC16d4c; const format_tag_t abDC32d4c = dnnl_abDC32d4c; const format_tag_t abCd4c = dnnl_abCd4c; const format_tag_t abCde4c = dnnl_abCde4c; @@ -698,6 +738,7 @@ const format_tag_t abCde32c = dnnl_abCde32c; const format_tag_t abCdef32c = dnnl_abCdef32c; const format_tag_t abdEc16e = dnnl_abdEc16e; const format_tag_t abdEc32e = dnnl_abdEc32e; +const format_tag_t abdEC16e4c = dnnl_abdEC16e4c; const format_tag_t abdEC32e2c = dnnl_abdEC32e2c; const format_tag_t abdEC32e4c = dnnl_abdEC32e4c; const format_tag_t abdEC64e2c = dnnl_abdEC64e2c; @@ -1163,7 +1204,10 @@ const format_tag_t IOhw16i16o = dnnl_IOhw16i16o; const format_tag_t Ohwi32o = dnnl_Ohwi32o; const format_tag_t gIOhw16i16o = dnnl_gIOhw16i16o; const format_tag_t gOhwi32o = dnnl_gOhwi32o; +const format_tag_t Goidhw4g = dnnl_Goidhw4g; +const format_tag_t Goidhw8g = dnnl_Goidhw8g; const format_tag_t Goidhw16g = dnnl_Goidhw16g; +const format_tag_t IOw8o8i = dnnl_IOw8o8i; const format_tag_t IOw16o16i = dnnl_IOw16o16i; const format_tag_t IOw16i16o = dnnl_IOw16i16o; const format_tag_t gIOw16i16o = dnnl_gIOw16i16o; @@ -1219,7 +1263,9 @@ const format_tag_t Owi4o = dnnl_Owi4o; const format_tag_t Owi8o = dnnl_Owi8o; const format_tag_t OwI8o2i = dnnl_OwI8o2i; const format_tag_t OwI8o4i = dnnl_OwI8o4i; +const format_tag_t IOdhw8o8i = dnnl_IOdhw8o8i; const format_tag_t IOdhw16o16i = dnnl_IOdhw16o16i; +const format_tag_t IOhw8o8i = dnnl_IOhw8o8i; const format_tag_t IOhw16o16i = dnnl_IOhw16o16i; const format_tag_t Ohwi16o = dnnl_Ohwi16o; const format_tag_t OhwI16o2i = dnnl_OhwI16o2i; @@ -1272,6 +1318,8 @@ const format_tag_t OhwI8i8o = dnnl_OhwI8i8o; const format_tag_t OIhw8o16i2o = dnnl_OIhw8o16i2o; const format_tag_t IOhw8o16i2o = dnnl_IOhw8o16i2o; const format_tag_t OIhw8o8i = dnnl_OIhw8o8i; +const format_tag_t OIhw8o32i = dnnl_OIhw8o32i; +const format_tag_t OIhw16o32i = dnnl_OIhw16o32i; const format_tag_t OIhw8o4i = dnnl_OIhw8o4i; const format_tag_t Owhi16o = dnnl_Owhi16o; const format_tag_t Odwhi16o = dnnl_Odwhi16o; @@ -1327,6 +1375,7 @@ const format_tag_t OIdhw8i8o = dnnl_OIdhw8i8o; const format_tag_t OdhwI8i8o = dnnl_OdhwI8i8o; const format_tag_t OIdhw8o8i = dnnl_OIdhw8o8i; const format_tag_t OIdhw8o4i = dnnl_OIdhw8o4i; +const format_tag_t gIOw8o8i = dnnl_gIOw8o8i; const format_tag_t gIOw16o16i = dnnl_gIOw16o16i; const format_tag_t Goiw16g = dnnl_Goiw16g; const format_tag_t Goiw8g = dnnl_Goiw8g; @@ -1355,7 +1404,9 @@ const format_tag_t gOwi4o = dnnl_gOwi4o; const format_tag_t gOwi8o = dnnl_gOwi8o; const format_tag_t gOwI8o2i = dnnl_gOwI8o2i; const format_tag_t gOwI8o4i = dnnl_gOwI8o4i; +const format_tag_t gIOdhw8o8i = dnnl_gIOdhw8o8i; const format_tag_t gIOdhw16o16i = dnnl_gIOdhw16o16i; +const format_tag_t gIOhw8o8i = dnnl_gIOhw8o8i; const format_tag_t gIOhw16o16i = dnnl_gIOhw16o16i; const format_tag_t gOhwi16o = dnnl_gOhwi16o; const format_tag_t gOhwI16o2i = dnnl_gOhwI16o2i; @@ -1454,10 +1505,12 @@ const format_tag_t gOIhw4o8i2o = dnnl_gOIhw4o8i2o; const format_tag_t gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o; const format_tag_t ldOi16o = dnnl_ldOi16o; const format_tag_t ldOi32o = dnnl_ldOi32o; +const format_tag_t ldOI16o4i = dnnl_ldOI16o4i; const format_tag_t ldOI32o4i = dnnl_ldOI32o4i; const format_tag_t ldIo32i = dnnl_ldIo32i; const format_tag_t ldgOi16o = dnnl_ldgOi16o; const format_tag_t ldgOi32o = dnnl_ldgOi32o; +const format_tag_t ldgOI16o4i = dnnl_ldgOI16o4i; const format_tag_t ldgOI32o2i = dnnl_ldgOI32o2i; const format_tag_t ldgOI32o4i = dnnl_ldgOI32o4i; const format_tag_t ldgOI64o2i = dnnl_ldgOI64o2i; @@ -1894,6 +1947,15 @@ const rnn_flags_t diff_weights_overwrite = dnnl_rnn_flags_diff_weights_overwrite; } // namespace rnn_flags +using sparse_encoding_t = dnnl_sparse_encoding_t; +namespace sparse_encoding { +const sparse_encoding_t undef = dnnl_sparse_encoding_undef; +const sparse_encoding_t any = dnnl_sparse_encoding_any; +const sparse_encoding_t packed = dnnl_sparse_encoding_packed; +const sparse_encoding_t csr = dnnl_sparse_encoding_csr; +const sparse_encoding_t coo = dnnl_sparse_encoding_coo; +} // namespace sparse_encoding + using engine_kind_t = dnnl_engine_kind_t; namespace engine_kind { const engine_kind_t any_engine = dnnl_any_engine; @@ -1906,6 +1968,7 @@ enum runtime_kind_t { dnnl_runtime_seq, dnnl_runtime_omp, dnnl_runtime_tbb, + dnnl_runtime_tbb_auto, dnnl_runtime_threadpool, dnnl_runtime_ocl, dnnl_runtime_sycl, @@ -1916,6 +1979,7 @@ const runtime_kind_t none = dnnl_runtime_none; const runtime_kind_t seq = dnnl_runtime_seq; const runtime_kind_t omp = dnnl_runtime_omp; const runtime_kind_t tbb = dnnl_runtime_tbb; +const runtime_kind_t tbb_auto = dnnl_runtime_tbb_auto; const runtime_kind_t threadpool = dnnl_runtime_threadpool; const runtime_kind_t ocl = dnnl_runtime_ocl; const runtime_kind_t sycl = dnnl_runtime_sycl; @@ -1945,6 +2009,9 @@ const primitive_kind_t reduction = dnnl_reduction; const primitive_kind_t softmax = dnnl_softmax; const primitive_kind_t layer_normalization = dnnl_layer_normalization; const primitive_kind_t group_normalization = dnnl_group_normalization; +const primitive_kind_t depthwise = dnnl_depthwise; +const primitive_kind_t quantization = dnnl_quantization; +const primitive_kind_t binarization = dnnl_binarization; // Internal only primitive kinds. const primitive_kind_t internal_only_start = (primitive_kind_t)(1 << 12); @@ -2025,17 +2092,26 @@ const query_t sparse_encoding = dnnl_query_sparse_encoding; const query_t nnz_s64 = dnnl_query_nnz_s64; const query_t num_handles_s32 = dnnl_query_num_handles_s32; #else -const query_t sparse_encoding = static_cast(266); -const query_t nnz_s64 = static_cast(267); -const query_t num_handles_s32 = static_cast(268); +// const query_t sparse_encoding = static_cast(266); +// const query_t nnz_s64 = static_cast(267); +// const query_t num_handles_s32 = static_cast(268); #endif // Internal only query kinds. const query_t internal_only_start = (query_t)(1 << 12); const query_t zero_pad_d = internal_only_start; const query_t preferred_gpu_threads_per_eu = (query_t)(internal_only_start + 1); +const query_t sparse_encoding = dnnl_query_sparse_encoding; } // namespace query +// There are no external values to map to because this is an internal feature +// for now. +using matmul_reduce_kind_t = int; +namespace matmul_reduce_kind { +const matmul_reduce_kind_t undef = 0; +const matmul_reduce_kind_t src = 1; +} // namespace matmul_reduce_kind + using rnn_direction_t = dnnl_rnn_direction_t; using engine_t = dnnl_engine; diff --git a/src/common/cache_blob_id.cpp b/src/common/cache_blob_id.cpp index aedcc393bfd..b6b9de9c553 100644 --- a/src/common/cache_blob_id.cpp +++ b/src/common/cache_blob_id.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,8 @@ #include "common/dnnl_thread.hpp" #include "common/engine.hpp" #include "common/primitive_desc.hpp" +#include "common/primitive_serialization.hpp" #include "common/serialization.hpp" -#include "common/serialization_stream.hpp" namespace dnnl { namespace impl { @@ -38,45 +38,43 @@ const std::vector &cache_blob_id_t::get( return sstream_.get_data(); } - if (pd->op_desc()->kind == primitive_kind::zero_pad) { - return sstream_.get_data(); - } + if (pd->kind() == primitive_kind::zero_pad) { return sstream_.get_data(); } assert(engine->kind() == engine_kind::gpu && engine->runtime_kind() == runtime_kind::ocl); const auto init_id = [&]() { - serialization::serialize_desc(sstream_, pd->op_desc()); - serialization::serialize_attr(sstream_, *pd->attr()); + serialize_desc(sstream_, pd->op_desc()); + serialize(sstream_, *pd->attr()); const int nthr = engine->kind() == engine_kind::gpu ? 0 : dnnl_get_max_threads(); - sstream_.write(&nthr); + sstream_.append(nthr); for (const auto &md : pd->hint_mds(false /* is_hint */)) { - serialization::serialize_md(sstream_, md); + serialize(sstream_, md); } - sstream_.write(&engine_kind); + sstream_.append(engine_kind); // TODO: blob object can probably be re-used for different runtimes // if the engine kind is the same. Check this assumption when extending // this API to DPCPP runtime. - sstream_.write(&runtime_kind); + sstream_.append(runtime_kind); engine->serialize_device(sstream_); auto pd_iterator_offset = pd->pd_iterator_offset(); - sstream_.write(&pd_iterator_offset); + sstream_.append(pd_iterator_offset); auto pd_skip_idx = pd->skip_idx(); - sstream_.write(&pd_skip_idx); + sstream_.append(pd_skip_idx); auto version = dnnl_version(); - sstream_.write(&version->major); - sstream_.write(&version->minor); - sstream_.write(&version->patch); + sstream_.append(version->major); + sstream_.append(version->minor); + sstream_.append(version->patch); - sstream_.write(version->hash, std::strlen(version->hash)); + sstream_.append_array(std::strlen(version->hash), version->hash); is_initialized_ = true; }; diff --git a/src/common/cache_blob_id.hpp b/src/common/cache_blob_id.hpp index 53c0f002709..46eadf217da 100644 --- a/src/common/cache_blob_id.hpp +++ b/src/common/cache_blob_id.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include #include -#include "common/serialization_stream.hpp" +#include "common/serialization.hpp" namespace dnnl { namespace impl { diff --git a/src/common/compiler_workarounds.hpp b/src/common/compiler_workarounds.hpp index 17beeb84b72..bedbd8f82d4 100644 --- a/src/common/compiler_workarounds.hpp +++ b/src/common/compiler_workarounds.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,11 +17,6 @@ #ifndef COMPILER_WORKAROUNDS_HPP #define COMPILER_WORKAROUNDS_HPP -#if (defined __GNUC__) && (!defined(__INTEL_COMPILER)) \ - && (!defined(__INTEL_LLVM_COMPILER)) && (!defined(__clang__major__)) -#define NEED_GCC_WA_CHECK 1 -#endif - // Workaround 01: clang. // // Clang has an issue [1] with `#pragma omp simd` that might lead to segfault. @@ -32,7 +27,7 @@ // vectorization for clang altogether for now. // // [1] https://bugs.llvm.org/show_bug.cgi?id=48104 -#if (defined __clang_major__) && (__clang_major__ >= 6) +#if (defined __clang_major__) && (__clang_major__ < 13) #define CLANG_WA_01_SAFE_TO_USE_OMP_SIMD 0 #else #define CLANG_WA_01_SAFE_TO_USE_OMP_SIMD 1 @@ -40,48 +35,15 @@ // Workaround 02: clang. // -// Clang 6+ generates incorrect code with OMP_SIMD in some particular cases. +// Clang generates incorrect code with OMP_SIMD in some particular cases. // Unlike CLANG_WA_01_SAFE_TO_USE_OMP_SIMD, the issue happens even with -O3. -#if (defined __clang_major__) && (__clang_major__ >= 6) +#if (defined __clang_major__) && (__clang_major__ < 13) #define CLANG_WA_02_SAFE_TO_USE_OMP_SIMD 0 #else #define CLANG_WA_02_SAFE_TO_USE_OMP_SIMD 1 #endif -// Workaround 03: GCC -// -// For very large functions with too much control flow (i.e. if, switch, goto -// statements), GCC 7 may struggle to perform optimizations based on tree -// dominator (i.e. -ftree-dominator-opts, which is enabled with O1), thereby -// producing an internal compiler error (ICE). Specifically, it seems that the -// jump threading optimization is the culprit, which cannot be disabled on its -// own. There is no reliable way to reproduce the ICE, therefore it is not clear -// which __GCC_MINOR__ version fixes issue. -#if (defined NEED_GCC_WA_CHECK) && (__GNUC__ == 7) -#define GCC_WA_NO_TREE_DOMINATOR_OPTS 1 -#else -#define GCC_WA_NO_TREE_DOMINATOR_OPTS 0 -#endif - -// Workaround 04: GCC -// -// GCC 10 & 11 && 12 (at least versiona 10.1, 10.3 & 11.1, 12.2) report false positives -// in xbyak when -Warray-bounds build setting is on -#if (defined NEED_GCC_WA_CHECK) && (__GNUC__ >= 10) -#pragma GCC diagnostic ignored "-Warray-bounds" -#endif - -// Workaround 05: GCC -// -// NOTE: inside lambda, type cast variables captured by reference using -// either c-like "(type)var" or functional "type(var)" notation in order -// to avoid gcc7 bug with c++14 standard -// (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=83204). -#if (defined NEED_GCC_WA_CHECK) && (__GNUC__ <= 7) -#define GCC_WA_LAMBDA_C_CAST -#endif - -// Workaround 05: c++17 vs c++20 +// Workaround 03: MSVC c++17 vs c++20 // // C++17/20 are contradictory wrt capturing this and using default '=' capture. // - C++17 and before have to return a warning for the [=, this] capture as diff --git a/src/common/concat.cpp b/src/common/concat.cpp index d686df416f8..df4c65bc00b 100644 --- a/src/common/concat.cpp +++ b/src/common/concat.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,13 +51,25 @@ status_t concat_primitive_desc_create(std::shared_ptr &pd, attr = &default_attr(); else { using smask_t = primitive_attr_t::skip_mask_t; - VCHECK_CONCAT_UNIMPL(attr->has_default_values(smask_t::scales_runtime), + VCHECK_CONCAT_UNIMPL(attr->has_default_values(smask_t::scales), VERBOSE_UNSUPPORTED_ATTR); const auto &scales = attr->scales_; - if (!scales.has_default_values()) - for (const auto &s : scales.scales_) - VCHECK_CONCAT_UNIMPL( - s.second.mask_ == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + if (!scales.has_default_values()) { + std::vector supported_args(n); + for (int i = 0; i < n; i++) { + supported_args[i] = DNNL_ARG_MULTIPLE_SRC + i; + } + VCHECK_CONCAT_UNIMPL( + attr->scales_.has_default_values(supported_args), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + for (int arg : supported_args) { + if (scales.has_default_values(arg)) continue; + + int mask = scales.get_mask(arg); + VCHECK_CONCAT_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } + } } const int ndims = src_mds[0]->ndims; diff --git a/src/common/concat_pd.hpp b/src/common/concat_pd.hpp index 26925bbb718..615ee62737f 100644 --- a/src/common/concat_pd.hpp +++ b/src/common/concat_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,13 +37,14 @@ namespace dnnl { namespace impl { +// NOLINTBEGIN(google-default-arguments) struct concat_pd_t : public primitive_desc_t { const concat_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { return reinterpret_cast(this->desc()); } - ~concat_pd_t() = default; + ~concat_pd_t() override = default; arg_usage_t arg_usage(int arg) const override { if (arg >= DNNL_ARG_MULTIPLE_SRC @@ -95,7 +96,6 @@ struct concat_pd_t : public primitive_desc_t { * use this auxiliary array iff init() returned success */ std::vector src_image_mds_; -protected: concat_desc_t desc_; concat_pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, @@ -112,14 +112,14 @@ struct concat_pd_t : public primitive_desc_t { init_desc(); } - concat_pd_t(const concat_pd_t &other) : primitive_desc_t(other) { - n_ = other.n_; - concat_dim_ = other.concat_dim_; - dst_md_ = other.dst_md_; - original_dst_ = other.original_dst_; - src_mds_ = other.src_mds_; - src_image_mds_ = other.src_image_mds_; - + concat_pd_t(const concat_pd_t &other) + : primitive_desc_t(other) + , n_(other.n_) + , concat_dim_(other.concat_dim_) + , dst_md_(other.dst_md_) + , original_dst_(other.original_dst_) + , src_mds_(other.src_mds_) + , src_image_mds_(other.src_image_mds_) { init_desc(); } @@ -266,6 +266,7 @@ struct concat_pd_t : public primitive_desc_t { desc_.src_mds.push_back(&md); } }; +// NOLINTEND(google-default-arguments) #define DECLARE_CONCAT_PD_t(impl_name, ...) \ static status_t create(concat_pd_t **concat_pd, \ @@ -284,6 +285,7 @@ struct concat_pd_t : public primitive_desc_t { &primitive, \ dnnl::impl::engine_t *engine, const cache_blob_t &cache_blob) \ const override { \ + DNNL_PRIMITIVE_CREATE(pd_t) \ return primitive_t::create_primitive_common<__VA_ARGS__, pd_t>( \ primitive, this, engine, false, cache_blob); \ } \ diff --git a/src/common/convolution.cpp b/src/common/convolution.cpp index 9300043adcc..93c55bc64b2 100644 --- a/src/common/convolution.cpp +++ b/src/common/convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -116,6 +116,18 @@ status_t conv_desc_init(convolution_desc_t *conv_desc, prop_kind_t prop_kind, VERBOSE_INCONSISTENT_DIM, "src", 1, "weights", with_groups + 1); VCHECK_CONV(dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0], VERBOSE_INCONSISTENT_DIM, "dst", 1, "weights", with_groups + 0); + // s4/u4/f4 weights requires channels to be multiple of 2 to be byte aligned + VCHECK_CONV(IMPLICATION(utils::one_of(weights_desc->data_type, + data_type::s4, data_type::u4, + data_type::f4_e2m1, data_type::f4_e3m0), + weights_desc->dims[with_groups + 1] % 2 == 0), + VERBOSE_INCONSISTENT_DIM, "weights", with_groups + 1); + // s4/u4/f4 src requires channels to be multiple of 2 to be byte aligned + VCHECK_CONV(IMPLICATION(utils::one_of(src_desc->data_type, data_type::s4, + data_type::u4, data_type::f4_e2m1, + data_type::f4_e3m0), + src_desc->dims[1] % 2 == 0), + VERBOSE_INCONSISTENT_DIM, "src", 1); int sp_dims = src_desc->ndims - 2; utils::array_copy(cd.strides, strides, sp_dims); @@ -136,7 +148,8 @@ status_t conv_desc_init(convolution_desc_t *conv_desc, prop_kind_t prop_kind, dim_t dst = dst_desc->dims[i]; dim_t ker_range = 1 + (ker - 1) * (dil + 1); VCHECK_CONV(str > 0, VERBOSE_BAD_DIM, "strides", i - 2); - VCHECK_CONV(dil >= 0 && pad_l >= 0 && pad_r + str > 0, + //VCHECK_CONV(dil >= 0 && pad_l >= 0 && pad_r + str > 0, // TODO: [dmitrygo] Commented as WA to support dw conv fusing + VCHECK_CONV(dil >= 0 && pad_l >= 0, VERBOSE_INCONSISTENT_PRB); VCHECK_CONV((src - ker_range + pad_l + pad_r) / str + 1 == dst, VERBOSE_INCONSISTENT_PRB); @@ -159,18 +172,26 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine, const data_type_t src_dt = desc.src_desc.data_type; const data_type_t dst_dt = desc.dst_desc.data_type; - auto fwd_attr_mask - = smask_t::post_ops | smask_t::sum_dt | smask_t::fpmath_mode; - - bool is_int8 = utils::one_of(src_dt, data_type::s8, data_type::u8); - if (engine->kind() == engine_kind::gpu) - is_int8 = is_int8 - || utils::one_of(dst_dt, data_type::s8, data_type::u8, - data_type::s32); - if (is_int8) - fwd_attr_mask |= smask_t::scales_runtime - | smask_t::zero_points_runtime - | smask_t::zero_points_runtime_data_type; + auto fwd_attr_mask = smask_t::post_ops | smask_t::sum_dt + | smask_t::fpmath_mode | smask_t::rounding_mode; + const bool is_gpu = engine->kind() == engine_kind::gpu; + + const bool is_int8 = utils::one_of(src_dt, data_type::s8, data_type::u8) + || (is_gpu + && utils::one_of(dst_dt, data_type::s8, data_type::u8, + data_type::s32)); + const bool is_fp8 = is_gpu + && (utils::one_of( + src_dt, data_type::f8_e5m2, data_type::f8_e4m3) + || utils::one_of(dst_dt, data_type::f8_e5m2, + data_type::f8_e4m3)); + const bool enable_quantization = is_int8 || is_fp8; + if (enable_quantization) + fwd_attr_mask |= smask_t::zero_points_data_type + | smask_t::scales_data_type + | smask_t::input_zero_points + | smask_t::output_compensations + | smask_t::weights_zero_points; VCHECK_CONV_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); @@ -178,27 +199,37 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine, // Check scales if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; const bool with_groups = desc.src_desc.ndims != desc.weights_desc.ndims; - VCHECK_CONV_UNIMPL(utils::everyone_is(0, mask_src, mask_dst) - && utils::one_of(mask_wei, 0, with_groups ? 3 : 1), + VCHECK_CONV_UNIMPL(IMPLICATION(!sc.has_default_values(DNNL_ARG_SRC), + sc.get_mask(DNNL_ARG_SRC) == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_CONV_UNIMPL( + IMPLICATION(!sc.has_default_values(DNNL_ARG_WEIGHTS), + utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0, + with_groups ? 3 : 1)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_CONV_UNIMPL( + IMPLICATION(!sc.has_default_values(DNNL_ARG_DST), + utils::one_of(sc.get_mask(DNNL_ARG_DST), 0, 2)), VERBOSE_UNSUPPORTED_SCALES_CFG); } // Check zero points if (!attr->zero_points_.has_default_values()) { const auto &zp = attr->zero_points_; - int mask_src = 0, mask_wei = 0, mask_dst = 0; - zp.get(DNNL_ARG_SRC, &mask_src); - zp.get(DNNL_ARG_WEIGHTS, &mask_wei); - zp.get(DNNL_ARG_DST, &mask_dst); - - VCHECK_CONV_UNIMPL((mask_src == 0 || mask_src == 1 << 1) - && (mask_wei == 0) - && (mask_dst == 0 || mask_dst == 1 << 1), + + VCHECK_CONV_UNIMPL(IMPLICATION(!zp.has_default_values(DNNL_ARG_SRC), + utils::one_of(zp.get_mask(DNNL_ARG_SRC), + 0, 1 << 1)), + VERBOSE_UNSUPPORTED_ZP_CFG); + VCHECK_CONV_UNIMPL( + IMPLICATION(!zp.has_default_values(DNNL_ARG_WEIGHTS), + zp.get_mask(DNNL_ARG_WEIGHTS) == 0), + VERBOSE_UNSUPPORTED_ZP_CFG); + VCHECK_CONV_UNIMPL(IMPLICATION(!zp.has_default_values(DNNL_ARG_DST), + utils::one_of(zp.get_mask(DNNL_ARG_DST), + 0, 1 << 1)), VERBOSE_UNSUPPORTED_ZP_CFG); } @@ -207,17 +238,20 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine, const auto &po = attr->post_ops_; using namespace primitive_kind; VCHECK_CONV_UNIMPL(po.has_default_values({binary, eltwise, prelu, - sum, convolution}), + sum, convolution, depthwise, quantization}), VERBOSE_UNSUPPORTED_POSTOP); // Check sum VCHECK_CONV_UNIMPL(po.check_sum_consistency(dst_dt, is_int8, true), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } - } else { - auto bwd_attr_mask = smask_t::fpmath_mode; - VCHECK_CONV_UNIMPL(attr->has_default_values(bwd_attr_mask), - VERBOSE_UNSUPPORTED_ATTR); + // } else { + // auto bwd_attr_mask = smask_t::fpmath_mode | smask_t::accumulation_mode; + // VCHECK_CONV_UNIMPL(attr->has_default_values(bwd_attr_mask), + // VERBOSE_UNSUPPORTED_ATTR); } return status::success; diff --git a/src/common/convolution_pd.hpp b/src/common/convolution_pd.hpp index 123f15d6e95..f85d5cd8ae6 100644 --- a/src/common/convolution_pd.hpp +++ b/src/common/convolution_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -182,17 +182,16 @@ struct convolution_pd_t : public primitive_desc_t { convolution_desc_t desc_; const convolution_fwd_pd_t *hint_fwd_pd_; - convolution_pd_t(const convolution_desc_t *adesc, - const primitive_attr_t *attr, + convolution_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const convolution_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) {} bool set_default_formats_common_template(memory_desc_t &src_md, format_tag_t src_tag, memory_desc_t &wei_md, format_tag_t wei_tag, memory_desc_t &dst_md, format_tag_t dst_tag, - memory_desc_t &bia_md) { + memory_desc_t &bia_md) const { using namespace format_tag; #define IS_OK(f) \ @@ -243,9 +242,13 @@ struct convolution_pd_t : public primitive_desc_t { = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { bool ok = attr()->scales_.has_default_values(supported_args); for (int arg : supported_args) { - const auto &mask = attr()->scales_.get(arg).mask_; + if (attr()->scales_.has_default_values(arg)) continue; + + const auto &mask = attr()->scales_.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS) ok = ok && (mask == 0 || mask == (with_groups() ? 3 : 1)); + else if (arg == DNNL_ARG_DST) + ok = ok && (mask == 0 || mask == 2); else ok = ok && (mask == 0); } @@ -253,15 +256,17 @@ struct convolution_pd_t : public primitive_desc_t { } }; +// NOLINTBEGIN(google-default-arguments) struct convolution_fwd_pd_t : public convolution_pd_t { - typedef convolution_fwd_pd_t base_class; - typedef convolution_fwd_pd_t hint_class; + using base_class = convolution_fwd_pd_t; + using hint_class = convolution_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS)) return arg_usage_t::input; - if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; + if (arg == DNNL_ARG_BIAS) + return with_bias() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_DST) return arg_usage_t::output; @@ -299,7 +304,7 @@ struct convolution_fwd_pd_t : public convolution_pd_t { int n_inputs() const override { return 2 + with_bias() + attr_post_op_dw_inputs() + n_binary_po_inputs() - + n_prelu_po_inputs(); + + n_prelu_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs(); } int n_outputs() const override { return 1; } @@ -310,8 +315,7 @@ struct convolution_fwd_pd_t : public convolution_pd_t { memory_desc_t bias_md_; memory_desc_t dst_md_; - convolution_fwd_pd_t(const convolution_desc_t *adesc, - const primitive_attr_t *attr, + convolution_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const convolution_fwd_pd_t *hint_fwd_pd) : convolution_pd_t(adesc, attr, hint_fwd_pd) , src_md_(desc_.src_desc) @@ -329,14 +333,15 @@ struct convolution_fwd_pd_t : public convolution_pd_t { const auto &po = attr_.post_ops_; int conv = po.find(primitive_kind::convolution); if (conv == -1) return 0; - return po.entry_[conv].depthwise_conv.bias_dt == data_type::undef ? 1 - : 2; + return 2; } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct convolution_bwd_data_pd_t : public convolution_pd_t { - typedef convolution_bwd_data_pd_t base_class; - typedef convolution_fwd_pd_t hint_class; + using base_class = convolution_bwd_data_pd_t; + using hint_class = convolution_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_WEIGHTS, DNNL_ARG_DIFF_DST)) @@ -378,7 +383,9 @@ struct convolution_bwd_data_pd_t : public convolution_pd_t { return &glob_zero_md; } - int n_inputs() const override { return 2 + with_bias(); } + int n_inputs() const override { + return 2 + with_bias() + n_depthwise_po_inputs() + n_quantization_po_inputs(); + } int n_outputs() const override { return 1; } virtual bool support_bias() const { return false; } @@ -389,7 +396,7 @@ struct convolution_bwd_data_pd_t : public convolution_pd_t { memory_desc_t bias_md_; memory_desc_t diff_dst_md_; - convolution_bwd_data_pd_t(const convolution_desc_t *adesc, + convolution_bwd_data_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const convolution_fwd_pd_t *hint_fwd_pd) : convolution_pd_t(adesc, attr, hint_fwd_pd) @@ -404,12 +411,14 @@ struct convolution_bwd_data_pd_t : public convolution_pd_t { weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct convolution_bwd_weights_pd_t : public convolution_pd_t { - typedef convolution_bwd_weights_pd_t base_class; - typedef convolution_fwd_pd_t hint_class; + using base_class = convolution_bwd_weights_pd_t; + using hint_class = convolution_fwd_pd_t; - convolution_bwd_weights_pd_t(const convolution_desc_t *adesc, + convolution_bwd_weights_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const convolution_fwd_pd_t *hint_fwd_pd) : convolution_pd_t(adesc, attr, hint_fwd_pd) @@ -424,8 +433,8 @@ struct convolution_bwd_weights_pd_t : public convolution_pd_t { if (arg == DNNL_ARG_DIFF_WEIGHTS) return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_BIAS && with_bias()) - return arg_usage_t::output; + if (arg == DNNL_ARG_DIFF_BIAS) + return with_bias() ? arg_usage_t::output : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -477,6 +486,7 @@ struct convolution_bwd_weights_pd_t : public convolution_pd_t { diff_bias_md_); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/deconvolution.cpp b/src/common/deconvolution.cpp index 00f3f89d037..54352dbcdd3 100644 --- a/src/common/deconvolution.cpp +++ b/src/common/deconvolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -131,8 +131,12 @@ status_t deconv_desc_init(deconvolution_desc_t *deconv_desc, dim_t ker_range = 1 + (ker - 1) * (dil + 1); VCHECK_DECONV(str > 0, VERBOSE_BAD_DIM, "strides", i - 2); - VCHECK_DECONV(dil >= 0 && pad_l >= 0 && pad_r + str > 0, - VERBOSE_INCONSISTENT_PRB); + // VCHECK_DECONV(dil >= 0 && pad_l >= 0 && pad_r + str > 0, + // VERBOSE_INCONSISTENT_PRB); + //WA: OV has feature to set output shape, which would cause specified output space dims are larger than deconv actural space dims. + // Need to extra padding on the space dims. pad_r < 0 && pad_r + str <=0 in these test cases. + VCHECK_DECONV(dil >= 0 && pad_l >= 0, + VERBOSE_INCONSISTENT_PRB); VCHECK_DECONV((dst - ker_range + pad_l + pad_r) / str + 1 == src, VERBOSE_INCONSISTENT_PRB); } @@ -162,9 +166,7 @@ status_t deconv_attr_check(const deconvolution_desc_t &desc, is_int8 = is_int8 || utils::one_of(dst_dt, data_type::s8, data_type::u8, data_type::s32); - if (is_int8) - fwd_attr_mask - |= smask_t::scales_runtime | smask_t::zero_points_runtime; + if (is_int8) fwd_attr_mask |= smask_t::scales | smask_t::zero_points; VCHECK_DECONV_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); @@ -172,26 +174,38 @@ status_t deconv_attr_check(const deconvolution_desc_t &desc, // Check scales if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; const bool with_groups = desc.src_desc.ndims != desc.weights_desc.ndims; - VCHECK_DECONV_UNIMPL(utils::everyone_is(0, mask_src, mask_dst) - && utils::one_of(mask_wei, 0, with_groups ? 3 : 1), + VCHECK_DECONV_UNIMPL( + IMPLICATION(!sc.has_default_values(DNNL_ARG_SRC), + sc.get_mask(DNNL_ARG_SRC) == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_DECONV_UNIMPL( + IMPLICATION(!sc.has_default_values(DNNL_ARG_WEIGHTS), + utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0, + with_groups ? 3 : 1)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_DECONV_UNIMPL( + IMPLICATION(!sc.has_default_values(DNNL_ARG_DST), + sc.get_mask(DNNL_ARG_DST) == 0), VERBOSE_UNSUPPORTED_SCALES_CFG); } // Check zero points if (!attr->zero_points_.has_default_values()) { const auto &zp = attr->zero_points_; - int mask_src = 0, mask_dst = 0; - zp.get(DNNL_ARG_SRC, &mask_src); - zp.get(DNNL_ARG_DST, &mask_dst); - VCHECK_DECONV_UNIMPL(zp.has_default_values(DNNL_ARG_WEIGHTS) - && (mask_src == 0 || mask_src == 1 << 1) - && (mask_dst == 0 || mask_dst == 1 << 1), + VCHECK_DECONV_UNIMPL( + IMPLICATION(!zp.has_default_values(DNNL_ARG_SRC), + utils::one_of( + zp.get_mask(DNNL_ARG_SRC), 0, 1 << 1)), + VERBOSE_UNSUPPORTED_ZP_CFG); + VCHECK_DECONV_UNIMPL(zp.has_default_values(DNNL_ARG_WEIGHTS), + VERBOSE_UNSUPPORTED_ZP_CFG); + VCHECK_DECONV_UNIMPL( + IMPLICATION(!zp.has_default_values(DNNL_ARG_DST), + utils::one_of( + zp.get_mask(DNNL_ARG_DST), 0, 1 << 1)), VERBOSE_UNSUPPORTED_ZP_CFG); } @@ -207,6 +221,9 @@ status_t deconv_attr_check(const deconvolution_desc_t &desc, VCHECK_DECONV_UNIMPL( po.check_sum_consistency(dst_dt, is_int8, true), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } } else { auto bwd_attr_mask = smask_t::fpmath_mode; diff --git a/src/common/deconvolution_pd.hpp b/src/common/deconvolution_pd.hpp index bcc372384ac..62aa5c3bdf5 100644 --- a/src/common/deconvolution_pd.hpp +++ b/src/common/deconvolution_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -163,18 +163,19 @@ struct deconvolution_pd_t : public primitive_desc_t { deconvolution_desc_t desc_; const deconvolution_fwd_pd_t *hint_fwd_pd_; - deconvolution_pd_t(const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, + deconvolution_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const deconvolution_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) {} bool attr_scales_ok(const std::vector &supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { bool ok = attr()->scales_.has_default_values(supported_args); for (int arg : supported_args) { - const auto &mask = attr()->scales_.get(arg).mask_; + if (attr()->scales_.has_default_values(arg)) continue; + + const auto &mask = attr()->scales_.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS) ok = ok && (mask == 0 || mask == (with_groups() ? 3 : 1)); else @@ -200,15 +201,17 @@ struct deconvolution_pd_t : public primitive_desc_t { } }; +// NOLINTBEGIN(google-default-arguments) struct deconvolution_fwd_pd_t : public deconvolution_pd_t { - typedef deconvolution_fwd_pd_t base_class; - typedef deconvolution_fwd_pd_t hint_class; + using base_class = deconvolution_fwd_pd_t; + using hint_class = deconvolution_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS)) return arg_usage_t::input; - if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; + if (arg == DNNL_ARG_BIAS) + return with_bias() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_DST) return arg_usage_t::output; @@ -245,7 +248,7 @@ struct deconvolution_fwd_pd_t : public deconvolution_pd_t { } int n_inputs() const override { - return 2 + with_bias() + n_prelu_po_inputs() + n_binary_po_inputs(); + return 2 + with_bias() + n_prelu_po_inputs() + n_binary_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs(); } int n_outputs() const override { return 1; } @@ -255,8 +258,7 @@ struct deconvolution_fwd_pd_t : public deconvolution_pd_t { memory_desc_t bias_md_; memory_desc_t dst_md_; - deconvolution_fwd_pd_t(const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, + deconvolution_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const deconvolution_fwd_pd_t *hint_fwd_pd) : deconvolution_pd_t(adesc, attr, hint_fwd_pd) , src_md_(desc_.src_desc) @@ -264,10 +266,12 @@ struct deconvolution_fwd_pd_t : public deconvolution_pd_t { , bias_md_(desc_.bias_desc) , dst_md_(desc_.dst_desc) {} }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct deconvolution_bwd_data_pd_t : public deconvolution_pd_t { - typedef deconvolution_bwd_data_pd_t base_class; - typedef deconvolution_fwd_pd_t hint_class; + using base_class = deconvolution_bwd_data_pd_t; + using hint_class = deconvolution_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_WEIGHTS, DNNL_ARG_DIFF_DST)) @@ -316,7 +320,7 @@ struct deconvolution_bwd_data_pd_t : public deconvolution_pd_t { memory_desc_t weights_md_; memory_desc_t diff_dst_md_; - deconvolution_bwd_data_pd_t(const deconvolution_desc_t *adesc, + deconvolution_bwd_data_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const deconvolution_fwd_pd_t *hint_fwd_pd) : deconvolution_pd_t(adesc, attr, hint_fwd_pd) @@ -324,10 +328,12 @@ struct deconvolution_bwd_data_pd_t : public deconvolution_pd_t { , weights_md_(desc_.weights_desc) , diff_dst_md_(desc_.diff_dst_desc) {} }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct deconvolution_bwd_weights_pd_t : public deconvolution_pd_t { - typedef deconvolution_bwd_weights_pd_t base_class; - typedef deconvolution_fwd_pd_t hint_class; + using base_class = deconvolution_bwd_weights_pd_t; + using hint_class = deconvolution_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST)) @@ -335,8 +341,8 @@ struct deconvolution_bwd_weights_pd_t : public deconvolution_pd_t { if (arg == DNNL_ARG_DIFF_WEIGHTS) return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_BIAS && with_bias()) - return arg_usage_t::output; + if (arg == DNNL_ARG_DIFF_BIAS) + return with_bias() ? arg_usage_t::output : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -381,7 +387,7 @@ struct deconvolution_bwd_weights_pd_t : public deconvolution_pd_t { memory_desc_t diff_bias_md_; memory_desc_t diff_dst_md_; - deconvolution_bwd_weights_pd_t(const deconvolution_desc_t *adesc, + deconvolution_bwd_weights_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const deconvolution_fwd_pd_t *hint_fwd_pd) : deconvolution_pd_t(adesc, attr, hint_fwd_pd) @@ -390,6 +396,7 @@ struct deconvolution_bwd_weights_pd_t : public deconvolution_pd_t { , diff_bias_md_(desc_.diff_bias_desc) , diff_dst_md_(desc_.diff_dst_desc) {} }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/dnnl_debug.cpp b/src/common/dnnl_debug.cpp index 8f2dc99a278..a7aacb41115 100644 --- a/src/common/dnnl_debug.cpp +++ b/src/common/dnnl_debug.cpp @@ -33,6 +33,7 @@ const char *dnnl_runtime2str(unsigned runtime) { case DNNL_RUNTIME_SEQ: return "sequential"; case DNNL_RUNTIME_OMP: return "OpenMP"; case DNNL_RUNTIME_TBB: return "TBB"; + case DNNL_RUNTIME_TBB_AUTO: return "TBB_AUTO"; case DNNL_RUNTIME_OCL: return "OpenCL"; case DNNL_RUNTIME_THREADPOOL: return "threadpool"; #ifdef DNNL_WITH_SYCL @@ -49,8 +50,11 @@ const char *dnnl_fmt_kind2str(dnnl_format_kind_t v) { #ifdef DNNL_EXPERIMENTAL_SPARSE if (v == dnnl_format_kind_sparse) return "sparse"; #endif - if (v == format_kind::wino || v == format_kind::rnn_packed) return "opaque"; + if (v == format_kind::wino || v == format_kind::rnn_packed + || v == format_kind::cublaslt_blocked) + return "opaque"; if (v == dnnl_format_kind_max) return "max"; + if (v == dnnl_format_sparse) return "format_sparse"; assert(!"unknown fmt_kind"); return "unknown fmt_kind"; } diff --git a/src/common/dnnl_debug_autogenerated.cpp b/src/common/dnnl_debug_autogenerated.cpp index 05c3a55f9d7..1cbad069cd3 100644 --- a/src/common/dnnl_debug_autogenerated.cpp +++ b/src/common/dnnl_debug_autogenerated.cpp @@ -1,5 +1,6 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation +* Copyright 2024-2025 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,6 +58,12 @@ const char *dnnl_dt2str(dnnl_data_type_t v) { if (v == dnnl_s4) return "s4"; if (v == dnnl_u4) return "u4"; if (v == dnnl_e8m0) return "e8m0"; + if (v == dnnl_f4_e2m1) return "f4_e2m1"; + if (v == dnnl_f4_e3m0) return "f4_e3m0"; + if (v == dnnl_bin) return "bin"; + if (v == dnnl_nf4) return "nf4"; + if (v == dnnl_s4) return "s4"; + if (v == dnnl_u4) return "u4"; if (v == dnnl_data_type_max) return "data_type_max"; assert(!"unknown dt"); return "unknown dt"; @@ -96,6 +103,7 @@ const char *dnnl_sparse_encoding2str(dnnl_sparse_encoding_t v) { if (v == dnnl_sparse_encoding_undef) return "undef"; if (v == dnnl_csr) return "csr"; if (v == dnnl_packed) return "packed"; + if (v == dnnl_coo) return "coo"; assert(!"unknown sparse_encoding"); return "unknown sparse_encoding"; } @@ -203,6 +211,8 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_ABcd8a16b2a) return "ABcd8a16b2a"; if (v == dnnl_ABcd2b8a4b) return "ABcd2b8a4b"; if (v == dnnl_ABcd8a8b) return "ABcd8a8b"; + if (v == dnnl_ABcd8a32b) return "ABcd8a32b"; + if (v == dnnl_ABcd16a32b) return "ABcd16a32b"; if (v == dnnl_ABcd8a4b) return "ABcd8a4b"; if (v == dnnl_aBcd8b) return "aBcd8b"; if (v == dnnl_aBCd4c8b2c) return "aBCd4c8b2c"; @@ -301,6 +311,8 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_aCBdef16c16b) return "aCBdef16c16b"; if (v == dnnl_aBdefc4b) return "aBdefc4b"; if (v == dnnl_aBdefc8b) return "aBdefc8b"; + if (v == dnnl_Abcdef4a) return "Abcdef4a"; + if (v == dnnl_Abcdef8a) return "Abcdef8a"; if (v == dnnl_Abcdef16a) return "Abcdef16a"; if (v == dnnl_Abcdef32a) return "Abcdef32a"; if (v == dnnl_aBedc16b) return "aBedc16b"; @@ -940,6 +952,18 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_bcad) return "bcad"; if (v == dnnl_cabd) return "cabd"; if (v == dnnl_dabc) return "dabc"; + if (v == dnnl_Ab32a) return "Ab32a"; + if (v == dnnl_aCBd8b8c) return "aCBd8b8c"; + if (v == dnnl_aCBde8b8c) return "aCBde8b8c"; + if (v == dnnl_BAc8a8b) return "BAc8a8b"; + if (v == dnnl_BAcd8a8b) return "BAcd8a8b"; + if (v == dnnl_BAcde8a8b) return "BAcde8a8b"; + if (v == dnnl_aCBdef8b8c) return "aCBdef8b8c"; + if (v == dnnl_abdEC16e4c) return "abdEC16e4c"; + if (v == dnnl_abDC16d4c) return "abDC16d4c"; + if (v == dnnl_BA24b8a) return "BA24b8a"; + if (v == dnnl_aCB24c8b) return "aCB24c8b"; + if (v == dnnl_abDC24d8c) return "abDC24d8c"; if (v == dnnl_format_tag_last) return "format_tag_last"; if (v == dnnl_x) return "x"; if (v == dnnl_nc) return "nc"; @@ -993,9 +1017,11 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_ldgo) return "ldgo"; if (v == dnnl_ldOi16o) return "ldOi16o"; if (v == dnnl_ldOi32o) return "ldOi32o"; + if (v == dnnl_ldOI16o4i) return "ldOI16o4i"; if (v == dnnl_ldOI32o4i) return "ldOI32o4i"; if (v == dnnl_ldIo32i) return "ldIo32i"; if (v == dnnl_ldgOi16o) return "ldgOi16o"; + if (v == dnnl_ldgOI16o4i) return "ldgOI16o4i"; if (v == dnnl_ldgOi32o) return "ldgOi32o"; if (v == dnnl_ldgOI32o2i) return "ldgOI32o2i"; if (v == dnnl_ldgOI32o4i) return "ldgOI32o4i"; @@ -1045,6 +1071,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_OI8i24o) return "OI8i24o"; if (v == dnnl_OI8i16o) return "OI8i16o"; if (v == dnnl_OI8i8o) return "OI8i8o"; + if (v == dnnl_IOw8o8i) return "IOw8o8i"; if (v == dnnl_IOw16o16i) return "IOw16o16i"; if (v == dnnl_IOw16i16o) return "IOw16i16o"; if (v == dnnl_OIw16i16o) return "OIw16i16o"; @@ -1113,6 +1140,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_OwI8i16o) return "OwI8i16o"; if (v == dnnl_OwI8o4i) return "OwI8o4i"; if (v == dnnl_IOhw16i16o) return "IOhw16i16o"; + if (v == dnnl_IOhw8o8i) return "IOhw8o8i"; if (v == dnnl_IOhw16o16i) return "IOhw16o16i"; if (v == dnnl_Ohwi16o) return "Ohwi16o"; if (v == dnnl_OhwI16o2i) return "OhwI16o2i"; @@ -1173,6 +1201,8 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_OIhw8o16i2o) return "OIhw8o16i2o"; if (v == dnnl_OIhw2i8o4i) return "OIhw2i8o4i"; if (v == dnnl_IOhw8o16i2o) return "IOhw8o16i2o"; + if (v == dnnl_OIhw8o32i) return "OIhw8o23i"; + if (v == dnnl_OIhw16o32i) return "OIhw16o23i"; if (v == dnnl_OIhw8o8i) return "OIhw8o8i"; if (v == dnnl_OIhw8o4i) return "OIhw8o4i"; if (v == dnnl_Owhi16o) return "Owhi16o"; @@ -1243,6 +1273,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_OIdhw8o4i) return "OIdhw8o4i"; if (v == dnnl_IOdhw16i16o) return "IOdhw16i16o"; if (v == dnnl_OIdhw4o8i8o4i) return "OIdhw4o8i8o4i"; + if (v == dnnl_IOdhw8o8i) return "IOdhw8o8i"; if (v == dnnl_IOdhw16o16i) return "IOdhw16o16i"; if (v == dnnl_OIdhw16o16i2o) return "OIdhw16o16i2o"; if (v == dnnl_OIdhw8i32o) return "OIdhw8i32o"; @@ -1254,6 +1285,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_Goiw16g) return "Goiw16g"; if (v == dnnl_Goiw8g) return "Goiw8g"; if (v == dnnl_Goiw4g) return "Goiw4g"; + if (v == dnnl_gIOw8o8i) return "gIOw8o8i"; if (v == dnnl_gIOw16o16i) return "gIOw16o16i"; if (v == dnnl_gIOw16i16o) return "gIOw16i16o"; if (v == dnnl_gOIw16i16o) return "gOIw16i16o"; @@ -1297,6 +1329,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_goIw4i) return "goIw4i"; if (v == dnnl_goIw32i) return "goIw32i"; if (v == dnnl_gIOhw16i16o) return "gIOhw16i16o"; + if (v == dnnl_gIOhw8o8i) return "gIOhw8o8i"; if (v == dnnl_gIOhw16o16i) return "gIOhw16o16i"; if (v == dnnl_gOhwi16o) return "gOhwi16o"; if (v == dnnl_gOhwI16o2i) return "gOhwI16o2i"; @@ -1360,6 +1393,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_gOIhw4i8o2i) return "gOIhw4i8o2i"; if (v == dnnl_gOIhw4o8i2o) return "gOIhw4o8i2o"; if (v == dnnl_gIOdhw16i16o) return "gIOdhw16i16o"; + if (v == dnnl_gIOdhw8o8i) return "gIOdhw8o8i"; if (v == dnnl_gIOdhw16o16i) return "gIOdhw16o16i"; if (v == dnnl_gOdhwi16o) return "gOdhwi16o"; if (v == dnnl_gOdhwI16o2i) return "gOdhwI16o2i"; @@ -1395,6 +1429,8 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_gIOdhw8o16i2o) return "gIOdhw8o16i2o"; if (v == dnnl_gOIdhw8o8i) return "gOIdhw8o8i"; if (v == dnnl_gOIdhw8o4i) return "gOIdhw8o4i"; + if (v == dnnl_Goidhw4g) return "Goidhw4g"; + if (v == dnnl_Goidhw8g) return "Goidhw8g"; if (v == dnnl_Goidhw16g) return "Goidhw16g"; if (v == dnnl_Goidhw32g) return "Goidhw32g"; if (v == dnnl_gOIdhw2i4o2i) return "gOIdhw2i4o2i"; @@ -1751,6 +1787,8 @@ const char *dnnl_prim_kind2str(dnnl_primitive_kind_t v) { if (v == dnnl_softmax) return "softmax"; if (v == dnnl_layer_normalization) return "layer_normalization"; if (v == dnnl_group_normalization) return "group_normalization"; + if (v == dnnl_depthwise) return "depthwise"; + if (v == dnnl_quantization) return "quantization"; if (v == dnnl_primitive_kind_max) return "primitive_kind_max"; if (v == dnnl::impl::primitive_kind::sdpa) return "sdpa"; assert(!"unknown prim_kind"); @@ -1785,6 +1823,9 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) { if (v == dnnl_eltwise_round) return "eltwise_round"; if (v == dnnl_eltwise_mish) return "eltwise_mish"; if (v == dnnl_eltwise_hardswish) return "eltwise_hardswish"; + if (v == dnnl_eltwise_hsigmoid) return "eltwise_hsigmoid"; + if (v == dnnl_eltwise_round_half_to_even) return "eltwise_round_half_to_even"; + if (v == dnnl_eltwise_round_half_away_from_zero) return "eltwise_round_half_away_from_zero"; if (v == dnnl_eltwise_relu_use_dst_for_bwd) return "eltwise_relu_use_dst_for_bwd"; if (v == dnnl_eltwise_tanh_use_dst_for_bwd) return "eltwise_tanh_use_dst_for_bwd"; if (v == dnnl_eltwise_elu_use_dst_for_bwd) return "eltwise_elu_use_dst_for_bwd"; @@ -1815,6 +1856,8 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) { if (v == dnnl_binary_lt) return "binary_lt"; if (v == dnnl_binary_eq) return "binary_eq"; if (v == dnnl_binary_ne) return "binary_ne"; + if (v == dnnl_binary_select) return "binary_select"; + if (v == dnnl_binary_prelu) return "binary_prelu"; if (v == dnnl_resampling_nearest) return "resampling_nearest"; if (v == dnnl_resampling_linear) return "resampling_linear"; if (v == dnnl_reduction_max) return "reduction_max"; @@ -1828,10 +1871,23 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) { if (v == dnnl_reduction_norm_lp_power_p_sum) return "reduction_norm_lp_power_p_sum"; if (v == dnnl_softmax_accurate) return "softmax_accurate"; if (v == dnnl_softmax_log) return "softmax_log"; + if (v == dnnl_depthwise_scale_shift) return "depthwise_scale_shift"; + if (v == dnnl_depthwise_prelu) return "depthwise_prelu"; + if (v == dnnl_quantization_quantize_dequantize) return "quantization_quantize_dequantize"; + if (v == dnnl_quantization_quantize) return "quantization_quantize"; + if (v == dnnl_binarization_depthwise) return "binarization_depthwise"; assert(!"unknown alg_kind"); return "unknown alg_kind"; } +const char *dnnl_sparse_encoding2str(dnnl_sparse_encoding_t v) { + if (v == dnnl_sparse_encoding_undef) return "undef"; + if (v == dnnl_sparse_encoding_any) return "any"; + if (v == dnnl_sparse_encoding_packed) return "sparse_encoding_packed"; + assert(!"unknown sparse_encoding"); + return "unknown sparse_encoding"; +} + const char *dnnl_rnn_flags2str(dnnl_rnn_flags_t v) { if (v == dnnl_rnn_flags_undef) return "undef"; if (v == dnnl_rnn_flags_diff_weights_overwrite) return "rnn_flags_diff_weights_overwrite"; diff --git a/src/common/dnnl_sel_build.hpp b/src/common/dnnl_sel_build.hpp new file mode 100644 index 00000000000..fee17f6685b --- /dev/null +++ b/src/common/dnnl_sel_build.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + + +#define DNNL_MACRO_EXPAND(x) x + +#define DNNL_MACRO_CAT_(x, y) x ## y +#define DNNL_MACRO_CAT(x, y) DNNL_MACRO_CAT_(x, y) +#define DNNL_MACRO_CAT3_(x, y, z) x ## y ## z +#define DNNL_MACRO_CAT3(x, y, z) DNNL_MACRO_CAT3_(x, y, z) + +#define DNNL_MACRO_TOSTRING(...) DNNL_MACRO_TOSTRING_(__VA_ARGS__) +#define DNNL_MACRO_TOSTRING_(...) #__VA_ARGS__ + +#define DNNL_MACRO_NARG(...) DNNL_MACRO_EXPAND( DNNL_MACRO_NARG_(__VA_ARGS__, DNNL_MACRO_RSEQ_N()) ) +#define DNNL_MACRO_NARG_(...) DNNL_MACRO_EXPAND( DNNL_MACRO_ARG_N(__VA_ARGS__) ) +#define DNNL_MACRO_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define DNNL_MACRO_RSEQ_N() 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + +#define DNNL_MACRO_EVAL_(NAME, N) NAME ## _ ## N +#define DNNL_MACRO_EVAL(NAME, N) DNNL_MACRO_EVAL_(NAME, N) + +#define DNNL_MACRO_OVERLOAD(NAME, ...) \ + DNNL_MACRO_EXPAND( DNNL_MACRO_EVAL(NAME, DNNL_MACRO_EXPAND( DNNL_MACRO_NARG(__VA_ARGS__) ))(__VA_ARGS__) ) + +#if defined(SELECTIVE_BUILD_ANALYZER) + +# include + +namespace dnnl { + +OV_CC_DOMAINS(DNNL) + +} // namespace dnnl + +# define DNNL_CSCOPE(region) OV_SCOPE(DNNL, region) + +# define DNNL_PRIMITIVE_NAME_INIT(pd_t) name = typeid(pd_t).name(); +# define DNNL_PRIMITIVE_CREATE(pd_t) OV_ITT_SCOPED_TASK(dnnl::FACTORY_DNNL, std::string("CREATE$CPUEngine$") + typeid(pd_t).name()); +# define DNNL_PRIMITIVE_IMPL(...) DNNL_MACRO_OVERLOAD(DNNL_PRIMITIVE_IMPL, __VA_ARGS__), +# define DNNL_PRIMITIVE_IMPL_2(expr, type) dnnl::impl::move(expr(type), OV_CC_TOSTRING(type)) +# define DNNL_PRIMITIVE_IMPL_3(expr, type, t1) dnnl::impl::move(expr(type), OV_CC_TOSTRING(type ## _ ## t1)) +# define DNNL_PRIMITIVE_IMPL_4(expr, type, t1, t2) dnnl::impl::move(expr(type), OV_CC_TOSTRING(type ## _ ## t1 ## _ ## t2)) +# define DNNL_PRIMITIVE_IMPL_5(expr, type, t1, t2, t3) dnnl::impl::move(expr(type), OV_CC_TOSTRING(type ## _ ## t1 ## _ ## t2 ## _ ## t3)) +# define DNNL_PRIMITIVE_IMPL_6(expr, type, t1, t2, t3, t4) dnnl::impl::move(expr(type), OV_CC_TOSTRING(type ## _ ## t1 ## _ ## t2 ## _ ## t3 ## _ ## t4)) +# define DNNL_PRIMITIVE_IMPL_7(expr, type, t1, t2, t3, t4, t5) dnnl::impl::move(expr(type), OV_CC_TOSTRING(type ## _ ## t1 ## _ ## t2 ## _ ## t3 ## _ ## t4 ## _ ## t5)) +# define DNNL_PRIMITIVE_IMPL_8(expr, type, t1, t2, t3, t4, t5, t6) dnnl::impl::move(expr(type), OV_CC_TOSTRING(type ## _ ## t1 ## _ ## t2 ## _ ## t3 ## _ ## t4 ## _ ## t5 ## _ ## t6)) +# define DNNL_PRIMITIVE_IMPL_9(expr, type, t1, t2, t3, t4, t5, t6, t7) dnnl::impl::move(expr(type), OV_CC_TOSTRING(type ## _ ## t1 ## _ ## t2 ## _ ## t3 ## _ ## t4 ## _ ## t5 ## _ ## t6 ## _ ## t7)) + +#elif defined(SELECTIVE_BUILD) + +# include + +# define DNNL_CSCOPE(region) OV_SCOPE(DNNL, region) + +# define DNNL_OBJ_BUILDER_0(...) +# define DNNL_OBJ_BUILDER_1(...) __VA_ARGS__, +# define DNNL_OBJ_BUILDER(name, ...) OV_CC_EXPAND(OV_CC_CAT(DNNL_OBJ_BUILDER_, OV_CC_EXPAND(OV_CC_SCOPE_IS_ENABLED(OV_CC_CAT(DNNL_, name))))(__VA_ARGS__)) + +# define DNNL_PRIMITIVE_NAME_INIT(pd_t) +# define DNNL_PRIMITIVE_CREATE(pd_t) +# define DNNL_PRIMITIVE_IMPL(...) DNNL_MACRO_OVERLOAD(DNNL_PRIMITIVE_IMPL, __VA_ARGS__) +# define DNNL_PRIMITIVE_IMPL_2(expr, type) DNNL_OBJ_BUILDER(type, expr(type)) +# define DNNL_PRIMITIVE_IMPL_3(expr, type, t1) DNNL_OBJ_BUILDER(type ## _ ## t1, expr(type)) +# define DNNL_PRIMITIVE_IMPL_4(expr, type, t1, t2) DNNL_OBJ_BUILDER(type ## _ ## t1 ## _ ## t2, expr(type)) +# define DNNL_PRIMITIVE_IMPL_5(expr, type, t1, t2, t3) DNNL_OBJ_BUILDER(type ## _ ## t1 ## _ ## t2 ## _ ## t3, expr(type)) +# define DNNL_PRIMITIVE_IMPL_6(expr, type, t1, t2, t3, t4) DNNL_OBJ_BUILDER(type ## _ ## t1 ## _ ## t2 ## _ ## t3 ## _ ## t4, expr(type)) +# define DNNL_PRIMITIVE_IMPL_7(expr, type, t1, t2, t3, t4, t5) DNNL_OBJ_BUILDER(type ## _ ## t1 ## _ ## t2 ## _ ## t3 ## _ ## t4 ## _ ## t5, expr(type)) +# define DNNL_PRIMITIVE_IMPL_8(expr, type, t1, t2, t3, t4, t5, t6) DNNL_OBJ_BUILDER(type ## _ ## t1 ## _ ## t2 ## _ ## t3 ## _ ## t4 ## _ ## t5 ## _ ## t6, expr(type)) +# define DNNL_PRIMITIVE_IMPL_9(expr, type, t1, t2, t3, t4, t5, t6, t7) DNNL_OBJ_BUILDER(type ## _ ## t1 ## _ ## t2 ## _ ## t3 ## _ ## t4 ## _ ## t5 ## _ ## t6 ## _ ## t7, expr(type)) + +#else + +# define DNNL_CSCOPE(region) + +# define DNNL_PRIMITIVE_NAME_INIT(pd_t) +# define DNNL_PRIMITIVE_CREATE(pd_t) +# define DNNL_PRIMITIVE_IMPL(...) DNNL_MACRO_OVERLOAD(DNNL_PRIMITIVE_IMPL, __VA_ARGS__), +# define DNNL_PRIMITIVE_IMPL_2(expr, type) expr(type) +# define DNNL_PRIMITIVE_IMPL_3(expr, type, t1) expr(type) +# define DNNL_PRIMITIVE_IMPL_4(expr, type, t1, t2) expr(type) +# define DNNL_PRIMITIVE_IMPL_5(expr, type, t1, t2, t3) expr(type) +# define DNNL_PRIMITIVE_IMPL_6(expr, type, t1, t2, t3, t4) expr(type) +# define DNNL_PRIMITIVE_IMPL_7(expr, type, t1, t2, t3, t4, t5) expr(type) +# define DNNL_PRIMITIVE_IMPL_8(expr, type, t1, t2, t3, t4, t5, t6) expr(type) +# define DNNL_PRIMITIVE_IMPL_9(expr, type, t1, t2, t3, t4, t5, t6, t7) expr(type) + +#endif diff --git a/src/common/dnnl_thread.cpp b/src/common/dnnl_thread.cpp new file mode 100644 index 00000000000..e28f92b3557 --- /dev/null +++ b/src/common/dnnl_thread.cpp @@ -0,0 +1,102 @@ +#include + +#include "dnnl_thread.hpp" + +#if defined(DNNL_ENABLE_ITT_TASKS) +#include "common/ittnotify.hpp" +#endif + +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL +#include "counting_barrier.hpp" +#endif + +namespace dnnl { +namespace impl { + +void parallel(int nthr, const std::function &f) { + nthr = adjust_num_threads(nthr, INT64_MAX); +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ + for (int i = 0; i < nthr; ++i) { + f(i, nthr); + } +#else +#if defined(DNNL_ENABLE_ITT_TASKS) + auto task_primitive_kind = itt::primitive_task_get_current_kind(); + bool itt_enable = itt::get_itt(itt::__itt_task_level_high); +#endif + if (nthr == 1) { + f(0, 1); + return; + } +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP +#pragma omp parallel num_threads(nthr) + { + int nthr_ = omp_get_num_threads(); + int ithr_ = omp_get_thread_num(); + assert(nthr_ == nthr); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (ithr_ && itt_enable) itt::primitive_task_start(task_primitive_kind); +#endif + f(ithr_, nthr_); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (ithr_ && itt_enable) itt::primitive_task_end(); +#endif + } +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB + tbb::parallel_for( + 0, nthr, + [&](int ithr) { +#if defined(DNNL_ENABLE_ITT_TASKS) + bool mark_task = itt::primitive_task_get_current_kind() + == primitive_kind::undefined; + if (mark_task && itt_enable) + itt::primitive_task_start(task_primitive_kind); +#endif + f(ithr, nthr); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (mark_task && itt_enable) itt::primitive_task_end(); +#endif + }, + tbb::static_partitioner()); +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO + tbb::parallel_for( + 0, nthr, [&](int ithr) { f(ithr, nthr); }); +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL + using namespace dnnl::impl::threadpool_utils; + dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); + if (!tp || dnnl_in_parallel()) { + threadpool_utils::deactivate_threadpool(); + for (int ithr = 0; ithr < nthr; ithr++) { + f(ithr, nthr); + } + threadpool_utils::activate_threadpool(tp); + } else { + bool async = tp->get_flags() + & dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS; + counting_barrier_t b; + if (async) b.init(nthr); + tp->parallel_for(nthr, [&, tp](int ithr, int nthr) { + bool is_master = threadpool_utils::get_active_threadpool() == tp; + if (!is_master) { + threadpool_utils::activate_threadpool(tp); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (itt_enable) itt::primitive_task_start(task_primitive_kind); +#endif + } + f(ithr, nthr); + if (!is_master) { +#if defined(DNNL_ENABLE_ITT_TASKS) + if (itt_enable) itt::primitive_task_end(); +#endif + threadpool_utils::deactivate_threadpool(); + } + if (async) b.notify(); + }); + if (async) b.wait(); + } +#endif +#endif +} + +} // namespace impl +} // namespace dnnl diff --git a/src/common/dnnl_thread.hpp b/src/common/dnnl_thread.hpp index 6122819a308..2f13770dec5 100644 --- a/src/common/dnnl_thread.hpp +++ b/src/common/dnnl_thread.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2023 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,7 +75,7 @@ inline void dnnl_thr_barrier() { #pragma omp barrier } -#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB +#elif (DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB || DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO) #include "tbb/parallel_for.h" #include "tbb/task_arena.h" #define DNNL_THR_SYNC 0 @@ -184,25 +184,25 @@ inline int dnnl_get_current_num_threads() { #define OMP_GET_NUM_THREADS() 1 #endif -// MSVC still supports omp 2.0 only -#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) +// Disabling OMP SIMD feature in the following scenarios: +// * For MSVC as it only supports OpenMP 2.0 +// * however VS2019 also now offers SIMD functionality +// * with the -openmp:experimental compilation switch that enables additional OpenMP features +// * not available when using the -openmp switch +// * In debug mode on Windows to avoid incorrect code generation +// by Intel(R) oneAPI DPC++/C++ Compiler +#if defined(_MSC_VER) && (_MSC_VER < 1900) \ + && ((!defined(__clang__) && !defined(__INTEL_COMPILER)) \ + || defined(_DEBUG)) #define collapse(x) #define PRAGMA_OMP_SIMD(...) #else #define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__)) -#endif // defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) - -// process simdlen; it is supported for Clang >= 3.9; ICC >= 17.0; GCC >= 6.1 -// No support on Windows. -#if (defined(__clang_major__) \ - && (__clang_major__ < 3 \ - || (__clang_major__ == 3 && __clang_minor__ < 9))) \ - || (defined(__INTEL_COMPILER) && __INTEL_COMPILER < 1700) \ - || (!defined(__INTEL_COMPILER) && !defined(__clang__) \ - && (defined(_MSC_VER) || __GNUC__ < 6 \ - || (__GNUC__ == 6 && __GNUC_MINOR__ < 1))) -#define simdlen(x) -#endif // long simdlen if +#endif // defined(_MSC_VER) && ((!defined(__clang__) && !defined(__INTEL_COMPILER)) || defined(_DEBUG)) + +#if defined(DNNL_ENABLE_ITT_TASKS) +#include "common/ittnotify.hpp" +#endif namespace dnnl { namespace impl { @@ -282,87 +282,7 @@ inline int adjust_num_threads(int nthr, dim_t work_amount) { #endif } -static inline void parallel(int nthr, const std::function &f) { - nthr = adjust_num_threads(nthr, INT64_MAX); -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ - for (int i = 0; i < nthr; ++i) { - f(i, nthr); - } -#else -#if defined(DNNL_ENABLE_ITT_TASKS) - auto task_primitive_kind = itt::primitive_task_get_current_kind(); - bool itt_enable = itt::get_itt(itt::__itt_task_level_high); -#endif - if (nthr == 1) { - f(0, 1); - return; - } -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP -#pragma omp parallel num_threads(nthr) - { - int nthr_ = omp_get_num_threads(); - int ithr_ = omp_get_thread_num(); - assert(nthr_ == nthr); -#if defined(DNNL_ENABLE_ITT_TASKS) - if (ithr_ && itt_enable) itt::primitive_task_start(task_primitive_kind); -#endif - f(ithr_, nthr_); -#if defined(DNNL_ENABLE_ITT_TASKS) - if (ithr_ && itt_enable) itt::primitive_task_end(); -#endif - } -#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB - tbb::parallel_for( - 0, nthr, - [&](int ithr) { -#if defined(DNNL_ENABLE_ITT_TASKS) - bool mark_task = itt::primitive_task_get_current_kind() - == primitive_kind::undefined; - if (mark_task && itt_enable) - itt::primitive_task_start(task_primitive_kind); -#endif - f(ithr, nthr); -#if defined(DNNL_ENABLE_ITT_TASKS) - if (mark_task && itt_enable) itt::primitive_task_end(); -#endif - }, - tbb::static_partitioner()); -#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL - using namespace dnnl::impl::threadpool_utils; - dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); - if (!tp || dnnl_in_parallel()) { - threadpool_utils::deactivate_threadpool(); - for (int ithr = 0; ithr < nthr; ithr++) { - f(ithr, nthr); - } - threadpool_utils::activate_threadpool(tp); - } else { - bool async = tp->get_flags() - & dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS; - counting_barrier_t b; - if (async) b.init(nthr); - tp->parallel_for(nthr, [&, tp](int ithr, int nthr) { - bool is_master = threadpool_utils::get_active_threadpool() == tp; - if (!is_master) { - threadpool_utils::activate_threadpool(tp); -#if defined(DNNL_ENABLE_ITT_TASKS) - if (itt_enable) itt::primitive_task_start(task_primitive_kind); -#endif - } - f(ithr, nthr); - if (!is_master) { -#if defined(DNNL_ENABLE_ITT_TASKS) - if (itt_enable) itt::primitive_task_end(); -#endif - threadpool_utils::deactivate_threadpool(); - } - if (async) b.notify(); - }); - if (async) b.wait(); - } -#endif -#endif -} +void DNNL_API parallel(int nthr, const std::function &f); // XXX: IMPORTANT!!! // Keep the functions below static. @@ -652,6 +572,171 @@ static inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3, dim_t D4, }); } +template +void parallel_legacy(int nthr, F f) { + nthr = adjust_num_threads(nthr, INT64_MAX); +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ + assert(nthr == 1); + f(0, 1); +#else +#if defined(DNNL_ENABLE_ITT_TASKS) + auto task_primitive_kind = itt::primitive_task_get_current_kind(); + bool itt_enable = itt::get_itt(itt::__itt_task_level_high); +#endif + if (nthr == 1) { + f(0, 1); + return; + } +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP +#pragma omp parallel num_threads(nthr) + { + int nthr_ = omp_get_num_threads(); + int ithr_ = omp_get_thread_num(); + assert(nthr_ == nthr); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (ithr_ && itt_enable) itt::primitive_task_start(task_primitive_kind); +#endif + f(ithr_, nthr_); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (ithr_ && itt_enable) itt::primitive_task_end(); +#endif + } +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB + tbb::parallel_for( + 0, nthr, + [&](int ithr) { +#if defined(DNNL_ENABLE_ITT_TASKS) + bool mark_task = itt::primitive_task_get_current_kind() + == primitive_kind::undefined; + if (mark_task && itt_enable) + itt::primitive_task_start(task_primitive_kind); +#endif + f(ithr, nthr); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (mark_task && itt_enable) itt::primitive_task_end(); +#endif + }, + tbb::static_partitioner()); +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO + tbb::parallel_for( + 0, nthr, [&](int ithr) { f(ithr, nthr); }); +#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL + using namespace dnnl::impl::threadpool_utils; + dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); + if (!tp || dnnl_in_parallel()) { + threadpool_utils::deactivate_threadpool(); + for (int ithr = 0; ithr < nthr; ithr++) { + f(ithr, nthr); + } + threadpool_utils::activate_threadpool(tp); + } else { + bool async = tp->get_flags() + & dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS; + counting_barrier_t b; + if (async) b.init(nthr); + tp->parallel_for(nthr, [&, tp](int ithr, int nthr) { + bool is_master = threadpool_utils::get_active_threadpool() == tp; + if (!is_master) { + threadpool_utils::activate_threadpool(tp); +#if defined(DNNL_ENABLE_ITT_TASKS) + if (itt_enable) itt::primitive_task_start(task_primitive_kind); +#endif + } + f(ithr, nthr); + if (!is_master) { +#if defined(DNNL_ENABLE_ITT_TASKS) + if (itt_enable) itt::primitive_task_end(); +#endif + threadpool_utils::deactivate_threadpool(); + } + if (async) b.notify(); + }); + if (async) b.wait(); + } +#endif +#endif +} + +template +void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, F f) { + T0 start {0}, end {0}; + balance211(D0, nthr, ithr, start, end); + for (T0 d0 = start; d0 < end; ++d0) + f(d0); +} + +template +void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + size_t start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0 {0}; + T1 d1 {0}; + T2 d2 {0}; + T3 d3 {0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } +} + +template +void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + size_t start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0 {0}; + T1 d1 {0}; + T2 d2 {0}; + T3 d3 {0}; + T4 d4 {0}; + T5 d5 {0}; + utils::nd_iterator_init( + start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } +} + +template +void parallel_nd_legacy(const T0 &D0, F f) { + const size_t work_amount = (size_t)D0; + int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); + if (nthr) + parallel_legacy(nthr, [&](int ithr, int nthr) { for_nd_legacy(ithr, nthr, D0, f); }); +} + +template +void parallel_nd_legacy(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); + if (nthr) + parallel_legacy(nthr, [&](int ithr, int nthr) { + for_nd_legacy(ithr, nthr, D0, D1, D2, D3, f); + }); +} + +template +void parallel_nd_legacy(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); + if (nthr) + parallel_legacy(nthr, [&](int ithr, int nthr) { + for_nd_legacy(ithr, nthr, D0, D1, D2, D3, D4, D5, f); + }); +} + } // namespace impl } // namespace dnnl diff --git a/src/common/dnnl_traits.hpp b/src/common/dnnl_traits.hpp index cefdf1a80ee..4f9b8029282 100644 --- a/src/common/dnnl_traits.hpp +++ b/src/common/dnnl_traits.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,171 +17,166 @@ #ifndef COMMON_DNNL_TRAITS_HPP #define COMMON_DNNL_TRAITS_HPP -#include -#include - -#include "oneapi/dnnl/dnnl.h" - #include "bfloat16.hpp" #include "c_types_map.hpp" #include "float16.hpp" +#include "float4.hpp" #include "float8.hpp" -#include "nstl.hpp" -#include "opdesc.hpp" -#include "utils.hpp" -#include "z_magic.hpp" +#include "int4.hpp" + +#include namespace dnnl { namespace impl { template -struct prec_traits {}; /* ::type -> float */ +struct prec_traits_t {}; /* ::type -> float */ template -struct data_traits {}; /* ::data_type -> f32 */ +struct data_traits_t {}; /* ::data_type -> f32 */ template -struct typesize_traits {}; /* ::data_type_size -> f32 */ +struct typesize_traits_t {}; /* ::data_type_size -> f32 */ template -struct pkind_traits {}; /* ::desc_type, ::query_d */ +struct pkind_traits_t {}; /* ::desc_type, ::query_d */ template <> -struct prec_traits { - typedef float8_e8m0_t type; +struct prec_traits_t { + using type = float4_e3m0_t; +}; +template <> +struct prec_traits_t { + using type = float4_e2m1_t; }; template <> -struct prec_traits { - typedef float8_e5m2_t type; +struct prec_traits_t { + using type = float8_e8m0_t; }; template <> -struct prec_traits { - typedef float8_e4m3_t type; +struct prec_traits_t { + using type = float8_e5m2_t; }; template <> -struct prec_traits { - typedef float16_t type; +struct prec_traits_t { + using type = float8_e4m3_t; }; template <> -struct prec_traits { - typedef bfloat16_t type; +struct prec_traits_t { + using type = float16_t; }; template <> -struct prec_traits { - typedef float type; +struct prec_traits_t { + using type = bfloat16_t; }; template <> -struct prec_traits { - typedef double type; +struct prec_traits_t { + using type = float; }; template <> -struct prec_traits { - typedef int32_t type; +struct prec_traits_t { + using type = double; }; template <> -struct prec_traits { - typedef int8_t type; +struct prec_traits_t { + using type = int32_t; }; template <> -struct prec_traits { - typedef uint8_t type; +struct prec_traits_t { + using type = int8_t; }; template <> -struct prec_traits { - typedef int4_t type; +struct prec_traits_t { + using type = uint8_t; }; template <> -struct prec_traits { - typedef uint4_t type; +struct prec_traits_t { + using type = int4_t; }; template <> -struct prec_traits { - typedef bool type; +struct prec_traits_t { + using type = uint4_t; +}; +template <> +struct prec_traits_t { + using type = bool; +}; + +template <> struct prec_traits_t { + using type = uint8_t; +}; + +template <> struct prec_traits_t { + using type = uint8_t; }; template <> -struct data_traits { +struct data_traits_t { + static constexpr data_type_t data_type = data_type::f4_e3m0; +}; +template <> +struct data_traits_t { + static constexpr data_type_t data_type = data_type::f4_e2m1; +}; +template <> +struct data_traits_t { + static constexpr data_type_t data_type = data_type::e8m0; +}; +template <> +struct data_traits_t { static constexpr data_type_t data_type = data_type::f8_e5m2; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::f8_e4m3; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::f16; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::bf16; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::f32; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::s32; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::s8; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::u8; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::s4; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::u4; }; template <> -struct data_traits { +struct data_traits_t { static constexpr data_type_t data_type = data_type::boolean; }; template <> -struct typesize_traits<4> { - typedef float type; +struct typesize_traits_t<4> { + using type = float; }; template <> -struct typesize_traits<2> { - typedef int16_t type; +struct typesize_traits_t<2> { + using type = int16_t; }; template <> -struct typesize_traits<1> { - typedef uint8_t type; +struct typesize_traits_t<1> { + using type = uint8_t; }; -#define PKIND_TRAITS_INST(op) \ - template <> \ - struct pkind_traits { \ - typedef CONCAT2(op, _desc_t) desc_type; \ - } -PKIND_TRAITS_INST(convolution); -PKIND_TRAITS_INST(deconvolution); -PKIND_TRAITS_INST(shuffle); -PKIND_TRAITS_INST(eltwise); -PKIND_TRAITS_INST(softmax); -PKIND_TRAITS_INST(pooling); -PKIND_TRAITS_INST(prelu); -PKIND_TRAITS_INST(lrn); -PKIND_TRAITS_INST(batch_normalization); -PKIND_TRAITS_INST(group_normalization); -PKIND_TRAITS_INST(layer_normalization); -PKIND_TRAITS_INST(inner_product); -PKIND_TRAITS_INST(rnn); -PKIND_TRAITS_INST(gemm); -PKIND_TRAITS_INST(zero_pad); -PKIND_TRAITS_INST(binary); -PKIND_TRAITS_INST(matmul); -PKIND_TRAITS_INST(resampling); -PKIND_TRAITS_INST(reduction); -PKIND_TRAITS_INST(sum); -PKIND_TRAITS_INST(sdpa); -#undef PKIND_TRAITS_INST - } // namespace impl } // namespace dnnl diff --git a/src/common/eltwise.cpp b/src/common/eltwise.cpp index 356584a54d0..ba7675120f8 100644 --- a/src/common/eltwise.cpp +++ b/src/common/eltwise.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,7 +57,8 @@ status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, VCHECK_ELTWISE( IMPLICATION(!is_fwd, !any_null(diff_src_desc, diff_dst_desc)), VERBOSE_NULL_ARG); - VCHECK_ELTWISE(IMPLICATION(alg_kind == eltwise_round, is_fwd), + VCHECK_ELTWISE(IMPLICATION(one_of(alg_kind, eltwise_round, eltwise_hsigmoid, + eltwise_round_half_away_from_zero, eltwise_round_half_to_even), is_fwd), VERBOSE_BAD_PROPKIND); VCHECK_ELTWISE( IMPLICATION(is_fwd, !memory_desc_wrapper(src_desc).format_any()), @@ -136,6 +137,9 @@ status_t eltwise_attr_check(const eltwise_desc_t &desc, const engine_t *engine, using namespace primitive_kind; VCHECK_ELTWISE_IMPL(po.has_default_values({binary}), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } } else { VCHECK_ELTWISE_IMPL(false, VERBOSE_UNSUPPORTED_ATTR); diff --git a/src/common/eltwise_pd.hpp b/src/common/eltwise_pd.hpp index e315f5866c8..4f2d43e33da 100644 --- a/src/common/eltwise_pd.hpp +++ b/src/common/eltwise_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -102,10 +102,10 @@ struct eltwise_pd_t : public primitive_desc_t { memory_desc_t src_md_; memory_desc_t dst_md_; - eltwise_pd_t(const eltwise_desc_t *adesc, const primitive_attr_t *attr, + eltwise_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const eltwise_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) , src_md_(desc_.src_desc) , dst_md_(desc_.dst_desc) {} @@ -116,9 +116,10 @@ struct eltwise_pd_t : public primitive_desc_t { } }; +// NOLINTBEGIN(google-default-arguments) struct eltwise_fwd_pd_t : public eltwise_pd_t { - typedef eltwise_fwd_pd_t base_class; - typedef eltwise_fwd_pd_t hint_class; + using base_class = eltwise_fwd_pd_t; + using hint_class = eltwise_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; @@ -158,7 +159,7 @@ struct eltwise_fwd_pd_t : public eltwise_pd_t { return one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_swish, eltwise_gelu_tanh, eltwise_gelu_erf, eltwise_round, - eltwise_hardswish) + eltwise_hardswish, eltwise_round_half_away_from_zero, eltwise_round_half_to_even) || one_of(alg, eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd, eltwise_elu_use_dst_for_bwd, @@ -179,7 +180,7 @@ struct eltwise_fwd_pd_t : public eltwise_pd_t { } protected: - eltwise_fwd_pd_t(const eltwise_desc_t *adesc, const primitive_attr_t *attr, + eltwise_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const eltwise_fwd_pd_t *hint_fwd_pd) : eltwise_pd_t(adesc, attr, hint_fwd_pd) {} @@ -190,14 +191,18 @@ struct eltwise_fwd_pd_t : public eltwise_pd_t { == status::success); } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct eltwise_bwd_pd_t : public eltwise_pd_t { - typedef eltwise_bwd_pd_t base_class; - typedef eltwise_fwd_pd_t hint_class; + using base_class = eltwise_bwd_pd_t; + using hint_class = eltwise_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { - if (use_dst() ? arg == DNNL_ARG_DST : arg == DNNL_ARG_SRC) - return arg_usage_t::input; + if (arg == DNNL_ARG_SRC) + return !use_dst() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_DST) + return use_dst() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_DIFF_DST) return arg_usage_t::input; if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; @@ -281,7 +286,7 @@ struct eltwise_bwd_pd_t : public eltwise_pd_t { memory_desc_t diff_src_md_; memory_desc_t diff_dst_md_; - eltwise_bwd_pd_t(const eltwise_desc_t *adesc, const primitive_attr_t *attr, + eltwise_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const eltwise_fwd_pd_t *hint_fwd_pd) : eltwise_pd_t(adesc, attr, hint_fwd_pd) , diff_src_md_(desc_.diff_src_desc) @@ -298,6 +303,7 @@ struct eltwise_bwd_pd_t : public eltwise_pd_t { == status::success); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/engine.cpp b/src/common/engine.cpp index 9fdd4e1f073..8768b3e0c1f 100644 --- a/src/common/engine.cpp +++ b/src/common/engine.cpp @@ -88,7 +88,19 @@ status_t dnnl_engine_create( try { auto ef = get_engine_factory(kind, get_default_runtime(kind)); VERROR_ENGINE(ef != nullptr, invalid_arguments, - VERBOSE_INVALID_ENGINE_KIND, dnnl_engine_kind2str(kind)); + VERBOSE_INVALID_ENGINE_KIND, "", dnnl_engine_kind2str(kind)); + + auto s_runtime_kind = dnnl_runtime2str(kind == engine_kind::cpu + ? dnnl_version()->cpu_runtime + : dnnl_version()->gpu_runtime); + + VERROR_ENGINE(ef->count() > 0, invalid_arguments, + "%s %s devices queried but not found", + get_default_runtime(kind) == runtime_kind::none + ? "" + : s_runtime_kind, + dnnl_engine_kind2str(kind)); + VERROR_ENGINE(index < ef->count(), invalid_arguments, VERBOSE_INVALID_ENGINE_IDX, ef->count(), dnnl_engine_kind2str(kind), index); diff --git a/src/common/engine.hpp b/src/common/engine.hpp index 5195b4aad28..159358fca66 100644 --- a/src/common/engine.hpp +++ b/src/common/engine.hpp @@ -67,8 +67,13 @@ struct dnnl_engine : public dnnl::impl::c_compatible { /** create memory storage */ virtual dnnl::impl::status_t create_memory_storage( dnnl::impl::memory_storage_t **storage, unsigned flags, size_t size, - void *handle) - = 0; + void *handle) { + assert(impl()); + if (!impl()) return dnnl::impl::status::runtime_error; + return impl()->create_memory_storage( + storage, this, flags, size, handle); + } + dnnl::impl::status_t create_memory_storage( dnnl::impl::memory_storage_t **storage, size_t size) { return create_memory_storage( @@ -187,6 +192,8 @@ inline runtime_kind_t get_default_runtime(engine_kind_t kind) { return runtime_kind::omp; #elif DNNL_CPU_RUNTIME == DNNL_RUNTIME_TBB return runtime_kind::tbb; +#elif DNNL_CPU_RUNTIME == DNNL_RUNTIME_TBB_AUTO + return runtime_kind::tbb_auto; #elif DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL return runtime_kind::threadpool; #elif DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL diff --git a/src/common/engine_impl.hpp b/src/common/engine_impl.hpp index d84cd6b9763..e42b8fa27fa 100644 --- a/src/common/engine_impl.hpp +++ b/src/common/engine_impl.hpp @@ -62,6 +62,12 @@ class engine_impl_t { virtual status_t init() { return status::success; } + virtual status_t create_memory_storage(memory_storage_t **storage, + engine_t *engine, unsigned flags, size_t size, void *handle) const { + assert(!"unexpected"); + return status::runtime_error; + } + virtual status_t create_stream_impl( impl::stream_impl_t **stream_impl, unsigned flags) const { auto *si = new impl::stream_impl_t(flags); diff --git a/src/common/experimental.cpp b/src/common/experimental.cpp index 6ec2f545a34..e32e9b433da 100644 --- a/src/common/experimental.cpp +++ b/src/common/experimental.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Intel Corporation +* Copyright 2022-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,16 @@ bool DNNL_API use_bnorm_stats_one_pass() { return stats_onepass_algo; } +bool use_gpu_conv_v2() { +#ifdef DNNL_EXPERIMENTAL + static const bool is_enabled + = getenv_int_user("EXPERIMENTAL_GPU_CONV_V2", 0); +#else + static const bool is_enabled = false; +#endif + return is_enabled; +} + } // namespace experimental } // namespace impl } // namespace dnnl diff --git a/src/common/experimental.hpp b/src/common/experimental.hpp index c6efb96fd9d..7cdfcbd7d83 100644 --- a/src/common/experimental.hpp +++ b/src/common/experimental.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Intel Corporation +* Copyright 2022-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,9 +21,10 @@ namespace impl { namespace experimental { bool use_bnorm_stats_one_pass(); +bool use_gpu_conv_v2(); } // namespace experimental } // namespace impl } // namespace dnnl -#endif \ No newline at end of file +#endif diff --git a/src/common/float16.hpp b/src/common/float16.hpp index 5449967d5d6..1125fd11e43 100644 --- a/src/common/float16.hpp +++ b/src/common/float16.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ #include "bit_cast.hpp" #include "oneapi/dnnl/dnnl.h" +#include "cpu/platform.hpp" namespace dnnl { namespace impl { @@ -39,7 +40,7 @@ struct float16_t { float16_t &operator=(float f); operator float() const; - float f() { return (float)(*this); } + float f() const { return (float)(*this); } float16_t &operator+=(float16_t a) { (*this) = float(f() + a.f()); diff --git a/src/common/float4.cpp b/src/common/float4.cpp new file mode 100644 index 00000000000..34eaae4b4c5 --- /dev/null +++ b/src/common/float4.cpp @@ -0,0 +1,162 @@ +/******************************************************************************* +* Copyright 2024-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "common/bit_cast.hpp" +#include "common/dnnl_thread.hpp" +#include "common/float16.hpp" +#include "common/float4.hpp" +#include "common/utils.hpp" + +namespace dnnl { +namespace impl { + +uint8_t float2e2m1(float f) { + uint32_t f_raw = float2int(f); + uint32_t sign = f_raw & 0x80000000; + + // There is no NaN or infinity in e2m1, for now we just return zero + // TODO: figure if there is a standard value to return + uint32_t naninf_mask = 0x7f800000; + if ((f_raw & naninf_mask) == naninf_mask) return 0x07 | (sign >> 28); + + // we convert with naive closest value computation out of 8 + float e2m1_val_table[8] = {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f}; + + float abs_f = fmin(e2m1_val_table[7], int2float(f_raw ^ sign)); + + int idx = 0; + float min_diff = ::fabsf(e2m1_val_table[idx] - abs_f); + uint8_t raw_bits = idx; + for (++idx; idx < 8; ++idx) { + float diff = ::fabsf(e2m1_val_table[idx] - abs_f); + if (diff < min_diff) { + min_diff = diff; + raw_bits = idx; + } + // Special case for midpoint, we round to even (so even index) + if ((diff == min_diff) && !(idx & 1)) raw_bits = idx; + } + assert(raw_bits < 8); + // reapply sign + if (sign) raw_bits = raw_bits | 0x08; + assert(raw_bits < 16); + return raw_bits; +} + +float4_e2m1_t &float4_e2m1_t::operator=(bfloat16_t f) { + float f32 = f; + raw_bits_ = float2e2m1(f32); + return *this; +} + +float4_e2m1_t &float4_e2m1_t::operator=(float16_t f) { + float f32 = f; + raw_bits_ = float2e2m1(f32); + return *this; +} + +float4_e2m1_t &float4_e2m1_t::operator=(float f) { + raw_bits_ = float2e2m1(f); + return *this; +} + +float4_e2m1_t::operator float() const { + // List of e2m1 values. The index of each value maps to its encoding. + static const float e2m1_table[16] = {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, + 4.0f, 6.0f, -0.0f, -.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f}; + assert(raw_bits_ < 16); + return e2m1_table[raw_bits_]; +} + +float4_e2m1_t::operator float16_t() const { + // List of e2m1 values. The index of each value maps to its encoding. + static const float16_t e2m1_table[16] = {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, + 4.0f, 6.0f, -0.0f, -.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f}; + assert(raw_bits_ < 16); + return e2m1_table[raw_bits_]; +} + +uint8_t float2e3m0(float f) { + uint32_t f_raw = float2int(f); + uint32_t sign = f_raw & 0x80000000; + + // There is no NaN or infinity in e3m0, we just return maxval + uint32_t naninf_mask = 0x7f800000; + if ((f_raw & naninf_mask) == naninf_mask) return 0x7 | (sign >> 28); + + // we convert with naive closest value computation out of 8 + float e3m0_val_table[8] = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f}; + + float abs_f = fmin(e3m0_val_table[7], int2float(f_raw ^ sign)); + + int idx = 0; + float min_diff = ::fabsf(e3m0_val_table[idx] - abs_f); + uint8_t raw_bits = idx; + for (++idx; idx < 8; ++idx) { + float diff = ::fabsf(e3m0_val_table[idx] - abs_f); + if (diff < min_diff) { + min_diff = diff; + raw_bits = idx; + } + // Special case for midpoint, we round to even (so even index) + if ((diff == min_diff) && !(idx & 1)) raw_bits = idx; + } + assert(raw_bits < 8); + // reapply sign + if (sign) raw_bits = raw_bits | 0x08; + assert(raw_bits < 16); + return raw_bits; +} + +float4_e3m0_t &float4_e3m0_t::operator=(bfloat16_t f) { + float f32 = f; + raw_bits_ = float2e3m0(f32); + return *this; +} + +float4_e3m0_t &float4_e3m0_t::operator=(float16_t f) { + float f32 = f; + raw_bits_ = float2e3m0(f32); + return *this; +} + +float4_e3m0_t &float4_e3m0_t::operator=(float f) { + raw_bits_ = float2e3m0(f); + return *this; +} + +float4_e3m0_t::operator float() const { + // List of e3m0 values. The index of each value maps to its encoding. + static const float e3m0_table[16] + = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f, + -.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f}; + assert(raw_bits_ < 16); + return e3m0_table[raw_bits_]; +} + +float4_e3m0_t::operator float16_t() const { + // List of e3m0 values. The index of each value maps to its encoding. + static const float16_t e3m0_table[16] + = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f, + -.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f}; + assert(raw_bits_ < 16); + return e3m0_table[raw_bits_]; +} + +} // namespace impl +} // namespace dnnl diff --git a/src/common/float4.hpp b/src/common/float4.hpp new file mode 100644 index 00000000000..44be31d9d0a --- /dev/null +++ b/src/common/float4.hpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_FLOAT4_HPP +#define COMMON_FLOAT4_HPP + +#include +#include + +#include "common/bfloat16.hpp" +#include "common/float16.hpp" + +namespace dnnl { +namespace impl { + +struct float4_e2m1_t { + uint8_t raw_bits_; + float4_e2m1_t() = default; + constexpr float4_e2m1_t(uint8_t r, bool = true) : raw_bits_(r) {} + float4_e2m1_t(float f) { (*this) = f; } + float4_e2m1_t(float16_t f) { (*this) = f; } + float4_e2m1_t(bfloat16_t f) { (*this) = f; } + + float4_e2m1_t DNNL_API &operator=(float f); + float4_e2m1_t DNNL_API &operator=(float16_t f); + float4_e2m1_t DNNL_API &operator=(bfloat16_t f); + + DNNL_API operator float() const; + DNNL_API operator float16_t() const; + DNNL_API operator bfloat16_t() const; + + float4_e2m1_t &operator+=(const float a) { + (*this) = float {*this} + a; + return *this; + } +}; +static_assert(sizeof(float4_e2m1_t) == 1, "float4_e2m1_t must be 1 byte"); + +struct float4_e3m0_t { + uint8_t raw_bits_; + float4_e3m0_t() = default; + constexpr float4_e3m0_t(uint8_t r, bool = true) : raw_bits_(r) {} + float4_e3m0_t(float f) { (*this) = f; } + float4_e3m0_t(float16_t f) { (*this) = f; } + float4_e3m0_t(bfloat16_t f) { (*this) = f; } + + float4_e3m0_t DNNL_API &operator=(float f); + float4_e3m0_t DNNL_API &operator=(float16_t f); + float4_e3m0_t DNNL_API &operator=(bfloat16_t f); + + DNNL_API operator float() const; + DNNL_API operator float16_t() const; + DNNL_API operator bfloat16_t() const; + + float4_e3m0_t &operator+=(const float a) { + (*this) = float {*this} + a; + return *this; + } +}; +static_assert(sizeof(float4_e3m0_t) == 1, "float4_e3m0_t must be 1 byte"); + +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/common/float8.hpp b/src/common/float8.hpp index a32e9bbca7d..e8f09cb0467 100644 --- a/src/common/float8.hpp +++ b/src/common/float8.hpp @@ -72,6 +72,7 @@ struct float8_e4m3_t { return *this; } }; +static_assert(sizeof(float8_e4m3_t) == 1, "float8_e4m3_t must be 1 byte"); void cvt_f8_e5m2_to_float(float *out, const float8_e5m2_t *inp, size_t nelems); void cvt_f8_e4m3_to_float(float *out, const float8_e4m3_t *inp, size_t nelems); @@ -85,8 +86,6 @@ void add_floats_and_cvt_to_f8_e5m2(float8_e5m2_t *out, const float *inp0, void add_floats_and_cvt_to_f8_e4m3(float8_e4m3_t *out, const float *inp0, const float *inp1, size_t nelems); -static_assert(sizeof(float8_e5m2_t) == 1, "float8_e4m3_t must be 1 byte"); - #if DNNL_X64 namespace cpu { namespace x64 { diff --git a/src/common/gemm.cpp b/src/common/gemm.cpp index 6a2578cd5ca..52da4a1be84 100644 --- a/src/common/gemm.cpp +++ b/src/common/gemm.cpp @@ -85,7 +85,7 @@ std::string get_descriptor(dim_t M, dim_t N, dim_t K) { if (!is_src_ab && lda != M) ss << "lda:" << lda << " "; \ if (is_wei_ab && ldb != N) ss << "ldb:" << ldb << " "; \ if (!is_wei_ab && ldb != K) ss << "ldb:" << ldb << " "; \ - if (alpha != 1.f) ss << "attr-oscale:common:" << alpha << " "; \ + if (alpha != 1.f) ss << "attr-scales:src:common:" << alpha << " "; \ if (beta != 0.f) ss << "attr-post-ops:sum:" << beta << " "; \ ss << ",," << get_descriptor(M, N, K); \ VPROF(start_ms, primitive, exec, VERBOSE_profile, ss.str().c_str(), \ diff --git a/src/common/gemm_pd.hpp b/src/common/gemm_pd.hpp index d7c4f2c650e..2180aada880 100644 --- a/src/common/gemm_pd.hpp +++ b/src/common/gemm_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,11 +36,12 @@ namespace impl { VCHECK(primitive, create, dispatch, gemm, (f), "%s," msg, \ this->info(engine), ##__VA_ARGS__) +// NOLINTBEGIN(google-default-arguments) struct gemm_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::gemm; - typedef gemm_pd_t base_class; - typedef gemm_pd_t hint_class; + using base_class = gemm_pd_t; + using hint_class = gemm_pd_t; const gemm_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { @@ -91,9 +92,10 @@ struct gemm_pd_t : public primitive_desc_t { // resolve the 'any' tags. gemm_desc_t desc_; - gemm_pd_t(const gemm_desc_t *adesc, const primitive_attr_t *attr, + gemm_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const hint_class *hint_fwd_pd) - : primitive_desc_t(attr, base_pkind), desc_(*adesc) {} + : primitive_desc_t(attr, base_pkind) + , desc_(*op_desc_t::to_desc(adesc)) {} // By default, we just resolve 'any' with blocked layout and trivial strides bool set_default_format(memory_desc_t *md) { @@ -121,6 +123,7 @@ struct gemm_pd_t : public primitive_desc_t { return ok; } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/gemm_types.hpp b/src/common/gemm_types.hpp index c23a7c9f58d..7dac69356e3 100644 --- a/src/common/gemm_types.hpp +++ b/src/common/gemm_types.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include #include "common/c_types_map.hpp" #include "common/memory_desc.hpp" +#include "common/opdesc.hpp" namespace dnnl { namespace impl { @@ -49,20 +50,23 @@ const sum_ab_t sum_none = dnnl_sum_none; // A descriptor for a matrix multiplication (gemm) operation. To make the // interface consistent, the descriptor represent the GEMM operation in row // major. -struct gemm_desc_t { - // The kind of primitive. Used for self identifying the primitive - // descriptor. Must be #dnnl_gemm. - dnnl_primitive_kind_t primitive_kind; +struct gemm_desc_t : public op_desc_t { + gemm_desc_t() : op_desc_t(primitive_kind::gemm) {} + + std::unique_ptr clone() const override { + return utils::make_unique(*this); + } + memory_desc_t a_desc; memory_desc_t b_desc; memory_desc_t c_desc; memory_desc_t bias_desc; // Type for accumulating A*B. - dnnl_data_type_t acc_type; + dnnl_data_type_t acc_type {}; // Sum across k dimension in either A or B tensor // and output to sum_ab tensor. - sum_ab_t sum_ab; - dnnl_data_type_t sum_ab_type; + sum_ab_t sum_ab {}; + dnnl_data_type_t sum_ab_type {}; // These accessors are to be used by the GEMM implementation. Because the // GEMM implementation currently assumes column major. These accessors @@ -73,7 +77,8 @@ struct gemm_desc_t { // Simplified accessors that comply to GEMM API static transpose_t get_trans(const memory_desc_t &md) { if (!md.ndims) return transpose::notrans; // arbitrary - return md.format_desc.blocking.strides[md.ndims - 1] != 1 + return md.dims[md.ndims - 1] != 1 + && md.format_desc.blocking.strides[md.ndims - 1] != 1 ? transpose::trans : transpose::notrans; } @@ -116,9 +121,16 @@ struct gemm_desc_t { // This assumes that one of the dimensions has strides 1 static dnnl_dim_t get_ld(const memory_desc_t &md) { auto strides = md.format_desc.blocking.strides; - assert(strides[md.ndims - 1] == 1 || strides[md.ndims - 2] == 1); - return strides[md.ndims - 1] != 1 ? strides[md.ndims - 1] - : strides[md.ndims - 2]; + assert(md.dims[md.ndims - 1] == 1 || strides[md.ndims - 1] == 1 + || md.dims[md.ndims - 2] == 1 || strides[md.ndims - 2] == 1); + switch (get_trans(md)) { + case transpose::trans: + return md.dims[md.ndims - 1] > 1 ? strides[md.ndims - 1] + : md.dims[md.ndims - 2]; + default: + return md.dims[md.ndims - 2] > 1 ? strides[md.ndims - 2] + : md.dims[md.ndims - 1]; + } } // Leading dimension of A. dnnl_dim_t lda() const { return get_ld(b_desc); } diff --git a/src/common/gemm_utils.hpp b/src/common/gemm_utils.hpp index 65045a7d911..23afbb2c2a0 100644 --- a/src/common/gemm_utils.hpp +++ b/src/common/gemm_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -101,9 +101,10 @@ static inline status_t create_2d_desc(memory_desc_t *md_2d, int d0, int d1, } } -static inline gemm_desc_t create_gemm_desc(const memory_desc_t *a_md, - const memory_desc_t *b_md, const memory_desc_t *c_md, - const memory_desc_t *bias_md, data_type_t acc_dt, engine_t *engine, +static inline status_t create_gemm_desc(gemm_desc_t *_gemm_desc, + const memory_desc_t *a_md, const memory_desc_t *b_md, + const memory_desc_t *c_md, const memory_desc_t *bias_md, + data_type_t acc_dt, engine_t *engine, sum_ab_t sum_ab = sum_ab::sum_none, data_type_t sum_ab_dt = data_type::undef) { auto gemm_desc = gemm_desc_t(); @@ -121,7 +122,8 @@ static inline gemm_desc_t create_gemm_desc(const memory_desc_t *a_md, data_type::f16, a_md->data_type, b_md->data_type)) { gemm_desc.acc_type = data_type::f16; } - return gemm_desc; + *_gemm_desc = gemm_desc; + return status::success; } static inline status_t create_gemm_pd( @@ -131,8 +133,9 @@ static inline status_t create_gemm_pd( data_type_t acc_dt, const primitive_attr_t *attr, bool skip_ref = false, sum_ab_t sum_ab = sum_ab::sum_none, data_type_t sum_ab_dt = data_type::undef) { - auto gemm_desc = create_gemm_desc( - a_md, b_md, c_md, bias_md, acc_dt, engine, sum_ab, sum_ab_dt); + gemm_desc_t gemm_desc; + CHECK(create_gemm_desc(&gemm_desc, a_md, b_md, c_md, bias_md, acc_dt, + engine, sum_ab, sum_ab_dt)); primitive_attr_t gemm_attr = *attr; @@ -141,7 +144,7 @@ static inline status_t create_gemm_pd( gemm_pd_ = *(++it); if (!gemm_pd_) return status::unimplemented; - if (skip_ref && strstr(gemm_pd_.get()->name(), "ref") != NULL) + if (skip_ref && strstr(gemm_pd_->name(), "ref") != nullptr) return status::unimplemented; return status::success; @@ -156,8 +159,11 @@ static inline bool is_md_gemm_compatible_plain_format( if (blk_desc.inner_nblks != 0) return false; - return (blk_desc.strides[md->ndims - 1] == 1) - || (!is_dst && blk_desc.strides[md->ndims - 2] == 1); + return (md->dims[md->ndims - 1] == 1 + || blk_desc.strides[md->ndims - 1] == 1) + || (!is_dst + && (md->dims[md->ndims - 2] == 1 + || blk_desc.strides[md->ndims - 2] == 1)); } } // namespace impl diff --git a/src/common/group_normalization.cpp b/src/common/group_normalization.cpp index 4e0abf3b6a2..ddc44734119 100644 --- a/src/common/group_normalization.cpp +++ b/src/common/group_normalization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -147,19 +147,25 @@ status_t group_normalization_attr_check(const group_normalization_desc_t &desc, const bool is_int8 = utils::one_of(src_dt, data_type::s8, data_type::u8) || utils::one_of(dst_dt, data_type::s8, data_type::u8); - if (is_int8) fwd_attr_mask |= smask_t::scales_runtime; + if (is_int8) fwd_attr_mask |= smask_t::scales; VCHECK_GNORM_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); // Check scales if (!attr->scales_.has_default_values()) { - const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - VCHECK_GNORM_UNIMPL(utils::everyone_is(0, mask_src, mask_dst), + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_DST}; + VCHECK_GNORM_UNIMPL( + attr->scales_.has_default_values(supported_args), VERBOSE_UNSUPPORTED_SCALES_CFG); + + for (int arg : supported_args) { + if (attr->scales_.has_default_values(arg)) continue; + + const int mask = attr->scales_.get_mask(arg); + VCHECK_GNORM_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } } // Check post-ops @@ -168,6 +174,9 @@ status_t group_normalization_attr_check(const group_normalization_desc_t &desc, using namespace primitive_kind; VCHECK_GNORM_UNIMPL(po.has_default_values({binary, eltwise}), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } } else { VCHECK_GNORM_UNIMPL(false, VERBOSE_UNSUPPORTED_ATTR); diff --git a/src/common/group_normalization_pd.hpp b/src/common/group_normalization_pd.hpp index 313f36f378d..7b908050efc 100644 --- a/src/common/group_normalization_pd.hpp +++ b/src/common/group_normalization_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -100,20 +100,21 @@ struct group_normalization_pd_t : public primitive_desc_t { memory_desc_t stat_md_; memory_desc_t scaleshift_md_; - group_normalization_pd_t(const group_normalization_desc_t *adesc, + group_normalization_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const group_normalization_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) , src_md_(desc_.src_desc) , stat_md_(desc_.stat_desc) , scaleshift_md_(desc_.scaleshift_desc) {} }; +// NOLINTBEGIN(google-default-arguments) struct group_normalization_fwd_pd_t : public group_normalization_pd_t { - typedef group_normalization_fwd_pd_t base_class; - typedef group_normalization_fwd_pd_t hint_class; + using base_class = group_normalization_fwd_pd_t; + using hint_class = group_normalization_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; @@ -123,8 +124,10 @@ struct group_normalization_fwd_pd_t : public group_normalization_pd_t { return arg_usage_t::unused; } - if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input; - if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input; + if (arg == DNNL_ARG_SCALE) + return use_scale() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_SHIFT) + return use_shift() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_DST) return arg_usage_t::output; return primitive_desc_t::arg_usage(arg); @@ -175,7 +178,7 @@ struct group_normalization_fwd_pd_t : public group_normalization_pd_t { protected: memory_desc_t dst_md_; - group_normalization_fwd_pd_t(const group_normalization_desc_t *adesc, + group_normalization_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const hint_class *hint_fwd_pd) : group_normalization_pd_t(adesc, attr, hint_fwd_pd) , dst_md_(desc_.dst_desc) {} @@ -190,43 +193,44 @@ struct group_normalization_fwd_pd_t : public group_normalization_pd_t { return IMPLICATION(use_scale() || use_shift(), weights_md()->data_type == data_type::f32); } - bool attr_scales_ok() const { + bool attr_scales_ok(const std::vector &supported_args + = {DNNL_ARG_SRC, DNNL_ARG_DST}) const { using namespace data_type; const auto &scales = attr()->scales_; - const std::vector supported_args({DNNL_ARG_SRC, DNNL_ARG_DST}); bool ok = scales.has_default_values(supported_args); for (const auto &arg : supported_args) { - const auto &sc = scales.get(arg); - if (!sc.has_default_values()) { + if (!scales.has_default_values(arg)) { const data_type_t dt = arg_md(arg)->data_type; - ok = ok && utils::one_of(dt, s8, u8) && sc.mask_ == 0; + ok = ok && utils::one_of(dt, s8, u8); + ok = ok && scales.get_mask(arg) == 0; } } return ok; } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct group_normalization_bwd_pd_t : public group_normalization_pd_t { - typedef group_normalization_bwd_pd_t base_class; - typedef group_normalization_fwd_pd_t hint_class; + using base_class = group_normalization_bwd_pd_t; + using hint_class = group_normalization_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE, DNNL_ARG_DIFF_DST)) return arg_usage_t::input; - if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input; - - if (arg == DNNL_ARG_WORKSPACE && !types::is_zero_md(workspace_md())) - return arg_usage_t::input; + if (arg == DNNL_ARG_SCALE) + return use_scale() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_SCALE && use_scale()) - return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_SHIFT && use_shift()) - return arg_usage_t::output; + if (arg == DNNL_ARG_DIFF_SCALE) + return use_scale() ? arg_usage_t::output : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_SHIFT) + return use_shift() ? arg_usage_t::output : arg_usage_t::unused; + return primitive_desc_t::arg_usage(arg); } @@ -285,7 +289,7 @@ struct group_normalization_bwd_pd_t : public group_normalization_pd_t { memory_desc_t diff_dst_md_; memory_desc_t diff_scaleshift_md_; - group_normalization_bwd_pd_t(const group_normalization_desc_t *adesc, + group_normalization_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const hint_class *hint_fwd_pd) : group_normalization_pd_t(adesc, attr, hint_fwd_pd) , diff_src_md_(desc_.diff_src_desc) @@ -309,6 +313,7 @@ struct group_normalization_bwd_pd_t : public group_normalization_pd_t { diff_weights_md()->data_type)); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/impl_list_item.hpp b/src/common/impl_list_item.hpp index 45a8d0aee1f..a18f4d49a9f 100644 --- a/src/common/impl_list_item.hpp +++ b/src/common/impl_list_item.hpp @@ -19,6 +19,7 @@ #include "c_types_map.hpp" #include "primitive_desc.hpp" +#include "dnnl_sel_build.hpp" #include "utils.hpp" namespace dnnl { @@ -89,24 +90,35 @@ struct impl_list_item_t { : public type_deduction_helper_t {}; template - constexpr impl_list_item_t(type_deduction_helper_t) - : create_pd_func_(&primitive_desc_t::create< - typename type_deduction_helper_t::type>) {} + impl_list_item_t(type_deduction_helper_t) { + using deduced_pd_t = typename type_deduction_helper_t::type; + create_pd_func_ = &primitive_desc_t::create; + DNNL_PRIMITIVE_NAME_INIT(pd_t); + } template - constexpr impl_list_item_t(concat_type_deduction_helper_t) - : create_concat_pd_func_( - concat_type_deduction_helper_t::type::create) {} + impl_list_item_t(concat_type_deduction_helper_t) { + using deduced_pd_t = + typename concat_type_deduction_helper_t::type; + create_concat_pd_func_ = deduced_pd_t::create; + DNNL_PRIMITIVE_NAME_INIT(pd_t); + } template - constexpr impl_list_item_t(sum_type_deduction_helper_t) - : create_sum_pd_func_(sum_type_deduction_helper_t::type::create) { + impl_list_item_t(sum_type_deduction_helper_t) { + using deduced_pd_t = typename sum_type_deduction_helper_t::type; + create_sum_pd_func_ = deduced_pd_t::create; + DNNL_PRIMITIVE_NAME_INIT(pd_t); } template - constexpr impl_list_item_t(reorder_type_deduction_helper_t) - : create_reorder_pd_func_( - reorder_type_deduction_helper_t::type::create) {} + impl_list_item_t(reorder_type_deduction_helper_t) { + using deduced_pd_t = + typename reorder_type_deduction_helper_t::type; + create_reorder_pd_func_ = deduced_pd_t::create; + DNNL_PRIMITIVE_NAME_INIT(pd_t); + } + explicit operator bool() const { return !utils::everyone_is(nullptr, create_pd_func_, @@ -127,6 +139,10 @@ struct impl_list_item_t { return -1; } +#if defined(SELECTIVE_BUILD_ANALYZER) + const char *name = {}; +#endif + private: status_t operator()(primitive_desc_t **pd, const op_desc_t *adesc, const primitive_attr_t *attr, engine_t *engine, @@ -206,6 +222,15 @@ struct impl_list_item_t { engine_t *, const primitive_attr_t *); }; +#if defined(SELECTIVE_BUILD_ANALYZER) +inline impl_list_item_t&& move(impl_list_item_t &&t, const char *name) { + OV_ITT_SCOPED_TASK( + dnnl::FACTORY_DNNL, + openvino::itt::handle(std::string("REG$CPUEngine$") + t.name + "$" + name)); + return static_cast(t); +} +#endif + } // namespace impl } // namespace dnnl diff --git a/src/common/impl_registration.hpp b/src/common/impl_registration.hpp index 6625f21653d..d6dc6c36079 100644 --- a/src/common/impl_registration.hpp +++ b/src/common/impl_registration.hpp @@ -27,9 +27,9 @@ #define REG_BWD_D_PK(...) __VA_ARGS__ #else #define REG_BWD_PK(...) \ - { nullptr } + { nullptr }, #define REG_BWD_D_PK(...) \ - { nullptr } + { nullptr }, #endif // Primitives section @@ -56,7 +56,7 @@ #define REG_CONCAT_P(...) __VA_ARGS__ #else #define REG_CONCAT_P(...) \ - { nullptr } + { nullptr }, #endif #if BUILD_PRIMITIVE_ALL || BUILD_CONVOLUTION @@ -128,7 +128,7 @@ #define REG_MATMUL_P(...) __VA_ARGS__ #else #define REG_MATMUL_P(...) \ - { nullptr } + { nullptr }, #endif #if BUILD_PRIMITIVE_ALL || BUILD_POOLING @@ -149,7 +149,7 @@ #define REG_REDUCTION_P(...) __VA_ARGS__ #else #define REG_REDUCTION_P(...) \ - { nullptr } + { nullptr }, #endif #if BUILD_PRIMITIVE_ALL || BUILD_REORDER @@ -245,4 +245,10 @@ #define REG_XE2_ISA(...) #endif +#if BUILD_PRIMITIVE_GPU_ISA_ALL || BUILD_XE3 +#define REG_XE3_ISA(...) __VA_ARGS__ +#else +#define REG_XE3_ISA(...) +#endif + #endif diff --git a/src/common/inner_product.cpp b/src/common/inner_product.cpp index 8375869cec0..05c314a7645 100644 --- a/src/common/inner_product.cpp +++ b/src/common/inner_product.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -109,36 +109,60 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine, using smask_t = primitive_attr_t::skip_mask_t; if (attr == nullptr) return status::success; - if (attr->has_default_values()) return status::success; + const data_type_t src_dt = desc.src_desc.data_type; + const data_type_t wei_dt = desc.weights_desc.data_type; + bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) && + one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) || + (one_of(src_dt, data_type::f32) && one_of(wei_dt, data_type::f16, data_type::bf16)); + auto attr_mask = smask_t::none; + // From oneDNN 3.5, those checks must be skipped if wei_decomp is enabled + // reference from src/plugins/intel_cpu/thirdparty/onednn/src/common/matmul.cpp:L62 + if (is_weight_compression) { + attr_mask |= smask_t::zero_points_data_type; + attr_mask |= smask_t::zero_points_groups; + attr_mask |= smask_t::scales_data_type; + attr_mask |= smask_t::scales_groups; + } + if (attr->has_default_values(attr_mask)) return status::success; // Check attributes if (utils::one_of(desc.prop_kind, prop_kind::forward_inference, prop_kind::forward_training)) { const data_type_t src_dt = desc.src_desc.data_type; const data_type_t dst_dt = desc.dst_desc.data_type; + const data_type_t wei_dt = desc.weights_desc.data_type; - auto fwd_attr_mask - = smask_t::post_ops | smask_t::sum_dt | smask_t::fpmath_mode; + auto fwd_attr_mask = smask_t::post_ops | smask_t::sum_dt + | smask_t::fpmath_mode | smask_t::accumulation_mode; bool is_int8 = utils::one_of(src_dt, data_type::s8, data_type::u8); if (engine->kind() == engine_kind::gpu) is_int8 = is_int8 || utils::one_of(dst_dt, data_type::s8, data_type::u8, data_type::s32); - if (is_int8) fwd_attr_mask |= smask_t::scales_runtime; + if (engine->kind() == engine_kind::cpu) + is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1); + if (is_int8) fwd_attr_mask |= smask_t::scales | smask_t::zero_points | smask_t::src_dyn_quant_params; + + if (is_weight_compression) { + fwd_attr_mask |= attr_mask; + } VCHECK_IP_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); // Check scales if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - VCHECK_IP_UNIMPL(utils::everyone_is(0, mask_src, mask_dst) - && utils::one_of(mask_wei, 0, 1), + VCHECK_IP_UNIMPL(IMPLICATION(!sc.has_default_values(DNNL_ARG_SRC), + sc.get_mask(DNNL_ARG_SRC) == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_IP_UNIMPL( + IMPLICATION(!sc.has_default_values(DNNL_ARG_WEIGHTS), + utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0, 1)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_IP_UNIMPL(IMPLICATION(!sc.has_default_values(DNNL_ARG_DST), + sc.get_mask(DNNL_ARG_DST) == 0), VERBOSE_UNSUPPORTED_SCALES_CFG); } @@ -153,9 +177,12 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine, // Check sum VCHECK_IP_UNIMPL(po.check_sum_consistency(dst_dt, is_int8, true), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } } else { - auto bwd_attr_mask = smask_t::fpmath_mode; + auto bwd_attr_mask = smask_t::fpmath_mode | smask_t::accumulation_mode; VCHECK_IP_UNIMPL(attr->has_default_values(bwd_attr_mask), VERBOSE_UNSUPPORTED_ATTR); } diff --git a/src/common/inner_product_pd.hpp b/src/common/inner_product_pd.hpp index 6dbdfcafc64..4e71f61aa17 100644 --- a/src/common/inner_product_pd.hpp +++ b/src/common/inner_product_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,13 +44,6 @@ struct inner_product_fwd_pd_t; struct inner_product_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::inner_product; - inner_product_pd_t(const inner_product_desc_t *adesc, - const primitive_attr_t *attr, - const inner_product_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) {} - const inner_product_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { return reinterpret_cast(this->desc()); @@ -139,10 +132,16 @@ struct inner_product_pd_t : public primitive_desc_t { inner_product_desc_t desc_; const inner_product_fwd_pd_t *hint_fwd_pd_; + inner_product_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(attr, base_pkind) + , desc_(*op_desc_t::to_desc(adesc)) + , hint_fwd_pd_(hint_fwd_pd) {} + bool set_default_formats_common_template(memory_desc_t &src_md, format_tag_t src_tag, memory_desc_t &wei_md, format_tag_t wei_tag, memory_desc_t &dst_md, format_tag_t dst_tag, - memory_desc_t &bia_md) { + memory_desc_t &bia_md) const { using namespace format_tag; #define IS_OK(f) \ @@ -185,7 +184,9 @@ struct inner_product_pd_t : public primitive_desc_t { = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { bool ok = attr()->scales_.has_default_values(supported_args); for (auto arg : supported_args) { - int mask = attr()->scales_.get(arg).mask_; + if (attr()->scales_.has_default_values(arg)) continue; + + int mask = attr()->scales_.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS) ok = ok && (mask == 0 || mask == (1 << 0)); else @@ -195,24 +196,17 @@ struct inner_product_pd_t : public primitive_desc_t { } }; +// NOLINTBEGIN(google-default-arguments) struct inner_product_fwd_pd_t : public inner_product_pd_t { - typedef inner_product_fwd_pd_t base_class; - typedef inner_product_fwd_pd_t hint_class; - - inner_product_fwd_pd_t(const inner_product_desc_t *adesc, - const primitive_attr_t *attr, - const inner_product_fwd_pd_t *hint_fwd_pd) - : inner_product_pd_t(adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , weights_md_(desc_.weights_desc) - , bias_md_(desc_.bias_desc) - , dst_md_(desc_.dst_desc) {} + using base_class = inner_product_fwd_pd_t; + using hint_class = inner_product_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS)) return arg_usage_t::input; - if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; + if (arg == DNNL_ARG_BIAS) + return with_bias() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_DST) return arg_usage_t::output; @@ -259,24 +253,26 @@ struct inner_product_fwd_pd_t : public inner_product_pd_t { memory_desc_t bias_md_; memory_desc_t dst_md_; + inner_product_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) {} + bool set_default_formats_common( format_tag_t src_tag, format_tag_t wei_tag, format_tag_t dst_tag) { return set_default_formats_common_template(src_md_, src_tag, weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct inner_product_bwd_data_pd_t : public inner_product_pd_t { - typedef inner_product_bwd_data_pd_t base_class; - typedef inner_product_fwd_pd_t hint_class; - - inner_product_bwd_data_pd_t(const inner_product_desc_t *adesc, - const primitive_attr_t *attr, - const inner_product_fwd_pd_t *hint_fwd_pd) - : inner_product_pd_t(adesc, attr, hint_fwd_pd) - , diff_src_md_(desc_.diff_src_desc) - , weights_md_(desc_.weights_desc) - , diff_dst_md_(desc_.diff_dst_desc) {} + using base_class = inner_product_bwd_data_pd_t; + using hint_class = inner_product_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_WEIGHTS, DNNL_ARG_DIFF_DST)) @@ -324,6 +320,14 @@ struct inner_product_bwd_data_pd_t : public inner_product_pd_t { memory_desc_t weights_md_; memory_desc_t diff_dst_md_; + inner_product_bwd_data_pd_t(const op_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) {} + bool set_default_formats_common(format_tag_t diff_src_tag, format_tag_t wei_tag, format_tag_t diff_dst_tag) { memory_desc_t dummy_md; @@ -331,19 +335,12 @@ struct inner_product_bwd_data_pd_t : public inner_product_pd_t { weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, dummy_md); } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct inner_product_bwd_weights_pd_t : public inner_product_pd_t { - typedef inner_product_bwd_weights_pd_t base_class; - typedef inner_product_fwd_pd_t hint_class; - - inner_product_bwd_weights_pd_t(const inner_product_desc_t *adesc, - const primitive_attr_t *attr, - const inner_product_fwd_pd_t *hint_fwd_pd) - : inner_product_pd_t(adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , diff_weights_md_(desc_.diff_weights_desc) - , diff_bias_md_(desc_.diff_bias_desc) - , diff_dst_md_(desc_.diff_dst_desc) {} + using base_class = inner_product_bwd_weights_pd_t; + using hint_class = inner_product_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST)) @@ -351,8 +348,8 @@ struct inner_product_bwd_weights_pd_t : public inner_product_pd_t { if (arg == DNNL_ARG_DIFF_WEIGHTS) return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_BIAS && with_bias()) - return arg_usage_t::output; + if (arg == DNNL_ARG_DIFF_BIAS) + return with_bias() ? arg_usage_t::output : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -397,6 +394,15 @@ struct inner_product_bwd_weights_pd_t : public inner_product_pd_t { memory_desc_t diff_bias_md_; memory_desc_t diff_dst_md_; + inner_product_bwd_weights_pd_t(const op_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) {} + bool set_default_formats_common(format_tag_t src_tag, format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { return set_default_formats_common_template(src_md_, src_tag, @@ -404,6 +410,7 @@ struct inner_product_bwd_weights_pd_t : public inner_product_pd_t { diff_bias_md_); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/int4.hpp b/src/common/int4.hpp index 2e692d13a5a..6658eac9921 100644 --- a/src/common/int4.hpp +++ b/src/common/int4.hpp @@ -25,41 +25,23 @@ namespace dnnl { namespace impl { -enum class int4_extract_t : uint8_t { low_half = 0, high_half = 4 }; - -inline uint8_t extract_half_byte(uint8_t val, int4_extract_t half) { - uint8_t shift = static_cast(half); - return (val >> shift) & 0xF; -} - -inline uint8_t insert_half_byte(uint8_t src, uint8_t val, int4_extract_t half) { - uint8_t shift = static_cast(half); - uint8_t mask = half == int4_extract_t::high_half ? 0x0F : 0xF0; - return (src & mask) | (uint8_t)(val << shift); -} - struct uint4_t { template ::value>::type> - constexpr uint4_t(IntegerType raw) : raw_(raw) {} + constexpr uint4_t(IntegerType raw) : raw_bits_(static_cast(raw)) { +#if __cplusplus >= 201402L + assert(0 <= raw && raw <= std::numeric_limits::max()); +#endif + } uint4_t(float val_f32) { uint8_t val_uint8 = static_cast(val_f32); - raw_ = val_uint8 & 0xF; + raw_bits_ = val_uint8 & 0xF; } - operator float() const { return (float)raw_; } - - uint8_t insert(uint8_t src, int4_extract_t half) const { - return insert_half_byte(src, raw_, half); - } + operator float() const { return (float)raw_bits_; } - static uint4_t extract(uint8_t val, int4_extract_t half) { - return uint4_t(extract_half_byte(val, half)); - } - -private: - uint8_t raw_; + uint8_t raw_bits_; }; static_assert(sizeof(uint4_t) == 1, "uint4_t must be 1 byte"); @@ -68,30 +50,21 @@ struct int4_t { template ::value>::type> - constexpr int4_t(IntegerType i) : raw_(static_cast(i)) {} + constexpr int4_t(IntegerType i) : raw_bits_(static_cast(i)) {} int4_t(float val_f32) { int8_t val_int8 = static_cast(val_f32); bool negative = val_f32 < 0; // positive numbers have the most significant bit set to 0 // negative numbers have the most significant bit set to 1 - raw_ = negative ? (val_int8 & 0xF) | 0x8 : val_int8 & 0x7; + raw_bits_ = negative ? (val_int8 & 0xF) | 0x8 : val_int8 & 0x7; } operator float() const { - float sign = (raw_ & (1 << 3)) ? -1.f : 1.f; - return sign * (float)(sign == -1 ? (~raw_ & 0xF) + 1 : raw_); - } - - uint8_t insert(uint8_t src, int4_extract_t half) const { - return insert_half_byte(src, raw_, half); - } - - static int4_t extract(uint8_t val, int4_extract_t half) { - return int4_t(extract_half_byte(val, half)); + float sign = (raw_bits_ & (1 << 3)) ? -1.f : 1.f; + return sign * (float)(sign == -1 ? (~raw_bits_ & 0xF) + 1 : raw_bits_); } -private: - uint8_t raw_; + uint8_t raw_bits_; }; static_assert(sizeof(int4_t) == 1, "int4_t must be 1 byte"); diff --git a/src/common/ittnotify.cpp b/src/common/ittnotify.cpp index e9c9dfa8404..2994962a997 100644 --- a/src/common/ittnotify.cpp +++ b/src/common/ittnotify.cpp @@ -18,8 +18,8 @@ #include "utils.hpp" #if defined(DNNL_ENABLE_ITT_TASKS) -#include "common/ittnotify/ittnotify.h" #include "dnnl_debug.h" +#include "ittnotify/ittnotify.h" #endif namespace dnnl { @@ -80,12 +80,16 @@ void primitive_task_start(primitive_kind_t kind) { CASE(layer_normalization), CASE(group_normalization), CASE(sdpa), + CASE(depthwise), + CASE(quantization), }; #undef CASE int kind_idx = (int)kind; assert(kind_idx >= 0); - assert((size_t)kind_idx - < sizeof(prim_kind_itt_strings) / sizeof(prim_kind_itt_strings[0])); + if (kind_idx < primitive_kind::internal_only_start) { + assert((size_t)kind_idx < sizeof(prim_kind_itt_strings) + / sizeof(prim_kind_itt_strings[0])); + } __itt_task_begin(itt_domain(), __itt_null, __itt_null, prim_kind_itt_strings[kind_idx]); thread_primitive_kind = kind; diff --git a/src/common/ittnotify.hpp b/src/common/ittnotify.hpp index b1ec4b7e248..71a51394bbb 100644 --- a/src/common/ittnotify.hpp +++ b/src/common/ittnotify.hpp @@ -24,7 +24,9 @@ namespace dnnl { namespace impl { namespace itt { -typedef enum { +// GCC treats using and typedef differently for enums and structs +// https://stackoverflow.com/questions/48613758 +typedef enum { // NOLINT(modernize-use-using) __itt_task_level_none = 0, __itt_task_level_low, __itt_task_level_high diff --git a/src/common/ittnotify/ittnotify.h b/src/common/ittnotify/ittnotify.h deleted file mode 100644 index d3df4b5e380..00000000000 --- a/src/common/ittnotify/ittnotify.h +++ /dev/null @@ -1,4459 +0,0 @@ -/* - Copyright (C) 2005-2019 Intel Corporation - - SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause -*/ -#ifndef _ITTNOTIFY_H_ -#define _ITTNOTIFY_H_ - -/** -@file -@brief Public User API functions and types -@mainpage - -The Instrumentation and Tracing Technology API (ITT API) is used to -annotate a user's program with additional information -that can be used by correctness and performance tools. The user inserts -calls in their program. Those calls generate information that is collected -at runtime, and used by Intel(R) Threading Tools. - -@section API Concepts -The following general concepts are used throughout the API. - -@subsection Unicode Support -Many API functions take character string arguments. On Windows, there -are two versions of each such function. The function name is suffixed -by W if Unicode support is enabled, and by A otherwise. Any API function -that takes a character string argument adheres to this convention. - -@subsection Conditional Compilation -Many users prefer having an option to modify ITT API code when linking it -inside their runtimes. ITT API header file provides a mechanism to replace -ITT API function names inside your code with empty strings. To do this, -define the macros INTEL_NO_ITTNOTIFY_API during compilation and remove the -static library from the linker script. - -@subsection Domains -[see domains] -Domains provide a way to separate notification for different modules or -libraries in a program. Domains are specified by dotted character strings, -e.g. TBB.Internal.Control. - -A mechanism (to be specified) is provided to enable and disable -domains. By default, all domains are enabled. -@subsection Named Entities and Instances -Named entities (frames, regions, tasks, and markers) communicate -information about the program to the analysis tools. A named entity often -refers to a section of program code, or to some set of logical concepts -that the programmer wants to group together. - -Named entities relate to the programmer's static view of the program. When -the program actually executes, many instances of a given named entity -may be created. - -The API annotations denote instances of named entities. The actual -named entities are displayed using the analysis tools. In other words, -the named entities come into existence when instances are created. - -Instances of named entities may have instance identifiers (IDs). Some -API calls use instance identifiers to create relationships between -different instances of named entities. Other API calls associate data -with instances of named entities. - -Some named entities must always have instance IDs. In particular, regions -and frames always have IDs. Task and markers need IDs only if the ID is -needed in another API call (such as adding a relation or metadata). - -The lifetime of instance IDs is distinct from the lifetime of -instances. This allows various relationships to be specified separate -from the actual execution of instances. This flexibility comes at the -expense of extra API calls. - -The same ID may not be reused for different instances, unless a previous -[ref] __itt_id_destroy call for that ID has been issued. -*/ - -/** @cond exclude_from_documentation */ -#ifndef ITT_OS_WIN -# define ITT_OS_WIN 1 -#endif /* ITT_OS_WIN */ - -#ifndef ITT_OS_LINUX -# define ITT_OS_LINUX 2 -#endif /* ITT_OS_LINUX */ - -#ifndef ITT_OS_MAC -# define ITT_OS_MAC 3 -#endif /* ITT_OS_MAC */ - -#ifndef ITT_OS_FREEBSD -# define ITT_OS_FREEBSD 4 -#endif /* ITT_OS_FREEBSD */ - -#ifndef ITT_OS -# if defined WIN32 || defined _WIN32 -# define ITT_OS ITT_OS_WIN -# elif defined( __APPLE__ ) && defined( __MACH__ ) -# define ITT_OS ITT_OS_MAC -# elif defined( __FreeBSD__ ) -# define ITT_OS ITT_OS_FREEBSD -# else -# define ITT_OS ITT_OS_LINUX -# endif -#endif /* ITT_OS */ - -#ifndef ITT_PLATFORM_WIN -# define ITT_PLATFORM_WIN 1 -#endif /* ITT_PLATFORM_WIN */ - -#ifndef ITT_PLATFORM_POSIX -# define ITT_PLATFORM_POSIX 2 -#endif /* ITT_PLATFORM_POSIX */ - -#ifndef ITT_PLATFORM_MAC -# define ITT_PLATFORM_MAC 3 -#endif /* ITT_PLATFORM_MAC */ - -#ifndef ITT_PLATFORM_FREEBSD -# define ITT_PLATFORM_FREEBSD 4 -#endif /* ITT_PLATFORM_FREEBSD */ - -#ifndef ITT_PLATFORM -# if ITT_OS==ITT_OS_WIN -# define ITT_PLATFORM ITT_PLATFORM_WIN -# elif ITT_OS==ITT_OS_MAC -# define ITT_PLATFORM ITT_PLATFORM_MAC -# elif ITT_OS==ITT_OS_FREEBSD -# define ITT_PLATFORM ITT_PLATFORM_FREEBSD -# else -# define ITT_PLATFORM ITT_PLATFORM_POSIX -# endif -#endif /* ITT_PLATFORM */ - -#if defined(_UNICODE) && !defined(UNICODE) -#define UNICODE -#endif - -#include -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#include -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#include -#if defined(UNICODE) || defined(_UNICODE) -#include -#endif /* UNICODE || _UNICODE */ -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -#ifndef ITTAPI_CDECL -# if ITT_PLATFORM==ITT_PLATFORM_WIN -# define ITTAPI_CDECL __cdecl -# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -# if defined _M_IX86 || defined __i386__ -# define ITTAPI_CDECL __attribute__ ((cdecl)) -# else /* _M_IX86 || __i386__ */ -# define ITTAPI_CDECL /* actual only on x86 platform */ -# endif /* _M_IX86 || __i386__ */ -# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* ITTAPI_CDECL */ - -#ifndef STDCALL -# if ITT_PLATFORM==ITT_PLATFORM_WIN -# define STDCALL __stdcall -# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -# if defined _M_IX86 || defined __i386__ -# define STDCALL __attribute__ ((stdcall)) -# else /* _M_IX86 || __i386__ */ -# define STDCALL /* supported only on x86 platform */ -# endif /* _M_IX86 || __i386__ */ -# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* STDCALL */ - -#define ITTAPI ITTAPI_CDECL -#define LIBITTAPI ITTAPI_CDECL - -/* TODO: Temporary for compatibility! */ -#define ITTAPI_CALL ITTAPI_CDECL -#define LIBITTAPI_CALL ITTAPI_CDECL - -#if ITT_PLATFORM==ITT_PLATFORM_WIN -/* use __forceinline (VC++ specific) */ -#if defined(__MINGW32__) && !defined(__cplusplus) -#define ITT_INLINE static __inline__ __attribute__((__always_inline__,__gnu_inline__)) -#else -#define ITT_INLINE static __forceinline -#endif /* __MINGW32__ */ - -#define ITT_INLINE_ATTRIBUTE /* nothing */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -/* - * Generally, functions are not inlined unless optimization is specified. - * For functions declared inline, this attribute inlines the function even - * if no optimization level was specified. - */ -#ifdef __STRICT_ANSI__ -#define ITT_INLINE static -#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) -#else /* __STRICT_ANSI__ */ -#define ITT_INLINE static inline -#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) -#endif /* __STRICT_ANSI__ */ -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -/** @endcond */ - -#ifdef INTEL_ITTNOTIFY_ENABLE_LEGACY -# if ITT_PLATFORM==ITT_PLATFORM_WIN -# pragma message("WARNING!!! Deprecated API is used. Please undefine INTEL_ITTNOTIFY_ENABLE_LEGACY macro") -# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -# warning "Deprecated API is used. Please undefine INTEL_ITTNOTIFY_ENABLE_LEGACY macro" -# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -# include "legacy/ittnotify.h" -#endif /* INTEL_ITTNOTIFY_ENABLE_LEGACY */ - -/** @cond exclude_from_documentation */ -/* Helper macro for joining tokens */ -#define ITT_JOIN_AUX(p,n) p##n -#define ITT_JOIN(p,n) ITT_JOIN_AUX(p,n) - -#ifdef ITT_MAJOR -#undef ITT_MAJOR -#endif -#ifdef ITT_MINOR -#undef ITT_MINOR -#endif -#define ITT_MAJOR 3 -#define ITT_MINOR 0 - -/* Standard versioning of a token with major and minor version numbers */ -#define ITT_VERSIONIZE(x) \ - ITT_JOIN(x, \ - ITT_JOIN(_, \ - ITT_JOIN(ITT_MAJOR, \ - ITT_JOIN(_, ITT_MINOR)))) - -#ifndef INTEL_ITTNOTIFY_PREFIX -# define INTEL_ITTNOTIFY_PREFIX __itt_ -#endif /* INTEL_ITTNOTIFY_PREFIX */ -#ifndef INTEL_ITTNOTIFY_POSTFIX -# define INTEL_ITTNOTIFY_POSTFIX _ptr_ -#endif /* INTEL_ITTNOTIFY_POSTFIX */ - -#define ITTNOTIFY_NAME_AUX(n) ITT_JOIN(INTEL_ITTNOTIFY_PREFIX,n) -#define ITTNOTIFY_NAME(n) ITT_VERSIONIZE(ITTNOTIFY_NAME_AUX(ITT_JOIN(n,INTEL_ITTNOTIFY_POSTFIX))) - -#define ITTNOTIFY_VOID(n) (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n) -#define ITTNOTIFY_DATA(n) (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n) - -#define ITTNOTIFY_VOID_D0(n,d) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d) -#define ITTNOTIFY_VOID_D1(n,d,x) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x) -#define ITTNOTIFY_VOID_D2(n,d,x,y) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y) -#define ITTNOTIFY_VOID_D3(n,d,x,y,z) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z) -#define ITTNOTIFY_VOID_D4(n,d,x,y,z,a) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a) -#define ITTNOTIFY_VOID_D5(n,d,x,y,z,a,b) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b) -#define ITTNOTIFY_VOID_D6(n,d,x,y,z,a,b,c) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b,c) -#define ITTNOTIFY_DATA_D0(n,d) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d) -#define ITTNOTIFY_DATA_D1(n,d,x) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x) -#define ITTNOTIFY_DATA_D2(n,d,x,y) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y) -#define ITTNOTIFY_DATA_D3(n,d,x,y,z) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z) -#define ITTNOTIFY_DATA_D4(n,d,x,y,z,a) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a) -#define ITTNOTIFY_DATA_D5(n,d,x,y,z,a,b) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b) -#define ITTNOTIFY_DATA_D6(n,d,x,y,z,a,b,c) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b,c) - -#ifdef ITT_STUB -#undef ITT_STUB -#endif -#ifdef ITT_STUBV -#undef ITT_STUBV -#endif -#define ITT_STUBV(api,type,name,args) \ - typedef type (api* ITT_JOIN(ITTNOTIFY_NAME(name),_t)) args; \ - extern ITT_JOIN(ITTNOTIFY_NAME(name),_t) ITTNOTIFY_NAME(name); -#define ITT_STUB ITT_STUBV -/** @endcond */ - -#ifdef __cplusplus -extern "C" { -#endif /* __cplusplus */ - -/** @cond exclude_from_gpa_documentation */ -/** - * @defgroup public Public API - * @{ - * @} - */ - -/** - * @defgroup control Collection Control - * @ingroup public - * General behavior: application continues to run, but no profiling information is being collected - * - * Pausing occurs not only for the current thread but for all process as well as spawned processes - * - Intel(R) Parallel Inspector and Intel(R) Inspector XE: - * - Does not analyze or report errors that involve memory access. - * - Other errors are reported as usual. Pausing data collection in - * Intel(R) Parallel Inspector and Intel(R) Inspector XE - * only pauses tracing and analyzing memory access. - * It does not pause tracing or analyzing threading APIs. - * . - * - Intel(R) Parallel Amplifier and Intel(R) VTune(TM) Amplifier XE: - * - Does continue to record when new threads are started. - * . - * - Other effects: - * - Possible reduction of runtime overhead. - * . - * @{ - */ -/** @brief Pause collection */ -void ITTAPI __itt_pause(void); -/** @brief Resume collection */ -void ITTAPI __itt_resume(void); -/** @brief Detach collection */ -void ITTAPI __itt_detach(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, pause, (void)) -ITT_STUBV(ITTAPI, void, resume, (void)) -ITT_STUBV(ITTAPI, void, detach, (void)) -#define __itt_pause ITTNOTIFY_VOID(pause) -#define __itt_pause_ptr ITTNOTIFY_NAME(pause) -#define __itt_resume ITTNOTIFY_VOID(resume) -#define __itt_resume_ptr ITTNOTIFY_NAME(resume) -#define __itt_detach ITTNOTIFY_VOID(detach) -#define __itt_detach_ptr ITTNOTIFY_NAME(detach) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_pause() -#define __itt_pause_ptr 0 -#define __itt_resume() -#define __itt_resume_ptr 0 -#define __itt_detach() -#define __itt_detach_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_pause_ptr 0 -#define __itt_resume_ptr 0 -#define __itt_detach_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} control group */ -/** @endcond */ - -/** - * @defgroup Intel Processor Trace control - * API from this group provides control over collection and analysis of Intel Processor Trace (Intel PT) data - * Information about Intel Processor Trace technology can be found here (Volume 3 chapter 35): - * https://software.intel.com/sites/default/files/managed/39/c5/325462-sdm-vol-1-2abcd-3abcd.pdf - * Use this API to mark particular code regions for loading detailed performance statistics. - * This mode makes your analysis faster and more accurate. - * @{ -*/ -typedef unsigned char __itt_pt_region; - -/** - * @brief function saves a region name marked with Intel PT API and returns a region id. - * Only 7 names can be registered. Attempts to register more names will be ignored and a region id with auto names will be returned. - * For automatic naming of regions pass NULL as function parameter -*/ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_pt_region ITTAPI __itt_pt_region_createA(const char *name); -__itt_pt_region ITTAPI __itt_pt_region_createW(const wchar_t *name); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_pt_region_create __itt_pt_region_createW -#else /* UNICODE */ -# define __itt_pt_region_create __itt_pt_region_createA -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_pt_region ITTAPI __itt_pt_region_create(const char *name); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_pt_region, pt_region_createA, (const char *name)) -ITT_STUB(ITTAPI, __itt_pt_region, pt_region_createW, (const wchar_t *name)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_pt_region, pt_region_create, (const char *name)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_pt_region_createA ITTNOTIFY_DATA(pt_region_createA) -#define __itt_pt_region_createA_ptr ITTNOTIFY_NAME(pt_region_createA) -#define __itt_pt_region_createW ITTNOTIFY_DATA(pt_region_createW) -#define __itt_pt_region_createW_ptr ITTNOTIFY_NAME(pt_region_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_pt_region_create ITTNOTIFY_DATA(pt_region_create) -#define __itt_pt_region_create_ptr ITTNOTIFY_NAME(pt_region_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_pt_region_createA(name) (__itt_pt_region)0 -#define __itt_pt_region_createA_ptr 0 -#define __itt_pt_region_createW(name) (__itt_pt_region)0 -#define __itt_pt_region_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_pt_region_create(name) (__itt_pt_region)0 -#define __itt_pt_region_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_pt_region_createA_ptr 0 -#define __itt_pt_region_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_pt_region_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief function contains a special code pattern identified on the post-processing stage and - * marks the beginning of a code region targeted for Intel PT analysis - * @param[in] region - region id, 0 <= region < 8 -*/ -void __itt_mark_pt_region_begin(__itt_pt_region region); -/** - * @brief function contains a special code pattern identified on the post-processing stage and - * marks the end of a code region targeted for Intel PT analysis - * @param[in] region - region id, 0 <= region < 8 -*/ -void __itt_mark_pt_region_end(__itt_pt_region region); -/** @} Intel PT control group*/ - -/** - * @defgroup threads Threads - * @ingroup public - * Give names to threads - * @{ - */ -/** - * @brief Sets thread name of calling thread - * @param[in] name - name of thread - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_thread_set_nameA(const char *name); -void ITTAPI __itt_thread_set_nameW(const wchar_t *name); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_thread_set_name __itt_thread_set_nameW -# define __itt_thread_set_name_ptr __itt_thread_set_nameW_ptr -#else /* UNICODE */ -# define __itt_thread_set_name __itt_thread_set_nameA -# define __itt_thread_set_name_ptr __itt_thread_set_nameA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -void ITTAPI __itt_thread_set_name(const char *name); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUBV(ITTAPI, void, thread_set_nameA, (const char *name)) -ITT_STUBV(ITTAPI, void, thread_set_nameW, (const wchar_t *name)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUBV(ITTAPI, void, thread_set_name, (const char *name)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_thread_set_nameA ITTNOTIFY_VOID(thread_set_nameA) -#define __itt_thread_set_nameA_ptr ITTNOTIFY_NAME(thread_set_nameA) -#define __itt_thread_set_nameW ITTNOTIFY_VOID(thread_set_nameW) -#define __itt_thread_set_nameW_ptr ITTNOTIFY_NAME(thread_set_nameW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_thread_set_name ITTNOTIFY_VOID(thread_set_name) -#define __itt_thread_set_name_ptr ITTNOTIFY_NAME(thread_set_name) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_thread_set_nameA(name) -#define __itt_thread_set_nameA_ptr 0 -#define __itt_thread_set_nameW(name) -#define __itt_thread_set_nameW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_thread_set_name(name) -#define __itt_thread_set_name_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_thread_set_nameA_ptr 0 -#define __itt_thread_set_nameW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_thread_set_name_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @cond exclude_from_gpa_documentation */ - -/** - * @brief Mark current thread as ignored from this point on, for the duration of its existence. - */ -void ITTAPI __itt_thread_ignore(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, thread_ignore, (void)) -#define __itt_thread_ignore ITTNOTIFY_VOID(thread_ignore) -#define __itt_thread_ignore_ptr ITTNOTIFY_NAME(thread_ignore) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_thread_ignore() -#define __itt_thread_ignore_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_thread_ignore_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} threads group */ - -/** - * @defgroup suppress Error suppression - * @ingroup public - * General behavior: application continues to run, but errors are suppressed - * - * @{ - */ - -/*****************************************************************//** - * @name group of functions used for error suppression in correctness tools - *********************************************************************/ -/** @{ */ -/** - * @hideinitializer - * @brief possible value for suppression mask - */ -#define __itt_suppress_all_errors 0x7fffffff - -/** - * @hideinitializer - * @brief possible value for suppression mask (suppresses errors from threading analysis) - */ -#define __itt_suppress_threading_errors 0x000000ff - -/** - * @hideinitializer - * @brief possible value for suppression mask (suppresses errors from memory analysis) - */ -#define __itt_suppress_memory_errors 0x0000ff00 - -/** - * @brief Start suppressing errors identified in mask on this thread - */ -void ITTAPI __itt_suppress_push(unsigned int mask); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, suppress_push, (unsigned int mask)) -#define __itt_suppress_push ITTNOTIFY_VOID(suppress_push) -#define __itt_suppress_push_ptr ITTNOTIFY_NAME(suppress_push) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_suppress_push(mask) -#define __itt_suppress_push_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_suppress_push_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Undo the effects of the matching call to __itt_suppress_push - */ -void ITTAPI __itt_suppress_pop(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, suppress_pop, (void)) -#define __itt_suppress_pop ITTNOTIFY_VOID(suppress_pop) -#define __itt_suppress_pop_ptr ITTNOTIFY_NAME(suppress_pop) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_suppress_pop() -#define __itt_suppress_pop_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_suppress_pop_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @enum __itt_model_disable - * @brief Enumerator for the disable methods - */ -typedef enum __itt_suppress_mode { - __itt_unsuppress_range, - __itt_suppress_range -} __itt_suppress_mode_t; - -/** - * @enum __itt_collection_state - * @brief Enumerator for collection state. All non-work states have negative values. - */ -typedef enum { - __itt_collection_uninitialized = 0, /* uninitialized */ - __itt_collection_init_fail = 1, /* failed to init */ - __itt_collection_collector_absent = 2, /* non work state collector exists */ - __itt_collection_collector_exists = 3, /* work state collector exists */ - __itt_collection_init_successful = 4 /* success to init */ -} __itt_collection_state; - -/** - * @brief Mark a range of memory for error suppression or unsuppression for error types included in mask - */ -void ITTAPI __itt_suppress_mark_range(__itt_suppress_mode_t mode, unsigned int mask, void * address, size_t size); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, suppress_mark_range, (__itt_suppress_mode_t mode, unsigned int mask, void * address, size_t size)) -#define __itt_suppress_mark_range ITTNOTIFY_VOID(suppress_mark_range) -#define __itt_suppress_mark_range_ptr ITTNOTIFY_NAME(suppress_mark_range) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_suppress_mark_range(mask) -#define __itt_suppress_mark_range_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_suppress_mark_range_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Undo the effect of a matching call to __itt_suppress_mark_range. If not matching - * call is found, nothing is changed. - */ -void ITTAPI __itt_suppress_clear_range(__itt_suppress_mode_t mode, unsigned int mask, void * address, size_t size); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, suppress_clear_range, (__itt_suppress_mode_t mode, unsigned int mask, void * address, size_t size)) -#define __itt_suppress_clear_range ITTNOTIFY_VOID(suppress_clear_range) -#define __itt_suppress_clear_range_ptr ITTNOTIFY_NAME(suppress_clear_range) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_suppress_clear_range(mask) -#define __itt_suppress_clear_range_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_suppress_clear_range_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} */ -/** @} suppress group */ - -/** - * @defgroup sync Synchronization - * @ingroup public - * Indicate user-written synchronization code - * @{ - */ -/** - * @hideinitializer - * @brief possible value of attribute argument for sync object type - */ -#define __itt_attr_barrier 1 - -/** - * @hideinitializer - * @brief possible value of attribute argument for sync object type - */ -#define __itt_attr_mutex 2 - -/** -@brief Name a synchronization object -@param[in] addr Handle for the synchronization object. You should -use a real address to uniquely identify the synchronization object. -@param[in] objtype null-terminated object type string. If NULL is -passed, the name will be "User Synchronization". -@param[in] objname null-terminated object name string. If NULL, -no name will be assigned to the object. -@param[in] attribute one of [#__itt_attr_barrier, #__itt_attr_mutex] - */ - -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_sync_createA(void *addr, const char *objtype, const char *objname, int attribute); -void ITTAPI __itt_sync_createW(void *addr, const wchar_t *objtype, const wchar_t *objname, int attribute); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_sync_create __itt_sync_createW -# define __itt_sync_create_ptr __itt_sync_createW_ptr -#else /* UNICODE */ -# define __itt_sync_create __itt_sync_createA -# define __itt_sync_create_ptr __itt_sync_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -void ITTAPI __itt_sync_create (void *addr, const char *objtype, const char *objname, int attribute); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUBV(ITTAPI, void, sync_createA, (void *addr, const char *objtype, const char *objname, int attribute)) -ITT_STUBV(ITTAPI, void, sync_createW, (void *addr, const wchar_t *objtype, const wchar_t *objname, int attribute)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUBV(ITTAPI, void, sync_create, (void *addr, const char* objtype, const char* objname, int attribute)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_createA ITTNOTIFY_VOID(sync_createA) -#define __itt_sync_createA_ptr ITTNOTIFY_NAME(sync_createA) -#define __itt_sync_createW ITTNOTIFY_VOID(sync_createW) -#define __itt_sync_createW_ptr ITTNOTIFY_NAME(sync_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_create ITTNOTIFY_VOID(sync_create) -#define __itt_sync_create_ptr ITTNOTIFY_NAME(sync_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_createA(addr, objtype, objname, attribute) -#define __itt_sync_createA_ptr 0 -#define __itt_sync_createW(addr, objtype, objname, attribute) -#define __itt_sync_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_create(addr, objtype, objname, attribute) -#define __itt_sync_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_createA_ptr 0 -#define __itt_sync_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** -@brief Rename a synchronization object - -You can use the rename call to assign or reassign a name to a given -synchronization object. -@param[in] addr handle for the synchronization object. -@param[in] name null-terminated object name string. -*/ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_sync_renameA(void *addr, const char *name); -void ITTAPI __itt_sync_renameW(void *addr, const wchar_t *name); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_sync_rename __itt_sync_renameW -# define __itt_sync_rename_ptr __itt_sync_renameW_ptr -#else /* UNICODE */ -# define __itt_sync_rename __itt_sync_renameA -# define __itt_sync_rename_ptr __itt_sync_renameA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -void ITTAPI __itt_sync_rename(void *addr, const char *name); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUBV(ITTAPI, void, sync_renameA, (void *addr, const char *name)) -ITT_STUBV(ITTAPI, void, sync_renameW, (void *addr, const wchar_t *name)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUBV(ITTAPI, void, sync_rename, (void *addr, const char *name)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_renameA ITTNOTIFY_VOID(sync_renameA) -#define __itt_sync_renameA_ptr ITTNOTIFY_NAME(sync_renameA) -#define __itt_sync_renameW ITTNOTIFY_VOID(sync_renameW) -#define __itt_sync_renameW_ptr ITTNOTIFY_NAME(sync_renameW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_rename ITTNOTIFY_VOID(sync_rename) -#define __itt_sync_rename_ptr ITTNOTIFY_NAME(sync_rename) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_renameA(addr, name) -#define __itt_sync_renameA_ptr 0 -#define __itt_sync_renameW(addr, name) -#define __itt_sync_renameW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_rename(addr, name) -#define __itt_sync_rename_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_renameA_ptr 0 -#define __itt_sync_renameW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_rename_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - @brief Destroy a synchronization object. - @param addr Handle for the synchronization object. - */ -void ITTAPI __itt_sync_destroy(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, sync_destroy, (void *addr)) -#define __itt_sync_destroy ITTNOTIFY_VOID(sync_destroy) -#define __itt_sync_destroy_ptr ITTNOTIFY_NAME(sync_destroy) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_sync_destroy(addr) -#define __itt_sync_destroy_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_sync_destroy_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/*****************************************************************//** - * @name group of functions is used for performance measurement tools - *********************************************************************/ -/** @{ */ -/** - * @brief Enter spin loop on user-defined sync object - */ -void ITTAPI __itt_sync_prepare(void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, sync_prepare, (void *addr)) -#define __itt_sync_prepare ITTNOTIFY_VOID(sync_prepare) -#define __itt_sync_prepare_ptr ITTNOTIFY_NAME(sync_prepare) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_sync_prepare(addr) -#define __itt_sync_prepare_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_sync_prepare_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Quit spin loop without acquiring spin object - */ -void ITTAPI __itt_sync_cancel(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, sync_cancel, (void *addr)) -#define __itt_sync_cancel ITTNOTIFY_VOID(sync_cancel) -#define __itt_sync_cancel_ptr ITTNOTIFY_NAME(sync_cancel) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_sync_cancel(addr) -#define __itt_sync_cancel_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_sync_cancel_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Successful spin loop completion (sync object acquired) - */ -void ITTAPI __itt_sync_acquired(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, sync_acquired, (void *addr)) -#define __itt_sync_acquired ITTNOTIFY_VOID(sync_acquired) -#define __itt_sync_acquired_ptr ITTNOTIFY_NAME(sync_acquired) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_sync_acquired(addr) -#define __itt_sync_acquired_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_sync_acquired_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Start sync object releasing code. Is called before the lock release call. - */ -void ITTAPI __itt_sync_releasing(void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, sync_releasing, (void *addr)) -#define __itt_sync_releasing ITTNOTIFY_VOID(sync_releasing) -#define __itt_sync_releasing_ptr ITTNOTIFY_NAME(sync_releasing) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_sync_releasing(addr) -#define __itt_sync_releasing_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_sync_releasing_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} */ - -/** @} sync group */ - -/**************************************************************//** - * @name group of functions is used for correctness checking tools - ******************************************************************/ -/** @{ */ -/** - * @ingroup legacy - * @deprecated Legacy API - * @brief Fast synchronization which does no require spinning. - * - This special function is to be used by TBB and OpenMP libraries only when they know - * there is no spin but they need to suppress TC warnings about shared variable modifications. - * - It only has corresponding pointers in static library and does not have corresponding function - * in dynamic library. - * @see void __itt_sync_prepare(void* addr); - */ -void ITTAPI __itt_fsync_prepare(void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, fsync_prepare, (void *addr)) -#define __itt_fsync_prepare ITTNOTIFY_VOID(fsync_prepare) -#define __itt_fsync_prepare_ptr ITTNOTIFY_NAME(fsync_prepare) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_fsync_prepare(addr) -#define __itt_fsync_prepare_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_fsync_prepare_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup legacy - * @deprecated Legacy API - * @brief Fast synchronization which does no require spinning. - * - This special function is to be used by TBB and OpenMP libraries only when they know - * there is no spin but they need to suppress TC warnings about shared variable modifications. - * - It only has corresponding pointers in static library and does not have corresponding function - * in dynamic library. - * @see void __itt_sync_cancel(void *addr); - */ -void ITTAPI __itt_fsync_cancel(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, fsync_cancel, (void *addr)) -#define __itt_fsync_cancel ITTNOTIFY_VOID(fsync_cancel) -#define __itt_fsync_cancel_ptr ITTNOTIFY_NAME(fsync_cancel) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_fsync_cancel(addr) -#define __itt_fsync_cancel_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_fsync_cancel_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup legacy - * @deprecated Legacy API - * @brief Fast synchronization which does no require spinning. - * - This special function is to be used by TBB and OpenMP libraries only when they know - * there is no spin but they need to suppress TC warnings about shared variable modifications. - * - It only has corresponding pointers in static library and does not have corresponding function - * in dynamic library. - * @see void __itt_sync_acquired(void *addr); - */ -void ITTAPI __itt_fsync_acquired(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, fsync_acquired, (void *addr)) -#define __itt_fsync_acquired ITTNOTIFY_VOID(fsync_acquired) -#define __itt_fsync_acquired_ptr ITTNOTIFY_NAME(fsync_acquired) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_fsync_acquired(addr) -#define __itt_fsync_acquired_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_fsync_acquired_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup legacy - * @deprecated Legacy API - * @brief Fast synchronization which does no require spinning. - * - This special function is to be used by TBB and OpenMP libraries only when they know - * there is no spin but they need to suppress TC warnings about shared variable modifications. - * - It only has corresponding pointers in static library and does not have corresponding function - * in dynamic library. - * @see void __itt_sync_releasing(void* addr); - */ -void ITTAPI __itt_fsync_releasing(void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, fsync_releasing, (void *addr)) -#define __itt_fsync_releasing ITTNOTIFY_VOID(fsync_releasing) -#define __itt_fsync_releasing_ptr ITTNOTIFY_NAME(fsync_releasing) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_fsync_releasing(addr) -#define __itt_fsync_releasing_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_fsync_releasing_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} */ - -/** - * @defgroup model Modeling by Intel(R) Parallel Advisor - * @ingroup public - * This is the subset of itt used for modeling by Intel(R) Parallel Advisor. - * This API is called ONLY using annotate.h, by "Annotation" macros - * the user places in their sources during the parallelism modeling steps. - * - * site_begin/end and task_begin/end take the address of handle variables, - * which are writeable by the API. Handles must be 0 initialized prior - * to the first call to begin, or may cause a run-time failure. - * The handles are initialized in a multi-thread safe way by the API if - * the handle is 0. The commonly expected idiom is one static handle to - * identify a site or task. If a site or task of the same name has already - * been started during this collection, the same handle MAY be returned, - * but is not required to be - it is unspecified if data merging is done - * based on name. These routines also take an instance variable. Like - * the lexical instance, these must be 0 initialized. Unlike the lexical - * instance, this is used to track a single dynamic instance. - * - * API used by the Intel(R) Parallel Advisor to describe potential concurrency - * and related activities. User-added source annotations expand to calls - * to these procedures to enable modeling of a hypothetical concurrent - * execution serially. - * @{ - */ -#if !defined(_ADVISOR_ANNOTATE_H_) || defined(ANNOTATE_EXPAND_NULL) - -typedef void* __itt_model_site; /*!< @brief handle for lexical site */ -typedef void* __itt_model_site_instance; /*!< @brief handle for dynamic instance */ -typedef void* __itt_model_task; /*!< @brief handle for lexical site */ -typedef void* __itt_model_task_instance; /*!< @brief handle for dynamic instance */ - -/** - * @enum __itt_model_disable - * @brief Enumerator for the disable methods - */ -typedef enum { - __itt_model_disable_observation, - __itt_model_disable_collection -} __itt_model_disable; - -#endif /* !_ADVISOR_ANNOTATE_H_ || ANNOTATE_EXPAND_NULL */ - -/** - * @brief ANNOTATE_SITE_BEGIN/ANNOTATE_SITE_END support. - * - * site_begin/end model a potential concurrency site. - * site instances may be recursively nested with themselves. - * site_end exits the most recently started but unended site for the current - * thread. The handle passed to end may be used to validate structure. - * Instances of a site encountered on different threads concurrently - * are considered completely distinct. If the site name for two different - * lexical sites match, it is unspecified whether they are treated as the - * same or different for data presentation. - */ -void ITTAPI __itt_model_site_begin(__itt_model_site *site, __itt_model_site_instance *instance, const char *name); -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_model_site_beginW(const wchar_t *name); -#endif -void ITTAPI __itt_model_site_beginA(const char *name); -void ITTAPI __itt_model_site_beginAL(const char *name, size_t siteNameLen); -void ITTAPI __itt_model_site_end (__itt_model_site *site, __itt_model_site_instance *instance); -void ITTAPI __itt_model_site_end_2(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_site_begin, (__itt_model_site *site, __itt_model_site_instance *instance, const char *name)) -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUBV(ITTAPI, void, model_site_beginW, (const wchar_t *name)) -#endif -ITT_STUBV(ITTAPI, void, model_site_beginA, (const char *name)) -ITT_STUBV(ITTAPI, void, model_site_beginAL, (const char *name, size_t siteNameLen)) -ITT_STUBV(ITTAPI, void, model_site_end, (__itt_model_site *site, __itt_model_site_instance *instance)) -ITT_STUBV(ITTAPI, void, model_site_end_2, (void)) -#define __itt_model_site_begin ITTNOTIFY_VOID(model_site_begin) -#define __itt_model_site_begin_ptr ITTNOTIFY_NAME(model_site_begin) -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_model_site_beginW ITTNOTIFY_VOID(model_site_beginW) -#define __itt_model_site_beginW_ptr ITTNOTIFY_NAME(model_site_beginW) -#endif -#define __itt_model_site_beginA ITTNOTIFY_VOID(model_site_beginA) -#define __itt_model_site_beginA_ptr ITTNOTIFY_NAME(model_site_beginA) -#define __itt_model_site_beginAL ITTNOTIFY_VOID(model_site_beginAL) -#define __itt_model_site_beginAL_ptr ITTNOTIFY_NAME(model_site_beginAL) -#define __itt_model_site_end ITTNOTIFY_VOID(model_site_end) -#define __itt_model_site_end_ptr ITTNOTIFY_NAME(model_site_end) -#define __itt_model_site_end_2 ITTNOTIFY_VOID(model_site_end_2) -#define __itt_model_site_end_2_ptr ITTNOTIFY_NAME(model_site_end_2) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_site_begin(site, instance, name) -#define __itt_model_site_begin_ptr 0 -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_model_site_beginW(name) -#define __itt_model_site_beginW_ptr 0 -#endif -#define __itt_model_site_beginA(name) -#define __itt_model_site_beginA_ptr 0 -#define __itt_model_site_beginAL(name, siteNameLen) -#define __itt_model_site_beginAL_ptr 0 -#define __itt_model_site_end(site, instance) -#define __itt_model_site_end_ptr 0 -#define __itt_model_site_end_2() -#define __itt_model_site_end_2_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_site_begin_ptr 0 -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_model_site_beginW_ptr 0 -#endif -#define __itt_model_site_beginA_ptr 0 -#define __itt_model_site_beginAL_ptr 0 -#define __itt_model_site_end_ptr 0 -#define __itt_model_site_end_2_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief ANNOTATE_TASK_BEGIN/ANNOTATE_TASK_END support - * - * task_begin/end model a potential task, which is contained within the most - * closely enclosing dynamic site. task_end exits the most recently started - * but unended task. The handle passed to end may be used to validate - * structure. It is unspecified if bad dynamic nesting is detected. If it - * is, it should be encoded in the resulting data collection. The collector - * should not fail due to construct nesting issues, nor attempt to directly - * indicate the problem. - */ -void ITTAPI __itt_model_task_begin(__itt_model_task *task, __itt_model_task_instance *instance, const char *name); -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_model_task_beginW(const wchar_t *name); -void ITTAPI __itt_model_iteration_taskW(const wchar_t *name); -#endif -void ITTAPI __itt_model_task_beginA(const char *name); -void ITTAPI __itt_model_task_beginAL(const char *name, size_t taskNameLen); -void ITTAPI __itt_model_iteration_taskA(const char *name); -void ITTAPI __itt_model_iteration_taskAL(const char *name, size_t taskNameLen); -void ITTAPI __itt_model_task_end (__itt_model_task *task, __itt_model_task_instance *instance); -void ITTAPI __itt_model_task_end_2(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_task_begin, (__itt_model_task *task, __itt_model_task_instance *instance, const char *name)) -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUBV(ITTAPI, void, model_task_beginW, (const wchar_t *name)) -ITT_STUBV(ITTAPI, void, model_iteration_taskW, (const wchar_t *name)) -#endif -ITT_STUBV(ITTAPI, void, model_task_beginA, (const char *name)) -ITT_STUBV(ITTAPI, void, model_task_beginAL, (const char *name, size_t taskNameLen)) -ITT_STUBV(ITTAPI, void, model_iteration_taskA, (const char *name)) -ITT_STUBV(ITTAPI, void, model_iteration_taskAL, (const char *name, size_t taskNameLen)) -ITT_STUBV(ITTAPI, void, model_task_end, (__itt_model_task *task, __itt_model_task_instance *instance)) -ITT_STUBV(ITTAPI, void, model_task_end_2, (void)) -#define __itt_model_task_begin ITTNOTIFY_VOID(model_task_begin) -#define __itt_model_task_begin_ptr ITTNOTIFY_NAME(model_task_begin) -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_model_task_beginW ITTNOTIFY_VOID(model_task_beginW) -#define __itt_model_task_beginW_ptr ITTNOTIFY_NAME(model_task_beginW) -#define __itt_model_iteration_taskW ITTNOTIFY_VOID(model_iteration_taskW) -#define __itt_model_iteration_taskW_ptr ITTNOTIFY_NAME(model_iteration_taskW) -#endif -#define __itt_model_task_beginA ITTNOTIFY_VOID(model_task_beginA) -#define __itt_model_task_beginA_ptr ITTNOTIFY_NAME(model_task_beginA) -#define __itt_model_task_beginAL ITTNOTIFY_VOID(model_task_beginAL) -#define __itt_model_task_beginAL_ptr ITTNOTIFY_NAME(model_task_beginAL) -#define __itt_model_iteration_taskA ITTNOTIFY_VOID(model_iteration_taskA) -#define __itt_model_iteration_taskA_ptr ITTNOTIFY_NAME(model_iteration_taskA) -#define __itt_model_iteration_taskAL ITTNOTIFY_VOID(model_iteration_taskAL) -#define __itt_model_iteration_taskAL_ptr ITTNOTIFY_NAME(model_iteration_taskAL) -#define __itt_model_task_end ITTNOTIFY_VOID(model_task_end) -#define __itt_model_task_end_ptr ITTNOTIFY_NAME(model_task_end) -#define __itt_model_task_end_2 ITTNOTIFY_VOID(model_task_end_2) -#define __itt_model_task_end_2_ptr ITTNOTIFY_NAME(model_task_end_2) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_task_begin(task, instance, name) -#define __itt_model_task_begin_ptr 0 -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_model_task_beginW(name) -#define __itt_model_task_beginW_ptr 0 -#endif -#define __itt_model_task_beginA(name) -#define __itt_model_task_beginA_ptr 0 -#define __itt_model_task_beginAL(name, siteNameLen) -#define __itt_model_task_beginAL_ptr 0 -#define __itt_model_iteration_taskA(name) -#define __itt_model_iteration_taskA_ptr 0 -#define __itt_model_iteration_taskAL(name, siteNameLen) -#define __itt_model_iteration_taskAL_ptr 0 -#define __itt_model_task_end(task, instance) -#define __itt_model_task_end_ptr 0 -#define __itt_model_task_end_2() -#define __itt_model_task_end_2_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_task_begin_ptr 0 -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_model_task_beginW_ptr 0 -#endif -#define __itt_model_task_beginA_ptr 0 -#define __itt_model_task_beginAL_ptr 0 -#define __itt_model_iteration_taskA_ptr 0 -#define __itt_model_iteration_taskAL_ptr 0 -#define __itt_model_task_end_ptr 0 -#define __itt_model_task_end_2_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief ANNOTATE_LOCK_ACQUIRE/ANNOTATE_LOCK_RELEASE support - * - * lock_acquire/release model a potential lock for both lockset and - * performance modeling. Each unique address is modeled as a separate - * lock, with invalid addresses being valid lock IDs. Specifically: - * no storage is accessed by the API at the specified address - it is only - * used for lock identification. Lock acquires may be self-nested and are - * unlocked by a corresponding number of releases. - * (These closely correspond to __itt_sync_acquired/__itt_sync_releasing, - * but may not have identical semantics.) - */ -void ITTAPI __itt_model_lock_acquire(void *lock); -void ITTAPI __itt_model_lock_acquire_2(void *lock); -void ITTAPI __itt_model_lock_release(void *lock); -void ITTAPI __itt_model_lock_release_2(void *lock); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_lock_acquire, (void *lock)) -ITT_STUBV(ITTAPI, void, model_lock_acquire_2, (void *lock)) -ITT_STUBV(ITTAPI, void, model_lock_release, (void *lock)) -ITT_STUBV(ITTAPI, void, model_lock_release_2, (void *lock)) -#define __itt_model_lock_acquire ITTNOTIFY_VOID(model_lock_acquire) -#define __itt_model_lock_acquire_ptr ITTNOTIFY_NAME(model_lock_acquire) -#define __itt_model_lock_acquire_2 ITTNOTIFY_VOID(model_lock_acquire_2) -#define __itt_model_lock_acquire_2_ptr ITTNOTIFY_NAME(model_lock_acquire_2) -#define __itt_model_lock_release ITTNOTIFY_VOID(model_lock_release) -#define __itt_model_lock_release_ptr ITTNOTIFY_NAME(model_lock_release) -#define __itt_model_lock_release_2 ITTNOTIFY_VOID(model_lock_release_2) -#define __itt_model_lock_release_2_ptr ITTNOTIFY_NAME(model_lock_release_2) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_lock_acquire(lock) -#define __itt_model_lock_acquire_ptr 0 -#define __itt_model_lock_acquire_2(lock) -#define __itt_model_lock_acquire_2_ptr 0 -#define __itt_model_lock_release(lock) -#define __itt_model_lock_release_ptr 0 -#define __itt_model_lock_release_2(lock) -#define __itt_model_lock_release_2_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_lock_acquire_ptr 0 -#define __itt_model_lock_acquire_2_ptr 0 -#define __itt_model_lock_release_ptr 0 -#define __itt_model_lock_release_2_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief ANNOTATE_RECORD_ALLOCATION/ANNOTATE_RECORD_DEALLOCATION support - * - * record_allocation/deallocation describe user-defined memory allocator - * behavior, which may be required for correctness modeling to understand - * when storage is not expected to be actually reused across threads. - */ -void ITTAPI __itt_model_record_allocation (void *addr, size_t size); -void ITTAPI __itt_model_record_deallocation(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_record_allocation, (void *addr, size_t size)) -ITT_STUBV(ITTAPI, void, model_record_deallocation, (void *addr)) -#define __itt_model_record_allocation ITTNOTIFY_VOID(model_record_allocation) -#define __itt_model_record_allocation_ptr ITTNOTIFY_NAME(model_record_allocation) -#define __itt_model_record_deallocation ITTNOTIFY_VOID(model_record_deallocation) -#define __itt_model_record_deallocation_ptr ITTNOTIFY_NAME(model_record_deallocation) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_record_allocation(addr, size) -#define __itt_model_record_allocation_ptr 0 -#define __itt_model_record_deallocation(addr) -#define __itt_model_record_deallocation_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_record_allocation_ptr 0 -#define __itt_model_record_deallocation_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief ANNOTATE_INDUCTION_USES support - * - * Note particular storage is inductive through the end of the current site - */ -void ITTAPI __itt_model_induction_uses(void* addr, size_t size); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_induction_uses, (void *addr, size_t size)) -#define __itt_model_induction_uses ITTNOTIFY_VOID(model_induction_uses) -#define __itt_model_induction_uses_ptr ITTNOTIFY_NAME(model_induction_uses) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_induction_uses(addr, size) -#define __itt_model_induction_uses_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_induction_uses_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief ANNOTATE_REDUCTION_USES support - * - * Note particular storage is used for reduction through the end - * of the current site - */ -void ITTAPI __itt_model_reduction_uses(void* addr, size_t size); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_reduction_uses, (void *addr, size_t size)) -#define __itt_model_reduction_uses ITTNOTIFY_VOID(model_reduction_uses) -#define __itt_model_reduction_uses_ptr ITTNOTIFY_NAME(model_reduction_uses) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_reduction_uses(addr, size) -#define __itt_model_reduction_uses_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_reduction_uses_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief ANNOTATE_OBSERVE_USES support - * - * Have correctness modeling record observations about uses of storage - * through the end of the current site - */ -void ITTAPI __itt_model_observe_uses(void* addr, size_t size); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_observe_uses, (void *addr, size_t size)) -#define __itt_model_observe_uses ITTNOTIFY_VOID(model_observe_uses) -#define __itt_model_observe_uses_ptr ITTNOTIFY_NAME(model_observe_uses) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_observe_uses(addr, size) -#define __itt_model_observe_uses_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_observe_uses_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief ANNOTATE_CLEAR_USES support - * - * Clear the special handling of a piece of storage related to induction, - * reduction or observe_uses - */ -void ITTAPI __itt_model_clear_uses(void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_clear_uses, (void *addr)) -#define __itt_model_clear_uses ITTNOTIFY_VOID(model_clear_uses) -#define __itt_model_clear_uses_ptr ITTNOTIFY_NAME(model_clear_uses) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_clear_uses(addr) -#define __itt_model_clear_uses_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_clear_uses_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief ANNOTATE_DISABLE_*_PUSH/ANNOTATE_DISABLE_*_POP support - * - * disable_push/disable_pop push and pop disabling based on a parameter. - * Disabling observations stops processing of memory references during - * correctness modeling, and all annotations that occur in the disabled - * region. This allows description of code that is expected to be handled - * specially during conversion to parallelism or that is not recognized - * by tools (e.g. some kinds of synchronization operations.) - * This mechanism causes all annotations in the disabled region, other - * than disable_push and disable_pop, to be ignored. (For example, this - * might validly be used to disable an entire parallel site and the contained - * tasks and locking in it for data collection purposes.) - * The disable for collection is a more expensive operation, but reduces - * collector overhead significantly. This applies to BOTH correctness data - * collection and performance data collection. For example, a site - * containing a task might only enable data collection for the first 10 - * iterations. Both performance and correctness data should reflect this, - * and the program should run as close to full speed as possible when - * collection is disabled. - */ -void ITTAPI __itt_model_disable_push(__itt_model_disable x); -void ITTAPI __itt_model_disable_pop(void); -void ITTAPI __itt_model_aggregate_task(size_t x); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, model_disable_push, (__itt_model_disable x)) -ITT_STUBV(ITTAPI, void, model_disable_pop, (void)) -ITT_STUBV(ITTAPI, void, model_aggregate_task, (size_t x)) -#define __itt_model_disable_push ITTNOTIFY_VOID(model_disable_push) -#define __itt_model_disable_push_ptr ITTNOTIFY_NAME(model_disable_push) -#define __itt_model_disable_pop ITTNOTIFY_VOID(model_disable_pop) -#define __itt_model_disable_pop_ptr ITTNOTIFY_NAME(model_disable_pop) -#define __itt_model_aggregate_task ITTNOTIFY_VOID(model_aggregate_task) -#define __itt_model_aggregate_task_ptr ITTNOTIFY_NAME(model_aggregate_task) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_model_disable_push(x) -#define __itt_model_disable_push_ptr 0 -#define __itt_model_disable_pop() -#define __itt_model_disable_pop_ptr 0 -#define __itt_model_aggregate_task(x) -#define __itt_model_aggregate_task_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_model_disable_push_ptr 0 -#define __itt_model_disable_pop_ptr 0 -#define __itt_model_aggregate_task_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} model group */ - -/** - * @defgroup heap Heap - * @ingroup public - * Heap group - * @{ - */ - -typedef void* __itt_heap_function; - -/** - * @brief Create an identification for heap function - * @return non-zero identifier or NULL - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_heap_function ITTAPI __itt_heap_function_createA(const char* name, const char* domain); -__itt_heap_function ITTAPI __itt_heap_function_createW(const wchar_t* name, const wchar_t* domain); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_heap_function_create __itt_heap_function_createW -# define __itt_heap_function_create_ptr __itt_heap_function_createW_ptr -#else -# define __itt_heap_function_create __itt_heap_function_createA -# define __itt_heap_function_create_ptr __itt_heap_function_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_heap_function ITTAPI __itt_heap_function_create(const char* name, const char* domain); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_heap_function, heap_function_createA, (const char* name, const char* domain)) -ITT_STUB(ITTAPI, __itt_heap_function, heap_function_createW, (const wchar_t* name, const wchar_t* domain)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_heap_function, heap_function_create, (const char* name, const char* domain)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_heap_function_createA ITTNOTIFY_DATA(heap_function_createA) -#define __itt_heap_function_createA_ptr ITTNOTIFY_NAME(heap_function_createA) -#define __itt_heap_function_createW ITTNOTIFY_DATA(heap_function_createW) -#define __itt_heap_function_createW_ptr ITTNOTIFY_NAME(heap_function_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_heap_function_create ITTNOTIFY_DATA(heap_function_create) -#define __itt_heap_function_create_ptr ITTNOTIFY_NAME(heap_function_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_heap_function_createA(name, domain) (__itt_heap_function)0 -#define __itt_heap_function_createA_ptr 0 -#define __itt_heap_function_createW(name, domain) (__itt_heap_function)0 -#define __itt_heap_function_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_heap_function_create(name, domain) (__itt_heap_function)0 -#define __itt_heap_function_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_heap_function_createA_ptr 0 -#define __itt_heap_function_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_heap_function_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record an allocation begin occurrence. - */ -void ITTAPI __itt_heap_allocate_begin(__itt_heap_function h, size_t size, int initialized); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_allocate_begin, (__itt_heap_function h, size_t size, int initialized)) -#define __itt_heap_allocate_begin ITTNOTIFY_VOID(heap_allocate_begin) -#define __itt_heap_allocate_begin_ptr ITTNOTIFY_NAME(heap_allocate_begin) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_allocate_begin(h, size, initialized) -#define __itt_heap_allocate_begin_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_allocate_begin_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record an allocation end occurrence. - */ -void ITTAPI __itt_heap_allocate_end(__itt_heap_function h, void** addr, size_t size, int initialized); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_allocate_end, (__itt_heap_function h, void** addr, size_t size, int initialized)) -#define __itt_heap_allocate_end ITTNOTIFY_VOID(heap_allocate_end) -#define __itt_heap_allocate_end_ptr ITTNOTIFY_NAME(heap_allocate_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_allocate_end(h, addr, size, initialized) -#define __itt_heap_allocate_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_allocate_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record a free begin occurrence. - */ -void ITTAPI __itt_heap_free_begin(__itt_heap_function h, void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_free_begin, (__itt_heap_function h, void* addr)) -#define __itt_heap_free_begin ITTNOTIFY_VOID(heap_free_begin) -#define __itt_heap_free_begin_ptr ITTNOTIFY_NAME(heap_free_begin) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_free_begin(h, addr) -#define __itt_heap_free_begin_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_free_begin_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record a free end occurrence. - */ -void ITTAPI __itt_heap_free_end(__itt_heap_function h, void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_free_end, (__itt_heap_function h, void* addr)) -#define __itt_heap_free_end ITTNOTIFY_VOID(heap_free_end) -#define __itt_heap_free_end_ptr ITTNOTIFY_NAME(heap_free_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_free_end(h, addr) -#define __itt_heap_free_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_free_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record a reallocation begin occurrence. - */ -void ITTAPI __itt_heap_reallocate_begin(__itt_heap_function h, void* addr, size_t new_size, int initialized); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_reallocate_begin, (__itt_heap_function h, void* addr, size_t new_size, int initialized)) -#define __itt_heap_reallocate_begin ITTNOTIFY_VOID(heap_reallocate_begin) -#define __itt_heap_reallocate_begin_ptr ITTNOTIFY_NAME(heap_reallocate_begin) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_reallocate_begin(h, addr, new_size, initialized) -#define __itt_heap_reallocate_begin_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_reallocate_begin_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record a reallocation end occurrence. - */ -void ITTAPI __itt_heap_reallocate_end(__itt_heap_function h, void* addr, void** new_addr, size_t new_size, int initialized); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_reallocate_end, (__itt_heap_function h, void* addr, void** new_addr, size_t new_size, int initialized)) -#define __itt_heap_reallocate_end ITTNOTIFY_VOID(heap_reallocate_end) -#define __itt_heap_reallocate_end_ptr ITTNOTIFY_NAME(heap_reallocate_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_reallocate_end(h, addr, new_addr, new_size, initialized) -#define __itt_heap_reallocate_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_reallocate_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @brief internal access begin */ -void ITTAPI __itt_heap_internal_access_begin(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_internal_access_begin, (void)) -#define __itt_heap_internal_access_begin ITTNOTIFY_VOID(heap_internal_access_begin) -#define __itt_heap_internal_access_begin_ptr ITTNOTIFY_NAME(heap_internal_access_begin) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_internal_access_begin() -#define __itt_heap_internal_access_begin_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_internal_access_begin_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @brief internal access end */ -void ITTAPI __itt_heap_internal_access_end(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_internal_access_end, (void)) -#define __itt_heap_internal_access_end ITTNOTIFY_VOID(heap_internal_access_end) -#define __itt_heap_internal_access_end_ptr ITTNOTIFY_NAME(heap_internal_access_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_internal_access_end() -#define __itt_heap_internal_access_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_internal_access_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @brief record memory growth begin */ -void ITTAPI __itt_heap_record_memory_growth_begin(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_record_memory_growth_begin, (void)) -#define __itt_heap_record_memory_growth_begin ITTNOTIFY_VOID(heap_record_memory_growth_begin) -#define __itt_heap_record_memory_growth_begin_ptr ITTNOTIFY_NAME(heap_record_memory_growth_begin) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_record_memory_growth_begin() -#define __itt_heap_record_memory_growth_begin_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_record_memory_growth_begin_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @brief record memory growth end */ -void ITTAPI __itt_heap_record_memory_growth_end(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_record_memory_growth_end, (void)) -#define __itt_heap_record_memory_growth_end ITTNOTIFY_VOID(heap_record_memory_growth_end) -#define __itt_heap_record_memory_growth_end_ptr ITTNOTIFY_NAME(heap_record_memory_growth_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_record_memory_growth_end() -#define __itt_heap_record_memory_growth_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_record_memory_growth_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Specify the type of heap detection/reporting to modify. - */ -/** - * @hideinitializer - * @brief Report on memory leaks. - */ -#define __itt_heap_leaks 0x00000001 - -/** - * @hideinitializer - * @brief Report on memory growth. - */ -#define __itt_heap_growth 0x00000002 - - -/** @brief heap reset detection */ -void ITTAPI __itt_heap_reset_detection(unsigned int reset_mask); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_reset_detection, (unsigned int reset_mask)) -#define __itt_heap_reset_detection ITTNOTIFY_VOID(heap_reset_detection) -#define __itt_heap_reset_detection_ptr ITTNOTIFY_NAME(heap_reset_detection) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_reset_detection() -#define __itt_heap_reset_detection_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_reset_detection_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @brief report */ -void ITTAPI __itt_heap_record(unsigned int record_mask); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, heap_record, (unsigned int record_mask)) -#define __itt_heap_record ITTNOTIFY_VOID(heap_record) -#define __itt_heap_record_ptr ITTNOTIFY_NAME(heap_record) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_heap_record() -#define __itt_heap_record_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_heap_record_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @} heap group */ -/** @endcond */ -/* ========================================================================== */ - -/** - * @defgroup domains Domains - * @ingroup public - * Domains group - * @{ - */ - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_domain -{ - volatile int flags; /*!< Zero if disabled, non-zero if enabled. The meaning of different non-zero values is reserved to the runtime */ - const char* nameA; /*!< Copy of original name in ASCII. */ -#if defined(UNICODE) || defined(_UNICODE) - const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ -#else /* UNICODE || _UNICODE */ - void* nameW; -#endif /* UNICODE || _UNICODE */ - int extra1; /*!< Reserved to the runtime */ - void* extra2; /*!< Reserved to the runtime */ - struct ___itt_domain* next; -} __itt_domain; - -#pragma pack(pop) -/** @endcond */ - -/** - * @ingroup domains - * @brief Create a domain. - * Create domain using some domain name: the URI naming style is recommended. - * Because the set of domains is expected to be static over the application's - * execution time, there is no mechanism to destroy a domain. - * Any domain can be accessed by any thread in the process, regardless of - * which thread created the domain. This call is thread-safe. - * @param[in] name name of domain - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_domain* ITTAPI __itt_domain_createA(const char *name); -__itt_domain* ITTAPI __itt_domain_createW(const wchar_t *name); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_domain_create __itt_domain_createW -# define __itt_domain_create_ptr __itt_domain_createW_ptr -#else /* UNICODE */ -# define __itt_domain_create __itt_domain_createA -# define __itt_domain_create_ptr __itt_domain_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_domain* ITTAPI __itt_domain_create(const char *name); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_domain*, domain_createA, (const char *name)) -ITT_STUB(ITTAPI, __itt_domain*, domain_createW, (const wchar_t *name)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_domain*, domain_create, (const char *name)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_domain_createA ITTNOTIFY_DATA(domain_createA) -#define __itt_domain_createA_ptr ITTNOTIFY_NAME(domain_createA) -#define __itt_domain_createW ITTNOTIFY_DATA(domain_createW) -#define __itt_domain_createW_ptr ITTNOTIFY_NAME(domain_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_domain_create ITTNOTIFY_DATA(domain_create) -#define __itt_domain_create_ptr ITTNOTIFY_NAME(domain_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_domain_createA(name) (__itt_domain*)0 -#define __itt_domain_createA_ptr 0 -#define __itt_domain_createW(name) (__itt_domain*)0 -#define __itt_domain_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_domain_create(name) (__itt_domain*)0 -#define __itt_domain_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_domain_createA_ptr 0 -#define __itt_domain_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_domain_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} domains group */ - -/** - * @defgroup ids IDs - * @ingroup public - * IDs group - * @{ - */ - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_id -{ - unsigned long long d1, d2, d3; -} __itt_id; - -#pragma pack(pop) -/** @endcond */ - -static const __itt_id __itt_null = { 0, 0, 0 }; - -/** - * @ingroup ids - * @brief A convenience function is provided to create an ID without domain control. - * @brief This is a convenience function to initialize an __itt_id structure. This function - * does not affect the collector runtime in any way. After you make the ID with this - * function, you still must create it with the __itt_id_create function before using the ID - * to identify a named entity. - * @param[in] addr The address of object; high QWORD of the ID value. - * @param[in] extra The extra data to unique identify object; low QWORD of the ID value. - */ - -ITT_INLINE __itt_id ITTAPI __itt_id_make(void* addr, unsigned long long extra) ITT_INLINE_ATTRIBUTE; -ITT_INLINE __itt_id ITTAPI __itt_id_make(void* addr, unsigned long long extra) -{ - __itt_id id = __itt_null; - id.d1 = (unsigned long long)((uintptr_t)addr); - id.d2 = (unsigned long long)extra; - id.d3 = (unsigned long long)0; /* Reserved. Must be zero */ - return id; -} - -/** - * @ingroup ids - * @brief Create an instance of identifier. - * This establishes the beginning of the lifetime of an instance of - * the given ID in the trace. Once this lifetime starts, the ID - * can be used to tag named entity instances in calls such as - * __itt_task_begin, and to specify relationships among - * identified named entity instances, using the \ref relations APIs. - * Instance IDs are not domain specific! - * @param[in] domain The domain controlling the execution of this call. - * @param[in] id The ID to create. - */ -void ITTAPI __itt_id_create(const __itt_domain *domain, __itt_id id); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, id_create, (const __itt_domain *domain, __itt_id id)) -#define __itt_id_create(d,x) ITTNOTIFY_VOID_D1(id_create,d,x) -#define __itt_id_create_ptr ITTNOTIFY_NAME(id_create) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_id_create(domain,id) -#define __itt_id_create_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_id_create_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup ids - * @brief Destroy an instance of identifier. - * This ends the lifetime of the current instance of the given ID value in the trace. - * Any relationships that are established after this lifetime ends are invalid. - * This call must be performed before the given ID value can be reused for a different - * named entity instance. - * @param[in] domain The domain controlling the execution of this call. - * @param[in] id The ID to destroy. - */ -void ITTAPI __itt_id_destroy(const __itt_domain *domain, __itt_id id); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, id_destroy, (const __itt_domain *domain, __itt_id id)) -#define __itt_id_destroy(d,x) ITTNOTIFY_VOID_D1(id_destroy,d,x) -#define __itt_id_destroy_ptr ITTNOTIFY_NAME(id_destroy) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_id_destroy(domain,id) -#define __itt_id_destroy_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_id_destroy_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} ids group */ - -/** - * @defgroup handless String Handles - * @ingroup public - * String Handles group - * @{ - */ - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_string_handle -{ - const char* strA; /*!< Copy of original string in ASCII. */ -#if defined(UNICODE) || defined(_UNICODE) - const wchar_t* strW; /*!< Copy of original string in UNICODE. */ -#else /* UNICODE || _UNICODE */ - void* strW; -#endif /* UNICODE || _UNICODE */ - int extra1; /*!< Reserved. Must be zero */ - void* extra2; /*!< Reserved. Must be zero */ - struct ___itt_string_handle* next; -} __itt_string_handle; - -#pragma pack(pop) -/** @endcond */ - -/** - * @ingroup handles - * @brief Create a string handle. - * Create and return handle value that can be associated with a string. - * Consecutive calls to __itt_string_handle_create with the same name - * return the same value. Because the set of string handles is expected to remain - * static during the application's execution time, there is no mechanism to destroy a string handle. - * Any string handle can be accessed by any thread in the process, regardless of which thread created - * the string handle. This call is thread-safe. - * @param[in] name The input string - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_string_handle* ITTAPI __itt_string_handle_createA(const char *name); -__itt_string_handle* ITTAPI __itt_string_handle_createW(const wchar_t *name); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_string_handle_create __itt_string_handle_createW -# define __itt_string_handle_create_ptr __itt_string_handle_createW_ptr -#else /* UNICODE */ -# define __itt_string_handle_create __itt_string_handle_createA -# define __itt_string_handle_create_ptr __itt_string_handle_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_string_handle* ITTAPI __itt_string_handle_create(const char *name); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_string_handle*, string_handle_createA, (const char *name)) -ITT_STUB(ITTAPI, __itt_string_handle*, string_handle_createW, (const wchar_t *name)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_string_handle*, string_handle_create, (const char *name)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_string_handle_createA ITTNOTIFY_DATA(string_handle_createA) -#define __itt_string_handle_createA_ptr ITTNOTIFY_NAME(string_handle_createA) -#define __itt_string_handle_createW ITTNOTIFY_DATA(string_handle_createW) -#define __itt_string_handle_createW_ptr ITTNOTIFY_NAME(string_handle_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_string_handle_create ITTNOTIFY_DATA(string_handle_create) -#define __itt_string_handle_create_ptr ITTNOTIFY_NAME(string_handle_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_string_handle_createA(name) (__itt_string_handle*)0 -#define __itt_string_handle_createA_ptr 0 -#define __itt_string_handle_createW(name) (__itt_string_handle*)0 -#define __itt_string_handle_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_string_handle_create(name) (__itt_string_handle*)0 -#define __itt_string_handle_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_string_handle_createA_ptr 0 -#define __itt_string_handle_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_string_handle_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} handles group */ - -/** @cond exclude_from_documentation */ -typedef unsigned long long __itt_timestamp; -/** @endcond */ - -#define __itt_timestamp_none ((__itt_timestamp)-1LL) - -/** @cond exclude_from_gpa_documentation */ - -/** - * @ingroup timestamps - * @brief Return timestamp corresponding to the current moment. - * This returns the timestamp in the format that is the most relevant for the current - * host or platform (RDTSC, QPC, and others). You can use the "<" operator to - * compare __itt_timestamp values. - */ -__itt_timestamp ITTAPI __itt_get_timestamp(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_timestamp, get_timestamp, (void)) -#define __itt_get_timestamp ITTNOTIFY_DATA(get_timestamp) -#define __itt_get_timestamp_ptr ITTNOTIFY_NAME(get_timestamp) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_get_timestamp() -#define __itt_get_timestamp_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_get_timestamp_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} timestamps */ -/** @endcond */ - -/** @cond exclude_from_gpa_documentation */ - -/** - * @defgroup regions Regions - * @ingroup public - * Regions group - * @{ - */ -/** - * @ingroup regions - * @brief Begin of region instance. - * Successive calls to __itt_region_begin with the same ID are ignored - * until a call to __itt_region_end with the same ID - * @param[in] domain The domain for this region instance - * @param[in] id The instance ID for this region instance. Must not be __itt_null - * @param[in] parentid The instance ID for the parent of this region instance, or __itt_null - * @param[in] name The name of this region - */ -void ITTAPI __itt_region_begin(const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name); - -/** - * @ingroup regions - * @brief End of region instance. - * The first call to __itt_region_end with a given ID ends the - * region. Successive calls with the same ID are ignored, as are - * calls that do not have a matching __itt_region_begin call. - * @param[in] domain The domain for this region instance - * @param[in] id The instance ID for this region instance - */ -void ITTAPI __itt_region_end(const __itt_domain *domain, __itt_id id); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, region_begin, (const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name)) -ITT_STUBV(ITTAPI, void, region_end, (const __itt_domain *domain, __itt_id id)) -#define __itt_region_begin(d,x,y,z) ITTNOTIFY_VOID_D3(region_begin,d,x,y,z) -#define __itt_region_begin_ptr ITTNOTIFY_NAME(region_begin) -#define __itt_region_end(d,x) ITTNOTIFY_VOID_D1(region_end,d,x) -#define __itt_region_end_ptr ITTNOTIFY_NAME(region_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_region_begin(d,x,y,z) -#define __itt_region_begin_ptr 0 -#define __itt_region_end(d,x) -#define __itt_region_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_region_begin_ptr 0 -#define __itt_region_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} regions group */ - -/** - * @defgroup frames Frames - * @ingroup public - * Frames are similar to regions, but are intended to be easier to use and to implement. - * In particular: - * - Frames always represent periods of elapsed time - * - By default, frames have no nesting relationships - * @{ - */ - -/** - * @ingroup frames - * @brief Begin a frame instance. - * Successive calls to __itt_frame_begin with the - * same ID are ignored until a call to __itt_frame_end with the same ID. - * @param[in] domain The domain for this frame instance - * @param[in] id The instance ID for this frame instance or NULL - */ -void ITTAPI __itt_frame_begin_v3(const __itt_domain *domain, __itt_id *id); - -/** - * @ingroup frames - * @brief End a frame instance. - * The first call to __itt_frame_end with a given ID - * ends the frame. Successive calls with the same ID are ignored, as are - * calls that do not have a matching __itt_frame_begin call. - * @param[in] domain The domain for this frame instance - * @param[in] id The instance ID for this frame instance or NULL for current - */ -void ITTAPI __itt_frame_end_v3(const __itt_domain *domain, __itt_id *id); - -/** - * @ingroup frames - * @brief Submits a frame instance. - * Successive calls to __itt_frame_begin or __itt_frame_submit with the - * same ID are ignored until a call to __itt_frame_end or __itt_frame_submit - * with the same ID. - * Passing special __itt_timestamp_none value as "end" argument means - * take the current timestamp as the end timestamp. - * @param[in] domain The domain for this frame instance - * @param[in] id The instance ID for this frame instance or NULL - * @param[in] begin Timestamp of the beginning of the frame - * @param[in] end Timestamp of the end of the frame - */ -void ITTAPI __itt_frame_submit_v3(const __itt_domain *domain, __itt_id *id, - __itt_timestamp begin, __itt_timestamp end); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, frame_begin_v3, (const __itt_domain *domain, __itt_id *id)) -ITT_STUBV(ITTAPI, void, frame_end_v3, (const __itt_domain *domain, __itt_id *id)) -ITT_STUBV(ITTAPI, void, frame_submit_v3, (const __itt_domain *domain, __itt_id *id, __itt_timestamp begin, __itt_timestamp end)) -#define __itt_frame_begin_v3(d,x) ITTNOTIFY_VOID_D1(frame_begin_v3,d,x) -#define __itt_frame_begin_v3_ptr ITTNOTIFY_NAME(frame_begin_v3) -#define __itt_frame_end_v3(d,x) ITTNOTIFY_VOID_D1(frame_end_v3,d,x) -#define __itt_frame_end_v3_ptr ITTNOTIFY_NAME(frame_end_v3) -#define __itt_frame_submit_v3(d,x,b,e) ITTNOTIFY_VOID_D3(frame_submit_v3,d,x,b,e) -#define __itt_frame_submit_v3_ptr ITTNOTIFY_NAME(frame_submit_v3) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_frame_begin_v3(domain,id) -#define __itt_frame_begin_v3_ptr 0 -#define __itt_frame_end_v3(domain,id) -#define __itt_frame_end_v3_ptr 0 -#define __itt_frame_submit_v3(domain,id,begin,end) -#define __itt_frame_submit_v3_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_frame_begin_v3_ptr 0 -#define __itt_frame_end_v3_ptr 0 -#define __itt_frame_submit_v3_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} frames group */ -/** @endcond */ - -/** - * @defgroup taskgroup Task Group - * @ingroup public - * Task Group - * @{ - */ -/** - * @ingroup task_groups - * @brief Denotes a task_group instance. - * Successive calls to __itt_task_group with the same ID are ignored. - * @param[in] domain The domain for this task_group instance - * @param[in] id The instance ID for this task_group instance. Must not be __itt_null. - * @param[in] parentid The instance ID for the parent of this task_group instance, or __itt_null. - * @param[in] name The name of this task_group - */ -void ITTAPI __itt_task_group(const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, task_group, (const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name)) -#define __itt_task_group(d,x,y,z) ITTNOTIFY_VOID_D3(task_group,d,x,y,z) -#define __itt_task_group_ptr ITTNOTIFY_NAME(task_group) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_task_group(d,x,y,z) -#define __itt_task_group_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_task_group_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} taskgroup group */ - -/** - * @defgroup tasks Tasks - * @ingroup public - * A task instance represents a piece of work performed by a particular - * thread for a period of time. A call to __itt_task_begin creates a - * task instance. This becomes the current instance for that task on that - * thread. A following call to __itt_task_end on the same thread ends the - * instance. There may be multiple simultaneous instances of tasks with the - * same name on different threads. If an ID is specified, the task instance - * receives that ID. Nested tasks are allowed. - * - * Note: The task is defined by the bracketing of __itt_task_begin and - * __itt_task_end on the same thread. If some scheduling mechanism causes - * task switching (the thread executes a different user task) or task - * switching (the user task switches to a different thread) then this breaks - * the notion of current instance. Additional API calls are required to - * deal with that possibility. - * @{ - */ - -/** - * @ingroup tasks - * @brief Begin a task instance. - * @param[in] domain The domain for this task - * @param[in] taskid The instance ID for this task instance, or __itt_null - * @param[in] parentid The parent instance to which this task instance belongs, or __itt_null - * @param[in] name The name of this task - */ -void ITTAPI __itt_task_begin(const __itt_domain *domain, __itt_id taskid, __itt_id parentid, __itt_string_handle *name); - -/** - * @ingroup tasks - * @brief Begin a task instance. - * @param[in] domain The domain for this task - * @param[in] taskid The identifier for this task instance (may be 0) - * @param[in] parentid The parent of this task (may be 0) - * @param[in] fn The pointer to the function you are tracing - */ -void ITTAPI __itt_task_begin_fn(const __itt_domain *domain, __itt_id taskid, __itt_id parentid, void* fn); - -/** - * @ingroup tasks - * @brief End the current task instance. - * @param[in] domain The domain for this task - */ -void ITTAPI __itt_task_end(const __itt_domain *domain); - -/** - * @ingroup tasks - * @brief Begin an overlapped task instance. - * @param[in] domain The domain for this task. - * @param[in] taskid The identifier for this task instance, *cannot* be __itt_null. - * @param[in] parentid The parent of this task, or __itt_null. - * @param[in] name The name of this task. - */ -void ITTAPI __itt_task_begin_overlapped(const __itt_domain* domain, __itt_id taskid, __itt_id parentid, __itt_string_handle* name); - -/** - * @ingroup tasks - * @brief End an overlapped task instance. - * @param[in] domain The domain for this task - * @param[in] taskid Explicit ID of finished task - */ -void ITTAPI __itt_task_end_overlapped(const __itt_domain *domain, __itt_id taskid); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, task_begin, (const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name)) -ITT_STUBV(ITTAPI, void, task_begin_fn, (const __itt_domain *domain, __itt_id id, __itt_id parentid, void* fn)) -ITT_STUBV(ITTAPI, void, task_end, (const __itt_domain *domain)) -ITT_STUBV(ITTAPI, void, task_begin_overlapped, (const __itt_domain *domain, __itt_id taskid, __itt_id parentid, __itt_string_handle *name)) -ITT_STUBV(ITTAPI, void, task_end_overlapped, (const __itt_domain *domain, __itt_id taskid)) -#define __itt_task_begin(d,x,y,z) ITTNOTIFY_VOID_D3(task_begin,d,x,y,z) -#define __itt_task_begin_ptr ITTNOTIFY_NAME(task_begin) -#define __itt_task_begin_fn(d,x,y,z) ITTNOTIFY_VOID_D3(task_begin_fn,d,x,y,z) -#define __itt_task_begin_fn_ptr ITTNOTIFY_NAME(task_begin_fn) -#define __itt_task_end(d) ITTNOTIFY_VOID_D0(task_end,d) -#define __itt_task_end_ptr ITTNOTIFY_NAME(task_end) -#define __itt_task_begin_overlapped(d,x,y,z) ITTNOTIFY_VOID_D3(task_begin_overlapped,d,x,y,z) -#define __itt_task_begin_overlapped_ptr ITTNOTIFY_NAME(task_begin_overlapped) -#define __itt_task_end_overlapped(d,x) ITTNOTIFY_VOID_D1(task_end_overlapped,d,x) -#define __itt_task_end_overlapped_ptr ITTNOTIFY_NAME(task_end_overlapped) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_task_begin(domain,id,parentid,name) -#define __itt_task_begin_ptr 0 -#define __itt_task_begin_fn(domain,id,parentid,fn) -#define __itt_task_begin_fn_ptr 0 -#define __itt_task_end(domain) -#define __itt_task_end_ptr 0 -#define __itt_task_begin_overlapped(domain,taskid,parentid,name) -#define __itt_task_begin_overlapped_ptr 0 -#define __itt_task_end_overlapped(domain,taskid) -#define __itt_task_end_overlapped_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_task_begin_ptr 0 -#define __itt_task_begin_fn_ptr 0 -#define __itt_task_end_ptr 0 -#define __itt_task_begin_overlapped_ptr 0 -#define __itt_task_end_overlapped_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} tasks group */ - - -/** - * @defgroup markers Markers - * Markers represent a single discreet event in time. Markers have a scope, - * described by an enumerated type __itt_scope. Markers are created by - * the API call __itt_marker. A marker instance can be given an ID for use in - * adding metadata. - * @{ - */ - -/** - * @brief Describes the scope of an event object in the trace. - */ -typedef enum -{ - __itt_scope_unknown = 0, - __itt_scope_global, - __itt_scope_track_group, - __itt_scope_track, - __itt_scope_task, - __itt_scope_marker -} __itt_scope; - -/** @cond exclude_from_documentation */ -#define __itt_marker_scope_unknown __itt_scope_unknown -#define __itt_marker_scope_global __itt_scope_global -#define __itt_marker_scope_process __itt_scope_track_group -#define __itt_marker_scope_thread __itt_scope_track -#define __itt_marker_scope_task __itt_scope_task -/** @endcond */ - -/** - * @ingroup markers - * @brief Create a marker instance - * @param[in] domain The domain for this marker - * @param[in] id The instance ID for this marker or __itt_null - * @param[in] name The name for this marker - * @param[in] scope The scope for this marker - */ -void ITTAPI __itt_marker(const __itt_domain *domain, __itt_id id, __itt_string_handle *name, __itt_scope scope); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, marker, (const __itt_domain *domain, __itt_id id, __itt_string_handle *name, __itt_scope scope)) -#define __itt_marker(d,x,y,z) ITTNOTIFY_VOID_D3(marker,d,x,y,z) -#define __itt_marker_ptr ITTNOTIFY_NAME(marker) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_marker(domain,id,name,scope) -#define __itt_marker_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_marker_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} markers group */ - -/** - * @defgroup metadata Metadata - * The metadata API is used to attach extra information to named - * entities. Metadata can be attached to an identified named entity by ID, - * or to the current entity (which is always a task). - * - * Conceptually metadata has a type (what kind of metadata), a key (the - * name of the metadata), and a value (the actual data). The encoding of - * the value depends on the type of the metadata. - * - * The type of metadata is specified by an enumerated type __itt_metdata_type. - * @{ - */ - -/** - * @ingroup parameters - * @brief describes the type of metadata - */ -typedef enum { - __itt_metadata_unknown = 0, - __itt_metadata_u64, /**< Unsigned 64-bit integer */ - __itt_metadata_s64, /**< Signed 64-bit integer */ - __itt_metadata_u32, /**< Unsigned 32-bit integer */ - __itt_metadata_s32, /**< Signed 32-bit integer */ - __itt_metadata_u16, /**< Unsigned 16-bit integer */ - __itt_metadata_s16, /**< Signed 16-bit integer */ - __itt_metadata_float, /**< Signed 32-bit floating-point */ - __itt_metadata_double /**< SIgned 64-bit floating-point */ -} __itt_metadata_type; - -/** - * @ingroup parameters - * @brief Add metadata to an instance of a named entity. - * @param[in] domain The domain controlling the call - * @param[in] id The identifier of the instance to which the metadata is to be added, or __itt_null to add to the current task - * @param[in] key The name of the metadata - * @param[in] type The type of the metadata - * @param[in] count The number of elements of the given type. If count == 0, no metadata will be added. - * @param[in] data The metadata itself -*/ -void ITTAPI __itt_metadata_add(const __itt_domain *domain, __itt_id id, __itt_string_handle *key, __itt_metadata_type type, size_t count, void *data); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, metadata_add, (const __itt_domain *domain, __itt_id id, __itt_string_handle *key, __itt_metadata_type type, size_t count, void *data)) -#define __itt_metadata_add(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(metadata_add,d,x,y,z,a,b) -#define __itt_metadata_add_ptr ITTNOTIFY_NAME(metadata_add) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_metadata_add(d,x,y,z,a,b) -#define __itt_metadata_add_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_metadata_add_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup parameters - * @brief Add string metadata to an instance of a named entity. - * @param[in] domain The domain controlling the call - * @param[in] id The identifier of the instance to which the metadata is to be added, or __itt_null to add to the current task - * @param[in] key The name of the metadata - * @param[in] data The metadata itself - * @param[in] length The number of characters in the string, or -1 if the length is unknown but the string is null-terminated -*/ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_metadata_str_addA(const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const char *data, size_t length); -void ITTAPI __itt_metadata_str_addW(const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const wchar_t *data, size_t length); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_metadata_str_add __itt_metadata_str_addW -# define __itt_metadata_str_add_ptr __itt_metadata_str_addW_ptr -#else /* UNICODE */ -# define __itt_metadata_str_add __itt_metadata_str_addA -# define __itt_metadata_str_add_ptr __itt_metadata_str_addA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -void ITTAPI __itt_metadata_str_add(const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const char *data, size_t length); -#endif - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUBV(ITTAPI, void, metadata_str_addA, (const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const char *data, size_t length)) -ITT_STUBV(ITTAPI, void, metadata_str_addW, (const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const wchar_t *data, size_t length)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUBV(ITTAPI, void, metadata_str_add, (const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const char *data, size_t length)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_metadata_str_addA(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_addA,d,x,y,z,a) -#define __itt_metadata_str_addA_ptr ITTNOTIFY_NAME(metadata_str_addA) -#define __itt_metadata_str_addW(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_addW,d,x,y,z,a) -#define __itt_metadata_str_addW_ptr ITTNOTIFY_NAME(metadata_str_addW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_metadata_str_add(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_add,d,x,y,z,a) -#define __itt_metadata_str_add_ptr ITTNOTIFY_NAME(metadata_str_add) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_metadata_str_addA(d,x,y,z,a) -#define __itt_metadata_str_addA_ptr 0 -#define __itt_metadata_str_addW(d,x,y,z,a) -#define __itt_metadata_str_addW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_metadata_str_add(d,x,y,z,a) -#define __itt_metadata_str_add_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_metadata_str_addA_ptr 0 -#define __itt_metadata_str_addW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_metadata_str_add_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup parameters - * @brief Add metadata to an instance of a named entity. - * @param[in] domain The domain controlling the call - * @param[in] scope The scope of the instance to which the metadata is to be added - - * @param[in] id The identifier of the instance to which the metadata is to be added, or __itt_null to add to the current task - - * @param[in] key The name of the metadata - * @param[in] type The type of the metadata - * @param[in] count The number of elements of the given type. If count == 0, no metadata will be added. - * @param[in] data The metadata itself -*/ -void ITTAPI __itt_metadata_add_with_scope(const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, __itt_metadata_type type, size_t count, void *data); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, metadata_add_with_scope, (const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, __itt_metadata_type type, size_t count, void *data)) -#define __itt_metadata_add_with_scope(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(metadata_add_with_scope,d,x,y,z,a,b) -#define __itt_metadata_add_with_scope_ptr ITTNOTIFY_NAME(metadata_add_with_scope) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_metadata_add_with_scope(d,x,y,z,a,b) -#define __itt_metadata_add_with_scope_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_metadata_add_with_scope_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup parameters - * @brief Add string metadata to an instance of a named entity. - * @param[in] domain The domain controlling the call - * @param[in] scope The scope of the instance to which the metadata is to be added - - * @param[in] id The identifier of the instance to which the metadata is to be added, or __itt_null to add to the current task - - * @param[in] key The name of the metadata - * @param[in] data The metadata itself - * @param[in] length The number of characters in the string, or -1 if the length is unknown but the string is null-terminated -*/ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_metadata_str_add_with_scopeA(const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const char *data, size_t length); -void ITTAPI __itt_metadata_str_add_with_scopeW(const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const wchar_t *data, size_t length); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_metadata_str_add_with_scope __itt_metadata_str_add_with_scopeW -# define __itt_metadata_str_add_with_scope_ptr __itt_metadata_str_add_with_scopeW_ptr -#else /* UNICODE */ -# define __itt_metadata_str_add_with_scope __itt_metadata_str_add_with_scopeA -# define __itt_metadata_str_add_with_scope_ptr __itt_metadata_str_add_with_scopeA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -void ITTAPI __itt_metadata_str_add_with_scope(const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const char *data, size_t length); -#endif - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUBV(ITTAPI, void, metadata_str_add_with_scopeA, (const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const char *data, size_t length)) -ITT_STUBV(ITTAPI, void, metadata_str_add_with_scopeW, (const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const wchar_t *data, size_t length)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUBV(ITTAPI, void, metadata_str_add_with_scope, (const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const char *data, size_t length)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_metadata_str_add_with_scopeA(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_add_with_scopeA,d,x,y,z,a) -#define __itt_metadata_str_add_with_scopeA_ptr ITTNOTIFY_NAME(metadata_str_add_with_scopeA) -#define __itt_metadata_str_add_with_scopeW(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_add_with_scopeW,d,x,y,z,a) -#define __itt_metadata_str_add_with_scopeW_ptr ITTNOTIFY_NAME(metadata_str_add_with_scopeW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_metadata_str_add_with_scope(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_add_with_scope,d,x,y,z,a) -#define __itt_metadata_str_add_with_scope_ptr ITTNOTIFY_NAME(metadata_str_add_with_scope) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_metadata_str_add_with_scopeA(d,x,y,z,a) -#define __itt_metadata_str_add_with_scopeA_ptr 0 -#define __itt_metadata_str_add_with_scopeW(d,x,y,z,a) -#define __itt_metadata_str_add_with_scopeW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_metadata_str_add_with_scope(d,x,y,z,a) -#define __itt_metadata_str_add_with_scope_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_metadata_str_add_with_scopeA_ptr 0 -#define __itt_metadata_str_add_with_scopeW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_metadata_str_add_with_scope_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @} metadata group */ - -/** - * @defgroup relations Relations - * Instances of named entities can be explicitly associated with other - * instances using instance IDs and the relationship API calls. - * - * @{ - */ - -/** - * @ingroup relations - * @brief The kind of relation between two instances is specified by the enumerated type __itt_relation. - * Relations between instances can be added with an API call. The relation - * API uses instance IDs. Relations can be added before or after the actual - * instances are created and persist independently of the instances. This - * is the motivation for having different lifetimes for instance IDs and - * the actual instances. - */ -typedef enum -{ - __itt_relation_is_unknown = 0, - __itt_relation_is_dependent_on, /**< "A is dependent on B" means that A cannot start until B completes */ - __itt_relation_is_sibling_of, /**< "A is sibling of B" means that A and B were created as a group */ - __itt_relation_is_parent_of, /**< "A is parent of B" means that A created B */ - __itt_relation_is_continuation_of, /**< "A is continuation of B" means that A assumes the dependencies of B */ - __itt_relation_is_child_of, /**< "A is child of B" means that A was created by B (inverse of is_parent_of) */ - __itt_relation_is_continued_by, /**< "A is continued by B" means that B assumes the dependencies of A (inverse of is_continuation_of) */ - __itt_relation_is_predecessor_to /**< "A is predecessor to B" means that B cannot start until A completes (inverse of is_dependent_on) */ -} __itt_relation; - -/** - * @ingroup relations - * @brief Add a relation to the current task instance. - * The current task instance is the head of the relation. - * @param[in] domain The domain controlling this call - * @param[in] relation The kind of relation - * @param[in] tail The ID for the tail of the relation - */ -void ITTAPI __itt_relation_add_to_current(const __itt_domain *domain, __itt_relation relation, __itt_id tail); - -/** - * @ingroup relations - * @brief Add a relation between two instance identifiers. - * @param[in] domain The domain controlling this call - * @param[in] head The ID for the head of the relation - * @param[in] relation The kind of relation - * @param[in] tail The ID for the tail of the relation - */ -void ITTAPI __itt_relation_add(const __itt_domain *domain, __itt_id head, __itt_relation relation, __itt_id tail); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, relation_add_to_current, (const __itt_domain *domain, __itt_relation relation, __itt_id tail)) -ITT_STUBV(ITTAPI, void, relation_add, (const __itt_domain *domain, __itt_id head, __itt_relation relation, __itt_id tail)) -#define __itt_relation_add_to_current(d,x,y) ITTNOTIFY_VOID_D2(relation_add_to_current,d,x,y) -#define __itt_relation_add_to_current_ptr ITTNOTIFY_NAME(relation_add_to_current) -#define __itt_relation_add(d,x,y,z) ITTNOTIFY_VOID_D3(relation_add,d,x,y,z) -#define __itt_relation_add_ptr ITTNOTIFY_NAME(relation_add) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_relation_add_to_current(d,x,y) -#define __itt_relation_add_to_current_ptr 0 -#define __itt_relation_add(d,x,y,z) -#define __itt_relation_add_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_relation_add_to_current_ptr 0 -#define __itt_relation_add_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} relations group */ - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_clock_info -{ - unsigned long long clock_freq; /*!< Clock domain frequency */ - unsigned long long clock_base; /*!< Clock domain base timestamp */ -} __itt_clock_info; - -#pragma pack(pop) -/** @endcond */ - -/** @cond exclude_from_documentation */ -typedef void (ITTAPI *__itt_get_clock_info_fn)(__itt_clock_info* clock_info, void* data); -/** @endcond */ - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_clock_domain -{ - __itt_clock_info info; /*!< Most recent clock domain info */ - __itt_get_clock_info_fn fn; /*!< Callback function pointer */ - void* fn_data; /*!< Input argument for the callback function */ - int extra1; /*!< Reserved. Must be zero */ - void* extra2; /*!< Reserved. Must be zero */ - struct ___itt_clock_domain* next; -} __itt_clock_domain; - -#pragma pack(pop) -/** @endcond */ - -/** - * @ingroup clockdomains - * @brief Create a clock domain. - * Certain applications require the capability to trace their application using - * a clock domain different than the CPU, for instance the instrumentation of events - * that occur on a GPU. - * Because the set of domains is expected to be static over the application's execution time, - * there is no mechanism to destroy a domain. - * Any domain can be accessed by any thread in the process, regardless of which thread created - * the domain. This call is thread-safe. - * @param[in] fn A pointer to a callback function which retrieves alternative CPU timestamps - * @param[in] fn_data Argument for a callback function; may be NULL - */ -__itt_clock_domain* ITTAPI __itt_clock_domain_create(__itt_get_clock_info_fn fn, void* fn_data); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_clock_domain*, clock_domain_create, (__itt_get_clock_info_fn fn, void* fn_data)) -#define __itt_clock_domain_create ITTNOTIFY_DATA(clock_domain_create) -#define __itt_clock_domain_create_ptr ITTNOTIFY_NAME(clock_domain_create) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_clock_domain_create(fn,fn_data) (__itt_clock_domain*)0 -#define __itt_clock_domain_create_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_clock_domain_create_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup clockdomains - * @brief Recalculate clock domains frequencies and clock base timestamps. - */ -void ITTAPI __itt_clock_domain_reset(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, clock_domain_reset, (void)) -#define __itt_clock_domain_reset ITTNOTIFY_VOID(clock_domain_reset) -#define __itt_clock_domain_reset_ptr ITTNOTIFY_NAME(clock_domain_reset) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_clock_domain_reset() -#define __itt_clock_domain_reset_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_clock_domain_reset_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup clockdomain - * @brief Create an instance of identifier. This establishes the beginning of the lifetime of - * an instance of the given ID in the trace. Once this lifetime starts, the ID can be used to - * tag named entity instances in calls such as __itt_task_begin, and to specify relationships among - * identified named entity instances, using the \ref relations APIs. - * @param[in] domain The domain controlling the execution of this call. - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] id The ID to create. - */ -void ITTAPI __itt_id_create_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id); - -/** - * @ingroup clockdomain - * @brief Destroy an instance of identifier. This ends the lifetime of the current instance of the - * given ID value in the trace. Any relationships that are established after this lifetime ends are - * invalid. This call must be performed before the given ID value can be reused for a different - * named entity instance. - * @param[in] domain The domain controlling the execution of this call. - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] id The ID to destroy. - */ -void ITTAPI __itt_id_destroy_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, id_create_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id)) -ITT_STUBV(ITTAPI, void, id_destroy_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id)) -#define __itt_id_create_ex(d,x,y,z) ITTNOTIFY_VOID_D3(id_create_ex,d,x,y,z) -#define __itt_id_create_ex_ptr ITTNOTIFY_NAME(id_create_ex) -#define __itt_id_destroy_ex(d,x,y,z) ITTNOTIFY_VOID_D3(id_destroy_ex,d,x,y,z) -#define __itt_id_destroy_ex_ptr ITTNOTIFY_NAME(id_destroy_ex) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_id_create_ex(domain,clock_domain,timestamp,id) -#define __itt_id_create_ex_ptr 0 -#define __itt_id_destroy_ex(domain,clock_domain,timestamp,id) -#define __itt_id_destroy_ex_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_id_create_ex_ptr 0 -#define __itt_id_destroy_ex_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup clockdomain - * @brief Begin a task instance. - * @param[in] domain The domain for this task - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] taskid The instance ID for this task instance, or __itt_null - * @param[in] parentid The parent instance to which this task instance belongs, or __itt_null - * @param[in] name The name of this task - */ -void ITTAPI __itt_task_begin_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid, __itt_id parentid, __itt_string_handle* name); - -/** - * @ingroup clockdomain - * @brief Begin a task instance. - * @param[in] domain The domain for this task - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] taskid The identifier for this task instance, or __itt_null - * @param[in] parentid The parent of this task, or __itt_null - * @param[in] fn The pointer to the function you are tracing - */ -void ITTAPI __itt_task_begin_fn_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid, __itt_id parentid, void* fn); - -/** - * @ingroup clockdomain - * @brief End the current task instance. - * @param[in] domain The domain for this task - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - */ -void ITTAPI __itt_task_end_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, task_begin_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id, __itt_id parentid, __itt_string_handle *name)) -ITT_STUBV(ITTAPI, void, task_begin_fn_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id, __itt_id parentid, void* fn)) -ITT_STUBV(ITTAPI, void, task_end_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp)) -#define __itt_task_begin_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(task_begin_ex,d,x,y,z,a,b) -#define __itt_task_begin_ex_ptr ITTNOTIFY_NAME(task_begin_ex) -#define __itt_task_begin_fn_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(task_begin_fn_ex,d,x,y,z,a,b) -#define __itt_task_begin_fn_ex_ptr ITTNOTIFY_NAME(task_begin_fn_ex) -#define __itt_task_end_ex(d,x,y) ITTNOTIFY_VOID_D2(task_end_ex,d,x,y) -#define __itt_task_end_ex_ptr ITTNOTIFY_NAME(task_end_ex) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_task_begin_ex(domain,clock_domain,timestamp,id,parentid,name) -#define __itt_task_begin_ex_ptr 0 -#define __itt_task_begin_fn_ex(domain,clock_domain,timestamp,id,parentid,fn) -#define __itt_task_begin_fn_ex_ptr 0 -#define __itt_task_end_ex(domain,clock_domain,timestamp) -#define __itt_task_end_ex_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_task_begin_ex_ptr 0 -#define __itt_task_begin_fn_ex_ptr 0 -#define __itt_task_end_ex_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @defgroup counters Counters - * @ingroup public - * Counters are user-defined objects with a monotonically increasing - * value. Counter values are 64-bit unsigned integers. - * Counters have names that can be displayed in - * the tools. - * @{ - */ - -/** - * @brief opaque structure for counter identification - */ -/** @cond exclude_from_documentation */ - -typedef struct ___itt_counter* __itt_counter; - -/** - * @brief Create an unsigned 64 bits integer counter with given name/domain - * - * After __itt_counter_create() is called, __itt_counter_inc(id), __itt_counter_inc_delta(id, delta), - * __itt_counter_set_value(id, value_ptr) or __itt_counter_set_value_ex(id, clock_domain, timestamp, value_ptr) - * can be used to change the value of the counter, where value_ptr is a pointer to an unsigned 64 bits integer - * - * The call is equal to __itt_counter_create_typed(name, domain, __itt_metadata_u64) - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_counter ITTAPI __itt_counter_createA(const char *name, const char *domain); -__itt_counter ITTAPI __itt_counter_createW(const wchar_t *name, const wchar_t *domain); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_counter_create __itt_counter_createW -# define __itt_counter_create_ptr __itt_counter_createW_ptr -#else /* UNICODE */ -# define __itt_counter_create __itt_counter_createA -# define __itt_counter_create_ptr __itt_counter_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_counter ITTAPI __itt_counter_create(const char *name, const char *domain); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_counter, counter_createA, (const char *name, const char *domain)) -ITT_STUB(ITTAPI, __itt_counter, counter_createW, (const wchar_t *name, const wchar_t *domain)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_counter, counter_create, (const char *name, const char *domain)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_counter_createA ITTNOTIFY_DATA(counter_createA) -#define __itt_counter_createA_ptr ITTNOTIFY_NAME(counter_createA) -#define __itt_counter_createW ITTNOTIFY_DATA(counter_createW) -#define __itt_counter_createW_ptr ITTNOTIFY_NAME(counter_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_counter_create ITTNOTIFY_DATA(counter_create) -#define __itt_counter_create_ptr ITTNOTIFY_NAME(counter_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_counter_createA(name, domain) -#define __itt_counter_createA_ptr 0 -#define __itt_counter_createW(name, domain) -#define __itt_counter_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_counter_create(name, domain) -#define __itt_counter_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_counter_createA_ptr 0 -#define __itt_counter_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_counter_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Increment the unsigned 64 bits integer counter value - * - * Calling this function to non-unsigned 64 bits integer counters has no effect - */ -void ITTAPI __itt_counter_inc(__itt_counter id); - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_inc, (__itt_counter id)) -#define __itt_counter_inc ITTNOTIFY_VOID(counter_inc) -#define __itt_counter_inc_ptr ITTNOTIFY_NAME(counter_inc) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_inc(id) -#define __itt_counter_inc_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_inc_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** - * @brief Increment the unsigned 64 bits integer counter value with x - * - * Calling this function to non-unsigned 64 bits integer counters has no effect - */ -void ITTAPI __itt_counter_inc_delta(__itt_counter id, unsigned long long value); - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_inc_delta, (__itt_counter id, unsigned long long value)) -#define __itt_counter_inc_delta ITTNOTIFY_VOID(counter_inc_delta) -#define __itt_counter_inc_delta_ptr ITTNOTIFY_NAME(counter_inc_delta) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_inc_delta(id, value) -#define __itt_counter_inc_delta_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_inc_delta_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Decrement the unsigned 64 bits integer counter value - * - * Calling this function to non-unsigned 64 bits integer counters has no effect - */ -void ITTAPI __itt_counter_dec(__itt_counter id); - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_dec, (__itt_counter id)) -#define __itt_counter_dec ITTNOTIFY_VOID(counter_dec) -#define __itt_counter_dec_ptr ITTNOTIFY_NAME(counter_dec) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_dec(id) -#define __itt_counter_dec_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_dec_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** - * @brief Decrement the unsigned 64 bits integer counter value with x - * - * Calling this function to non-unsigned 64 bits integer counters has no effect - */ -void ITTAPI __itt_counter_dec_delta(__itt_counter id, unsigned long long value); - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_dec_delta, (__itt_counter id, unsigned long long value)) -#define __itt_counter_dec_delta ITTNOTIFY_VOID(counter_dec_delta) -#define __itt_counter_dec_delta_ptr ITTNOTIFY_NAME(counter_dec_delta) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_dec_delta(id, value) -#define __itt_counter_dec_delta_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_dec_delta_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup counters - * @brief Increment a counter by one. - * The first call with a given name creates a counter by that name and sets its - * value to zero. Successive calls increment the counter value. - * @param[in] domain The domain controlling the call. Counter names are not domain specific. - * The domain argument is used only to enable or disable the API calls. - * @param[in] name The name of the counter - */ -void ITTAPI __itt_counter_inc_v3(const __itt_domain *domain, __itt_string_handle *name); - -/** - * @ingroup counters - * @brief Increment a counter by the value specified in delta. - * @param[in] domain The domain controlling the call. Counter names are not domain specific. - * The domain argument is used only to enable or disable the API calls. - * @param[in] name The name of the counter - * @param[in] delta The amount by which to increment the counter - */ -void ITTAPI __itt_counter_inc_delta_v3(const __itt_domain *domain, __itt_string_handle *name, unsigned long long delta); - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_inc_v3, (const __itt_domain *domain, __itt_string_handle *name)) -ITT_STUBV(ITTAPI, void, counter_inc_delta_v3, (const __itt_domain *domain, __itt_string_handle *name, unsigned long long delta)) -#define __itt_counter_inc_v3(d,x) ITTNOTIFY_VOID_D1(counter_inc_v3,d,x) -#define __itt_counter_inc_v3_ptr ITTNOTIFY_NAME(counter_inc_v3) -#define __itt_counter_inc_delta_v3(d,x,y) ITTNOTIFY_VOID_D2(counter_inc_delta_v3,d,x,y) -#define __itt_counter_inc_delta_v3_ptr ITTNOTIFY_NAME(counter_inc_delta_v3) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_inc_v3(domain,name) -#define __itt_counter_inc_v3_ptr 0 -#define __itt_counter_inc_delta_v3(domain,name,delta) -#define __itt_counter_inc_delta_v3_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_inc_v3_ptr 0 -#define __itt_counter_inc_delta_v3_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - - -/** - * @ingroup counters - * @brief Decrement a counter by one. - * The first call with a given name creates a counter by that name and sets its - * value to zero. Successive calls decrement the counter value. - * @param[in] domain The domain controlling the call. Counter names are not domain specific. - * The domain argument is used only to enable or disable the API calls. - * @param[in] name The name of the counter - */ -void ITTAPI __itt_counter_dec_v3(const __itt_domain *domain, __itt_string_handle *name); - -/** - * @ingroup counters - * @brief Decrement a counter by the value specified in delta. - * @param[in] domain The domain controlling the call. Counter names are not domain specific. - * The domain argument is used only to enable or disable the API calls. - * @param[in] name The name of the counter - * @param[in] delta The amount by which to decrement the counter - */ -void ITTAPI __itt_counter_dec_delta_v3(const __itt_domain *domain, __itt_string_handle *name, unsigned long long delta); - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_dec_v3, (const __itt_domain *domain, __itt_string_handle *name)) -ITT_STUBV(ITTAPI, void, counter_dec_delta_v3, (const __itt_domain *domain, __itt_string_handle *name, unsigned long long delta)) -#define __itt_counter_dec_v3(d,x) ITTNOTIFY_VOID_D1(counter_dec_v3,d,x) -#define __itt_counter_dec_v3_ptr ITTNOTIFY_NAME(counter_dec_v3) -#define __itt_counter_dec_delta_v3(d,x,y) ITTNOTIFY_VOID_D2(counter_dec_delta_v3,d,x,y) -#define __itt_counter_dec_delta_v3_ptr ITTNOTIFY_NAME(counter_dec_delta_v3) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_dec_v3(domain,name) -#define __itt_counter_dec_v3_ptr 0 -#define __itt_counter_dec_delta_v3(domain,name,delta) -#define __itt_counter_dec_delta_v3_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_dec_v3_ptr 0 -#define __itt_counter_dec_delta_v3_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @} counters group */ - - -/** - * @brief Set the counter value - */ -void ITTAPI __itt_counter_set_value(__itt_counter id, void *value_ptr); - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_set_value, (__itt_counter id, void *value_ptr)) -#define __itt_counter_set_value ITTNOTIFY_VOID(counter_set_value) -#define __itt_counter_set_value_ptr ITTNOTIFY_NAME(counter_set_value) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_set_value(id, value_ptr) -#define __itt_counter_set_value_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_set_value_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Set the counter value - */ -void ITTAPI __itt_counter_set_value_ex(__itt_counter id, __itt_clock_domain *clock_domain, unsigned long long timestamp, void *value_ptr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_set_value_ex, (__itt_counter id, __itt_clock_domain *clock_domain, unsigned long long timestamp, void *value_ptr)) -#define __itt_counter_set_value_ex ITTNOTIFY_VOID(counter_set_value_ex) -#define __itt_counter_set_value_ex_ptr ITTNOTIFY_NAME(counter_set_value_ex) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_set_value_ex(id, clock_domain, timestamp, value_ptr) -#define __itt_counter_set_value_ex_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_set_value_ex_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Create a typed counter with given name/domain - * - * After __itt_counter_create_typed() is called, __itt_counter_inc(id), __itt_counter_inc_delta(id, delta), - * __itt_counter_set_value(id, value_ptr) or __itt_counter_set_value_ex(id, clock_domain, timestamp, value_ptr) - * can be used to change the value of the counter - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_counter ITTAPI __itt_counter_create_typedA(const char *name, const char *domain, __itt_metadata_type type); -__itt_counter ITTAPI __itt_counter_create_typedW(const wchar_t *name, const wchar_t *domain, __itt_metadata_type type); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_counter_create_typed __itt_counter_create_typedW -# define __itt_counter_create_typed_ptr __itt_counter_create_typedW_ptr -#else /* UNICODE */ -# define __itt_counter_create_typed __itt_counter_create_typedA -# define __itt_counter_create_typed_ptr __itt_counter_create_typedA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_counter ITTAPI __itt_counter_create_typed(const char *name, const char *domain, __itt_metadata_type type); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_counter, counter_create_typedA, (const char *name, const char *domain, __itt_metadata_type type)) -ITT_STUB(ITTAPI, __itt_counter, counter_create_typedW, (const wchar_t *name, const wchar_t *domain, __itt_metadata_type type)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_counter, counter_create_typed, (const char *name, const char *domain, __itt_metadata_type type)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_counter_create_typedA ITTNOTIFY_DATA(counter_create_typedA) -#define __itt_counter_create_typedA_ptr ITTNOTIFY_NAME(counter_create_typedA) -#define __itt_counter_create_typedW ITTNOTIFY_DATA(counter_create_typedW) -#define __itt_counter_create_typedW_ptr ITTNOTIFY_NAME(counter_create_typedW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_counter_create_typed ITTNOTIFY_DATA(counter_create_typed) -#define __itt_counter_create_typed_ptr ITTNOTIFY_NAME(counter_create_typed) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_counter_create_typedA(name, domain, type) -#define __itt_counter_create_typedA_ptr 0 -#define __itt_counter_create_typedW(name, domain, type) -#define __itt_counter_create_typedW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_counter_create_typed(name, domain, type) -#define __itt_counter_create_typed_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_counter_create_typedA_ptr 0 -#define __itt_counter_create_typedW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_counter_create_typed_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Destroy the counter identified by the pointer previously returned by __itt_counter_create() or - * __itt_counter_create_typed() - */ -void ITTAPI __itt_counter_destroy(__itt_counter id); - -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, counter_destroy, (__itt_counter id)) -#define __itt_counter_destroy ITTNOTIFY_VOID(counter_destroy) -#define __itt_counter_destroy_ptr ITTNOTIFY_NAME(counter_destroy) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_counter_destroy(id) -#define __itt_counter_destroy_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_counter_destroy_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} counters group */ - -/** - * @ingroup markers - * @brief Create a marker instance. - * @param[in] domain The domain for this marker - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] id The instance ID for this marker, or __itt_null - * @param[in] name The name for this marker - * @param[in] scope The scope for this marker - */ -void ITTAPI __itt_marker_ex(const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id, __itt_string_handle *name, __itt_scope scope); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, marker_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id, __itt_string_handle *name, __itt_scope scope)) -#define __itt_marker_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(marker_ex,d,x,y,z,a,b) -#define __itt_marker_ex_ptr ITTNOTIFY_NAME(marker_ex) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_marker_ex(domain,clock_domain,timestamp,id,name,scope) -#define __itt_marker_ex_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_marker_ex_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @ingroup clockdomain - * @brief Add a relation to the current task instance. - * The current task instance is the head of the relation. - * @param[in] domain The domain controlling this call - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] relation The kind of relation - * @param[in] tail The ID for the tail of the relation - */ -void ITTAPI __itt_relation_add_to_current_ex(const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_relation relation, __itt_id tail); - -/** - * @ingroup clockdomain - * @brief Add a relation between two instance identifiers. - * @param[in] domain The domain controlling this call - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] head The ID for the head of the relation - * @param[in] relation The kind of relation - * @param[in] tail The ID for the tail of the relation - */ -void ITTAPI __itt_relation_add_ex(const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id head, __itt_relation relation, __itt_id tail); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, relation_add_to_current_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_relation relation, __itt_id tail)) -ITT_STUBV(ITTAPI, void, relation_add_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id head, __itt_relation relation, __itt_id tail)) -#define __itt_relation_add_to_current_ex(d,x,y,z,a) ITTNOTIFY_VOID_D4(relation_add_to_current_ex,d,x,y,z,a) -#define __itt_relation_add_to_current_ex_ptr ITTNOTIFY_NAME(relation_add_to_current_ex) -#define __itt_relation_add_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(relation_add_ex,d,x,y,z,a,b) -#define __itt_relation_add_ex_ptr ITTNOTIFY_NAME(relation_add_ex) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_relation_add_to_current_ex(domain,clock_domain,timestame,relation,tail) -#define __itt_relation_add_to_current_ex_ptr 0 -#define __itt_relation_add_ex(domain,clock_domain,timestamp,head,relation,tail) -#define __itt_relation_add_ex_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_relation_add_to_current_ex_ptr 0 -#define __itt_relation_add_ex_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @cond exclude_from_documentation */ -typedef enum ___itt_track_group_type -{ - __itt_track_group_type_normal = 0 -} __itt_track_group_type; -/** @endcond */ - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_track_group -{ - __itt_string_handle* name; /*!< Name of the track group */ - struct ___itt_track* track; /*!< List of child tracks */ - __itt_track_group_type tgtype; /*!< Type of the track group */ - int extra1; /*!< Reserved. Must be zero */ - void* extra2; /*!< Reserved. Must be zero */ - struct ___itt_track_group* next; -} __itt_track_group; - -#pragma pack(pop) -/** @endcond */ - -/** - * @brief Placeholder for custom track types. Currently, "normal" custom track - * is the only available track type. - */ -typedef enum ___itt_track_type -{ - __itt_track_type_normal = 0 -#ifdef INTEL_ITTNOTIFY_API_PRIVATE - , __itt_track_type_queue -#endif /* INTEL_ITTNOTIFY_API_PRIVATE */ -} __itt_track_type; - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_track -{ - __itt_string_handle* name; /*!< Name of the track group */ - __itt_track_group* group; /*!< Parent group to a track */ - __itt_track_type ttype; /*!< Type of the track */ - int extra1; /*!< Reserved. Must be zero */ - void* extra2; /*!< Reserved. Must be zero */ - struct ___itt_track* next; -} __itt_track; - -#pragma pack(pop) -/** @endcond */ - -/** - * @brief Create logical track group. - */ -__itt_track_group* ITTAPI __itt_track_group_create(__itt_string_handle* name, __itt_track_group_type track_group_type); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_track_group*, track_group_create, (__itt_string_handle* name, __itt_track_group_type track_group_type)) -#define __itt_track_group_create ITTNOTIFY_DATA(track_group_create) -#define __itt_track_group_create_ptr ITTNOTIFY_NAME(track_group_create) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_track_group_create(name) (__itt_track_group*)0 -#define __itt_track_group_create_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_track_group_create_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Create logical track. - */ -__itt_track* ITTAPI __itt_track_create(__itt_track_group* track_group, __itt_string_handle* name, __itt_track_type track_type); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_track*, track_create, (__itt_track_group* track_group,__itt_string_handle* name, __itt_track_type track_type)) -#define __itt_track_create ITTNOTIFY_DATA(track_create) -#define __itt_track_create_ptr ITTNOTIFY_NAME(track_create) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_track_create(track_group,name,track_type) (__itt_track*)0 -#define __itt_track_create_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_track_create_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Set the logical track. - */ -void ITTAPI __itt_set_track(__itt_track* track); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, set_track, (__itt_track *track)) -#define __itt_set_track ITTNOTIFY_VOID(set_track) -#define __itt_set_track_ptr ITTNOTIFY_NAME(set_track) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_set_track(track) -#define __itt_set_track_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_set_track_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/* ========================================================================== */ -/** @cond exclude_from_gpa_documentation */ -/** - * @defgroup events Events - * @ingroup public - * Events group - * @{ - */ -/** @brief user event type */ -typedef int __itt_event; - -/** - * @brief Create an event notification - * @note name or namelen being null/name and namelen not matching, user event feature not enabled - * @return non-zero event identifier upon success and __itt_err otherwise - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_event LIBITTAPI __itt_event_createA(const char *name, int namelen); -__itt_event LIBITTAPI __itt_event_createW(const wchar_t *name, int namelen); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_event_create __itt_event_createW -# define __itt_event_create_ptr __itt_event_createW_ptr -#else -# define __itt_event_create __itt_event_createA -# define __itt_event_create_ptr __itt_event_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_event LIBITTAPI __itt_event_create(const char *name, int namelen); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(LIBITTAPI, __itt_event, event_createA, (const char *name, int namelen)) -ITT_STUB(LIBITTAPI, __itt_event, event_createW, (const wchar_t *name, int namelen)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(LIBITTAPI, __itt_event, event_create, (const char *name, int namelen)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_event_createA ITTNOTIFY_DATA(event_createA) -#define __itt_event_createA_ptr ITTNOTIFY_NAME(event_createA) -#define __itt_event_createW ITTNOTIFY_DATA(event_createW) -#define __itt_event_createW_ptr ITTNOTIFY_NAME(event_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_event_create ITTNOTIFY_DATA(event_create) -#define __itt_event_create_ptr ITTNOTIFY_NAME(event_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_event_createA(name, namelen) (__itt_event)0 -#define __itt_event_createA_ptr 0 -#define __itt_event_createW(name, namelen) (__itt_event)0 -#define __itt_event_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_event_create(name, namelen) (__itt_event)0 -#define __itt_event_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_event_createA_ptr 0 -#define __itt_event_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_event_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record an event occurrence. - * @return __itt_err upon failure (invalid event id/user event feature not enabled) - */ -int LIBITTAPI __itt_event_start(__itt_event event); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(LIBITTAPI, int, event_start, (__itt_event event)) -#define __itt_event_start ITTNOTIFY_DATA(event_start) -#define __itt_event_start_ptr ITTNOTIFY_NAME(event_start) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_event_start(event) (int)0 -#define __itt_event_start_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_event_start_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record an event end occurrence. - * @note It is optional if events do not have durations. - * @return __itt_err upon failure (invalid event id/user event feature not enabled) - */ -int LIBITTAPI __itt_event_end(__itt_event event); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(LIBITTAPI, int, event_end, (__itt_event event)) -#define __itt_event_end ITTNOTIFY_DATA(event_end) -#define __itt_event_end_ptr ITTNOTIFY_NAME(event_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_event_end(event) (int)0 -#define __itt_event_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_event_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} events group */ - - -/** - * @defgroup arrays Arrays Visualizer - * @ingroup public - * Visualize arrays - * @{ - */ - -/** - * @enum __itt_av_data_type - * @brief Defines types of arrays data (for C/C++ intrinsic types) - */ -typedef enum -{ - __itt_e_first = 0, - __itt_e_char = 0, /* 1-byte integer */ - __itt_e_uchar, /* 1-byte unsigned integer */ - __itt_e_int16, /* 2-byte integer */ - __itt_e_uint16, /* 2-byte unsigned integer */ - __itt_e_int32, /* 4-byte integer */ - __itt_e_uint32, /* 4-byte unsigned integer */ - __itt_e_int64, /* 8-byte integer */ - __itt_e_uint64, /* 8-byte unsigned integer */ - __itt_e_float, /* 4-byte floating */ - __itt_e_double, /* 8-byte floating */ - __itt_e_last = __itt_e_double -} __itt_av_data_type; - -/** - * @brief Save an array data to a file. - * Output format is defined by the file extension. The csv and bmp formats are supported (bmp - for 2-dimensional array only). - * @param[in] data - pointer to the array data - * @param[in] rank - the rank of the array - * @param[in] dimensions - pointer to an array of integers, which specifies the array dimensions. - * The size of dimensions must be equal to the rank - * @param[in] type - the type of the array, specified as one of the __itt_av_data_type values (for intrinsic types) - * @param[in] filePath - the file path; the output format is defined by the file extension - * @param[in] columnOrder - defines how the array is stored in the linear memory. - * It should be 1 for column-major order (e.g. in FORTRAN) or 0 - for row-major order (e.g. in C). - */ - -#if ITT_PLATFORM==ITT_PLATFORM_WIN -int ITTAPI __itt_av_saveA(void *data, int rank, const int *dimensions, int type, const char *filePath, int columnOrder); -int ITTAPI __itt_av_saveW(void *data, int rank, const int *dimensions, int type, const wchar_t *filePath, int columnOrder); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_av_save __itt_av_saveW -# define __itt_av_save_ptr __itt_av_saveW_ptr -#else /* UNICODE */ -# define __itt_av_save __itt_av_saveA -# define __itt_av_save_ptr __itt_av_saveA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -int ITTAPI __itt_av_save(void *data, int rank, const int *dimensions, int type, const char *filePath, int columnOrder); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, int, av_saveA, (void *data, int rank, const int *dimensions, int type, const char *filePath, int columnOrder)) -ITT_STUB(ITTAPI, int, av_saveW, (void *data, int rank, const int *dimensions, int type, const wchar_t *filePath, int columnOrder)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, int, av_save, (void *data, int rank, const int *dimensions, int type, const char *filePath, int columnOrder)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_av_saveA ITTNOTIFY_DATA(av_saveA) -#define __itt_av_saveA_ptr ITTNOTIFY_NAME(av_saveA) -#define __itt_av_saveW ITTNOTIFY_DATA(av_saveW) -#define __itt_av_saveW_ptr ITTNOTIFY_NAME(av_saveW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_av_save ITTNOTIFY_DATA(av_save) -#define __itt_av_save_ptr ITTNOTIFY_NAME(av_save) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_av_saveA(name) -#define __itt_av_saveA_ptr 0 -#define __itt_av_saveW(name) -#define __itt_av_saveW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_av_save(name) -#define __itt_av_save_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_av_saveA_ptr 0 -#define __itt_av_saveW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_av_save_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -void ITTAPI __itt_enable_attach(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, enable_attach, (void)) -#define __itt_enable_attach ITTNOTIFY_VOID(enable_attach) -#define __itt_enable_attach_ptr ITTNOTIFY_NAME(enable_attach) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_enable_attach() -#define __itt_enable_attach_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_enable_attach_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @cond exclude_from_gpa_documentation */ - -/** @} arrays group */ - -/** @endcond */ - -/** - * @brief Module load notification - * This API is used to report necessary information in case of bypassing default system loader. - * Notification should be done immidiatelly after this module is loaded to process memory. - * @param[in] start_addr - module start address - * @param[in] end_addr - module end address - * @param[in] path - file system full path to the module - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_module_loadA(void *start_addr, void *end_addr, const char *path); -void ITTAPI __itt_module_loadW(void *start_addr, void *end_addr, const wchar_t *path); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_module_load __itt_module_loadW -# define __itt_module_load_ptr __itt_module_loadW_ptr -#else /* UNICODE */ -# define __itt_module_load __itt_module_loadA -# define __itt_module_load_ptr __itt_module_loadA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -void ITTAPI __itt_module_load(void *start_addr, void *end_addr, const char *path); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, void, module_loadA, (void *start_addr, void *end_addr, const char *path)) -ITT_STUB(ITTAPI, void, module_loadW, (void *start_addr, void *end_addr, const wchar_t *path)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, void, module_load, (void *start_addr, void *end_addr, const char *path)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_module_loadA ITTNOTIFY_VOID(module_loadA) -#define __itt_module_loadA_ptr ITTNOTIFY_NAME(module_loadA) -#define __itt_module_loadW ITTNOTIFY_VOID(module_loadW) -#define __itt_module_loadW_ptr ITTNOTIFY_NAME(module_loadW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_module_load ITTNOTIFY_VOID(module_load) -#define __itt_module_load_ptr ITTNOTIFY_NAME(module_load) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_module_loadA(start_addr, end_addr, path) -#define __itt_module_loadA_ptr 0 -#define __itt_module_loadW(start_addr, end_addr, path) -#define __itt_module_loadW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_module_load(start_addr, end_addr, path) -#define __itt_module_load_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_module_loadA_ptr 0 -#define __itt_module_loadW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_module_load_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Report module unload - * This API is used to report necessary information in case of bypassing default system loader. - * Notification should be done just before the module is unloaded from process memory. - * @param[in] addr - base address of loaded module - */ -void ITTAPI __itt_module_unload(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, module_unload, (void *addr)) -#define __itt_module_unload ITTNOTIFY_VOID(module_unload) -#define __itt_module_unload_ptr ITTNOTIFY_NAME(module_unload) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_module_unload(addr) -#define __itt_module_unload_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_module_unload_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @cond exclude_from_documentation */ -typedef enum -{ - __itt_module_type_unknown = 0, - __itt_module_type_elf, - __itt_module_type_coff -} __itt_module_type; -/** @endcond */ - -/** @cond exclude_from_documentation */ -typedef enum -{ - itt_section_type_unknown, - itt_section_type_bss, /* notifies that the section contains uninitialized data. These are the relevant section types and the modules that contain them: - * ELF module: SHT_NOBITS section type - * COFF module: IMAGE_SCN_CNT_UNINITIALIZED_DATA section type - */ - itt_section_type_data, /* notifies that section contains initialized data. These are the relevant section types and the modules that contain them: - * ELF module: SHT_PROGBITS section type - * COFF module: IMAGE_SCN_CNT_INITIALIZED_DATA section type - */ - itt_section_type_text /* notifies that the section contains executable code. These are the relevant section types and the modules that contain them: - * ELF module: SHT_PROGBITS section type - * COFF module: IMAGE_SCN_CNT_CODE section type - */ -} __itt_section_type; -/** @endcond */ - -/** - * @hideinitializer - * @brief bit-mask, detects a section attribute that indicates whether a section can be executed as code: - * These are the relevant section attributes and the modules that contain them: - * ELF module: PF_X section attribute - * COFF module: IMAGE_SCN_MEM_EXECUTE attribute - */ -#define __itt_section_exec 0x20000000 - -/** - * @hideinitializer - * @brief bit-mask, detects a section attribute that indicates whether a section can be read. - * These are the relevant section attributes and the modules that contain them: - * ELF module: PF_R attribute - * COFF module: IMAGE_SCN_MEM_READ attribute - */ -#define __itt_section_read 0x40000000 - -/** - * @hideinitializer - * @brief bit-mask, detects a section attribute that indicates whether a section can be written to. - * These are the relevant section attributes and the modules that contain them: - * ELF module: PF_W attribute - * COFF module: IMAGE_SCN_MEM_WRITE attribute - */ -#define __itt_section_write 0x80000000 - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_section_info -{ - const char* name; /*!< Section name in UTF8 */ - __itt_section_type type; /*!< Section content and semantics description */ - size_t flags; /*!< Section bit flags that describe attributes using bit mask - * Zero if disabled, non-zero if enabled - */ - void* start_addr; /*!< Section load(relocated) start address */ - size_t size; /*!< Section file offset */ - size_t file_offset; /*!< Section size */ -} __itt_section_info; - -#pragma pack(pop) -/** @endcond */ - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_module_object -{ - unsigned int version; /*!< API version*/ - __itt_id module_id; /*!< Unique identifier. This is unchanged for sections that belong to the same module */ - __itt_module_type module_type; /*!< Binary module format */ - const char* module_name; /*!< Unique module name or path to module in UTF8 - * Contains module name when module_bufer and module_size exist - * Contains module path when module_bufer and module_size absent - * module_name remains the same for the certain module_id - */ - void* module_buffer; /*!< Module buffer content */ - size_t module_size; /*!< Module buffer size */ - /*!< If module_buffer and module_size exist, the binary module is dumped onto the system. - * If module_buffer and module_size do not exist, - * the binary module exists on the system already. - * The module_name parameter contains the path to the module. - */ - __itt_section_info* section_array; /*!< Reference to section information */ - size_t section_number; -} __itt_module_object; - -#pragma pack(pop) -/** @endcond */ - -/** - * @brief Load module content and its loaded(relocated) sections. - * This API is useful to save a module, or specify its location on the system and report information about loaded sections. - * The target module is saved on the system if module buffer content and size are available. - * If module buffer content and size are unavailable, the module name contains the path to the existing binary module. - * @param[in] module_obj - provides module and section information, along with unique module identifiers (name,module ID) - * which bind the binary module to particular sections. - */ -void ITTAPI __itt_module_load_with_sections(__itt_module_object* module_obj); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, module_load_with_sections, (__itt_module_object* module_obj)) -#define __itt_module_load_with_sections ITTNOTIFY_VOID(module_load_with_sections) -#define __itt_module_load_with_sections_ptr ITTNOTIFY_NAME(module_load_with_sections) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_module_load_with_sections(module_obj) -#define __itt_module_load_with_sections_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_module_load_with_sections_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Unload a module and its loaded(relocated) sections. - * This API notifies that the module and its sections were unloaded. - * @param[in] module_obj - provides module and sections information, along with unique module identifiers (name,module ID) - * which bind the binary module to particular sections. - */ -void ITTAPI __itt_module_unload_with_sections(__itt_module_object* module_obj); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, module_unload_with_sections, (__itt_module_object* module_obj)) -#define __itt_module_unload_with_sections ITTNOTIFY_VOID(module_unload_with_sections) -#define __itt_module_unload_with_sections_ptr ITTNOTIFY_NAME(module_unload_with_sections) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_module_unload_with_sections(module_obj) -#define __itt_module_unload_with_sections_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_module_unload_with_sections_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @cond exclude_from_documentation */ -#pragma pack(push, 8) - -typedef struct ___itt_histogram -{ - const __itt_domain* domain; /*!< Domain of the histogram*/ - const char* nameA; /*!< Name of the histogram */ -#if defined(UNICODE) || defined(_UNICODE) - const wchar_t* nameW; -#else /* UNICODE || _UNICODE */ - void* nameW; -#endif /* UNICODE || _UNICODE */ - __itt_metadata_type x_type; /*!< Type of the histogram X axis */ - __itt_metadata_type y_type; /*!< Type of the histogram Y axis */ - int extra1; /*!< Reserved to the runtime */ - void* extra2; /*!< Reserved to the runtime */ - struct ___itt_histogram* next; -} __itt_histogram; - -#pragma pack(pop) -/** @endcond */ - -/** - * @brief Create a typed histogram instance with given name/domain. - * @param[in] domain The domain controlling the call. - * @param[in] name The name of the histogram. - * @param[in] x_type The type of the X axis in histogram (may be 0 to calculate batch statistics). - * @param[in] y_type The type of the Y axis in histogram. -*/ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_histogram* ITTAPI __itt_histogram_createA(const __itt_domain* domain, const char* name, __itt_metadata_type x_type, __itt_metadata_type y_type); -__itt_histogram* ITTAPI __itt_histogram_createW(const __itt_domain* domain, const wchar_t* name, __itt_metadata_type x_type, __itt_metadata_type y_type); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_histogram_create __itt_histogram_createW -# define __itt_histogram_create_ptr __itt_histogram_createW_ptr -#else /* UNICODE */ -# define __itt_histogram_create __itt_histogram_createA -# define __itt_histogram_create_ptr __itt_histogram_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_histogram* ITTAPI __itt_histogram_create(const __itt_domain* domain, const char* name, __itt_metadata_type x_type, __itt_metadata_type y_type); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_histogram*, histogram_createA, (const __itt_domain* domain, const char* name, __itt_metadata_type x_type, __itt_metadata_type y_type)) -ITT_STUB(ITTAPI, __itt_histogram*, histogram_createW, (const __itt_domain* domain, const wchar_t* name, __itt_metadata_type x_type, __itt_metadata_type y_type)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_histogram*, histogram_create, (const __itt_domain* domain, const char* name, __itt_metadata_type x_type, __itt_metadata_type y_type)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_histogram_createA ITTNOTIFY_DATA(histogram_createA) -#define __itt_histogram_createA_ptr ITTNOTIFY_NAME(histogram_createA) -#define __itt_histogram_createW ITTNOTIFY_DATA(histogram_createW) -#define __itt_histogram_createW_ptr ITTNOTIFY_NAME(histogram_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_histogram_create ITTNOTIFY_DATA(histogram_create) -#define __itt_histogram_create_ptr ITTNOTIFY_NAME(histogram_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_histogram_createA(domain, name, x_type, y_type) (__itt_histogram*)0 -#define __itt_histogram_createA_ptr 0 -#define __itt_histogram_createW(domain, name, x_type, y_type) (__itt_histogram*)0 -#define __itt_histogram_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_histogram_create(domain, name, x_type, y_type) (__itt_histogram*)0 -#define __itt_histogram_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_histogram_createA_ptr 0 -#define __itt_histogram_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_histogram_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Submit statistics for a histogram instance. - * @param[in] hist Pointer to the histogram instance to which the histogram statistic is to be dumped. - * @param[in] length The number of elements in dumped axis data array. - * @param[in] x_data The X axis dumped data itself (may be NULL to calculate batch statistics). - * @param[in] y_data The Y axis dumped data itself. -*/ -void ITTAPI __itt_histogram_submit(__itt_histogram* hist, size_t length, void* x_data, void* y_data); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, histogram_submit, (__itt_histogram* hist, size_t length, void* x_data, void* y_data)) -#define __itt_histogram_submit ITTNOTIFY_VOID(histogram_submit) -#define __itt_histogram_submit_ptr ITTNOTIFY_NAME(histogram_submit) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_histogram_submit(hist, length, x_data, y_data) -#define __itt_histogram_submit_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_histogram_submit_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ - -/** -* @brief function allows to obtain the current collection state at the moment -* @return collection state as a enum __itt_collection_state -*/ -__itt_collection_state __itt_get_collection_state(void); - -/** -* @brief function releases resources allocated by ITT API static part -* this API should be called from the library destructor -* @return void -*/ -void __itt_release_resources(void); -/** @endcond */ - -#ifdef __cplusplus -} -#endif /* __cplusplus */ - -#endif /* _ITTNOTIFY_H_ */ - -#ifdef INTEL_ITTNOTIFY_API_PRIVATE - -#ifndef _ITTNOTIFY_PRIVATE_ -#define _ITTNOTIFY_PRIVATE_ - -#ifdef __cplusplus -extern "C" { -#endif /* __cplusplus */ - -/** - * @ingroup clockdomain - * @brief Begin an overlapped task instance. - * @param[in] domain The domain for this task - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] taskid The identifier for this task instance, *cannot* be __itt_null. - * @param[in] parentid The parent of this task, or __itt_null. - * @param[in] name The name of this task. - */ -void ITTAPI __itt_task_begin_overlapped_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid, __itt_id parentid, __itt_string_handle* name); - -/** - * @ingroup clockdomain - * @brief End an overlapped task instance. - * @param[in] domain The domain for this task - * @param[in] clock_domain The clock domain controlling the execution of this call. - * @param[in] timestamp The user defined timestamp. - * @param[in] taskid Explicit ID of finished task - */ -void ITTAPI __itt_task_end_overlapped_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, task_begin_overlapped_ex, (const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid, __itt_id parentid, __itt_string_handle* name)) -ITT_STUBV(ITTAPI, void, task_end_overlapped_ex, (const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid)) -#define __itt_task_begin_overlapped_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(task_begin_overlapped_ex,d,x,y,z,a,b) -#define __itt_task_begin_overlapped_ex_ptr ITTNOTIFY_NAME(task_begin_overlapped_ex) -#define __itt_task_end_overlapped_ex(d,x,y,z) ITTNOTIFY_VOID_D3(task_end_overlapped_ex,d,x,y,z) -#define __itt_task_end_overlapped_ex_ptr ITTNOTIFY_NAME(task_end_overlapped_ex) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_task_begin_overlapped_ex(domain,clock_domain,timestamp,taskid,parentid,name) -#define __itt_task_begin_overlapped_ex_ptr 0 -#define __itt_task_end_overlapped_ex(domain,clock_domain,timestamp,taskid) -#define __itt_task_end_overlapped_ex_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_task_begin_overlapped_ex_ptr 0 -#define __itt_task_end_overlapped_ptr 0 -#define __itt_task_end_overlapped_ex_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @defgroup makrs_internal Marks - * @ingroup internal - * Marks group - * @warning Internal API: - * - It is not shipped to outside of Intel - * - It is delivered to internal Intel teams using e-mail or SVN access only - * @{ - */ -/** @brief user mark type */ -typedef int __itt_mark_type; - -/** - * @brief Creates a user mark type with the specified name using char or Unicode string. - * @param[in] name - name of mark to create - * @return Returns a handle to the mark type - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_mark_type ITTAPI __itt_mark_createA(const char *name); -__itt_mark_type ITTAPI __itt_mark_createW(const wchar_t *name); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_mark_create __itt_mark_createW -# define __itt_mark_create_ptr __itt_mark_createW_ptr -#else /* UNICODE */ -# define __itt_mark_create __itt_mark_createA -# define __itt_mark_create_ptr __itt_mark_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_mark_type ITTAPI __itt_mark_create(const char *name); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_mark_type, mark_createA, (const char *name)) -ITT_STUB(ITTAPI, __itt_mark_type, mark_createW, (const wchar_t *name)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_mark_type, mark_create, (const char *name)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_mark_createA ITTNOTIFY_DATA(mark_createA) -#define __itt_mark_createA_ptr ITTNOTIFY_NAME(mark_createA) -#define __itt_mark_createW ITTNOTIFY_DATA(mark_createW) -#define __itt_mark_createW_ptr ITTNOTIFY_NAME(mark_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark_create ITTNOTIFY_DATA(mark_create) -#define __itt_mark_create_ptr ITTNOTIFY_NAME(mark_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_mark_createA(name) (__itt_mark_type)0 -#define __itt_mark_createA_ptr 0 -#define __itt_mark_createW(name) (__itt_mark_type)0 -#define __itt_mark_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark_create(name) (__itt_mark_type)0 -#define __itt_mark_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_mark_createA_ptr 0 -#define __itt_mark_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Creates a "discrete" user mark type of the specified type and an optional parameter using char or Unicode string. - * - * - The mark of "discrete" type is placed to collection results in case of success. It appears in overtime view(s) as a special tick sign. - * - The call is "synchronous" - function returns after mark is actually added to results. - * - This function is useful, for example, to mark different phases of application - * (beginning of the next mark automatically meand end of current region). - * - Can be used together with "continuous" marks (see below) at the same collection session - * @param[in] mt - mark, created by __itt_mark_create(const char* name) function - * @param[in] parameter - string parameter of mark - * @return Returns zero value in case of success, non-zero value otherwise. - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -int ITTAPI __itt_markA(__itt_mark_type mt, const char *parameter); -int ITTAPI __itt_markW(__itt_mark_type mt, const wchar_t *parameter); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_mark __itt_markW -# define __itt_mark_ptr __itt_markW_ptr -#else /* UNICODE */ -# define __itt_mark __itt_markA -# define __itt_mark_ptr __itt_markA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -int ITTAPI __itt_mark(__itt_mark_type mt, const char *parameter); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, int, markA, (__itt_mark_type mt, const char *parameter)) -ITT_STUB(ITTAPI, int, markW, (__itt_mark_type mt, const wchar_t *parameter)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, int, mark, (__itt_mark_type mt, const char *parameter)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_markA ITTNOTIFY_DATA(markA) -#define __itt_markA_ptr ITTNOTIFY_NAME(markA) -#define __itt_markW ITTNOTIFY_DATA(markW) -#define __itt_markW_ptr ITTNOTIFY_NAME(markW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark ITTNOTIFY_DATA(mark) -#define __itt_mark_ptr ITTNOTIFY_NAME(mark) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_markA(mt, parameter) (int)0 -#define __itt_markA_ptr 0 -#define __itt_markW(mt, parameter) (int)0 -#define __itt_markW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark(mt, parameter) (int)0 -#define __itt_mark_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_markA_ptr 0 -#define __itt_markW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Use this if necessary to create a "discrete" user event type (mark) for process - * rather then for one thread - * @see int __itt_mark(__itt_mark_type mt, const char* parameter); - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -int ITTAPI __itt_mark_globalA(__itt_mark_type mt, const char *parameter); -int ITTAPI __itt_mark_globalW(__itt_mark_type mt, const wchar_t *parameter); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_mark_global __itt_mark_globalW -# define __itt_mark_global_ptr __itt_mark_globalW_ptr -#else /* UNICODE */ -# define __itt_mark_global __itt_mark_globalA -# define __itt_mark_global_ptr __itt_mark_globalA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -int ITTAPI __itt_mark_global(__itt_mark_type mt, const char *parameter); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, int, mark_globalA, (__itt_mark_type mt, const char *parameter)) -ITT_STUB(ITTAPI, int, mark_globalW, (__itt_mark_type mt, const wchar_t *parameter)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, int, mark_global, (__itt_mark_type mt, const char *parameter)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_mark_globalA ITTNOTIFY_DATA(mark_globalA) -#define __itt_mark_globalA_ptr ITTNOTIFY_NAME(mark_globalA) -#define __itt_mark_globalW ITTNOTIFY_DATA(mark_globalW) -#define __itt_mark_globalW_ptr ITTNOTIFY_NAME(mark_globalW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark_global ITTNOTIFY_DATA(mark_global) -#define __itt_mark_global_ptr ITTNOTIFY_NAME(mark_global) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_mark_globalA(mt, parameter) (int)0 -#define __itt_mark_globalA_ptr 0 -#define __itt_mark_globalW(mt, parameter) (int)0 -#define __itt_mark_globalW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark_global(mt, parameter) (int)0 -#define __itt_mark_global_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_mark_globalA_ptr 0 -#define __itt_mark_globalW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_mark_global_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Creates an "end" point for "continuous" mark with specified name. - * - * - Returns zero value in case of success, non-zero value otherwise. - * Also returns non-zero value when preceding "begin" point for the - * mark with the same name failed to be created or not created. - * - The mark of "continuous" type is placed to collection results in - * case of success. It appears in overtime view(s) as a special tick - * sign (different from "discrete" mark) together with line from - * corresponding "begin" mark to "end" mark. - * @note Continuous marks can overlap and be nested inside each other. - * Discrete mark can be nested inside marked region - * @param[in] mt - mark, created by __itt_mark_create(const char* name) function - * @return Returns zero value in case of success, non-zero value otherwise. - */ -int ITTAPI __itt_mark_off(__itt_mark_type mt); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, int, mark_off, (__itt_mark_type mt)) -#define __itt_mark_off ITTNOTIFY_DATA(mark_off) -#define __itt_mark_off_ptr ITTNOTIFY_NAME(mark_off) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_mark_off(mt) (int)0 -#define __itt_mark_off_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_mark_off_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Use this if necessary to create an "end" point for mark of process - * @see int __itt_mark_off(__itt_mark_type mt); - */ -int ITTAPI __itt_mark_global_off(__itt_mark_type mt); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, int, mark_global_off, (__itt_mark_type mt)) -#define __itt_mark_global_off ITTNOTIFY_DATA(mark_global_off) -#define __itt_mark_global_off_ptr ITTNOTIFY_NAME(mark_global_off) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_mark_global_off(mt) (int)0 -#define __itt_mark_global_off_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_mark_global_off_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} marks group */ - -/** - * @defgroup counters_internal Counters - * @ingroup internal - * Counters group - * @{ - */ - - -/** - * @defgroup stitch Stack Stitching - * @ingroup internal - * Stack Stitching group - * @{ - */ -/** - * @brief opaque structure for counter identification - */ -typedef struct ___itt_caller *__itt_caller; - -/** - * @brief Create the stitch point e.g. a point in call stack where other stacks should be stitched to. - * The function returns a unique identifier which is used to match the cut points with corresponding stitch points. - */ -__itt_caller ITTAPI __itt_stack_caller_create(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_caller, stack_caller_create, (void)) -#define __itt_stack_caller_create ITTNOTIFY_DATA(stack_caller_create) -#define __itt_stack_caller_create_ptr ITTNOTIFY_NAME(stack_caller_create) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_stack_caller_create() (__itt_caller)0 -#define __itt_stack_caller_create_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_stack_caller_create_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Destroy the information about stitch point identified by the pointer previously returned by __itt_stack_caller_create() - */ -void ITTAPI __itt_stack_caller_destroy(__itt_caller id); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, stack_caller_destroy, (__itt_caller id)) -#define __itt_stack_caller_destroy ITTNOTIFY_VOID(stack_caller_destroy) -#define __itt_stack_caller_destroy_ptr ITTNOTIFY_NAME(stack_caller_destroy) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_stack_caller_destroy(id) -#define __itt_stack_caller_destroy_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_stack_caller_destroy_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Sets the cut point. Stack from each event which occurs after this call will be cut - * at the same stack level the function was called and stitched to the corresponding stitch point. - */ -void ITTAPI __itt_stack_callee_enter(__itt_caller id); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, stack_callee_enter, (__itt_caller id)) -#define __itt_stack_callee_enter ITTNOTIFY_VOID(stack_callee_enter) -#define __itt_stack_callee_enter_ptr ITTNOTIFY_NAME(stack_callee_enter) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_stack_callee_enter(id) -#define __itt_stack_callee_enter_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_stack_callee_enter_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief This function eliminates the cut point which was set by latest __itt_stack_callee_enter(). - */ -void ITTAPI __itt_stack_callee_leave(__itt_caller id); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, stack_callee_leave, (__itt_caller id)) -#define __itt_stack_callee_leave ITTNOTIFY_VOID(stack_callee_leave) -#define __itt_stack_callee_leave_ptr ITTNOTIFY_NAME(stack_callee_leave) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_stack_callee_leave(id) -#define __itt_stack_callee_leave_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_stack_callee_leave_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @} stitch group */ - -/* ***************************************************************************************************************************** */ - -#include - -/** @cond exclude_from_documentation */ -typedef enum __itt_error_code -{ - __itt_error_success = 0, /*!< no error */ - __itt_error_no_module = 1, /*!< module can't be loaded */ - /* %1$s -- library name; win: %2$d -- system error code; unx: %2$s -- system error message. */ - __itt_error_no_symbol = 2, /*!< symbol not found */ - /* %1$s -- library name, %2$s -- symbol name. */ - __itt_error_unknown_group = 3, /*!< unknown group specified */ - /* %1$s -- env var name, %2$s -- group name. */ - __itt_error_cant_read_env = 4, /*!< GetEnvironmentVariable() failed */ - /* %1$s -- env var name, %2$d -- system error. */ - __itt_error_env_too_long = 5, /*!< variable value too long */ - /* %1$s -- env var name, %2$d -- actual length of the var, %3$d -- max allowed length. */ - __itt_error_system = 6 /*!< pthread_mutexattr_init or pthread_mutex_init failed */ - /* %1$s -- function name, %2$d -- errno. */ -} __itt_error_code; - -typedef void (__itt_error_handler_t)(__itt_error_code code, va_list); -__itt_error_handler_t* __itt_set_error_handler(__itt_error_handler_t*); - -const char* ITTAPI __itt_api_version(void); -/** @endcond */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#define __itt_error_handler ITT_JOIN(INTEL_ITTNOTIFY_PREFIX, error_handler) -void __itt_error_handler(__itt_error_code code, va_list args); -extern const int ITTNOTIFY_NAME(err); -#define __itt_err ITTNOTIFY_NAME(err) -ITT_STUB(ITTAPI, const char*, api_version, (void)) -#define __itt_api_version ITTNOTIFY_DATA(api_version) -#define __itt_api_version_ptr ITTNOTIFY_NAME(api_version) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_api_version() (const char*)0 -#define __itt_api_version_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_api_version_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -#ifdef __cplusplus -} -#endif /* __cplusplus */ - -#endif /* _ITTNOTIFY_PRIVATE_ */ - -#endif /* INTEL_ITTNOTIFY_API_PRIVATE */ diff --git a/src/common/ittnotify/legacy/ittnotify.h b/src/common/ittnotify/legacy/ittnotify.h deleted file mode 100644 index 0215db72963..00000000000 --- a/src/common/ittnotify/legacy/ittnotify.h +++ /dev/null @@ -1,992 +0,0 @@ -/* - Copyright (C) 2005-2019 Intel Corporation - - SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause -*/ -#ifndef _LEGACY_ITTNOTIFY_H_ -#define _LEGACY_ITTNOTIFY_H_ - -/** - * @file - * @brief Legacy User API functions and types - */ - -/** @cond exclude_from_documentation */ -#ifndef ITT_OS_WIN -# define ITT_OS_WIN 1 -#endif /* ITT_OS_WIN */ - -#ifndef ITT_OS_LINUX -# define ITT_OS_LINUX 2 -#endif /* ITT_OS_LINUX */ - -#ifndef ITT_OS_MAC -# define ITT_OS_MAC 3 -#endif /* ITT_OS_MAC */ - -#ifndef ITT_OS_FREEBSD -# define ITT_OS_FREEBSD 4 -#endif /* ITT_OS_FREEBSD */ - -#ifndef ITT_OS -# if defined WIN32 || defined _WIN32 -# define ITT_OS ITT_OS_WIN -# elif defined( __APPLE__ ) && defined( __MACH__ ) -# define ITT_OS ITT_OS_MAC -# elif defined( __FreeBSD__ ) -# define ITT_OS ITT_OS_FREEBSD -# else -# define ITT_OS ITT_OS_LINUX -# endif -#endif /* ITT_OS */ - -#ifndef ITT_PLATFORM_WIN -# define ITT_PLATFORM_WIN 1 -#endif /* ITT_PLATFORM_WIN */ - -#ifndef ITT_PLATFORM_POSIX -# define ITT_PLATFORM_POSIX 2 -#endif /* ITT_PLATFORM_POSIX */ - -#ifndef ITT_PLATFORM_MAC -# define ITT_PLATFORM_MAC 3 -#endif /* ITT_PLATFORM_MAC */ - -#ifndef ITT_PLATFORM_FREEBSD -# define ITT_PLATFORM_FREEBSD 4 -#endif /* ITT_PLATFORM_FREEBSD */ - -#ifndef ITT_PLATFORM -# if ITT_OS==ITT_OS_WIN -# define ITT_PLATFORM ITT_PLATFORM_WIN -# elif ITT_OS==ITT_OS_MAC -# define ITT_PLATFORM ITT_PLATFORM_MAC -# elif ITT_OS==ITT_OS_FREEBSD -# define ITT_PLATFORM ITT_PLATFORM_FREEBSD -# else -# define ITT_PLATFORM ITT_PLATFORM_POSIX -# endif -#endif /* ITT_PLATFORM */ - -#if defined(_UNICODE) && !defined(UNICODE) -#define UNICODE -#endif - -#include -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#include -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#include -#if defined(UNICODE) || defined(_UNICODE) -#include -#endif /* UNICODE || _UNICODE */ -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -#ifndef ITTAPI_CDECL -# if ITT_PLATFORM==ITT_PLATFORM_WIN -# define ITTAPI_CDECL __cdecl -# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -# if defined _M_IX86 || defined __i386__ -# define ITTAPI_CDECL __attribute__ ((cdecl)) -# else /* _M_IX86 || __i386__ */ -# define ITTAPI_CDECL /* actual only on x86 platform */ -# endif /* _M_IX86 || __i386__ */ -# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* ITTAPI_CDECL */ - -#ifndef STDCALL -# if ITT_PLATFORM==ITT_PLATFORM_WIN -# define STDCALL __stdcall -# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -# if defined _M_IX86 || defined __i386__ -# define STDCALL __attribute__ ((stdcall)) -# else /* _M_IX86 || __i386__ */ -# define STDCALL /* supported only on x86 platform */ -# endif /* _M_IX86 || __i386__ */ -# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* STDCALL */ - -#define ITTAPI ITTAPI_CDECL -#define LIBITTAPI ITTAPI_CDECL - -/* TODO: Temporary for compatibility! */ -#define ITTAPI_CALL ITTAPI_CDECL -#define LIBITTAPI_CALL ITTAPI_CDECL - -#if ITT_PLATFORM==ITT_PLATFORM_WIN -/* use __forceinline (VC++ specific) */ -#if defined(__MINGW32__) && !defined(__cplusplus) -#define ITT_INLINE static __inline__ __attribute__((__always_inline__,__gnu_inline__)) -#else -#define ITT_INLINE static __forceinline -#endif /* __MINGW32__ */ - -#define ITT_INLINE_ATTRIBUTE /* nothing */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -/* - * Generally, functions are not inlined unless optimization is specified. - * For functions declared inline, this attribute inlines the function even - * if no optimization level was specified. - */ -#ifdef __STRICT_ANSI__ -#define ITT_INLINE static -#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) -#else /* __STRICT_ANSI__ */ -#define ITT_INLINE static inline -#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) -#endif /* __STRICT_ANSI__ */ -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -/** @endcond */ - -/** @cond exclude_from_documentation */ -/* Helper macro for joining tokens */ -#define ITT_JOIN_AUX(p,n) p##n -#define ITT_JOIN(p,n) ITT_JOIN_AUX(p,n) - -#ifdef ITT_MAJOR -#undef ITT_MAJOR -#endif -#ifdef ITT_MINOR -#undef ITT_MINOR -#endif -#define ITT_MAJOR 3 -#define ITT_MINOR 0 - -/* Standard versioning of a token with major and minor version numbers */ -#define ITT_VERSIONIZE(x) \ - ITT_JOIN(x, \ - ITT_JOIN(_, \ - ITT_JOIN(ITT_MAJOR, \ - ITT_JOIN(_, ITT_MINOR)))) - -#ifndef INTEL_ITTNOTIFY_PREFIX -# define INTEL_ITTNOTIFY_PREFIX __itt_ -#endif /* INTEL_ITTNOTIFY_PREFIX */ -#ifndef INTEL_ITTNOTIFY_POSTFIX -# define INTEL_ITTNOTIFY_POSTFIX _ptr_ -#endif /* INTEL_ITTNOTIFY_POSTFIX */ - -#define ITTNOTIFY_NAME_AUX(n) ITT_JOIN(INTEL_ITTNOTIFY_PREFIX,n) -#define ITTNOTIFY_NAME(n) ITT_VERSIONIZE(ITTNOTIFY_NAME_AUX(ITT_JOIN(n,INTEL_ITTNOTIFY_POSTFIX))) - -#define ITTNOTIFY_VOID(n) (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n) -#define ITTNOTIFY_DATA(n) (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n) - -#define ITTNOTIFY_VOID_D0(n,d) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d) -#define ITTNOTIFY_VOID_D1(n,d,x) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x) -#define ITTNOTIFY_VOID_D2(n,d,x,y) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y) -#define ITTNOTIFY_VOID_D3(n,d,x,y,z) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z) -#define ITTNOTIFY_VOID_D4(n,d,x,y,z,a) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a) -#define ITTNOTIFY_VOID_D5(n,d,x,y,z,a,b) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b) -#define ITTNOTIFY_VOID_D6(n,d,x,y,z,a,b,c) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b,c) -#define ITTNOTIFY_DATA_D0(n,d) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d) -#define ITTNOTIFY_DATA_D1(n,d,x) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x) -#define ITTNOTIFY_DATA_D2(n,d,x,y) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y) -#define ITTNOTIFY_DATA_D3(n,d,x,y,z) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z) -#define ITTNOTIFY_DATA_D4(n,d,x,y,z,a) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a) -#define ITTNOTIFY_DATA_D5(n,d,x,y,z,a,b) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b) -#define ITTNOTIFY_DATA_D6(n,d,x,y,z,a,b,c) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b,c) - -#ifdef ITT_STUB -#undef ITT_STUB -#endif -#ifdef ITT_STUBV -#undef ITT_STUBV -#endif -#define ITT_STUBV(api,type,name,args) \ - typedef type (api* ITT_JOIN(ITTNOTIFY_NAME(name),_t)) args; \ - extern ITT_JOIN(ITTNOTIFY_NAME(name),_t) ITTNOTIFY_NAME(name); -#define ITT_STUB ITT_STUBV -/** @endcond */ - -#ifdef __cplusplus -extern "C" { -#endif /* __cplusplus */ - -/** - * @defgroup legacy Legacy API - * @{ - * @} - */ - -/** - * @defgroup legacy_control Collection Control - * @ingroup legacy - * General behavior: application continues to run, but no profiling information is being collected - * - * Pausing occurs not only for the current thread but for all process as well as spawned processes - * - Intel(R) Parallel Inspector and Intel(R) Inspector XE: - * - Does not analyze or report errors that involve memory access. - * - Other errors are reported as usual. Pausing data collection in - * Intel(R) Parallel Inspector and Intel(R) Inspector XE - * only pauses tracing and analyzing memory access. - * It does not pause tracing or analyzing threading APIs. - * . - * - Intel(R) Parallel Amplifier and Intel(R) VTune(TM) Amplifier XE: - * - Does continue to record when new threads are started. - * . - * - Other effects: - * - Possible reduction of runtime overhead. - * . - * @{ - */ -#ifndef _ITTNOTIFY_H_ -/** @brief Pause collection */ -void ITTAPI __itt_pause(void); -/** @brief Resume collection */ -void ITTAPI __itt_resume(void); -/** @brief Detach collection */ -void ITTAPI __itt_detach(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, pause, (void)) -ITT_STUBV(ITTAPI, void, resume, (void)) -ITT_STUBV(ITTAPI, void, detach, (void)) -#define __itt_pause ITTNOTIFY_VOID(pause) -#define __itt_pause_ptr ITTNOTIFY_NAME(pause) -#define __itt_resume ITTNOTIFY_VOID(resume) -#define __itt_resume_ptr ITTNOTIFY_NAME(resume) -#define __itt_detach ITTNOTIFY_VOID(detach) -#define __itt_detach_ptr ITTNOTIFY_NAME(detach) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_pause() -#define __itt_pause_ptr 0 -#define __itt_resume() -#define __itt_resume_ptr 0 -#define __itt_detach() -#define __itt_detach_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_pause_ptr 0 -#define __itt_resume_ptr 0 -#define __itt_detach_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -#endif /* _ITTNOTIFY_H_ */ -/** @} legacy_control group */ - -/** - * @defgroup legacy_threads Threads - * @ingroup legacy - * Threads group - * @warning Legacy API - * @{ - */ -/** - * @deprecated Legacy API - * @brief Set name to be associated with thread in analysis GUI. - * @return __itt_err upon failure (name or namelen being null,name and namelen mismatched) - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -int LIBITTAPI __itt_thr_name_setA(const char *name, int namelen); -int LIBITTAPI __itt_thr_name_setW(const wchar_t *name, int namelen); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_thr_name_set __itt_thr_name_setW -# define __itt_thr_name_set_ptr __itt_thr_name_setW_ptr -#else -# define __itt_thr_name_set __itt_thr_name_setA -# define __itt_thr_name_set_ptr __itt_thr_name_setA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -int LIBITTAPI __itt_thr_name_set(const char *name, int namelen); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(LIBITTAPI, int, thr_name_setA, (const char *name, int namelen)) -ITT_STUB(LIBITTAPI, int, thr_name_setW, (const wchar_t *name, int namelen)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(LIBITTAPI, int, thr_name_set, (const char *name, int namelen)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_thr_name_setA ITTNOTIFY_DATA(thr_name_setA) -#define __itt_thr_name_setA_ptr ITTNOTIFY_NAME(thr_name_setA) -#define __itt_thr_name_setW ITTNOTIFY_DATA(thr_name_setW) -#define __itt_thr_name_setW_ptr ITTNOTIFY_NAME(thr_name_setW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_thr_name_set ITTNOTIFY_DATA(thr_name_set) -#define __itt_thr_name_set_ptr ITTNOTIFY_NAME(thr_name_set) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_thr_name_setA(name, namelen) -#define __itt_thr_name_setA_ptr 0 -#define __itt_thr_name_setW(name, namelen) -#define __itt_thr_name_setW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_thr_name_set(name, namelen) -#define __itt_thr_name_set_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_thr_name_setA_ptr 0 -#define __itt_thr_name_setW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_thr_name_set_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief Mark current thread as ignored from this point on, for the duration of its existence. - */ -void LIBITTAPI __itt_thr_ignore(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(LIBITTAPI, void, thr_ignore, (void)) -#define __itt_thr_ignore ITTNOTIFY_VOID(thr_ignore) -#define __itt_thr_ignore_ptr ITTNOTIFY_NAME(thr_ignore) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_thr_ignore() -#define __itt_thr_ignore_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_thr_ignore_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} legacy_threads group */ - -/** - * @defgroup legacy_sync Synchronization - * @ingroup legacy - * Synchronization group - * @warning Legacy API - * @{ - */ -/** - * @hideinitializer - * @brief possible value of attribute argument for sync object type - */ -#define __itt_attr_barrier 1 - -/** - * @hideinitializer - * @brief possible value of attribute argument for sync object type - */ -#define __itt_attr_mutex 2 - -/** - * @deprecated Legacy API - * @brief Assign a name to a sync object using char or Unicode string - * @param[in] addr - pointer to the sync object. You should use a real pointer to your object - * to make sure that the values don't clash with other object addresses - * @param[in] objtype - null-terminated object type string. If NULL is passed, the object will - * be assumed to be of generic "User Synchronization" type - * @param[in] objname - null-terminated object name string. If NULL, no name will be assigned - * to the object -- you can use the __itt_sync_rename call later to assign - * the name - * @param[in] attribute - one of [#__itt_attr_barrier, #__itt_attr_mutex] values which defines the - * exact semantics of how prepare/acquired/releasing calls work. - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -void ITTAPI __itt_sync_set_nameA(void *addr, const char *objtype, const char *objname, int attribute); -void ITTAPI __itt_sync_set_nameW(void *addr, const wchar_t *objtype, const wchar_t *objname, int attribute); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_sync_set_name __itt_sync_set_nameW -# define __itt_sync_set_name_ptr __itt_sync_set_nameW_ptr -#else /* UNICODE */ -# define __itt_sync_set_name __itt_sync_set_nameA -# define __itt_sync_set_name_ptr __itt_sync_set_nameA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -void ITTAPI __itt_sync_set_name(void *addr, const char* objtype, const char* objname, int attribute); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUBV(ITTAPI, void, sync_set_nameA, (void *addr, const char *objtype, const char *objname, int attribute)) -ITT_STUBV(ITTAPI, void, sync_set_nameW, (void *addr, const wchar_t *objtype, const wchar_t *objname, int attribute)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUBV(ITTAPI, void, sync_set_name, (void *addr, const char *objtype, const char *objname, int attribute)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_set_nameA ITTNOTIFY_VOID(sync_set_nameA) -#define __itt_sync_set_nameA_ptr ITTNOTIFY_NAME(sync_set_nameA) -#define __itt_sync_set_nameW ITTNOTIFY_VOID(sync_set_nameW) -#define __itt_sync_set_nameW_ptr ITTNOTIFY_NAME(sync_set_nameW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_set_name ITTNOTIFY_VOID(sync_set_name) -#define __itt_sync_set_name_ptr ITTNOTIFY_NAME(sync_set_name) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_set_nameA(addr, objtype, objname, attribute) -#define __itt_sync_set_nameA_ptr 0 -#define __itt_sync_set_nameW(addr, objtype, objname, attribute) -#define __itt_sync_set_nameW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_set_name(addr, objtype, objname, attribute) -#define __itt_sync_set_name_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_sync_set_nameA_ptr 0 -#define __itt_sync_set_nameW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_sync_set_name_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief Assign a name and type to a sync object using char or Unicode string - * @param[in] addr - pointer to the sync object. You should use a real pointer to your object - * to make sure that the values don't clash with other object addresses - * @param[in] objtype - null-terminated object type string. If NULL is passed, the object will - * be assumed to be of generic "User Synchronization" type - * @param[in] objname - null-terminated object name string. If NULL, no name will be assigned - * to the object -- you can use the __itt_sync_rename call later to assign - * the name - * @param[in] typelen, namelen - a length of string for appropriate objtype and objname parameter - * @param[in] attribute - one of [#__itt_attr_barrier, #__itt_attr_mutex] values which defines the - * exact semantics of how prepare/acquired/releasing calls work. - * @return __itt_err upon failure (name or namelen being null,name and namelen mismatched) - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -int LIBITTAPI __itt_notify_sync_nameA(void *addr, const char *objtype, int typelen, const char *objname, int namelen, int attribute); -int LIBITTAPI __itt_notify_sync_nameW(void *addr, const wchar_t *objtype, int typelen, const wchar_t *objname, int namelen, int attribute); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_notify_sync_name __itt_notify_sync_nameW -#else -# define __itt_notify_sync_name __itt_notify_sync_nameA -#endif -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -int LIBITTAPI __itt_notify_sync_name(void *addr, const char *objtype, int typelen, const char *objname, int namelen, int attribute); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(LIBITTAPI, int, notify_sync_nameA, (void *addr, const char *objtype, int typelen, const char *objname, int namelen, int attribute)) -ITT_STUB(LIBITTAPI, int, notify_sync_nameW, (void *addr, const wchar_t *objtype, int typelen, const wchar_t *objname, int namelen, int attribute)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(LIBITTAPI, int, notify_sync_name, (void *addr, const char *objtype, int typelen, const char *objname, int namelen, int attribute)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_notify_sync_nameA ITTNOTIFY_DATA(notify_sync_nameA) -#define __itt_notify_sync_nameA_ptr ITTNOTIFY_NAME(notify_sync_nameA) -#define __itt_notify_sync_nameW ITTNOTIFY_DATA(notify_sync_nameW) -#define __itt_notify_sync_nameW_ptr ITTNOTIFY_NAME(notify_sync_nameW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_notify_sync_name ITTNOTIFY_DATA(notify_sync_name) -#define __itt_notify_sync_name_ptr ITTNOTIFY_NAME(notify_sync_name) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_notify_sync_nameA(addr, objtype, typelen, objname, namelen, attribute) -#define __itt_notify_sync_nameA_ptr 0 -#define __itt_notify_sync_nameW(addr, objtype, typelen, objname, namelen, attribute) -#define __itt_notify_sync_nameW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_notify_sync_name(addr, objtype, typelen, objname, namelen, attribute) -#define __itt_notify_sync_name_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_notify_sync_nameA_ptr 0 -#define __itt_notify_sync_nameW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_notify_sync_name_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief Enter spin loop on user-defined sync object - */ -void LIBITTAPI __itt_notify_sync_prepare(void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(LIBITTAPI, void, notify_sync_prepare, (void *addr)) -#define __itt_notify_sync_prepare ITTNOTIFY_VOID(notify_sync_prepare) -#define __itt_notify_sync_prepare_ptr ITTNOTIFY_NAME(notify_sync_prepare) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_notify_sync_prepare(addr) -#define __itt_notify_sync_prepare_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_notify_sync_prepare_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief Quit spin loop without acquiring spin object - */ -void LIBITTAPI __itt_notify_sync_cancel(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(LIBITTAPI, void, notify_sync_cancel, (void *addr)) -#define __itt_notify_sync_cancel ITTNOTIFY_VOID(notify_sync_cancel) -#define __itt_notify_sync_cancel_ptr ITTNOTIFY_NAME(notify_sync_cancel) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_notify_sync_cancel(addr) -#define __itt_notify_sync_cancel_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_notify_sync_cancel_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief Successful spin loop completion (sync object acquired) - */ -void LIBITTAPI __itt_notify_sync_acquired(void *addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(LIBITTAPI, void, notify_sync_acquired, (void *addr)) -#define __itt_notify_sync_acquired ITTNOTIFY_VOID(notify_sync_acquired) -#define __itt_notify_sync_acquired_ptr ITTNOTIFY_NAME(notify_sync_acquired) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_notify_sync_acquired(addr) -#define __itt_notify_sync_acquired_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_notify_sync_acquired_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief Start sync object releasing code. Is called before the lock release call. - */ -void LIBITTAPI __itt_notify_sync_releasing(void* addr); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(LIBITTAPI, void, notify_sync_releasing, (void *addr)) -#define __itt_notify_sync_releasing ITTNOTIFY_VOID(notify_sync_releasing) -#define __itt_notify_sync_releasing_ptr ITTNOTIFY_NAME(notify_sync_releasing) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_notify_sync_releasing(addr) -#define __itt_notify_sync_releasing_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_notify_sync_releasing_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} legacy_sync group */ - -#ifndef _ITTNOTIFY_H_ -/** - * @defgroup legacy_events Events - * @ingroup legacy - * Events group - * @{ - */ - -/** @brief user event type */ -typedef int __itt_event; - -/** - * @brief Create an event notification - * @note name or namelen being null/name and namelen not matching, user event feature not enabled - * @return non-zero event identifier upon success and __itt_err otherwise - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_event LIBITTAPI __itt_event_createA(const char *name, int namelen); -__itt_event LIBITTAPI __itt_event_createW(const wchar_t *name, int namelen); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_event_create __itt_event_createW -# define __itt_event_create_ptr __itt_event_createW_ptr -#else -# define __itt_event_create __itt_event_createA -# define __itt_event_create_ptr __itt_event_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_event LIBITTAPI __itt_event_create(const char *name, int namelen); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(LIBITTAPI, __itt_event, event_createA, (const char *name, int namelen)) -ITT_STUB(LIBITTAPI, __itt_event, event_createW, (const wchar_t *name, int namelen)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(LIBITTAPI, __itt_event, event_create, (const char *name, int namelen)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_event_createA ITTNOTIFY_DATA(event_createA) -#define __itt_event_createA_ptr ITTNOTIFY_NAME(event_createA) -#define __itt_event_createW ITTNOTIFY_DATA(event_createW) -#define __itt_event_createW_ptr ITTNOTIFY_NAME(event_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_event_create ITTNOTIFY_DATA(event_create) -#define __itt_event_create_ptr ITTNOTIFY_NAME(event_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_event_createA(name, namelen) (__itt_event)0 -#define __itt_event_createA_ptr 0 -#define __itt_event_createW(name, namelen) (__itt_event)0 -#define __itt_event_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_event_create(name, namelen) (__itt_event)0 -#define __itt_event_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_event_createA_ptr 0 -#define __itt_event_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_event_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record an event occurrence. - * @return __itt_err upon failure (invalid event id/user event feature not enabled) - */ -int LIBITTAPI __itt_event_start(__itt_event event); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(LIBITTAPI, int, event_start, (__itt_event event)) -#define __itt_event_start ITTNOTIFY_DATA(event_start) -#define __itt_event_start_ptr ITTNOTIFY_NAME(event_start) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_event_start(event) (int)0 -#define __itt_event_start_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_event_start_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @brief Record an event end occurrence. - * @note It is optional if events do not have durations. - * @return __itt_err upon failure (invalid event id/user event feature not enabled) - */ -int LIBITTAPI __itt_event_end(__itt_event event); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(LIBITTAPI, int, event_end, (__itt_event event)) -#define __itt_event_end ITTNOTIFY_DATA(event_end) -#define __itt_event_end_ptr ITTNOTIFY_NAME(event_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_event_end(event) (int)0 -#define __itt_event_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_event_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} legacy_events group */ -#endif /* _ITTNOTIFY_H_ */ - -/** - * @defgroup legacy_memory Memory Accesses - * @ingroup legacy - */ - -/** - * @deprecated Legacy API - * @brief Inform the tool of memory accesses on reading - */ -void LIBITTAPI __itt_memory_read(void *addr, size_t size); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(LIBITTAPI, void, memory_read, (void *addr, size_t size)) -#define __itt_memory_read ITTNOTIFY_VOID(memory_read) -#define __itt_memory_read_ptr ITTNOTIFY_NAME(memory_read) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_memory_read(addr, size) -#define __itt_memory_read_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_memory_read_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief Inform the tool of memory accesses on writing - */ -void LIBITTAPI __itt_memory_write(void *addr, size_t size); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(LIBITTAPI, void, memory_write, (void *addr, size_t size)) -#define __itt_memory_write ITTNOTIFY_VOID(memory_write) -#define __itt_memory_write_ptr ITTNOTIFY_NAME(memory_write) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_memory_write(addr, size) -#define __itt_memory_write_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_memory_write_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief Inform the tool of memory accesses on updating - */ -void LIBITTAPI __itt_memory_update(void *address, size_t size); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(LIBITTAPI, void, memory_update, (void *addr, size_t size)) -#define __itt_memory_update ITTNOTIFY_VOID(memory_update) -#define __itt_memory_update_ptr ITTNOTIFY_NAME(memory_update) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_memory_update(addr, size) -#define __itt_memory_update_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_memory_update_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} legacy_memory group */ - -/** - * @defgroup legacy_state Thread and Object States - * @ingroup legacy - */ - -/** @brief state type */ -typedef int __itt_state_t; - -/** @cond exclude_from_documentation */ -typedef enum __itt_obj_state { - __itt_obj_state_err = 0, - __itt_obj_state_clr = 1, - __itt_obj_state_set = 2, - __itt_obj_state_use = 3 -} __itt_obj_state_t; - -typedef enum __itt_thr_state { - __itt_thr_state_err = 0, - __itt_thr_state_clr = 1, - __itt_thr_state_set = 2 -} __itt_thr_state_t; - -typedef enum __itt_obj_prop { - __itt_obj_prop_watch = 1, - __itt_obj_prop_ignore = 2, - __itt_obj_prop_sharable = 3 -} __itt_obj_prop_t; - -typedef enum __itt_thr_prop { - __itt_thr_prop_quiet = 1 -} __itt_thr_prop_t; -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief managing thread and object states - */ -__itt_state_t LIBITTAPI __itt_state_get(void); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_state_t, state_get, (void)) -#define __itt_state_get ITTNOTIFY_DATA(state_get) -#define __itt_state_get_ptr ITTNOTIFY_NAME(state_get) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_state_get(void) (__itt_state_t)0 -#define __itt_state_get_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_state_get_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief managing thread and object states - */ -__itt_state_t LIBITTAPI __itt_state_set(__itt_state_t s); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_state_t, state_set, (__itt_state_t s)) -#define __itt_state_set ITTNOTIFY_DATA(state_set) -#define __itt_state_set_ptr ITTNOTIFY_NAME(state_set) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_state_set(s) (__itt_state_t)0 -#define __itt_state_set_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_state_set_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief managing thread and object modes - */ -__itt_thr_state_t LIBITTAPI __itt_thr_mode_set(__itt_thr_prop_t p, __itt_thr_state_t s); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_thr_state_t, thr_mode_set, (__itt_thr_prop_t p, __itt_thr_state_t s)) -#define __itt_thr_mode_set ITTNOTIFY_DATA(thr_mode_set) -#define __itt_thr_mode_set_ptr ITTNOTIFY_NAME(thr_mode_set) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_thr_mode_set(p, s) (__itt_thr_state_t)0 -#define __itt_thr_mode_set_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_thr_mode_set_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** - * @deprecated Legacy API - * @brief managing thread and object modes - */ -__itt_obj_state_t LIBITTAPI __itt_obj_mode_set(__itt_obj_prop_t p, __itt_obj_state_t s); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUB(ITTAPI, __itt_obj_state_t, obj_mode_set, (__itt_obj_prop_t p, __itt_obj_state_t s)) -#define __itt_obj_mode_set ITTNOTIFY_DATA(obj_mode_set) -#define __itt_obj_mode_set_ptr ITTNOTIFY_NAME(obj_mode_set) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_obj_mode_set(p, s) (__itt_obj_state_t)0 -#define __itt_obj_mode_set_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_obj_mode_set_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} legacy_state group */ - -/** - * @defgroup frames Frames - * @ingroup legacy - * Frames group - * @{ - */ -/** - * @brief opaque structure for frame identification - */ -typedef struct __itt_frame_t *__itt_frame; - -/** - * @brief Create a global frame with given domain - */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -__itt_frame ITTAPI __itt_frame_createA(const char *domain); -__itt_frame ITTAPI __itt_frame_createW(const wchar_t *domain); -#if defined(UNICODE) || defined(_UNICODE) -# define __itt_frame_create __itt_frame_createW -# define __itt_frame_create_ptr __itt_frame_createW_ptr -#else /* UNICODE */ -# define __itt_frame_create __itt_frame_createA -# define __itt_frame_create_ptr __itt_frame_createA_ptr -#endif /* UNICODE */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -__itt_frame ITTAPI __itt_frame_create(const char *domain); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -#if ITT_PLATFORM==ITT_PLATFORM_WIN -ITT_STUB(ITTAPI, __itt_frame, frame_createA, (const char *domain)) -ITT_STUB(ITTAPI, __itt_frame, frame_createW, (const wchar_t *domain)) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -ITT_STUB(ITTAPI, __itt_frame, frame_create, (const char *domain)) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_frame_createA ITTNOTIFY_DATA(frame_createA) -#define __itt_frame_createA_ptr ITTNOTIFY_NAME(frame_createA) -#define __itt_frame_createW ITTNOTIFY_DATA(frame_createW) -#define __itt_frame_createW_ptr ITTNOTIFY_NAME(frame_createW) -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_frame_create ITTNOTIFY_DATA(frame_create) -#define __itt_frame_create_ptr ITTNOTIFY_NAME(frame_create) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#else /* INTEL_NO_ITTNOTIFY_API */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_frame_createA(domain) -#define __itt_frame_createA_ptr 0 -#define __itt_frame_createW(domain) -#define __itt_frame_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_frame_create(domain) -#define __itt_frame_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_frame_createA_ptr 0 -#define __itt_frame_createW_ptr 0 -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#define __itt_frame_create_ptr 0 -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ - -/** @brief Record a frame begin occurrence. */ -void ITTAPI __itt_frame_begin(__itt_frame frame); -/** @brief Record a frame end occurrence. */ -void ITTAPI __itt_frame_end (__itt_frame frame); - -/** @cond exclude_from_documentation */ -#ifndef INTEL_NO_MACRO_BODY -#ifndef INTEL_NO_ITTNOTIFY_API -ITT_STUBV(ITTAPI, void, frame_begin, (__itt_frame frame)) -ITT_STUBV(ITTAPI, void, frame_end, (__itt_frame frame)) -#define __itt_frame_begin ITTNOTIFY_VOID(frame_begin) -#define __itt_frame_begin_ptr ITTNOTIFY_NAME(frame_begin) -#define __itt_frame_end ITTNOTIFY_VOID(frame_end) -#define __itt_frame_end_ptr ITTNOTIFY_NAME(frame_end) -#else /* INTEL_NO_ITTNOTIFY_API */ -#define __itt_frame_begin(frame) -#define __itt_frame_begin_ptr 0 -#define __itt_frame_end(frame) -#define __itt_frame_end_ptr 0 -#endif /* INTEL_NO_ITTNOTIFY_API */ -#else /* INTEL_NO_MACRO_BODY */ -#define __itt_frame_begin_ptr 0 -#define __itt_frame_end_ptr 0 -#endif /* INTEL_NO_MACRO_BODY */ -/** @endcond */ -/** @} frames group */ - -#ifdef __cplusplus -} -#endif /* __cplusplus */ - -#endif /* _LEGACY_ITTNOTIFY_H_ */ diff --git a/src/common/layer_normalization.cpp b/src/common/layer_normalization.cpp index 79ccc98c45c..62804c601dc 100644 --- a/src/common/layer_normalization.cpp +++ b/src/common/layer_normalization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -157,19 +157,25 @@ status_t layer_normalization_attr_check(const layer_normalization_desc_t &desc, const bool is_int8 = utils::one_of(src_dt, data_type::s8, data_type::u8) || utils::one_of(dst_dt, data_type::s8, data_type::u8); - if (is_int8) fwd_attr_mask |= smask_t::scales_runtime; + if (is_int8) fwd_attr_mask |= smask_t::scales; VCHECK_LNORM_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); // Check scales if (!attr->scales_.has_default_values()) { - const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - VCHECK_LNORM_UNIMPL(utils::everyone_is(0, mask_src, mask_dst), + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_DST}; + VCHECK_LNORM_UNIMPL( + attr->scales_.has_default_values(supported_args), VERBOSE_UNSUPPORTED_SCALES_CFG); + + for (int arg : supported_args) { + if (attr->scales_.has_default_values(arg)) continue; + + const int mask = attr->scales_.get_mask(arg); + VCHECK_LNORM_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } } // Check post-ops @@ -178,6 +184,9 @@ status_t layer_normalization_attr_check(const layer_normalization_desc_t &desc, using namespace primitive_kind; VCHECK_LNORM_UNIMPL(po.has_default_values({binary, eltwise, sum}), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } } else { VCHECK_LNORM_UNIMPL(false, VERBOSE_UNSUPPORTED_ATTR); diff --git a/src/common/layer_normalization_pd.hpp b/src/common/layer_normalization_pd.hpp index 242ea372eee..bdd27c47978 100644 --- a/src/common/layer_normalization_pd.hpp +++ b/src/common/layer_normalization_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -107,11 +107,11 @@ struct layer_normalization_pd_t : public primitive_desc_t { memory_desc_t stat_md_; memory_desc_t scaleshift_md_; - layer_normalization_pd_t(const layer_normalization_desc_t *adesc, + layer_normalization_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const layer_normalization_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) , src_md_(desc_.src_desc) , stat_md_(desc_.stat_desc) @@ -156,9 +156,10 @@ struct layer_normalization_pd_t : public primitive_desc_t { const memory_desc_t &src_desc() const { return desc_.src_desc; } }; +// NOLINTBEGIN(google-default-arguments) struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t { - typedef layer_normalization_fwd_pd_t base_class; - typedef layer_normalization_fwd_pd_t hint_class; + using base_class = layer_normalization_fwd_pd_t; + using hint_class = layer_normalization_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; @@ -170,8 +171,10 @@ struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t { return arg_usage_t::unused; } - if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input; - if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input; + if (arg == DNNL_ARG_SCALE) + return use_scale() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_SHIFT) + return use_shift() ? arg_usage_t::input : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -224,7 +227,7 @@ struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t { protected: memory_desc_t dst_md_; - layer_normalization_fwd_pd_t(const layer_normalization_desc_t *adesc, + layer_normalization_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const layer_normalization_fwd_pd_t *hint_fwd_pd) : layer_normalization_pd_t(adesc, attr, hint_fwd_pd) @@ -248,34 +251,44 @@ struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t { return false; } - bool attr_scales_ok() const { + bool attr_scales_ok(const std::vector &supported_args + = {DNNL_ARG_SRC, DNNL_ARG_DST}) const { + using namespace data_type; const auto &scales = attr()->scales_; - bool ok = true; - for (const auto &e : scales.scales_) { - ok = ok && e.second.mask_ == 0; + bool ok = scales.has_default_values(supported_args); + + for (const auto &arg : supported_args) { + if (!scales.has_default_values(arg)) { + // TODO: disallow non-int8 scales? + // const data_type_t dt = arg_md(arg)->data_type; + // ok = ok && utils::one_of(dt, s8, u8); + ok = ok && scales.get_mask(arg) == 0; + } } return ok; } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct layer_normalization_bwd_pd_t : public layer_normalization_pd_t { - typedef layer_normalization_bwd_pd_t base_class; - typedef layer_normalization_fwd_pd_t hint_class; + using base_class = layer_normalization_bwd_pd_t; + using hint_class = layer_normalization_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE, DNNL_ARG_DIFF_DST)) return arg_usage_t::input; - if (arg == DNNL_ARG_SCALE && use_scale()) return arg_usage_t::input; - if (arg == DNNL_ARG_SHIFT && use_shift()) return arg_usage_t::input; + if (arg == DNNL_ARG_SCALE) + return use_scale() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_SCALE && use_scale()) - return arg_usage_t::output; - if (arg == DNNL_ARG_DIFF_SHIFT && use_shift()) - return arg_usage_t::output; + if (arg == DNNL_ARG_DIFF_SCALE) + return use_scale() ? arg_usage_t::output : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_SHIFT) + return use_shift() ? arg_usage_t::output : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -324,7 +337,7 @@ struct layer_normalization_bwd_pd_t : public layer_normalization_pd_t { return index == 0 ? &diff_scaleshift_md_ : &glob_zero_md; } - int n_inputs() const override { return 4 + use_scale() + use_shift(); } + int n_inputs() const override { return 4 + use_scale(); } int n_outputs() const override { return 1 + (desc_.prop_kind == prop_kind::backward) @@ -336,7 +349,7 @@ struct layer_normalization_bwd_pd_t : public layer_normalization_pd_t { memory_desc_t diff_dst_md_; memory_desc_t diff_scaleshift_md_; - layer_normalization_bwd_pd_t(const layer_normalization_desc_t *adesc, + layer_normalization_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const layer_normalization_fwd_pd_t *hint_fwd_pd) : layer_normalization_pd_t(adesc, attr, hint_fwd_pd) @@ -368,6 +381,7 @@ struct layer_normalization_bwd_pd_t : public layer_normalization_pd_t { return false; } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/logging.cpp b/src/common/logging.cpp index 0eaec490899..bb4f1f9b2a8 100644 --- a/src/common/logging.cpp +++ b/src/common/logging.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,21 +17,20 @@ #include "common/logging.hpp" #include "common/utils.hpp" -#include "common/spdlog/sinks/rotating_file_sink.h" -#include "common/spdlog/spdlog.h" +#include "spdlog/sinks/rotating_file_sink.h" +#include "spdlog/spdlog.h" namespace dnnl { namespace impl { -log_manager_t::log_manager_t() { - +log_manager_t::log_manager_t() + : logfile_path_(getenv_string_user("VERBOSE_LOGFILE")) // enables logging as well as printing to stdout - console_flag_ = getenv_int_user("VERBOSE_LOG_WITH_CONSOLE", 0); + , console_flag_(getenv_int_user("VERBOSE_LOG_WITH_CONSOLE", 0)) { // logging is automatically disabled when no filepath is provided by // DNNL_VERBOSE_LOGFILE // in this case, we fall back to printing to stdout - logfile_path_ = getenv_string_user("VERBOSE_LOGFILE"); if (logfile_path_.empty()) { console_flag_ = true; return; @@ -93,7 +92,7 @@ void log_manager_t::log(const char *msg, log_level_t log_level) const { void log_manager_t::set_log_level(const std::string &vmode_str) const { // The logging level is determined from the verbose mode // with the following order of decreasing priority: - // [trace, debug, info, error, critical, off] + // [trace, debug, info, warn, error, critical, off] spdlog::set_level(spdlog::level::off); if (vmode_str == "-1" || vmode_str == "all") { @@ -104,6 +103,8 @@ void log_manager_t::set_log_level(const std::string &vmode_str) const { || vmode_str.find("profile") != std::string::npos || vmode_str.find("dispatch") != std::string::npos) { spdlog::set_level(spdlog::level::info); + } else if (vmode_str.find("warn") != std::string::npos) { + spdlog::set_level(spdlog::level::warn); } else if (vmode_str.find("check") != std::string::npos) { spdlog::set_level(spdlog::level::err); } else if (vmode_str.find("error") != std::string::npos) { diff --git a/src/common/lrn_pd.hpp b/src/common/lrn_pd.hpp index e7afb7b2bdb..0205c21b213 100644 --- a/src/common/lrn_pd.hpp +++ b/src/common/lrn_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -89,26 +89,27 @@ struct lrn_pd_t : public primitive_desc_t { memory_desc_t src_md_; memory_desc_t ws_md_; - lrn_pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, + lrn_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const lrn_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) - , src_md_(desc_.src_desc) - , ws_md_() {} + , src_md_(desc_.src_desc) {} }; +// NOLINTBEGIN(google-default-arguments) struct lrn_fwd_pd_t : public lrn_pd_t { - typedef lrn_fwd_pd_t base_class; - typedef lrn_fwd_pd_t hint_class; + using base_class = lrn_fwd_pd_t; + using hint_class = lrn_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; if (arg == DNNL_ARG_DST) return arg_usage_t::output; - if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md()))) - return arg_usage_t::output; + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::output + : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -145,7 +146,7 @@ struct lrn_fwd_pd_t : public lrn_pd_t { protected: memory_desc_t dst_md_; - lrn_fwd_pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, + lrn_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const lrn_fwd_pd_t *hint_fwd_pd) : lrn_pd_t(adesc, attr, hint_fwd_pd), dst_md_(desc_.dst_desc) {} @@ -156,10 +157,12 @@ struct lrn_fwd_pd_t : public lrn_pd_t { == status::success); } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct lrn_bwd_pd_t : public lrn_pd_t { - typedef lrn_bwd_pd_t base_class; - typedef lrn_fwd_pd_t hint_class; + using base_class = lrn_bwd_pd_t; + using hint_class = lrn_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST)) @@ -167,8 +170,9 @@ struct lrn_bwd_pd_t : public lrn_pd_t { if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; - if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md()))) - return arg_usage_t::input; + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::input + : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -214,7 +218,7 @@ struct lrn_bwd_pd_t : public lrn_pd_t { memory_desc_t diff_src_md_; memory_desc_t diff_dst_md_; - lrn_bwd_pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, + lrn_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const lrn_fwd_pd_t *hint_fwd_pd) : lrn_pd_t(adesc, attr, hint_fwd_pd) , diff_src_md_(desc_.diff_src_desc) @@ -231,6 +235,7 @@ struct lrn_bwd_pd_t : public lrn_pd_t { == status::success); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/math_utils.hpp b/src/common/math_utils.hpp index 0c156dff8db..848e393b6e9 100644 --- a/src/common/math_utils.hpp +++ b/src/common/math_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,7 +68,8 @@ inline T gcd(T a, T b) { return b; } -inline int lcm(int a, int b) { +template +inline T lcm(T a, T b) { a = impl::nstl::abs(a); b = impl::nstl::abs(b); assert(a > 0 && b > 0); @@ -88,12 +89,15 @@ inline int ilog2q(size_t v) { int p = 0; #define CP(pw) \ do { \ - if (v >= (1ull << pw)) { \ - v >>= pw; \ - p += pw; \ + if (v >= (1ull << (pw))) { \ + v >>= (pw); \ + p += (pw); \ } \ } while (0) + +#if INTPTR_MAX == INT64_MAX CP(32); +#endif CP(16); CP(8); CP(4); @@ -238,7 +242,7 @@ template ::type> inline U logistic_fwd(T s) { // Here we avoid division/inverse by infinity as some architectures have // non-standard behavior - float exp_overflow_bound = 88.72283172607421875; + float exp_overflow_bound = 88.72283172607421875f; float in = (float)-s; return in < exp_overflow_bound ? (U)(1.f / (1.f + ::expf(in))) : 0.f; } @@ -255,7 +259,7 @@ inline U logistic_bwd_use_dst(T dd, T d) { template ::type> inline U soft_relu_fwd(T s, A alpha) { - float exp_overflow_bound = 88.72283172607421875; + float exp_overflow_bound = 20.f; float in = (float)s * (float)alpha; float v = (in < exp_overflow_bound ? (U)(::log1pf(::expf(in))) : (U)in); return (U)(v / alpha); @@ -414,6 +418,31 @@ inline U hardswish_bwd(T dd, T s, A alpha, A beta) { return v <= 0.f ? 0.f : v >= 1.f ? dd : dd * w; } +template ::type> +inline U hsigmoid_fwd(T s) { + float v = s + 3.0f; + v = v > 0.0f ? v : 0.0f; + v = v < 6.0f ? v : 6.0f; + return (U)(v / 6.0f); +} + +template ::type> +inline U round_half_to_even_fwd(T s) { + float r = ::roundf((float)s); + float d = (float)s - r; + float remainder = ::fmodf(r, 2.0f); + return ((d != 0.5f) && (d != -0.5f)) || (remainder == 0.0f) ? (U)r : + (U)((float)s + d); +} + +template ::type> +inline U round_half_away_from_zero_fwd(T s) { + return (U)(::roundf((float)s)); +} + inline bool is_eltwise_ok( data_type_t src_dt, alg_kind_t alg, float alpha, float beta) { using namespace alg_kind; @@ -426,7 +455,8 @@ inline bool is_eltwise_ok( eltwise_exp, eltwise_gelu_tanh, eltwise_hardsigmoid, eltwise_hardswish, eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2, eltwise_pow, - eltwise_gelu_erf, eltwise_round) + eltwise_gelu_erf, eltwise_round, + eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even) && IMPLICATION( one_of(alg, eltwise_clip, eltwise_clip_v2), beta >= alpha) && IMPLICATION(alg == eltwise_round, src_dt == dnnl_f32) @@ -514,7 +544,7 @@ inline uint16_t philox8x16(uint32_t idx, uint32_t seed) { // - 1 lsb is used to index 16-bit words within this 32 bit random // value uint32_t r = philox4x32(idx >> 1, seed); - return (uint16_t)(r >> ((idx & 1) * sizeof(uint16_t))); + return (uint16_t)(r >> ((idx & 1) * sizeof(uint16_t) * 8)); } inline uint8_t philox16x8(uint32_t idx, uint32_t seed) { @@ -523,7 +553,7 @@ inline uint8_t philox16x8(uint32_t idx, uint32_t seed) { // - 2 lsb is used to index 8-bit words within this 32 bit random // value uint32_t r = philox4x32(idx >> 2, seed); - return (uint8_t)(r >> ((idx & 3) * sizeof(uint8_t))); + return (uint8_t)(r >> ((idx & 3) * sizeof(uint8_t) * 8)); } inline float stochastic_round_fwd( @@ -551,8 +581,8 @@ inline float stochastic_round_fwd( << (digits(data_type::f32) - digits(dst_dt)); // IMPORTANT: lsb of bias are used. - uint32_t rnd_bias = data_type_size(dst_dt) == 16 ? philox16x8(idx, seed) - : philox8x16(idx, seed); + uint32_t rnd_bias = data_type_size(dst_dt) == 2 ? philox16x8(idx, seed) + : philox8x16(idx, seed); rnd_bias = rnd_bias & ~truncation_mask; uint32_t s_u = utils::bit_cast(s); @@ -567,6 +597,42 @@ inline float stochastic_round_fwd( return r; } +inline float get_bias(const char *bias, size_t offset, data_type_t data_type) { + if (!bias) return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits_t
::type *)bias)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::bf16); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + +inline float get_sum(char *sum, size_t offset, data_type_t data_type) { + if (!sum) + return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits_t
::type *)sum)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + } // namespace math } // namespace impl } // namespace dnnl diff --git a/src/common/matmul.cpp b/src/common/matmul.cpp index c0fb43b6b56..12174cc3b24 100644 --- a/src/common/matmul.cpp +++ b/src/common/matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,25 +52,28 @@ status_t matmul_attr_check(const matmul_desc_t &desc, const engine_t *engine, auto attr_mask = smask_t::post_ops | smask_t::sum_dt | smask_t::dropout | smask_t::rounding_mode; // Matmul supports scales for floating point data types - attr_mask |= smask_t::scales_runtime; - attr_mask |= smask_t::scales_runtime_data_type; + attr_mask |= smask_t::scales_data_type; const bool src_is_int8 = utils::one_of(src_dt, data_type::s8, data_type::u8); - if (src_is_int8) attr_mask |= smask_t::zero_points_runtime; + const bool src_is_fp8 + = utils::one_of(src_dt, data_type::f8_e5m2, data_type::f8_e4m3); + if (src_is_int8 || src_is_fp8) attr_mask |= smask_t::zero_points; // Matmul supports zero points for floating point data types as part of // weights decompression. const bool wei_is_int = utils::one_of( wei_dt, data_type::s8, data_type::u8, data_type::s4, data_type::u4); - if (wei_is_int) { - attr_mask |= smask_t::zero_points_runtime_data_type; - attr_mask |= smask_t::zero_points_runtime_groups; - attr_mask |= smask_t::scales_runtime_groups; + const bool wei_is_fp8 + = utils::one_of(wei_dt, data_type::f8_e5m2, data_type::f8_e4m3); + if (wei_is_int || wei_is_fp8) { + attr_mask |= smask_t::zero_points_data_type; + attr_mask |= smask_t::zero_points_groups; + attr_mask |= smask_t::scales_groups; } - // Matmul supports fpmath mode - attr_mask |= smask_t::fpmath_mode; + // Matmul supports fpmath mode and accumulation mode + attr_mask |= smask_t::fpmath_mode | smask_t::accumulation_mode; VCHECK_MATMUL_UNIMPL(attr->has_default_values(attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); @@ -85,67 +88,173 @@ status_t matmul_attr_check(const matmul_desc_t &desc, const engine_t *engine, int wei_qmask_K = 1 << (ndims_wei - 2); int wei_qmask_N = 1 << (ndims_wei - 1); + int dst_qmask_M = src_qmask_K; + int dst_qmask_N = wei_qmask_N; + // Check scales if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - // Check allowed masks. - VCHECK_MATMUL_UNIMPL(utils::one_of(mask_src, 0, src_qmask_K, - src_qmask_M + src_qmask_K) - && utils::one_of(mask_wei, 0, wei_qmask_N, - wei_qmask_N + wei_qmask_K) - && mask_dst == 0, - VERBOSE_UNSUPPORTED_SCALES_CFG); + + dim_t src_scale_group_k = 1; + if (!sc.has_default_values(DNNL_ARG_SRC)) { + const int mask_src = sc.get_mask(DNNL_ARG_SRC); + + VCHECK_MATMUL_UNIMPL(utils::one_of(mask_src, 0, src_qmask_K, + src_qmask_M + src_qmask_K), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + if (!sc.get(DNNL_ARG_SRC).has_default_groups()) { + if (mask_src & src_qmask_K) + src_scale_group_k = sc.get_group(DNNL_ARG_SRC, 1); + } + + // Due to hardware specifics, groups should be multiple of 32. + VCHECK_MATMUL_UNIMPL(IMPLICATION(src_scale_group_k > 1, + src_scale_group_k % 32 == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + } + + dim_t wei_scale_group_k = 1; + dim_t wei_scale_group_n = 1; + if (!sc.has_default_values(DNNL_ARG_WEIGHTS)) { + const int mask_wei = sc.get_mask(DNNL_ARG_WEIGHTS); + + // Masks for weights scales can be any - skipping them. + + if (!sc.get(DNNL_ARG_WEIGHTS).has_default_groups()) { + if (mask_wei & wei_qmask_K) + wei_scale_group_k = sc.get_group(DNNL_ARG_WEIGHTS, 0); + if (mask_wei & wei_qmask_N) + wei_scale_group_n = sc.get_group(DNNL_ARG_WEIGHTS, 1); + } + + // Groups per N are solely for weights decompression as it's + // impossible to get performant kernel for a single `k` element in + // chain for regular quantized case. + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_scale_group_n > 1, + attr->fpmath_.apply_to_int_), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + // Due to hardware specifics, groups should be multiple of 32. + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_scale_group_k > 1, + wei_scale_group_k % 32 == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_scale_group_n > 1, + wei_scale_group_n % 32 == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + } + + if (!sc.has_default_values(DNNL_ARG_DST)) { + const int mask_dst = sc.get_mask(DNNL_ARG_DST); + + if (engine->kind() == engine_kind::gpu) { + VCHECK_MATMUL_UNIMPL( + utils::one_of(mask_dst, 0, dst_qmask_N, dst_qmask_M, + dst_qmask_N + dst_qmask_M), + VERBOSE_UNSUPPORTED_SCALES_CFG); + } else { + VCHECK_MATMUL_UNIMPL( + mask_dst == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } + } + // Check dependency between scales. // Source scales groups are supported for int8 source and must divide // or be divided by weights groups when both are greater than 1. - const auto src_scale_group_k = (mask_src & src_qmask_K) - ? sc.get(DNNL_ARG_SRC).group_dims_[1] - : 1; - const auto wei_scale_group_k = (mask_wei & wei_qmask_K) - ? sc.get(DNNL_ARG_WEIGHTS).group_dims_[0] - : 1; const bool groups_are_divisible = IMPLICATION( src_scale_group_k > 1 && wei_scale_group_k > 1, (src_scale_group_k % wei_scale_group_k == 0) || (wei_scale_group_k % src_scale_group_k == 0)); - VCHECK_MATMUL_UNIMPL(IMPLICATION(src_scale_group_k > 1, - src_is_int8 && groups_are_divisible), + VCHECK_MATMUL_UNIMPL( + IMPLICATION(src_scale_group_k > 1, + (src_is_int8 || src_is_fp8) && groups_are_divisible), VERBOSE_UNSUPPORTED_SCALES_CFG); } // Check zero points if (!attr->zero_points_.has_default_values()) { const auto &zp = attr->zero_points_; - int mask_src = 0, mask_wei = 0, mask_dst = 0; - zp.get(DNNL_ARG_SRC, &mask_src); - zp.get(DNNL_ARG_WEIGHTS, &mask_wei); - zp.get(DNNL_ARG_DST, &mask_dst); - VCHECK_MATMUL_UNIMPL(mask_src == 0 - || (desc.src_desc.ndims == 2 && mask_src == 1 << 1), - VERBOSE_UNSUPPORTED_ZP_CFG); - VCHECK_MATMUL_UNIMPL(utils::one_of(mask_wei, 0, wei_qmask_N, - wei_qmask_N + wei_qmask_K), - VERBOSE_UNSUPPORTED_ZP_CFG); - VCHECK_MATMUL_UNIMPL(mask_dst == 0 - || (desc.dst_desc.ndims == 2 && mask_dst == 1 << 1), - VERBOSE_UNSUPPORTED_ZP_CFG); + dim_t src_zero_point_group_k = 1; + if (!zp.has_default_values(DNNL_ARG_SRC)) { + const int mask_src = zp.get_mask(DNNL_ARG_SRC); - if (utils::one_of(zp.get_data_type(DNNL_ARG_WEIGHTS), data_type::s4, - data_type::u4)) { - dim_t k = desc.weights_desc.dims[ndims_wei - 2]; - dim_t n = desc.weights_desc.dims[ndims_wei - 1]; - VCHECK_MATMUL_UNIMPL( - IMPLICATION(mask_wei & wei_qmask_K, k % 2 == 0), + VCHECK_MATMUL_UNIMPL(utils::one_of(mask_src, 0, src_qmask_K, + src_qmask_M + src_qmask_K), VERBOSE_UNSUPPORTED_ZP_CFG); - VCHECK_MATMUL_UNIMPL( - IMPLICATION(mask_wei & wei_qmask_N, n % 2 == 0), + + if (!zp.get(DNNL_ARG_SRC).has_default_groups()) { + if (mask_src & src_qmask_K) + src_zero_point_group_k = zp.get_group(DNNL_ARG_SRC, 1); + } + + // Due to hardware specifics, groups should be multiple of 32. + VCHECK_MATMUL_UNIMPL(IMPLICATION(src_zero_point_group_k > 1, + src_zero_point_group_k % 32 == 0), VERBOSE_UNSUPPORTED_ZP_CFG); } + + dim_t wei_zero_point_group_k = 1; + dim_t wei_zero_point_group_n = 1; + if (!zp.has_default_values(DNNL_ARG_WEIGHTS)) { + const int mask_wei = zp.get_mask(DNNL_ARG_WEIGHTS); + + // Masks for weights zero_points can be any - skipping them. + + if (!zp.get(DNNL_ARG_WEIGHTS).has_default_groups()) { + if (mask_wei & wei_qmask_K) + wei_zero_point_group_k = zp.get_group(DNNL_ARG_WEIGHTS, 0); + if (mask_wei & wei_qmask_N) + wei_zero_point_group_n = zp.get_group(DNNL_ARG_WEIGHTS, 1); + } + + // Groups per N are solely for weights decompression as it's + // impossible to get performant kernel for a single `k` element in + // chain for regular quantized case. + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_zero_point_group_n > 1, + attr->fpmath_.apply_to_int_), + VERBOSE_UNSUPPORTED_ZP_CFG); + + // Due to hardware specifics, groups should be multiple of 32. + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_zero_point_group_k > 1, + wei_zero_point_group_k % 32 == 0), + VERBOSE_UNSUPPORTED_ZP_CFG); + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_zero_point_group_n > 1, + wei_zero_point_group_n % 32 == 0), + VERBOSE_UNSUPPORTED_ZP_CFG); + + if (utils::one_of(zp.get_data_type(DNNL_ARG_WEIGHTS), data_type::s4, + data_type::u4)) { + dim_t k = desc.weights_desc.dims[ndims_wei - 2]; + dim_t n = desc.weights_desc.dims[ndims_wei - 1]; + VCHECK_MATMUL_UNIMPL( + IMPLICATION(mask_wei & wei_qmask_K, k % 2 == 0), + VERBOSE_UNSUPPORTED_ZP_CFG); + VCHECK_MATMUL_UNIMPL( + IMPLICATION(mask_wei & wei_qmask_N, n % 2 == 0), + VERBOSE_UNSUPPORTED_ZP_CFG); + } + } + + if (!zp.has_default_values(DNNL_ARG_DST)) { + const int mask_dst = zp.get_mask(DNNL_ARG_DST); + + VCHECK_MATMUL_UNIMPL(mask_dst == 0 + || (desc.dst_desc.ndims == 2 && mask_dst == 1 << 1), + VERBOSE_UNSUPPORTED_ZP_CFG); + } + + // Check dependency between zero_points. + // Source zero_points groups are supported for int8 source and must + // divide or be divided by weights groups when both are greater than 1. + const bool groups_are_divisible = IMPLICATION( + src_zero_point_group_k > 1 && wei_zero_point_group_k > 1, + (src_zero_point_group_k % wei_zero_point_group_k == 0) + || (wei_zero_point_group_k % src_zero_point_group_k + == 0)); + VCHECK_MATMUL_UNIMPL(IMPLICATION(src_zero_point_group_k > 1, + src_is_int8 && groups_are_divisible), + VERBOSE_UNSUPPORTED_ZP_CFG); } // Check post-ops @@ -160,6 +269,9 @@ status_t matmul_attr_check(const matmul_desc_t &desc, const engine_t *engine, VCHECK_MATMUL_UNIMPL( po.check_sum_consistency(dst_dt, src_is_int8, true), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } return status::success; @@ -171,10 +283,16 @@ namespace dnnl { namespace impl { status_t matmul_desc_init(matmul_desc_t *matmul_desc, const memory_desc_t *src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) { + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const memory_desc_t *reduce_desc, matmul_reduce_kind_t reduce_kind) { VCHECK_MATMUL( !any_null(src_desc, weights_desc, dst_desc), VERBOSE_NULL_ARG); + // Note: This is an artificial limitation for the internal `reduce` feature + // to limit the scope to what is actually used. + VCHECK_MATMUL( + IMPLICATION(bias_desc, !reduce_desc), VERBOSE_UNSUPPORTED_BIAS_CFG); + auto op_d = matmul_desc_t(); op_d.primitive_kind = primitive_kind::matmul; @@ -182,8 +300,17 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc, op_d.weights_desc = *weights_desc; if (bias_desc) op_d.bias_desc = *bias_desc; op_d.dst_desc = *dst_desc; + if (reduce_desc) { + VCHECK_MATMUL(reduce_desc->format_kind != format_kind::any, + VERBOSE_UNSUPPORTED_FORMAT_KIND); + op_d.reduce_desc = *reduce_desc; + op_d.reduce_kind = reduce_kind; + VCHECK_MATMUL(op_d.reduce_kind != matmul_reduce_kind::undef, + VERBOSE_BAD_PARAM); + } const bool with_bias = op_d.bias_desc.ndims != 0; + const bool with_reduce = op_d.reduce_desc.ndims != 0; const int ndims = dst_desc->ndims; VCHECK_MATMUL(ndims >= 2 && ndims <= DNNL_MAX_NDIMS, VERBOSE_BAD_NDIMS, "dst", ndims); @@ -191,6 +318,8 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc, VERBOSE_INCONSISTENT_NDIMS, "src", "weights"); VCHECK_MATMUL(IMPLICATION(with_bias, op_d.bias_desc.ndims == ndims), VERBOSE_BAD_NDIMS, "bias", op_d.bias_desc.ndims); + VCHECK_MATMUL(IMPLICATION(with_reduce, op_d.reduce_desc.ndims == ndims), + VERBOSE_BAD_NDIMS, "reduce", op_d.reduce_desc.ndims); // check: m, n, k const int m_idx = ndims - 2; @@ -212,15 +341,52 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc, dst_desc->dims[m_idx])), VERBOSE_INCONSISTENT_DIM, "bias", m_idx, "dst", m_idx); + VCHECK_MATMUL(IMPLICATION(with_reduce, + one_of(op_d.reduce_desc.dims[n_idx], 1, + dst_desc->dims[n_idx])), + VERBOSE_INCONSISTENT_DIM, "reduce", n_idx, "dst", n_idx); + VCHECK_MATMUL(IMPLICATION(with_reduce, + one_of(op_d.reduce_desc.dims[m_idx], 1, + dst_desc->dims[m_idx])), + VERBOSE_INCONSISTENT_DIM, "reduce", m_idx, "dst", m_idx); + const int bia_mask = with_bias ? utils::get_dims_mask(dst_desc->dims, op_d.bias_desc.dims, ndims) : 0; - // s4/u4 requires n to be multiple of 2 - VCHECK_MATMUL(IMPLICATION(utils::one_of(weights_desc->data_type, - data_type::s4, data_type::u4), - weights_desc->dims[n_idx] % 2 == 0), - VERBOSE_BAD_DIM, "weights", n_idx); + using namespace data_type; + if (weights_desc->format_kind == format_kind::blocked + && utils::one_of( + weights_desc->data_type, s4, u4, f4_e2m1, f4_e3m0)) { + const auto &wei_strides = weights_desc->format_desc.blocking.strides; + + int n_unit_strides = 0; + for (int d = 0; d < ndims; d++) { + if (wei_strides[d] == 1) { + n_unit_strides++; + VCHECK_MATMUL( + n_unit_strides <= 1, VERBOSE_BAD_DIM, "weights", d); + } + VCHECK_MATMUL( + IMPLICATION(wei_strides[d] > 1, wei_strides[d] % 2 == 0), + VERBOSE_BAD_DIM, "weights", d); + } + } + if (src_desc->format_kind == format_kind::blocked + && utils::one_of(src_desc->data_type, s4, u4, f4_e2m1, f4_e3m0)) { + const auto &src_strides = src_desc->format_desc.blocking.strides; + + int n_unit_strides = 0; + for (int d = 0; d < ndims; d++) { + if (src_strides[d] == 1) { + n_unit_strides++; + VCHECK_MATMUL(n_unit_strides <= 1, VERBOSE_BAD_DIM, "src", d); + } + VCHECK_MATMUL( + IMPLICATION(src_strides[d] > 1, src_strides[d] % 2 == 0), + VERBOSE_BAD_DIM, "src", d); + } + } // check if other dims match. for (int d = 0; d < ndims - 2; ++d) { @@ -228,6 +394,7 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc, const dim_t w_dim = weights_desc->dims[d]; const dim_t d_dim = dst_desc->dims[d]; const dim_t b_dim = with_bias ? op_d.bias_desc.dims[d] : 0; + const dim_t r_dim = with_reduce ? op_d.reduce_desc.dims[d] : 0; if (one_of(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim, b_dim)) { @@ -246,6 +413,8 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc, VERBOSE_INVALID_BROADCAST, "src", d); VCHECK_MATMUL(IMPLICATION(with_bias, one_of(b_dim, 1, d_dim)), VERBOSE_INCONSISTENT_DIM, "bias", d, "dst", d); + VCHECK_MATMUL(IMPLICATION(with_reduce, one_of(r_dim, 1, d_dim)), + VERBOSE_INCONSISTENT_DIM, "reduce", d, "dst", d); } } @@ -256,6 +425,14 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc, *matmul_desc = op_d; return status::success; } + +status_t matmul_desc_init(matmul_desc_t *matmul_desc, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) { + return matmul_desc_init(matmul_desc, src_desc, weights_desc, bias_desc, + dst_desc, nullptr, matmul_reduce_kind::undef); +} + } // namespace impl } // namespace dnnl diff --git a/src/common/matmul_pd.hpp b/src/common/matmul_pd.hpp index f1963d7f8a3..76c43308044 100644 --- a/src/common/matmul_pd.hpp +++ b/src/common/matmul_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,15 +36,21 @@ namespace dnnl { namespace impl { +status_t matmul_desc_init(matmul_desc_t *matmul_desc, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const memory_desc_t *reduce_desc, matmul_reduce_kind_t reduce_kind); + status_t matmul_desc_init(matmul_desc_t *matmul_desc, const memory_desc_t *src_desc, const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, const memory_desc_t *dst_desc); +// NOLINTBEGIN(google-default-arguments) struct matmul_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::matmul; - typedef matmul_pd_t base_class; - typedef matmul_pd_t hint_class; + using base_class = matmul_pd_t; + using hint_class = matmul_pd_t; const matmul_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { @@ -55,8 +61,11 @@ struct matmul_pd_t : public primitive_desc_t { const bool input = utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS); if (input) return arg_usage_t::input; - if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; + if (arg == DNNL_ARG_BIAS) + return with_bias() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_REDUCE) + return with_reduce() ? arg_usage_t::output : arg_usage_t::unused; if (arg == DNNL_ARG_DST) return arg_usage_t::output; return primitive_desc_t::arg_usage(arg); @@ -69,6 +78,7 @@ struct matmul_pd_t : public primitive_desc_t { case DNNL_ARG_WEIGHTS: return weights_md(0); case DNNL_ARG_BIAS: return weights_md(1); case DNNL_ARG_DST: return dst_md(0, user_input); + case DNNL_ARG_REDUCE: return reduce_md(0); default: return primitive_desc_t::arg_md(arg); } } @@ -93,10 +103,16 @@ struct matmul_pd_t : public primitive_desc_t { return &glob_zero_md; } + const memory_desc_t *reduce_md( + int index = 0, bool user_input = false) const { + if (index == 0) return user_input ? &desc()->reduce_desc : &reduce_md_; + return &glob_zero_md; + } + int n_inputs() const override { return 2 + with_bias() + n_binary_po_inputs() + n_prelu_po_inputs(); } - int n_outputs() const override { return 1; } + int n_outputs() const override { return 1 + with_reduce(); } bool has_zero_dim_memory() const { return memory_desc_wrapper(src_md(0)).has_zero_dim() @@ -113,6 +129,10 @@ struct matmul_pd_t : public primitive_desc_t { } bool with_bias() const { return bias_md_.ndims != 0; } + bool with_reduce() const { return reduce_md_.ndims != 0; } + + matmul_reduce_kind_t reduce_kind() const { return desc_.reduce_kind; } + bool batched() const { return ndims() > 2; } dim_t batch() const { @@ -159,36 +179,65 @@ struct matmul_pd_t : public primitive_desc_t { return 1 << (wei_ndims - 2); } + int dst_qmask_N() const { return wei_qmask_N(); } + + int dst_qmask_M() const { return src_qmask_M(); } + virtual bool attr_scales_ok(const std::vector &supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { - if (attr()->scales_.has_default_values()) return true; + const auto &scales = attr()->scales_; + if (scales.has_default_values()) return true; - bool ok = attr()->scales_.has_default_values(supported_args); + bool ok = scales.has_default_values(supported_args); for (int arg : supported_args) { - const auto &sc = attr()->scales_.get(arg); - const auto &mask = sc.mask_; - if (sc.has_default_values()) { continue; } + if (scales.has_default_values(arg)) { continue; } + const auto &mask = scales.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS) { + const auto &g0 = scales.get_group(arg, 0); + const auto &g1 = scales.get_group(arg, 1); + const bool wei_k_group_ok = IMPLICATION(g0 > 1, K() % g1 == 0); + const bool wei_n_group_ok = IMPLICATION(g1 > 1, N() % g0 == 0); + + // Any group is allowed to be greater than 1 but only one at a + // time, not both. + ok = ok + && IMPLICATION(!scales.get(arg).has_default_groups(), + utils::one_of(1, g0, g1) && wei_k_group_ok + && wei_n_group_ok); + + // Mask over K dim is allowed for decompression feature only. + const bool is_decompression_or_dynquant + = utils::one_of(weights_md(0)->data_type, data_type::s8, + data_type::u8, data_type::s4, data_type::u4) + && IMPLICATION( + !types::is_integral_dt(src_md()->data_type), + attr()->fpmath_.apply_to_int_); ok = ok - && utils::one_of(mask, 0, wei_qmask_N(), - wei_qmask_K() + wei_qmask_N()); - ok = ok && utils::one_of(sc.ndims_, 0, 2) - && IMPLICATION(sc.ndims_ == 2, - sc.group_dims_[1] == 1 - && K() % sc.group_dims_[0] == 0); + && IMPLICATION((mask & wei_qmask_K()), + is_decompression_or_dynquant); } else if (arg == DNNL_ARG_SRC) { ok = ok && utils::one_of(mask, 0, src_qmask_K(), src_qmask_M() + src_qmask_K()); - ok = ok && utils::one_of(sc.ndims_, 0, 2); - ok = ok && IMPLICATION((mask & src_qmask_K()), sc.ndims_ == 2); ok = ok - && IMPLICATION(sc.ndims_ == 2, - sc.group_dims_[0] == 1 - && K() % sc.group_dims_[1] == 0); + && IMPLICATION((mask & src_qmask_K()), + !scales.get(arg).has_default_groups()); + ok = ok + && IMPLICATION(!scales.get(arg).has_default_groups(), + scales.get_group(arg, 0) + && K() % scales.get_group(arg, 1) == 0); + } else if (arg == DNNL_ARG_DST) { + ok = ok + && utils::one_of(mask, 0, dst_qmask_N(), + dst_qmask_M() + dst_qmask_N()); + ok = ok + && IMPLICATION(!scales.get(arg).has_default_groups(), + scales.get_group(arg, 1) == 1 + && (M() % scales.get_group(arg, 0)) + == 0); } else { - ok = ok && (mask == 0); + assert(!"Unsupported arg"); } } return ok; @@ -201,19 +250,22 @@ struct matmul_pd_t : public primitive_desc_t { memory_desc_t weights_md_; memory_desc_t bias_md_; memory_desc_t dst_md_; + memory_desc_t reduce_md_; - matmul_pd_t(const matmul_desc_t *adesc, const primitive_attr_t *attr, + matmul_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const matmul_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , src_md_(desc_.src_desc) , weights_md_(desc_.weights_desc) , bias_md_(desc_.bias_desc) - , dst_md_(desc_.dst_desc) {} + , dst_md_(desc_.dst_desc) + , reduce_md_(desc_.reduce_desc) {} // temporary solution to deal with format `any` bool set_default_formats() { - for (auto md : {&src_md_, &weights_md_, &bias_md_, &dst_md_}) { + for (auto md : + {&src_md_, &weights_md_, &bias_md_, &dst_md_, &reduce_md_}) { memory_desc_wrapper mdw(md); if (mdw.format_any()) { if (mdw.has_runtime_dims_or_strides()) return false; @@ -229,9 +281,10 @@ struct matmul_pd_t : public primitive_desc_t { // call this function. bool is_dense_format_kind() { return impl::is_dense_format_kind( - {&src_md_, &weights_md_, &bias_md_, &dst_md_}); + {&src_md_, &weights_md_, &bias_md_, &dst_md_, &reduce_md_}); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/memory.cpp b/src/common/memory.cpp index 745bfcb7d15..dd286cc4346 100644 --- a/src/common/memory.cpp +++ b/src/common/memory.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,19 +59,15 @@ size_t memory_desc_map_size(const memory_desc_t *md, int index = 0) { auto mdw = memory_desc_wrapper(md); if (mdw.has_runtime_dims_or_strides()) return DNNL_RUNTIME_SIZE_VAL; - if (mdw.offset0() == 0) return mdw.size(index); - memory_desc_t md_no_offset0 = *md; - md_no_offset0.offset0 = 0; - return memory_desc_wrapper(md_no_offset0).size(index) - + md->offset0 * mdw.data_type_size(); + return mdw.size(index, true, true); } } // namespace dnnl_memory::dnnl_memory(dnnl::impl::engine_t *engine, const dnnl::impl::memory_desc_t *md, const std::vector &flags, const std::vector &handles) - : engine_(engine), md_(*md) { + : engine_(engine), md_(*md), counter_(1) { const size_t nhandles = handles.size(); std::vector> mem_storages( @@ -91,14 +87,27 @@ dnnl_memory::dnnl_memory(dnnl::impl::engine_t *engine, dnnl_memory::dnnl_memory(dnnl::impl::engine_t *engine, const dnnl::impl::memory_desc_t *md, std::unique_ptr &&memory_storage) - : engine_(engine), md_(*md) { + : engine_(engine), md_(*md), counter_(1) { this->reset_memory_storage(std::move(memory_storage)); } +#ifdef DNNL_EXPERIMENTAL_SPARSE +dnnl_memory::dnnl_memory(dnnl::impl::engine_t *engine, + const dnnl::impl::memory_desc_t *md, + std::vector> + &&memory_storages) + : engine_(engine) + , md_(*md) + , memory_storages_(std::move(memory_storages)) + , counter_(1) {} +#endif + status_t dnnl_memory::set_data_handle(void *handle, int index) const { using namespace dnnl::impl; void *old_handle; - CHECK(memory_storage(index)->get_data_handle(&old_handle)); + auto *ms = memory_storage(index); + if (!ms) return status::invalid_arguments; + CHECK(ms->get_data_handle(&old_handle)); if (handle != old_handle) { CHECK(memory_storage(index)->set_data_handle(handle)); } @@ -154,13 +163,14 @@ status_t dnnl_memory_create(memory_t **memory, const memory_desc_t *md, auto _memory = new memory_t(engine, md, flags, handle_ptr); if (_memory == nullptr) return out_of_memory; if (_memory->memory_storage() == nullptr) { - delete _memory; + _memory->release(); return out_of_memory; } *memory = _memory; return success; } +#ifdef DNNL_EXPERIMENTAL_SPARSE status_t dnnl_memory_create_v2(memory_t **memory, const memory_desc_t *md, engine_t *engine, int nhandles, void **handles) { const bool args_ok = !any_null(memory, engine, handles) && nhandles > 0; @@ -169,8 +179,8 @@ status_t dnnl_memory_create_v2(memory_t **memory, const memory_desc_t *md, #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL if (engine->kind() == engine_kind::gpu) #endif - return dnnl_sycl_interop_memory_create( - memory, md, engine, dnnl_sycl_interop_usm, handles[0]); + return dnnl_sycl_interop_memory_create_v2( + memory, md, engine, dnnl_sycl_interop_usm, nhandles, handles); #endif memory_desc_t z_md = types::zero_md(); if (md == nullptr) md = &z_md; @@ -196,13 +206,14 @@ status_t dnnl_memory_create_v2(memory_t **memory, const memory_desc_t *md, if (_memory == nullptr) return out_of_memory; for (size_t i = 0; i < handles_vec.size(); i++) { if (_memory->memory_storage((int)i) == nullptr) { - delete _memory; + _memory->release(); return out_of_memory; } } *memory = _memory; return success; } +#endif status_t dnnl_memory_get_memory_desc( const memory_t *memory, const memory_desc_t **md) { @@ -288,8 +299,18 @@ status_t dnnl_memory_unmap_data(const memory_t *memory, void *mapped_ptr) { return dnnl_memory_unmap_data_v2(memory, mapped_ptr, 0); } +status_t dnnl_memory_unmap_data_sparse( + const_dnnl_memory_t memory, int index, void *mapped_ptr) { + bool args_ok = !any_null(memory); + if (!args_ok) return invalid_arguments; + + return memory->memory_storage()->unmap_data(mapped_ptr, nullptr); + + return unimplemented; +} + status_t dnnl_memory_destroy(memory_t *memory) { - delete memory; + if (memory != nullptr) memory->release(); return success; } diff --git a/src/common/memory.hpp b/src/common/memory.hpp index 3d64ac1a028..5dd0de18248 100644 --- a/src/common/memory.hpp +++ b/src/common/memory.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,7 +55,12 @@ struct dnnl_memory : public dnnl::impl::c_compatible { dnnl_memory(dnnl::impl::engine_t *engine, const dnnl::impl::memory_desc_t *md, std::unique_ptr &&memory_storage); - virtual ~dnnl_memory() = default; +#ifdef DNNL_EXPERIMENTAL_SPARSE + dnnl_memory(dnnl::impl::engine_t *engine, + const dnnl::impl::memory_desc_t *md, + std::vector> + &&memory_storage); +#endif /** returns memory's engine */ dnnl::impl::engine_t *engine() const { return engine_; } @@ -77,7 +82,9 @@ struct dnnl_memory : public dnnl::impl::c_compatible { /** returns data handle */ dnnl::impl::status_t get_data_handle(void **handle, int index = 0) const { - return memory_storage(index)->get_data_handle(handle); + auto ms = memory_storage(index); + if (!ms) return dnnl::impl::status::invalid_arguments; + return ms->get_data_handle(handle); } /** sets data handle */ @@ -91,7 +98,15 @@ struct dnnl_memory : public dnnl::impl::c_compatible { size_t get_num_handles() const { return memory_storages_.size(); } + void retain() { counter_++; } + + void release() { + if (--counter_ == 0) { delete this; } + } + protected: + virtual ~dnnl_memory() = default; + dnnl::impl::engine_t *engine_; const dnnl::impl::memory_desc_t md_; @@ -101,6 +116,17 @@ struct dnnl_memory : public dnnl::impl::c_compatible { // Number of storages is larger than 1 only for sparse memory. std::vector> memory_storages_; + std::atomic counter_; }; +namespace dnnl { +namespace impl { + +struct memory_deleter_t { + void operator()(memory_t *m) const { m->release(); } +}; + +} // namespace impl +} // namespace dnnl + #endif diff --git a/src/common/memory_desc.cpp b/src/common/memory_desc.cpp index 15115eab9e5..fc22cd5dbc7 100644 --- a/src/common/memory_desc.cpp +++ b/src/common/memory_desc.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2024 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -109,6 +109,7 @@ status_t memory_desc_init_by_strides(memory_desc_t &memory_desc, int ndims, return success; } +#if 0 status_t memory_desc_init_by_csr_encoding(memory_desc_t &memory_desc, int ndims, const dims_t dims, data_type_t data_type, dim_t nnz, data_type_t indices_dt, data_type_t pointers_dt) { @@ -139,6 +140,37 @@ status_t memory_desc_init_by_csr_encoding(memory_desc_t &memory_desc, int ndims, return success; } +#endif + +status_t memory_desc_init_by_coo_encoding(memory_desc_t &memory_desc, int ndims, + const dims_t dims, data_type_t data_type, dim_t nnz, + data_type_t indices_dt) { + if (ndims == 0) { + memory_desc = types::zero_md(); + return success; + } + + // This is the only number of dims that is supported at this point. + VCHECK_MEMORY(ndims <= 2, unimplemented, VERBOSE_BAD_NDIMS, "", ndims); + + bool args_ok = memory_desc_sanity_check( + ndims, dims, data_type, format_kind::undef); + VCHECK_MEMORY(args_ok, invalid_arguments, VERBOSE_MEM_DESC_CHECK_FAIL); + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind::sparse; + md.format_desc.sparse_desc.encoding = sparse_encoding::coo; + md.format_desc.sparse_desc.nnz = nnz; + md.format_desc.sparse_desc.metadata_types[0] = indices_dt; + + memory_desc = md; + + return success; +} status_t memory_desc_init_by_packed_encoding(memory_desc_t &memory_desc, int ndims, const dims_t dims, data_type_t data_type, dim_t nnz) { @@ -441,8 +473,9 @@ status_t memory_desc_permute_axes(memory_desc_t &out_memory_desc, VCHECK_MEMORY( !memory_desc_wrapper(in_memory_desc).has_runtime_dims_or_strides(), invalid_arguments, VERBOSE_UNSUPPORTED_MEM_STRIDE); - VCHECK_MEMORY(in_memory_desc.extra.flags == 0, invalid_arguments, - VERBOSE_UNSUPPORTED_MD_FLAG, "extra"); + VCHECK_MEMORY( + check_md_extra_flags_compensation_gpu(in_memory_desc.extra.flags), + invalid_arguments, VERBOSE_UNSUPPORTED_MD_FLAG, "extra"); // verify that perm is indeed a permutation of [0 .. ndims) unsigned occurrence_mask = 0; @@ -500,8 +533,10 @@ status_t memory_desc_init_by_string_tag(memory_desc_t &md, int ndims, pos--; int dim_idx = std::tolower(tag[pos0]) - 'a'; - VCHECK_MEMORY(dim_idx < ndims, invalid_arguments, VERBOSE_BAD_NDIMS, "", - ndims); + VCHECK_MEMORY(dim_idx < ndims, invalid_arguments, + "ndims deduced (%d) from the tag \'%s\' is inconsistent with " + "provided ndims (%d)", + dim_idx + 1, tag.c_str(), ndims); ndims_from_tag = std::max(dim_idx + 1, ndims_from_tag); int block_str_len = pos0 - pos - 1; bool is_blocked = block_str_len > 0; @@ -511,7 +546,9 @@ status_t memory_desc_init_by_string_tag(memory_desc_t &md, int ndims, dim_blocks.emplace_back(dim_idx, block); } VCHECK_MEMORY((ndims_from_tag == ndims), invalid_arguments, - VERBOSE_BAD_NDIMS, "", ndims); + "ndims deduced (%d) from the tag \'%s\' is inconsistent with " + "provided ndims (%d)", + ndims_from_tag, tag.c_str(), ndims); auto &blk = md.format_desc.blocking; @@ -579,6 +616,7 @@ status_t dnnl_memory_desc_create_with_strides(memory_desc_t **memory_desc, return success; } +#if 0 status_t dnnl_memory_desc_create_with_csr_encoding(memory_desc_t **memory_desc, int ndims, const dims_t dims, data_type_t data_type, dim_t nnz, data_type_t indices_dt, data_type_t pointers_dt) { @@ -591,6 +629,53 @@ status_t dnnl_memory_desc_create_with_csr_encoding(memory_desc_t **memory_desc, (*memory_desc) = md.release(); return success; } +#endif + +status_t dnnl_memory_desc_init_sparse(sparse_desc_t **sparse_desc, + sparse_encoding_t encoding) { + if (!sparse_desc) return invalid_arguments; + auto sd = utils::make_unique(); + + sd->encoding = encoding; + *sparse_desc = sd.release(); + + return success; +} + +status_t dnnl_memory_desc_create_sparse(memory_desc_t **memory_desc, + sparse_encoding_t encoding, int ndims, + const dims_t dims, data_type_t data_type) { + + sparse_desc_t* sd = nullptr;; + CHECK(dnnl_memory_desc_init_sparse(&sd, encoding)); + + auto md = utils::make_unique(); + if (!md) return out_of_memory; + + md->ndims = ndims; + array_copy(md->dims, dims, ndims); + md->data_type = data_type; + array_copy(md->padded_dims, dims, ndims); + md->format_kind = format_kind::sparse; + md->format_desc.sparse_desc = *sd; + + *memory_desc = md.release();; + + return success; +} + +status_t dnnl_memory_desc_create_with_coo_encoding(memory_desc_t **memory_desc, + int ndims, const dims_t dims, data_type_t data_type, dim_t nnz, + data_type_t indices_dt) { + if (any_null(memory_desc)) return invalid_arguments; + + auto md = utils::make_unique(); + if (!md) return out_of_memory; + CHECK(memory_desc_init_by_coo_encoding( + *md, ndims, dims, data_type, nnz, indices_dt)); + (*memory_desc) = md.release(); + return success; +} status_t dnnl_memory_desc_create_with_packed_encoding( memory_desc_t **memory_desc, int ndims, const dims_t dims, @@ -679,6 +764,7 @@ status_t dnnl_memory_desc_query( case query::format_kind: switch ((int)md->format_kind) { case format_kind::rnn_packed: + case format_kind::cublaslt_blocked: case format_kind::wino: *(format_kind_t *)result = format_kind::opaque; break; @@ -701,6 +787,10 @@ status_t dnnl_memory_desc_query( if (!is_blocked) return status::invalid_arguments; *(const dims_t **)result = &md->format_desc.blocking.inner_idxs; break; + case query::sparse_encoding: + if (md->format_kind != format_kind::sparse) return status::invalid_arguments; + *(const dnnl_sparse_encoding_t **)result = &md->format_desc.sparse_desc.encoding; + break; default: return status::unimplemented; } return status::success; @@ -728,12 +818,20 @@ status_t dnnl_memory_desc_query_v2( case query::data_type: *(data_type_t *)result = (index == 0) ? md->data_type - : md->format_desc.sparse_desc.metadata_types[index - 1]; + : md->format_desc.sparse_desc.metadata_types + [md->format_desc.sparse_desc.encoding + == sparse_encoding_t:: + dnnl_coo + ? 0 + : index - 1]; break; case query::num_handles_s32: if (is_sparse) { switch (md->format_desc.sparse_desc.encoding) { case sparse_encoding::csr: + case sparse_encoding::coo: + *(int *)result = md->ndims + 1; + break; case sparse_encoding::packed: *(int *)result = 3; break; default: assert(!"unknown encoding"); *(int *)result = 0; } diff --git a/src/common/memory_desc.hpp b/src/common/memory_desc.hpp index b8045a2b144..00efd877a10 100644 --- a/src/common/memory_desc.hpp +++ b/src/common/memory_desc.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2023 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ namespace dnnl { namespace impl { +enum class cublaslt_memory_format_t { col32_2r_4r4 }; + // Winograd-specific formats enum class wino_memory_format_t { // Undefined memory format, used for empty memory descriptors. @@ -54,7 +56,7 @@ const rnn_packed_memory_format_t ldio_p = rnn_packed_memory_format_t::ldio_p; // TODO: convert to 'enum class'. // Flags for memory special features enum memory_extra_flags_t { - dnnl_memory_extra_flag_none = 0x0U, + dnnl_memory_extra_flag_none = 0u, // Indicates the weights have an additional buffer, that depends on the // @p compensation_mask. // @@ -62,13 +64,22 @@ enum memory_extra_flags_t { // the additional buffer would consist of OC values: // O[oc : 0,OC] = // -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) } - dnnl_memory_extra_flag_compensation_conv_s8s8 = 0x1U, - dnnl_memory_extra_flag_scale_adjust = 0x2U, - dnnl_memory_extra_flag_rnn_u8s8_compensation = 0x4U, + dnnl_memory_extra_flag_compensation_conv_s8s8 = 1u, + dnnl_memory_extra_flag_scale_adjust = 2u, + dnnl_memory_extra_flag_rnn_u8s8_compensation = 4u, dnnl_memory_extra_flag_gpu_rnn_u8s8_compensation = dnnl_memory_extra_flag_rnn_u8s8_compensation, - dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 0x8U, - dnnl_memory_extra_flag_rnn_s8s8_compensation = 0x16U, + dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 8u, + dnnl_memory_extra_flag_rnn_s8s8_compensation = 16u, + // This flag has to be kept separate from *compensation_conv_asymmetric_src + // since the GPU precompute algorithm is incompatible with that of the CPU + dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src = 32u, + // This flag depends on *compensation_gpu_conv_asymmetric_src and is used + // when precompute is to be performed for a backward-by-data convolution + dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_bwd = 64u, + // This flag depends on *compensation_gpu_conv_asymmetric_src and is used + // when IC and OC are swapped to reinterpret a deconv as a BWD_D conv + dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_swap = 128u, }; // Create aliases for extra flags to preserve the old behavior. @@ -85,8 +96,23 @@ const memory_extra_flags_t rnn_s8s8_compensation = dnnl_memory_extra_flag_rnn_s8s8_compensation; const memory_extra_flags_t compensation_conv_asymmetric_src = dnnl_memory_extra_flag_compensation_conv_asymmetric_src; +const memory_extra_flags_t compensation_gpu_conv_asymmetric_src + = dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src; +const memory_extra_flags_t compensation_gpu_conv_asymmetric_src_bwd + = dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_bwd; +const memory_extra_flags_t compensation_gpu_conv_asymmetric_src_swap + = dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_swap; } // namespace memory_extra_flags +inline bool check_md_extra_flags_compensation_gpu(uint64_t flags) { + using namespace memory_extra_flags; + const uint64_t c = compensation_gpu_conv_asymmetric_src; + const uint64_t b = compensation_gpu_conv_asymmetric_src_bwd; + const uint64_t s = compensation_gpu_conv_asymmetric_src_swap; + return (flags == none) || (flags == c) || (flags == (c | b)) + || (flags == (c | b | s)); +} + // Generic description of blocked data layout for most memory formats. struct blocking_desc_t { // The strides between the outermost blocks. @@ -135,6 +161,12 @@ struct rnn_packed_desc_t { size_t size; }; +struct cublaslt_blocked_desc_t { + cublaslt_memory_format_t cublaslt_format; + size_t size; +}; + +#if 0 struct sparse_desc_t { static constexpr int max_metadata_types = 2; // Each encoding defines the number of handles it requires and their @@ -193,6 +225,21 @@ struct sparse_desc_t { // - Use the block number to find an offset in the packed data // - Use the bitmask to unpack the packed data blocking_desc_t packed_desc; +} +#endif + +struct sparse_desc_t { + static constexpr int max_metadata_types = 2; + /// Specifies what encoding is used. + dnnl_sparse_encoding_t encoding; + /// Descriptor for blocked bitmask - opaque. + blocking_desc_t packed_desc; + // Number of non-zero entries. + dnnl_dim_t nnz; + // Metadata types. Each encoding defines how to interpret these. + // - CSR: 0th - index data type + // 1st - pointer data type + dnnl_data_type_t metadata_types[max_metadata_types]; }; // Description of extra information stored in memory @@ -201,7 +248,12 @@ struct memory_extra_desc_t { : flags(0) , compensation_mask(0) , scale_adjust(0.0f) - , asymm_compensation_mask(0) {} + , asymm_compensation_mask(0) + , idhw {0, 0, 0} + , odhw {0, 0, 0} + , pdhw {0, 0, 0} + , ddhw {0, 0, 0} + , dst_size(0) {} // The flags contain arbitrary extra information, such as compensation. // @sa dnnl_memory_extra_flags_t uint64_t flags; @@ -211,6 +263,16 @@ struct memory_extra_desc_t { float scale_adjust; // Compensation mask for asymmetric quantization int asymm_compensation_mask; + // Precomp GPU ZP convolution input spatials + dim_t idhw[3]; + // Precomp GPU ZP convolution output spatials + dim_t odhw[3]; + // Precomp GPU ZP convolution padding spatials + dim_t pdhw[3]; + // Precomp GPU ZP convolution dilation spatials + dim_t ddhw[3]; + // Precomp GPU ZP convolution destination size + dim_t dst_size; }; status_t DNNL_API memory_desc_init_by_tag(memory_desc_t &memory_desc, int ndims, @@ -245,8 +307,7 @@ struct dnnl_memory_desc : public dnnl::impl::c_compatible { , padded_offsets {} , offset0(0) , format_kind(dnnl::impl::format_kind::undef) - , format_desc {} - , extra {} {} + , format_desc {} {} // Number of dimensions int ndims; // Dimensions in the following order: @@ -289,6 +350,8 @@ struct dnnl_memory_desc : public dnnl::impl::c_compatible { dnnl::impl::wino_desc_t wino_desc; // Tensor of packed weights for RNN. dnnl::impl::rnn_packed_desc_t rnn_packed_desc; + // Description of the data layout for memory formats used in cublasLt IMMA kernels. + dnnl::impl::cublaslt_blocked_desc_t cublaslt_blocked_desc; // Description of the sparse encodings. dnnl::impl::sparse_desc_t sparse_desc; // ... other descriptions possible diff --git a/src/common/memory_desc_wrapper.cpp b/src/common/memory_desc_wrapper.cpp index 4d6cb0b92cc..3636d051b18 100644 --- a/src/common/memory_desc_wrapper.cpp +++ b/src/common/memory_desc_wrapper.cpp @@ -1,5 +1,6 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation +* Copyright 2024-2025 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,11 +27,10 @@ namespace dnnl { namespace impl { -status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, - std::initializer_list inner_blks, - std::initializer_list inner_idxs) { +template +static status_t fill_blocked_impl(memory_desc_t &md, T&& perm, T&& inner_blks, T&& inner_idxs) { const bool ok = true && perm.size() == (size_t)md.ndims - && inner_blks.size() == inner_idxs.size(); + && inner_blks.size() == inner_idxs.size(); if (!ok) return status::invalid_arguments; md.offset0 = 0; @@ -81,6 +81,27 @@ status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, return status::success; } +status_t fill_blocked(memory_desc_t &md, + std::initializer_list perm, + std::initializer_list inner_blks, + std::initializer_list inner_idxs) { + return fill_blocked_impl(md, perm, inner_blks, inner_idxs); +} + +status_t fill_blocked(memory_desc_t &md, + std::vector& perm, + std::vector& inner_blks, + std::vector& inner_idxs) { + return fill_blocked_impl(md, perm, inner_blks, inner_idxs); +} + +status_t fill_blocked(memory_desc_t &md, + std::vector& perm, + std::vector& inner_blks, + std::vector& inner_idxs) { + return fill_blocked_impl(md, perm, inner_blks, inner_idxs); +} + void memory_desc_wrapper::compute_strides_compat(dims_t *strides_compat) const { if (ndims() == 0) return; @@ -125,15 +146,15 @@ void memory_desc_wrapper::compute_strides_compat(dims_t *strides_compat) const { utils::array_copy(strides_compat[1], inner_strides, ndims()); } -status_t memory_desc_wrapper::compute_blocking( - memory_desc_t &memory_desc, format_tag_t tag) { +template +status_t process_tag(F f, format_tag_t tag, Args&&... args) { using namespace format_tag; - VCHECK_MEMORY((memory_desc.ndims != 0), status::invalid_arguments, - VERBOSE_BAD_NDIMS, "", 0); + // VCHECK_MEMORY((memory_desc.ndims != 0), status::invalid_arguments, + // VERBOSE_BAD_NDIMS, "", 0); #define C(tag, ... /* perm, inner_blks, inner_idxs */) \ - case tag: return fill_blocked(memory_desc, __VA_ARGS__) + case tag: return f(std::forward(args)..., __VA_ARGS__) switch (tag) { C(a, {0}, {}, {}); @@ -192,6 +213,7 @@ status_t memory_desc_wrapper::compute_blocking( C(Ab4a, {0, 1}, {4}, {0}); C(Ab8a, {0, 1}, {8}, {0}); + C(Ab32a, {0, 1}, {32}, {0}); C(BA4b4a, {1, 0}, {4, 4}, {1, 0}); C(BA8b4a, {1, 0}, {8, 4}, {1, 0}); @@ -200,6 +222,9 @@ status_t memory_desc_wrapper::compute_blocking( C(BA16a32b, {1, 0}, {16, 32}, {0, 1}); C(BA16a48b, {1, 0}, {16, 48}, {0, 1}); C(BA16a64b, {1, 0}, {16, 64}, {0, 1}); + C(BA24b8a, {1, 0}, {24, 8}, {1, 0}); + C(aCB24c8b, {0, 2, 1}, {24, 8}, {2, 1}); + C(abDC24d8c, {0, 1, 3, 2}, {24, 8}, {3, 2}); C(BA16a16b2a, {1, 0}, {16, 16, 2}, {0, 1, 0}); C(BA16a32b2a, {1, 0}, {16, 32, 2}, {0, 1, 0}); C(BA16a48b2a, {1, 0}, {16, 48, 2}, {0, 1, 0}); @@ -396,6 +421,7 @@ status_t memory_desc_wrapper::compute_blocking( C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); C(BAcd8a16b2a, {1, 0, 2, 3}, {8, 16, 2}, {0, 1, 0}); C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); + C(ABcd8a32b, {0, 1, 2, 3}, {8, 32}, {0, 1}); C(ABcd8a4b, {0, 1, 2, 3}, {8, 4}, {0, 1}); C(ABcd8a2b, {0, 1, 2, 3}, {8, 2}, {0, 1}); C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); @@ -515,7 +541,9 @@ status_t memory_desc_wrapper::compute_blocking( C(Acb8a, {0, 2, 1}, {8}, {0}); C(AcB8a2b, {0, 2, 1}, {8, 2}, {0, 1}); C(AcB8a4b, {0, 2, 1}, {8, 4}, {0, 1}); + C(aCBd8b8c, {0, 2, 1, 3}, {8, 8}, {1, 2}); C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); + C(aCBde8b8c, {0, 2, 1, 3, 4}, {8, 8}, {1, 2}); C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); C(AcdB16a2b, {0, 2, 3, 1}, {16, 2}, {0, 1}); @@ -531,7 +559,9 @@ status_t memory_desc_wrapper::compute_blocking( C(AcdeB8a4b, {0, 2, 3, 4, 1}, {8, 4}, {0, 1}); C(Acedb16a, {0, 2, 4, 3, 1}, {16}, {0}); C(Adcb16a, {0, 3, 2, 1}, {16}, {0}); + C(BAc8a8b, {1, 0, 2}, {8, 8}, {0, 1}); C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); + C(BAcd8a8b, {1, 0, 2, 3}, {8, 8}, {0, 1}); C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); C(ABc32a16b, {0, 1, 2}, {32, 16}, {0, 1}); C(ABcd32a16b, {0, 1, 2, 3}, {32, 16}, {0, 1}); @@ -584,13 +614,17 @@ status_t memory_desc_wrapper::compute_blocking( C(aBCde2b8c8b2c, {0, 1, 2, 3, 4}, {2, 8, 8, 2}, {1, 2, 1, 2}); C(aBdec32b, {0, 1, 3, 4, 2}, {32}, {1}); C(aCBdef16c16b, {0, 2, 1, 3, 4, 5}, {16, 16}, {2, 1}); + C(aCBdef8b8c, {0, 2, 1, 3, 4, 5}, {8, 8}, {1, 2}); C(aCBdef16b16c, {0, 2, 1, 3, 4, 5}, {16, 16}, {1, 2}); + C(Abcdef4a, {0, 1, 2, 3, 4, 5}, {4}, {0}); + C(Abcdef8a, {0, 1, 2, 3, 4, 5}, {8}, {0}); C(Abcdef16a, {0, 1, 2, 3, 4, 5}, {16}, {0}); C(Abcdef32a, {0, 1, 2, 3, 4, 5}, {32}, {0}); C(aCBd16c16b, {0, 2, 1, 3}, {16, 16}, {2, 1}); C(aCBde16c16b, {0, 2, 1, 3, 4}, {16, 16}, {2, 1}); C(Acdb32a, {0, 2, 3, 1}, {32}, {0}); C(BAcd16b16a, {1, 0, 2, 3}, {16, 16}, {1, 0}); + C(BAcde8a8b, {1, 0, 2, 3, 4}, {8, 8}, {0, 1}); C(BAcde16a16b, {1, 0, 2, 3, 4}, {16, 16}, {0, 1}); C(BAc16b16a, {1, 0, 2}, {16, 16}, {1, 0}); C(aBCd2b4c2b, {0, 1, 2, 3}, {2, 4, 2}, {1, 2, 1}); @@ -611,6 +645,7 @@ status_t memory_desc_wrapper::compute_blocking( C(AB8a2b, {0, 1}, {8, 2}, {0, 1}); C(abDc16d, {0, 1, 3, 2}, {16}, {3}); C(abDc32d, {0, 1, 3, 2}, {32}, {3}); + C(abDC16d4c, {0, 1, 3, 2}, {16, 4}, {3, 2}); C(abDC32d4c, {0, 1, 3, 2}, {32, 4}, {3, 2}); C(abCd4c, {0, 1, 2, 3}, {4}, {2}); C(abCde4c, {0, 1, 2, 3, 4}, {4}, {2}); @@ -620,6 +655,7 @@ status_t memory_desc_wrapper::compute_blocking( C(abCdef32c, {0, 1, 2, 3, 4, 5}, {32}, {2}); C(abdEc16e, {0, 1, 3, 4, 2}, {16}, {4}); C(abdEc32e, {0, 1, 3, 4, 2}, {32}, {4}); + C(abdEC16e4c, {0, 1, 3, 4, 2}, {16, 4}, {4, 2}); C(abdEC32e2c, {0, 1, 3, 4, 2}, {32, 2}, {4, 2}); C(abdEC32e4c, {0, 1, 3, 4, 2}, {32, 4}, {4, 2}); C(abdEC64e2c, {0, 1, 3, 4, 2}, {64, 2}, {4, 2}); @@ -991,6 +1027,28 @@ status_t memory_desc_wrapper::compute_blocking( return status::invalid_arguments; } +status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, format_tag_t tag) { + using fill_blocked_t = status_t(memory_desc_t&, std::initializer_list, std::initializer_list, std::initializer_list); + if (memory_desc.ndims == 0) return status::invalid_arguments; + return process_tag(fill_blocked, tag, memory_desc); +} + +status_t memory_desc_wrapper::compute_blocking(format_tag_t tag, + std::vector &perm, + std::vector &inner_blks, + std::vector &inner_idxs) { + + auto extract_data = [&](std::initializer_list _perm, + std::initializer_list _inner_blks, + std::initializer_list _inner_idxs) -> status_t { + perm = {_perm.begin(), _perm.end()}; + inner_blks = {_inner_blks.begin(), _inner_blks.end()}; + inner_idxs = {_inner_idxs.begin(), _inner_idxs.end()}; + return status::success; + }; + return process_tag(extract_data, tag); +} + } // namespace impl } // namespace dnnl diff --git a/src/common/memory_desc_wrapper.hpp b/src/common/memory_desc_wrapper.hpp index 0d85e63437d..6cfad59ad81 100644 --- a/src/common/memory_desc_wrapper.hpp +++ b/src/common/memory_desc_wrapper.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,8 +32,20 @@ namespace dnnl { namespace impl { +status_t fill_blocked(memory_desc_t &md, + std::vector &perm, + std::vector &inner_blks, + std::vector &inner_idxs); + +status_t fill_blocked(memory_desc_t &md, + std::vector &perm, + std::vector &inner_blks, + std::vector &inner_idxs); + + /** thin wrapper class over \struct memory_desc_t which allows easy * manipulations with underlying C structure, which is taken by reference */ +// NOLINTNEXTLINE(readability-identifier-naming) struct memory_desc_wrapper : public c_compatible { const memory_desc_t *md_; @@ -67,12 +79,20 @@ struct memory_desc_wrapper : public c_compatible { bool is_rnn_packed_desc() const { return format_kind() == format_kind::rnn_packed; } + bool is_cublaslt_blocked_desc() const { + return format_kind() == format_kind::cublaslt_blocked; + } bool is_sparse_desc() const { return format_kind() == format_kind::sparse; } + bool is_blocking_or_sparse_packed_desc() const { + return is_blocking_desc() + || (is_sparse_desc() + && sparse_desc().encoding == sparse_encoding::packed); + } + const blocking_desc_t &blocking_desc() const { - assert(is_blocking_desc() || is_sparse_packed_desc()); - if (!is_sparse_desc()) return md_->format_desc.blocking; - return sparse_desc().packed_desc; + assert(is_blocking_desc()); + return md_->format_desc.blocking; } const wino_desc_t &wino_desc() const { assert(is_wino_desc()); @@ -82,6 +102,10 @@ struct memory_desc_wrapper : public c_compatible { assert(is_rnn_packed_desc()); return md_->format_desc.rnn_packed_desc; } + const cublaslt_blocked_desc_t &cublaslt_blocked_desc() const { + assert(is_cublaslt_blocked_desc()); + return md_->format_desc.cublaslt_blocked_desc; + } const sparse_desc_t &sparse_desc() const { assert(is_sparse_desc()); @@ -93,20 +117,18 @@ struct memory_desc_wrapper : public c_compatible { return sparse_desc().metadata_types[idx]; } - sparse_encoding_t encoding() const { - assert(is_sparse_desc()); - return sparse_desc().encoding; - } - dim_t nnz() const { assert(is_sparse_desc()); return sparse_desc().nnz; } - const dims_t &strides() const { return blocking_desc().strides; } - const memory_extra_desc_t &extra() const { return md_->extra; } + sparse_encoding_t encoding() const { + assert(is_sparse_desc()); + return sparse_desc().encoding; + } + /* some useful function */ /** returns the number of elements including padding if \param with_padding @@ -142,30 +164,28 @@ struct memory_desc_wrapper : public c_compatible { size_t additional_buffer_data_size(uint64_t flag_select) const { using namespace memory_extra_flags; if (flag_select & compensation_conv_s8s8) return sizeof(int32_t); - if ((flag_select & rnn_u8s8_compensation) - && !types::extra_flag_rnn_s8s8_compensation_is_set(flag_select)) - return sizeof(float); + if (flag_select & rnn_u8s8_compensation) return sizeof(float); if (flag_select & compensation_conv_asymmetric_src) return sizeof(int32_t); + if (flag_select & compensation_gpu_conv_asymmetric_src) + return sizeof(int32_t); return 0; } /** return true if memory format has additional buffer */ bool is_additional_buffer() const { using namespace memory_extra_flags; - // Currently compensation is not required for rnn_s8s8_compensation, - // but it has common bit with rnn_u8s8_compensation constant so we have - // to exclude rnn_s8s8_compensation case explicitly - return ((extra().flags - & (compensation_conv_s8s8 | rnn_u8s8_compensation - | compensation_conv_asymmetric_src)) - && !types::extra_flag_rnn_s8s8_compensation_is_set( - extra().flags)); + return extra().flags + & (compensation_conv_s8s8 | rnn_u8s8_compensation + | compensation_gpu_conv_asymmetric_src + | compensation_conv_asymmetric_src); } /** returns the size required for a particular extra memory buffer */ size_t additional_buffer_size(memory_extra_flags_t flag) const { using namespace memory_extra_flags; + const auto flags = extra().flags; + if (!(flags & flag)) return 0; const auto ndims = this->ndims(); const auto &pdims = padded_dims(); @@ -179,26 +199,26 @@ struct memory_desc_wrapper : public c_compatible { return (size_t)prod * buff_data_size; }; - if (extra().flags & compensation_conv_s8s8) { + if (flag == compensation_conv_s8s8) { return calculate_size(extra().compensation_mask, additional_buffer_data_size(flag)); } - - if ((extra().flags & rnn_u8s8_compensation) - && !types::extra_flag_rnn_s8s8_compensation_is_set( - extra().flags)) { + if (flag == rnn_u8s8_compensation) { return calculate_size(extra().compensation_mask, additional_buffer_data_size(flag)); } - if (extra().flags & compensation_conv_asymmetric_src) { + if (flag == compensation_conv_asymmetric_src) { return calculate_size(extra().asymm_compensation_mask, additional_buffer_data_size(flag)); } + if (flag == compensation_gpu_conv_asymmetric_src) { + return extra().dst_size; + } return 0; } - int blk_size() const { + dim_t blk_size() const { assert(is_blocking_desc() || is_sparse_packed_desc()); const auto &bd = blocking_desc(); return utils::array_product(bd.inner_blks, bd.inner_nblks); @@ -213,18 +233,22 @@ struct memory_desc_wrapper : public c_compatible { buff_size += additional_buffer_size(compensation_conv_s8s8); buff_size += additional_buffer_size(rnn_u8s8_compensation); buff_size += additional_buffer_size(compensation_conv_asymmetric_src); + buff_size + += additional_buffer_size(compensation_gpu_conv_asymmetric_src); return buff_size; } - /** returns the size required to store described memory - * note: if offset0 != 0 returns 0 (need to specify the behavior) */ - size_t size(int index = 0, bool include_additional_size = true) const { + /** returns the size required to store described memory note: does not + include offset0 by default */ + size_t size(int index = 0, bool include_additional_size = true, + bool include_offset0 = false) const { if (utils::one_of(format_kind(), format_kind::undef, format_kind::any) || is_zero() || has_zero_dim()) return 0; if (utils::one_of(format_kind(), format_kind::blocked, - format_kind::wino, format_kind::rnn_packed) + format_kind::wino, format_kind::rnn_packed, + format_kind::cublaslt_blocked) && index != 0) { return 0; } @@ -235,9 +259,28 @@ struct memory_desc_wrapper : public c_compatible { return wino_desc().size; } else if (is_rnn_packed_desc()) { return rnn_packed_desc().size; + } else if (is_cublaslt_blocked_desc()) { + return cublaslt_blocked_desc().size; + } else if (is_sparse_desc()) { + if (sparse_desc().encoding == sparse_encoding::packed) { + // Only 2D tensors are supported at this point. + assert(ndims() == 2); + // Only OI16i64o4i is supported at this point. + // assert(matches_tag(format_tag::OI16i64o4i)); - TODO: enable for sparse packed. + const size_t metadata = padded_dims()[0] * padded_dims()[1] / 64 + * sizeof(uint64_t); + using comp_tile_len_type = int; + size_t comp_tile_data_size = ceil(static_cast(padded_dims()[0] * padded_dims()[1]) + / (64 * 64 * (64 / sizeof(comp_tile_len_type)))) * 64; + return comp_tile_data_size + (padded_dims()[0] * padded_dims()[1] * data_type_size()) + + metadata + 1000; + // todo: [av] why 1000? + } else { + printf("encoding:%d\n", (int)sparse_desc().encoding), fflush(stdout); + assert(!"unknown sparse encoding"); + return 0; + } } else if (is_blocking_desc()) { - if (offset0() != 0) return 0; - dims_t blocks = {0}; compute_blocks(blocks); @@ -252,11 +295,13 @@ struct memory_desc_wrapper : public c_compatible { } if (max_size == 1 && bd.inner_nblks != 0) { - max_size = utils::array_product(bd.inner_blks, bd.inner_nblks); + max_size = static_cast(blk_size()); } - size_t data_size = max_size * data_type_size() - / sub_byte_data_type_multiplier(); + // `div_up` guarantees a spot in memory for odd number of half-byte + // elements. Crucial case is `1` when simple division returns 0. + size_t data_size = utils::div_up(max_size * data_type_size(), + sub_byte_data_type_multiplier()); if (is_additional_buffer()) { // The additional buffers, typically of data type int32_t, float // are stored at the end of data. Pad the data, so that the @@ -265,7 +310,9 @@ struct memory_desc_wrapper : public c_compatible { data_size = utils::rnd_up(data_size, alignment_in_bytes); } return data_size - + (include_additional_size ? additional_buffer_size() : 0); + + (include_additional_size ? additional_buffer_size() : 0) + + (include_offset0 ? data_type_size() * offset0() : 0); +#if 0 } else if (is_sparse_desc()) { if (sparse_desc().encoding == sparse_encoding::csr) { switch (index) { @@ -283,6 +330,18 @@ struct memory_desc_wrapper : public c_compatible { } default: assert(!"unknown index"); return 0; } + } else if (sparse_desc().encoding == sparse_encoding::coo) { + // Return size for values. + if (index == 0) { + return nnz() * data_type_size(); + } else if (index > 0 && index <= ndims()) { + // Return size for index buffers. + const auto idx_dt = metadata_type(0); + return nnz() * types::data_type_size(idx_dt); + } else { + assert(!"unknown index"); + return 0; + } } else if (sparse_desc().encoding == sparse_encoding::packed) { // If the size if queried from a user-created memory descriptor. if (blocking_desc().strides[0] == 0) return 0; @@ -305,6 +364,7 @@ struct memory_desc_wrapper : public c_compatible { assert(!"unknown sparse encoding"); return 0; } +#endif } else { assert(!"unknown format kind"); return 0; @@ -333,8 +393,7 @@ struct memory_desc_wrapper : public c_compatible { if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)) return false; if (has_runtime_dims_or_strides() || has_broadcast()) return false; - return nelems(with_padding) * data_type_size() - / sub_byte_data_type_multiplier() + return utils::div_up(nelems(with_padding)* data_type_size(), sub_byte_data_type_multiplier()) == size(0, /* include_additional_size = */ false); } @@ -418,15 +477,16 @@ struct memory_desc_wrapper : public c_compatible { * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ /* TODO: revise */ bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true, - bool with_data_type = true, int dim_start = 0) const; + bool with_data_type = true, int dim_start = 0, bool use_weak_cmp = false, + bool check_off0 = false, uint64_t stride_mask = 0xffffffffffffffff) const; /** returns true if one memory can be reordered to another */ bool consistent_with(const memory_desc_wrapper &rhs) const; /** returns true if the memory desc corresponds to the given format tag. * @sa memory_desc_matches_tag */ - bool matches_tag(format_tag_t tag) const { - return memory_desc_matches_tag(*md_, tag); + bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const { + return memory_desc_matches_tag(*md_, tag, strides); } /** returns matching tag (or undef if match is not found) @@ -439,14 +499,30 @@ struct memory_desc_wrapper : public c_compatible { return format_tag::undef; } + template + format_tag_t mb_stride_relaxed_match(Tags... tags) const { + dims_t skip_mb_stride {}; + // See `memory_desc_matches_tag` comment. + skip_mb_stride[0] = -1; + for (const auto &tag : {tags...}) + if (matches_tag(tag, skip_mb_stride)) return tag; + return format_tag::undef; + } + /* offset section */ /** returns physical offset by logical one. logical offset is represented by * an array \param pos. if \param is_pos_padded is true \param pos * represents the position in already padded area */ dim_t off_v(const dims_t pos, bool is_pos_padded = false) const { - assert(is_blocking_desc() || is_sparse_packed_desc()); - const blocking_desc_t &blk = blocking_desc(); + assert(is_blocking_or_sparse_packed_desc()); + + const blocking_desc_t &blk = [&]() { + if (is_blocking_desc()) + return blocking_desc(); + else + return sparse_desc().packed_desc; + }(); dims_t pos_copy = {0}; for (int d = 0; d < ndims(); ++d) @@ -520,7 +596,11 @@ struct memory_desc_wrapper : public c_compatible { * user responsibility to adjust the result to get offset within blocks */ template dim_t blk_off(Args... args) const { - return _blk_off(args...); + assert(is_blocking_or_sparse_packed_desc()); + if (is_blocking_desc()) { + return _blk_off(args...); + } + return _blk_off_sparse(args...); } template @@ -529,12 +609,27 @@ struct memory_desc_wrapper : public c_compatible { : blk_off(xn, args...); } + /** returns physical offset by logical one. Logical offset is represented by + * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a + * user responsibility to adjust the result to get offset within blocks. + * If @tparam sub_off0 is true, then offset0() will be subtracted + * from result.*/ + template + dim_t blk_off(T xn, Args... args) const { + return blk_off(xn, args...) - sub_off0 * offset0(); + } + /* static functions section */ /* TODO: replace with non-static, once md_ becomes non-const ref */ static status_t compute_blocking( memory_desc_t &memory_desc, format_tag_t tag); + static status_t compute_blocking(format_tag_t tag, + std::vector &perm, + std::vector &inner_blks, + std::vector &inner_idxs); + private: /* TODO: put logical_offset in utils */ template @@ -554,41 +649,71 @@ struct memory_desc_wrapper : public c_compatible { return offset0(); } + template + dim_t _blk_off_sparse() const { + return offset0(); + } + template dim_t _blk_off(T xc, Args... args) const { - assert(is_blocking_desc() || is_sparse_packed_desc()); + assert(is_blocking_desc()); + constexpr int dc = ORIG_LEN - sizeof...(args) - 1; + return xc * blocking_desc().strides[dc] + _blk_off(args...); + } + + template + dim_t _blk_off_sparse(T xc, Args... args) const { + assert(is_sparse_desc()); constexpr int dc = ORIG_LEN - sizeof...(args) - 1; - return xc * blocking_desc().strides[dc] - + _blk_off(args...); + return xc * sparse_desc().packed_desc.strides[dc] + _blk_off_sparse(args...); } }; inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, - bool with_padding, bool with_data_type, int dim_start) const { + bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask) const { using namespace utils; if (one_of(format_kind(), format_kind::undef, format_kind::any)) return false; - if (is_wino_desc() || is_rnn_packed_desc()) return false; + if (is_wino_desc() || is_rnn_packed_desc() || is_cublaslt_blocked_desc()) + return false; const int ds = dim_start; const auto &blk = blocking_desc(); const auto &r_blk = rhs.blocking_desc(); + auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp; + auto cmp_strides = [&]() { + if (0xffffffffffffffff == stride_mask) { + return custom_cpm(blk.strides + ds, r_blk.strides + ds, ndims() - ds); + } else { + for (int i = 0; i < ndims(); ++i) { + if (stride_mask & (1 << i)) { + if (blk.strides[i] != r_blk.strides[i] + && IMPLICATION(use_weak_cmp, (blk.strides[i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides[i] != DNNL_RUNTIME_DIM_VAL))) { + return false; + } + } + } + } + return true; + }; + return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */ && format_kind() == rhs.format_kind() && IMPLICATION(with_data_type, data_type() == rhs.data_type()) - && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) - && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) + && custom_cpm(dims() + ds, rhs.dims() + ds, ndims() - ds) + && cmp_strides() && blk.inner_nblks == r_blk.inner_nblks && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) && IMPLICATION(with_padding, true - && array_cmp(padded_dims() + ds, + && custom_cpm(padded_dims() + ds, rhs.padded_dims() + ds, ndims() - ds) - && array_cmp(padded_offsets() + ds, - rhs.padded_offsets() + ds, ndims() - ds)); + && custom_cpm(padded_offsets() + ds, + rhs.padded_offsets() + ds, ndims() - ds)) + && IMPLICATION(check_off0, (offset0() == DNNL_RUNTIME_DIM_VAL || rhs.offset0() ==DNNL_RUNTIME_DIM_VAL || offset0() == rhs.offset0())); } inline bool memory_desc_wrapper::consistent_with( diff --git a/src/common/memory_storage.hpp b/src/common/memory_storage.hpp index 822cce0391f..747e53ead5c 100644 --- a/src/common/memory_storage.hpp +++ b/src/common/memory_storage.hpp @@ -75,6 +75,14 @@ struct memory_storage_t : public c_compatible { /** returns shallow copy */ virtual std::unique_ptr clone() const = 0; + /** returns shallow copy with a offset for accessor pointer for buffers + * to prevent use of sub-buffers where possible*/ + virtual std::unique_ptr clone_ptr_off( + size_t offset) const { + assert(!"not expected"); + return nullptr; + } + /** returns true if the pointer associated with the storage is NULL */ bool is_null() const { void *ptr; diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index 9f5ca8612f1..58813ff00aa 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -1,6 +1,6 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation -* Copyright 2024 Arm Ltd. and affiliates +* Copyright 2018-2025 Intel Corporation +* Copyright 2024-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -166,11 +166,14 @@ enum { key_brgemm_primitive_buffer_d, key_brgemm_primitive_zp_comp_a, key_brgemm_primitive_zp_comp_b, + key_brgemm_primitive_buffer_reduce, + key_brgemm_primitive_decomp_buf, key_concat_iptrs, key_concat_istrides, key_concat_nelems, key_concat_optrs, key_concat_tent_dst, + key_conv_pack_space, key_conv_adjusted_scales, key_conv_amx_inp_buffer, key_conv_amx_tilecfg, @@ -179,6 +182,8 @@ enum { key_conv_amx_wsp_buffer, key_conv_bia_reduction, key_conv_bias_bf16_convert_wsp, + key_conv_bias_f16_convert_wsp, + key_conv_bias_s32_convert, key_conv_cudnn, key_conv_cudnn_algo, key_conv_cudnn_filter, @@ -198,9 +203,13 @@ enum { key_conv_gemm_imtr, key_conv_gemm_zp_src_comp, key_conv_int_dat_in_acc_dt, + key_conv_ncsp_dst, + key_conv_ncsp_src, + key_conv_ncsp_diff_dst, + key_conv_ncsp_diff_src, + key_conv_ncsp_matmul_dst, + key_conv_ncsp_diff_sp_sum, key_conv_padded_bias, - key_conv_permuted_inputs, - key_conv_permuted_outputs, key_conv_permuted_weights, key_conv_rtus_space, key_conv_store_wsp, @@ -247,20 +256,27 @@ enum { key_iprod_dst_bf16_convert_wsp, key_iprod_dst_reorder, key_iprod_int_dat_in_acc_dt, + key_iprod_src_reorder, + key_iprod_weights_reorder, key_lnorm_inv_sqrtvar, key_lnorm_tmp_mean, key_lnorm_tmp_var, key_lnorm_tmp_diff_ss, key_lnorm_reduction, + key_matmul_pack_space, key_matmul_dst_in_acc_dt, + key_matmul_lt_algo_scratch, + key_matmul_lt_block_c, key_matmul_src_trans, key_matmul_wei_trans, key_matmul_dst_trans, key_matmul_dst_cast_acc, + key_matmul_sparse_tmp_ptr, key_pool_dst_bf16cvt, key_pool_dst_plain2blocked_cvt, key_pool_ind_plain2blocked_cvt, key_pool_src_bf16cvt, + key_pool_src_f32_accum, key_pool_src_plain2blocked_cvt, key_pool_reduction, key_precomputed_scales, @@ -269,6 +285,7 @@ enum { key_reducer_space_bctx, key_reduction, key_reduction_1, + key_reduction_out, key_reorder_cross_space, key_reorder_space, key_reorder_src_scales, @@ -281,6 +298,9 @@ enum { key_reorder_rnn_weights_reduction, key_reorder_rnn_weights_transposition, key_reorder_rnn_weights_xf16_cvt, + key_reorder_cublaslt_src_float, + key_reorder_cublaslt_dst_float, + key_reorder_cublaslt_generic, key_rnn_space, key_rnn_bf32_attention_trans, key_rnn_bf32_wei_layer_trans, @@ -302,15 +322,20 @@ enum { key_softmax_interim_store, key_sum_reduction, key_sum_srcs_cvt, - key_wino_transformed_weights, key_wino_U, key_wino_V, key_wino_M, - key_wino_workspace, + key_decompression_scales, + key_decompression_zero_points, + key_src_quantized, + key_src_dequantized_scales, + key_src_grouped_sum, // These two keys should always be the last ones, // even though they are not in alphabetical order key_nested, key_nested_multiple, + key_dw_conv_buffer, + key_dw_conv_padded_bias, }; enum { @@ -414,14 +439,8 @@ struct registry_t { public: common_iterator_t(const void *base_ptr_, const std::unordered_map &map, - bool is_begin = true) { - base_ptr = base_ptr_; - if (is_begin) { - iter = map.cbegin(); - } else { - iter = map.cend(); - } - } + bool is_begin = true) + : base_ptr(base_ptr_), iter(is_begin ? map.cbegin() : map.cend()) {} common_iterator_t &operator++(int) { iter++; return *this; @@ -439,8 +458,8 @@ struct registry_t { (return_type)ptr_start, entry.size}; } }; - typedef common_iterator_t iterator; - typedef common_iterator_t const_iterator; + using iterator = common_iterator_t; + using const_iterator = common_iterator_t; iterator begin(void *base_ptr_) const { return iterator(base_ptr_, offset_map_); } diff --git a/src/common/memory_zero_pad.cpp b/src/common/memory_zero_pad.cpp index ebaedd428bf..77afaedb832 100644 --- a/src/common/memory_zero_pad.cpp +++ b/src/common/memory_zero_pad.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include "dnnl_thread.hpp" #include "dnnl_traits.hpp" +#include "dnnl_sel_build.hpp" #include "stream.hpp" #include "type_helpers.hpp" #include "utils.hpp" @@ -25,6 +26,7 @@ #include "memory.hpp" #include "primitive_exec_types.hpp" +using namespace dnnl; using namespace dnnl::impl; using namespace dnnl::impl::data_type; using namespace dnnl::impl::status; @@ -39,7 +41,7 @@ void typed_zero_pad_blk(const memory_desc_wrapper &m_d, void *data_handle) { * This allows user will be to create bf16 memory * on non-avx512_core machines. */ using data_t = typename utils::conditional
::type>::type; + typename prec_traits_t
::type>::type; auto data = reinterpret_cast(data_handle); const auto &dims = m_d.dims(); const auto &pdims = m_d.padded_dims(); @@ -142,7 +144,7 @@ void typed_zero_pad_generic_blocked( * This allows user will be to create bf16 memory * on non-avx512_core machines. */ using data_t = typename utils::conditional
::type>::type; + typename prec_traits_t
::type>::type; auto data = reinterpret_cast(data_handle); const int ndims = m_d.ndims(); const auto &dims = m_d.dims(); @@ -204,7 +206,7 @@ status_t typed_zero_pad(const memory_t *memory, const exec_ctx_t &ctx) { void *mapped_ptr = ctx.map_memory_storage(memory_storage, ctx.stream(), map_size); - auto *data = static_cast::type *>(mapped_ptr); + auto *data = static_cast::type *>(mapped_ptr); auto blk = mdw.blocking_desc(); auto get_blksize = [&](int ind) { @@ -219,9 +221,11 @@ status_t typed_zero_pad(const memory_t *memory, const exec_ctx_t &ctx) { #define CASE(blksize_, blk_kind) \ do { \ if (blksize == (blksize_)) { \ - typed_zero_pad_blk(mdw, data); \ - ctx.unmap_memory_storage( \ - memory_storage, mapped_ptr, ctx.stream()); \ + DNNL_CSCOPE(DNNL_MACRO_CAT3(typed_zero_pad_blk_, blksize_, blk_kind)) { \ + typed_zero_pad_blk(mdw, data); \ + ctx.unmap_memory_storage( \ + memory_storage, mapped_ptr, ctx.stream()); \ + } \ return success; \ } \ } while (0) @@ -280,6 +284,8 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) { switch (mdw.data_type()) { case f16: return typed_zero_pad(memory, ctx); case bf16: return typed_zero_pad(memory, ctx); + case f4_e3m0: return typed_zero_pad(memory, ctx); + case f4_e2m1: return typed_zero_pad(memory, ctx); case e8m0: return typed_zero_pad(memory, ctx); case f8_e5m2: return typed_zero_pad(memory, ctx); case f8_e4m3: return typed_zero_pad(memory, ctx); @@ -289,6 +295,8 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) { case u8: return typed_zero_pad(memory, ctx); case s4: return typed_zero_pad(memory, ctx); case u4: return typed_zero_pad(memory, ctx); + case bin: return typed_zero_pad(memory, ctx); + case nf4: return typed_zero_pad(memory, ctx); default: assert(!"memory is undefined"); return unimplemented; } return unimplemented; diff --git a/src/common/nstl.hpp b/src/common/nstl.hpp index 45a6d7c49ac..227ecff67f2 100644 --- a/src/common/nstl.hpp +++ b/src/common/nstl.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ #include "bfloat16.hpp" #include "float16.hpp" +#include "float4.hpp" #include "float8.hpp" #include "int4.hpp" #include "internal_defs.hpp" @@ -54,7 +55,7 @@ void *malloc(size_t size, int alignment); #endif void free(void *p); -struct c_compatible { +struct c_compatible { // NOLINT(readability-identifier-naming) enum { default_alignment = 64 }; static void *operator new(size_t sz) { return MALLOC(sz, default_alignment); @@ -83,14 +84,14 @@ struct c_compatible { namespace nstl { template -constexpr const T abs(const T &a) { +constexpr T abs(const T &a) { return a >= 0 ? a : -a; } // Computes the modulus and returns the result as the least positive residue // when the divisor > 0. template -inline const T modulo(const T ÷nd, const T &divisor) { +inline T modulo(const T ÷nd, const T &divisor) { static_assert(std::is_integral::value, "T must be an integer type."); assert(divisor > 0); T result = dividend % divisor; @@ -100,7 +101,7 @@ inline const T modulo(const T ÷nd, const T &divisor) { // Computes the additive inverse modulus and returns the result as the least // positive residue when the divisor > 0. template -inline const T additive_inverse_modulo(const T ÷nd, const T &divisor) { +inline T additive_inverse_modulo(const T ÷nd, const T &divisor) { static_assert(std::is_integral::value, "T must be an integer type."); assert(divisor > 0); T result = modulo(dividend, divisor); @@ -156,6 +157,38 @@ struct numeric_limits : public std::numeric_limits {}; template <> struct numeric_limits : public std::numeric_limits {}; +template <> +struct numeric_limits { + static constexpr float4_e3m0_t lowest() { return float4_e3m0_t(0xf, true); } + // Min normal is equal to the value 1.0 + static constexpr float4_e3m0_t min() { return float4_e3m0_t(0x1, true); } + // Max normal is equal to the value 6.0 + static constexpr float4_e3m0_t max() { return float4_e3m0_t(0x7, true); } + + static constexpr int bias = 0x3; + static constexpr int digits = 1; // 1 implicit bit + + static constexpr float4_e3m0_t epsilon() { + return float4_e3m0_t(0x3, true); + } +}; + +template <> +struct numeric_limits { + static constexpr float4_e2m1_t lowest() { return float4_e2m1_t(0xf, true); } + // Min normal is equal to the value 1.0 + static constexpr float4_e2m1_t min() { return float4_e2m1_t(0x2, true); } + // Max normal is equal to the value 6.0 + static constexpr float4_e2m1_t max() { return float4_e2m1_t(0x7, true); } + + static constexpr int bias = 0x1; + static constexpr int digits = 2; // 1+1 implicit bits + + static constexpr float4_e2m1_t epsilon() { + return float4_e2m1_t(0x2, true); + } +}; + template <> struct numeric_limits { static constexpr float8_e8m0_t lowest() { @@ -253,7 +286,7 @@ struct numeric_limits { }; template -struct is_integral { +struct is_integral { // NOLINT(readability-identifier-naming) static constexpr bool value = false; }; template <> @@ -282,7 +315,7 @@ struct is_integral { }; template -struct is_same { +struct is_same { // NOLINT(readability-identifier-naming) static constexpr bool value = false; }; template @@ -310,20 +343,20 @@ struct is_same { enum nstl_status_t { success = 0, out_of_memory }; template -class vector : public c_compatible { +class vector : public c_compatible { // NOLINT(readability-identifier-naming) private: std::vector _impl; public: - typedef typename std::vector::iterator iterator; - typedef typename std::vector::const_iterator const_iterator; - typedef typename std::vector::size_type size_type; - vector() {} + using iterator = typename std::vector::iterator; + using const_iterator = typename std::vector::const_iterator; + using size_type = typename std::vector::size_type; + vector() = default; vector(size_type n) : _impl(n) {} vector(size_type n, const T &value) : _impl(n, value) {} template vector(input_iterator first, input_iterator last) : _impl(first, last) {} - ~vector() {} + ~vector() = default; size_type size() const { return _impl.size(); } T &operator[](size_type i) { return _impl[i]; } const T &operator[](size_type i) const { return _impl[i]; } @@ -339,21 +372,25 @@ class vector : public c_compatible { } void clear() { _impl.clear(); } void push_back(const T &t) { _impl.push_back(t); } + template + void emplace_back(Args&&... args) { + _impl.emplace_back(std::forward(args)...); + } void resize(size_type count) { _impl.resize(count); } void reserve(size_type count) { _impl.reserve(count); } }; template -class map : public c_compatible { +class map : public c_compatible { // NOLINT(readability-identifier-naming) private: std::map _impl; public: - typedef typename std::map::iterator iterator; - typedef typename std::map::const_iterator const_iterator; - typedef typename std::map::size_type size_type; - map() {} - ~map() {} + using iterator = typename std::map::iterator; + using const_iterator = typename std::map::const_iterator; + using size_type = typename std::map::size_type; + map() = default; + ~map() = default; size_type size() const { return _impl.size(); } T &operator[](const Key &k) { return _impl[k]; } const T &operator[](const Key &k) const { return _impl[k]; } @@ -369,10 +406,10 @@ class map : public c_compatible { // Compile-time sequence of indices (part of C++14) template -struct index_sequence {}; +struct index_sequence {}; // NOLINT(readability-identifier-naming) template -struct make_index_sequence_helper +struct make_index_sequence_helper // NOLINT(readability-identifier-naming) : public make_index_sequence_helper {}; template diff --git a/src/common/opdesc.hpp b/src/common/opdesc.hpp index 8067ae0ddb6..ea2ff5c9975 100644 --- a/src/common/opdesc.hpp +++ b/src/common/opdesc.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,30 +17,83 @@ #ifndef COMMON_OPDESC_HPP #define COMMON_OPDESC_HPP -#include - #include "common/c_types_map.hpp" -#include "common/gemm_types.hpp" -#include "common/sdpa_types.hpp" +#include "common/memory_desc.hpp" +#include "common/utils.hpp" + +#include +#include namespace dnnl { namespace impl { -struct reorder_desc_t { +#define DECLARE_COMMON_OP_DESC_CLONE(op_desc_kind_t) \ + std::unique_ptr clone() const override { \ + return utils::make_unique(*this); \ + } + +// A base class for all descriptors that allows to dispatch between them through +// a dedicated `kind` field. +struct op_desc_t { + virtual ~op_desc_t() = default; + + virtual std::unique_ptr clone() const = 0; + + // Converters to a inherited type. + template + static const T *to_desc(const op_desc_t *op_desc) { + static_assert(!std::is_pointer::value, + "T is not expected to be a pointer type."); + return utils::downcast(op_desc); + } + template + static T *to_desc(op_desc_t *op_desc) { + static_assert(!std::is_pointer::value, + "T is not expected to be a pointer type."); + return utils::downcast(op_desc); + } + + // The kind of primitive. Used for self-identifying the primitive desc. primitive_kind_t primitive_kind; - const memory_desc_t *src_md; - const memory_desc_t *dst_md; - engine_kind_t src_engine_kind; - engine_kind_t dst_engine_kind; - bool is_cross_engine; + +protected: + op_desc_t() : primitive_kind(primitive_kind::undefined) {} + op_desc_t(primitive_kind_t pk) : primitive_kind(pk) {} + op_desc_t(const op_desc_t &) = default; + op_desc_t &operator=(const op_desc_t &) = default; + op_desc_t(op_desc_t &&) = default; + op_desc_t &operator=(op_desc_t &&) = default; }; -struct concat_desc_t { +// A descriptor of a reorder operation. +struct reorder_desc_t : public op_desc_t { + reorder_desc_t() = default; + reorder_desc_t(primitive_kind_t primitive_kind, const memory_desc_t *src_md, + const memory_desc_t *dst_md, engine_kind_t src_engine_kind, + engine_kind_t dst_engine_kind, bool is_cross_engine) + : op_desc_t(primitive_kind) + , src_md(src_md) + , dst_md(dst_md) + , src_engine_kind(src_engine_kind) + , dst_engine_kind(dst_engine_kind) + , is_cross_engine(is_cross_engine) {} + + DECLARE_COMMON_OP_DESC_CLONE(reorder_desc_t); + + const memory_desc_t *src_md {}; + const memory_desc_t *dst_md {}; + engine_kind_t src_engine_kind {}; + engine_kind_t dst_engine_kind {}; + bool is_cross_engine {}; +}; + +// A descriptor of a concat operation. +struct concat_desc_t : public op_desc_t { concat_desc_t() = default; concat_desc_t(primitive_kind_t primitive_kind, const memory_desc_t *dst_md, dim_t n, dim_t concat_dimension, const memory_desc_t *const *src_mds) - : primitive_kind(primitive_kind) + : op_desc_t(primitive_kind) , dst_md(dst_md) , n(n) , concat_dimension(concat_dimension) { @@ -48,41 +101,49 @@ struct concat_desc_t { this->src_mds.push_back(src_mds[i]); } - primitive_kind_t primitive_kind; - const memory_desc_t *dst_md; - dim_t n; - dim_t concat_dimension; + DECLARE_COMMON_OP_DESC_CLONE(concat_desc_t); + + const memory_desc_t *dst_md {}; + dim_t n {}; + dim_t concat_dimension {}; std::vector src_mds; }; -struct sum_desc_t { +// A descriptor of a sum operation. +struct sum_desc_t : public op_desc_t { sum_desc_t() = default; sum_desc_t(primitive_kind_t primitive_kind, const memory_desc_t *dst_md, dim_t n, const float *scales, const memory_desc_t *const *src_mds) - : primitive_kind(primitive_kind), dst_md(dst_md), n(n), scales(scales) { + : op_desc_t(primitive_kind), dst_md(dst_md), n(n), scales(scales) { for (dim_t i = 0; i < n; i++) this->src_mds.push_back(src_mds[i]); } - primitive_kind_t primitive_kind; - const memory_desc_t *dst_md; - dim_t n; - const float *scales; + DECLARE_COMMON_OP_DESC_CLONE(sum_desc_t); + + const memory_desc_t *dst_md {}; + dim_t n {}; + const float *scales {}; std::vector src_mds; }; -struct zero_pad_desc_t { - primitive_kind_t primitive_kind; +// A descriptor of a zero padding operation. +struct zero_pad_desc_t : public op_desc_t { + zero_pad_desc_t() : op_desc_t(primitive_kind::zero_pad) {} + + DECLARE_COMMON_OP_DESC_CLONE(zero_pad_desc_t); }; -struct inner_product_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_inner_product. - primitive_kind_t primitive_kind; +// A descriptor of a inner product operation. +struct inner_product_desc_t : public op_desc_t { + inner_product_desc_t() : op_desc_t(primitive_kind::inner_product) {} + + DECLARE_COMMON_OP_DESC_CLONE(inner_product_desc_t); + // The kind of propagation. Possible values: forward_training, // forward_inference, backward_data, // backward_weights, and backward_bias. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Source gradient memory descriptor. @@ -100,20 +161,22 @@ struct inner_product_desc_t { // Destination gradient memory descriptor. memory_desc_t diff_dst_desc; // The accumulator data type. - data_type_t accum_data_type; + data_type_t accum_data_type {}; }; -struct convolution_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_convolution. - primitive_kind_t primitive_kind; +// A descriptor of a convolution operation. +struct convolution_desc_t : public op_desc_t { + convolution_desc_t() : op_desc_t(primitive_kind::convolution) {} + + DECLARE_COMMON_OP_DESC_CLONE(convolution_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward_data, // #dnnl_backward_weights, and #dnnl_backward_bias. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // The kind of the convolution algorithm. Possible values: // #dnnl_convolution_direct. - alg_kind_t alg_kind; + alg_kind_t alg_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Source gradient memory descriptor. @@ -131,51 +194,53 @@ struct convolution_desc_t { // Destination gradient memory descriptor. memory_desc_t diff_dst_desc; // Convolution strides in each spatial dimension. - dims_t strides; + dims_t strides {}; // Convolution dilates in each spatial dimension. - dims_t dilates; + dims_t dilates {}; // Padding in each spatial dimension. padding[0] is a padding in the // beginning (@p padding_l), padding[1] is a padding in the end (@p // padding_r). - dims_t padding[2]; + dims_t padding[2] {}; // The accumulator data type. Initialized automatically. - data_type_t accum_data_type; + data_type_t accum_data_type {}; // For internal use only. To mark conv is used for deconv. - bool use_inversion; + bool use_inversion {}; }; // A descriptor of a deconvolution operation. using deconvolution_desc_t = convolution_desc_t; // A descriptor of a shuffle operation. -struct shuffle_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_shuffle. - primitive_kind_t primitive_kind; +struct shuffle_desc_t : public op_desc_t { + shuffle_desc_t() : op_desc_t(primitive_kind::shuffle) {} + + DECLARE_COMMON_OP_DESC_CLONE(shuffle_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, and #dnnl_backward_data. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // Source or source gradient memory descriptor. memory_desc_t src_desc; // Destination or destination gradient memory descriptor. memory_desc_t dst_desc; // Axis for shuffling. - int axis; + int axis {}; // Number of groups. - dim_t group_size; + dim_t group_size {}; }; // A descriptor of resampling operation. -struct resampling_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_resampling. - primitive_kind_t primitive_kind; +struct resampling_desc_t : public op_desc_t { + resampling_desc_t() : op_desc_t(primitive_kind::resampling) {} + + DECLARE_COMMON_OP_DESC_CLONE(resampling_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward_data, - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // The kind of the resampling algorithm. Possible values: // #dnnl_resampling_nearest, #dnnl_resampling_linear. - alg_kind_t alg_kind; + alg_kind_t alg_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Source gradient memory descriptor. @@ -185,7 +250,7 @@ struct resampling_desc_t { // Destination gradient memory descriptor. memory_desc_t diff_dst_desc; // Resampling factor in each spatial dimension. - float factors[DNNL_MAX_NDIMS]; + float factors[DNNL_MAX_NDIMS] {}; }; // A descriptor of a matrix multiplication operation. @@ -195,10 +260,11 @@ struct resampling_desc_t { // // 3D case: // dst[mb, m, n] = src[mb, m, k] * weights[mb, k, n] + bias[mb, m, n] -struct matmul_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_matmul. - primitive_kind_t primitive_kind; +struct matmul_desc_t : public op_desc_t { + matmul_desc_t() : op_desc_t(primitive_kind::matmul) {} + + DECLARE_COMMON_OP_DESC_CLONE(matmul_desc_t); + // Source memory descriptor. memory_desc_t src_desc; // Weights memory descriptor. @@ -207,18 +273,23 @@ struct matmul_desc_t { memory_desc_t bias_desc; // Destination memory descriptor. memory_desc_t dst_desc; + // Reduce memory descriptor; + memory_desc_t reduce_desc; + // Reduce kind. + matmul_reduce_kind_t reduce_kind {}; // The accumulator data type. Initialized automatically. - data_type_t accum_data_type; + data_type_t accum_data_type {}; }; // A descriptor of a element-wise operation. -struct eltwise_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_eltwise. - primitive_kind_t primitive_kind; +struct eltwise_desc_t : public op_desc_t { + eltwise_desc_t() : op_desc_t(primitive_kind::eltwise) {} + + DECLARE_COMMON_OP_DESC_CLONE(eltwise_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // The kind of eltwise algorithm. Possible values: #dnnl_eltwise_relu, // #dnnl_eltwise_tanh, #dnnl_eltwise_elu, #dnnl_eltwise_square, // #dnnl_eltwise_abs, #dnnl_eltwise_sqrt, #dnnl_eltwise_linear, @@ -233,7 +304,7 @@ struct eltwise_desc_t { // #dnnl_eltwise_logistic_use_dst_for_bwd, // #dnnl_eltwise_exp_use_dst_for_bwd, // #dnnl_eltwise_clip_v2_use_dst_for_bwd. - alg_kind_t alg_kind; + alg_kind_t alg_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Destination memory descriptor. @@ -265,17 +336,20 @@ struct eltwise_desc_t { // - #dnnl_eltwise_mish: @p alpha and @p beta ignored // - #dnnl_eltwise_hardswish: @p alpha and @p beta ignored // - #dnnl_eltwise_hardsigmoid: @p alpha -- scale, @p beta -- shift - float alpha, beta; + float alpha {}; + float beta {}; }; // A descriptor of a Batch Normalization operation. -struct batch_normalization_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_batch_normalization. - primitive_kind_t primitive_kind; +struct batch_normalization_desc_t : public op_desc_t { + batch_normalization_desc_t() + : op_desc_t(primitive_kind::batch_normalization) {} + + DECLARE_COMMON_OP_DESC_CLONE(batch_normalization_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Destination memory descriptor. @@ -293,18 +367,20 @@ struct batch_normalization_desc_t { // Statistics (mean or variance) descriptor use 1D #dnnl_x format[Channels]. memory_desc_t stat_desc; // Batch normalization epsilon parameter. - float batch_norm_epsilon; - unsigned flags; + float batch_norm_epsilon {}; + unsigned flags {}; }; // A descriptor of a Group Normalization operation. -struct group_normalization_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_group_normalization. - primitive_kind_t primitive_kind; +struct group_normalization_desc_t : public op_desc_t { + group_normalization_desc_t() + : op_desc_t(primitive_kind::group_normalization) {} + + DECLARE_COMMON_OP_DESC_CLONE(group_normalization_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Source gradient memory descriptor. @@ -318,10 +394,10 @@ struct group_normalization_desc_t { // format[Batch, groups]. memory_desc_t stat_desc; // Group normalization groups parameter. - dim_t groups; + dim_t groups {}; // Group normalization epsilon parameter. - float group_norm_epsilon; - unsigned flags; + float group_norm_epsilon {}; + unsigned flags {}; // Destination memory descriptor. memory_desc_t dst_desc; // Destination gradient memory descriptor. @@ -329,13 +405,15 @@ struct group_normalization_desc_t { }; // A descriptor of a Layer Normalization operation. -struct layer_normalization_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_layer_normalization. - primitive_kind_t primitive_kind; +struct layer_normalization_desc_t : public op_desc_t { + layer_normalization_desc_t() + : op_desc_t(primitive_kind::layer_normalization) {} + + DECLARE_COMMON_OP_DESC_CLONE(layer_normalization_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Source gradient memory descriptor. @@ -353,8 +431,8 @@ struct layer_normalization_desc_t { // (stride[last_dim] == 1) user-provided format. memory_desc_t stat_desc; // Layer normalization epsilon parameter. - float layer_norm_epsilon; - unsigned flags; + float layer_norm_epsilon {}; + unsigned flags {}; // Destination memory descriptor. memory_desc_t dst_desc; // Destination gradient memory descriptor. @@ -362,16 +440,17 @@ struct layer_normalization_desc_t { }; // A descriptor of a Local Response Normalization (LRN) operation. -struct lrn_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_lrn. - primitive_kind_t primitive_kind; +struct lrn_desc_t : public op_desc_t { + lrn_desc_t() : op_desc_t(primitive_kind::lrn) {} + + DECLARE_COMMON_OP_DESC_CLONE(lrn_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // LRN algorithm. Possible values: #dnnl_lrn_within_channel and // #dnnl_lrn_across_channels. - alg_kind_t alg_kind; + alg_kind_t alg_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Destination memory descriptor. @@ -382,26 +461,27 @@ struct lrn_desc_t { memory_desc_t diff_dst_desc; // The number of channels to sum over (for cross-channel LRN) or the side // length of the square region to sum over (for within-channel LRN). - dim_t local_size; + dim_t local_size {}; // LRN alpha parameter. - float lrn_alpha; + float lrn_alpha {}; // LRN beta parameter. - float lrn_beta; + float lrn_beta {}; // LRN k parameter. - float lrn_k; + float lrn_k {}; }; // A descriptor of reduction operation. -struct reduction_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_reduction. - primitive_kind_t primitive_kind; +struct reduction_desc_t : public op_desc_t { + reduction_desc_t() : op_desc_t(primitive_kind::reduction) {} + + DECLARE_COMMON_OP_DESC_CLONE(reduction_desc_t); + // The kind of reduction algorithm. Possible values: // #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum, // #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max, // #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max, // #dnnl_reduction_norm_lp_power_p_sum. - alg_kind_t alg_kind; + alg_kind_t alg_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Destination memory descriptor. @@ -417,26 +497,28 @@ struct reduction_desc_t { // #dnnl_reduction_sum: @p p and @p eps are ignored // #dnnl_reduction_mul: @p p and @p eps are ignored // #dnnl_reduction_mean: @p p and @p eps are ignored - float p, eps; + float p {}; + float eps {}; }; /// A descriptor of a Softmax operation. -struct softmax_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_softmax. - primitive_kind_t primitive_kind; +struct softmax_desc_t : public op_desc_t { + softmax_desc_t() : op_desc_t(primitive_kind::softmax) {} + + DECLARE_COMMON_OP_DESC_CLONE(softmax_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, and #dnnl_backward_data. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Source gradient memory descriptor. memory_desc_t diff_src_desc; // The axis along which to perform the softmax. - int softmax_axis; + int softmax_axis {}; // Softmax algorithm. Possible values: #dnnl_softmax_accurate and // #dnnl_softmax_log. - alg_kind_t alg_kind; + alg_kind_t alg_kind {}; // Destination memory descriptor. memory_desc_t dst_desc; // Destination gradient memory descriptor. @@ -444,28 +526,32 @@ struct softmax_desc_t { }; // A descriptor of a binary operation. -struct binary_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_binary. - primitive_kind_t primitive_kind; +struct binary_desc_t : public op_desc_t { + binary_desc_t() : op_desc_t(primitive_kind::binary) {} + + DECLARE_COMMON_OP_DESC_CLONE(binary_desc_t); + // The kind of the binary algorithm. Possible values: // #dnnl_binary_add, #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, - // #dnnl_binary_div and #dnnl_binary_sub. - alg_kind_t alg_kind; + // #dnnl_binary_div, #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, + // #dnnl_binary_le, #dnnl_binary_lt, #dnnl_binary_eq, #dnnl_binary_ne, + // and #dnnl_binary_select + alg_kind_t alg_kind {}; // Source memory descriptors. - memory_desc_t src_desc[2]; + memory_desc_t src_desc[3] {}; // Destination memory descriptor. memory_desc_t dst_desc; }; /// A descriptor of a PReLU operation. -struct prelu_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_prelu. - primitive_kind_t primitive_kind; +struct prelu_desc_t : public op_desc_t { + prelu_desc_t() : op_desc_t(primitive_kind::prelu) {} + + DECLARE_COMMON_OP_DESC_CLONE(prelu_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Learnable parameter alpha memory descriptor. @@ -482,18 +568,19 @@ struct prelu_desc_t { }; // A descriptor of a pooling operation. -struct pooling_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_pooling. - primitive_kind_t primitive_kind; +struct pooling_desc_t : public op_desc_t { + pooling_desc_t() : op_desc_t(primitive_kind::pooling) {} + + DECLARE_COMMON_OP_DESC_CLONE(pooling_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, #dnnl_backward, and #dnnl_backward_data. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // The kind of pooling algorithm. // Possible values: #dnnl_pooling_max, // #dnnl_pooling_avg_include_padding, and // #dnnl_pooling_avg_exclude_padding. - alg_kind_t alg_kind; + alg_kind_t alg_kind {}; // Source memory descriptor. memory_desc_t src_desc; // Source gradient memory descriptor. @@ -503,32 +590,33 @@ struct pooling_desc_t { // Destination gradient memory descriptor. memory_desc_t diff_dst_desc; // Pooling kernel strides for spatial dimensions. - dims_t strides; + dims_t strides {}; // Pooling kernel spatial dimensions. - dims_t kernel; + dims_t kernel {}; // Padding in each spatial dimension. padding[0] is a padding in the // beginning (@p padding_l), padding[1] is a padding in the end (@p // padding_r). - dims_t padding[2]; + dims_t padding[2] {}; // The accumulator data type. Initialized automatically. - data_type_t accum_data_type; + data_type_t accum_data_type {}; // Pooling dilations for spatial dimensions. - dims_t dilation; + dims_t dilation {}; }; // A descriptor for an RNN operation. -struct rnn_desc_t { - // The kind of primitive. Used for self-identifying the primitive - // descriptor. Must be #dnnl_rnn. - dnnl_primitive_kind_t primitive_kind; +struct rnn_desc_t : public op_desc_t { + rnn_desc_t() : op_desc_t(primitive_kind::rnn) {} + + DECLARE_COMMON_OP_DESC_CLONE(rnn_desc_t); + // The kind of propagation. Possible values: #dnnl_forward_training, // #dnnl_forward_inference, and #dnnl_backward. - prop_kind_t prop_kind; + prop_kind_t prop_kind {}; // RNN cell kind. Must be one of #dnnl_vanilla_rnn, // #dnnl_vanilla_lstm, #dnnl_vanilla_gru, or #dnnl_lbr_gru. - alg_kind_t cell_kind; + alg_kind_t cell_kind {}; // The direction of RNN primitive execution. - rnn_direction_t direction; + rnn_direction_t direction {}; // Source layer memory descriptor. memory_desc_t src_layer_desc; // Source iteration memory descriptor for hidden state. @@ -584,82 +672,15 @@ struct rnn_desc_t { memory_desc_t diff_weights_projection_desc; // RNN cell flags - unsigned int flags; + unsigned flags {}; // Activation function used for vanilla_rnn cell kind. // Must be either #dnnl_eltwise_relu or #dnnl_eltwise_tanh. - alg_kind_t activation_kind; - float alpha; - float beta; + alg_kind_t activation_kind {}; + float alpha {}; + float beta {}; }; -struct op_desc_t { - union { - primitive_kind_t kind; - convolution_desc_t convolution; - deconvolution_desc_t deconvolution; - shuffle_desc_t shuffle; - pooling_desc_t pooling; - prelu_desc_t prelu; - eltwise_desc_t eltwise; - softmax_desc_t softmax; - lrn_desc_t lrn; - batch_normalization_desc_t batch_normalization; - group_normalization_desc_t group_normalization; - layer_normalization_desc_t layer_normalization; - inner_product_desc_t inner_product; - rnn_desc_t rnn; - gemm_desc_t gemm; - concat_desc_t concat; - reorder_desc_t reorder; - sum_desc_t sum; - binary_desc_t binary; - matmul_desc_t matmul; - resampling_desc_t resampling; - zero_pad_desc_t zero_pad; - reduction_desc_t reduction; - sdpa_desc_t sdpa; - }; - -#define DECL_CTOR_AND_CONVERTERS(c_type) \ - op_desc_t(const c_type &) = delete; \ - static op_desc_t *convert_from_c(c_type *_) { \ - return reinterpret_cast(_); \ - } \ - static const op_desc_t *convert_from_c(const c_type *_) { \ - return reinterpret_cast(_); \ - } - - DECL_CTOR_AND_CONVERTERS(convolution_desc_t); - DECL_CTOR_AND_CONVERTERS(shuffle_desc_t); - DECL_CTOR_AND_CONVERTERS(pooling_desc_t); - DECL_CTOR_AND_CONVERTERS(prelu_desc_t); - DECL_CTOR_AND_CONVERTERS(eltwise_desc_t); - DECL_CTOR_AND_CONVERTERS(softmax_desc_t); - DECL_CTOR_AND_CONVERTERS(lrn_desc_t); - DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t); - DECL_CTOR_AND_CONVERTERS(group_normalization_desc_t); - DECL_CTOR_AND_CONVERTERS(layer_normalization_desc_t); - DECL_CTOR_AND_CONVERTERS(inner_product_desc_t); - DECL_CTOR_AND_CONVERTERS(rnn_desc_t); - DECL_CTOR_AND_CONVERTERS(gemm_desc_t); - DECL_CTOR_AND_CONVERTERS(concat_desc_t); - DECL_CTOR_AND_CONVERTERS(reorder_desc_t); - DECL_CTOR_AND_CONVERTERS(sum_desc_t); - DECL_CTOR_AND_CONVERTERS(binary_desc_t); - DECL_CTOR_AND_CONVERTERS(matmul_desc_t); - DECL_CTOR_AND_CONVERTERS(resampling_desc_t); - DECL_CTOR_AND_CONVERTERS(zero_pad_desc_t); - DECL_CTOR_AND_CONVERTERS(reduction_desc_t); - DECL_CTOR_AND_CONVERTERS(sdpa_desc_t); - - // concat_desc_t and sum_desc_t have data members which have non-trivial - // special member functions hence the default destructor is implicitly - // deleted by the compiler which causes a warning on Windows so we should - // delete the destructor explicitly. - ~op_desc_t() = delete; - -#undef DECL_CTOR_AND_CONVERTERS -}; +#undef DECLARE_COMMON_OP_DESC_CLONE } // namespace impl } // namespace dnnl diff --git a/src/common/optional.hpp b/src/common/optional.hpp index 83eac9eb70c..93388b3b80c 100644 --- a/src/common/optional.hpp +++ b/src/common/optional.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ class optional_t { optional_t(const nullopt_t nullopt_) : has_value_(false), dummy {} {} optional_t() : optional_t(nullopt) {} - optional_t(T object) : has_value_(true), value_(object) {} + optional_t(const T &object) : has_value_(true), value_(object) {} optional_t(const optional_t &other) : has_value_(other.has_value_), dummy {} { if (has_value_) new (std::addressof(value_)) T(other.value_); diff --git a/src/common/pooling.cpp b/src/common/pooling.cpp index c20685bc6b3..a1fd610fe68 100644 --- a/src/common/pooling.cpp +++ b/src/common/pooling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -117,13 +117,19 @@ status_t pooling_desc_init(pooling_desc_t *pool_desc, prop_kind_t prop_kind, VCHECK_POOLING((src - ker_range + pad_l + pad_r) / str + 1 == dst, VERBOSE_INCONSISTENT_PRB) + // [fork] Initially this check was also commented and padding handling + // was also corrected for nchw_pooling. + // after rebase to onednn v2.7 the changes in nchw_pooling led to the + // test fails because of accuracy. + // With the commented check and without any updates in nchw_pooling no issues found + // It's not allowed for pooling window to be totally placed outside // of real source domain for pooling_avg_exclude_padding algorithm // due to 0 / 0 ambiguity - VCHECK_POOLING( - IMPLICATION(alg_kind == pooling_avg_exclude_padding, - (pad_l < ker_range && pad_r < ker_range && dil < src)), - VERBOSE_INCONSISTENT_PRB); + // VCHECK_POOLING( + // IMPLICATION(alg_kind == pooling_avg_exclude_padding, + // (pad_l < ker_range && pad_r < ker_range && dil < src)), + // VERBOSE_INCONSISTENT_PRB); } *pool_desc = pd; @@ -151,8 +157,11 @@ status_t pooling_attr_check(const pooling_desc_t &desc, const engine_t *engine, if (!attr->post_ops_.has_default_values()) { const auto &po = attr->post_ops_; using namespace primitive_kind; - VCHECK_POOLING_IMPL(po.has_default_values({binary, eltwise}), + VCHECK_POOLING_IMPL(po.has_default_values({binary, eltwise, quantization}), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } } else { VCHECK_POOLING_IMPL(false, VERBOSE_UNSUPPORTED_ATTR); diff --git a/src/common/pooling_pd.hpp b/src/common/pooling_pd.hpp index 0690497ac38..62a7cdec05b 100644 --- a/src/common/pooling_pd.hpp +++ b/src/common/pooling_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -146,12 +146,11 @@ struct pooling_pd_t : public primitive_desc_t { memory_desc_t ws_md_; - pooling_pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr, + pooling_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const pooling_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - , ws_md_() {} + , desc_(*op_desc_t::to_desc(adesc)) + , hint_fwd_pd_(hint_fwd_pd) {} void init_default_ws(data_type_t dt = data_type::undef) { ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md(); @@ -161,7 +160,7 @@ struct pooling_pd_t : public primitive_desc_t { data_type_t indices_data_type() const { /* the simplest way to express 256... */ const int u8_max = nstl::numeric_limits< - typename prec_traits::type>::max(); + typename prec_traits_t::type>::max(); return utils::array_product(desc()->kernel, spatial_ndims()) <= u8_max ? data_type::u8 : data_type::s32; @@ -176,17 +175,19 @@ struct pooling_pd_t : public primitive_desc_t { } }; +// NOLINTBEGIN(google-default-arguments) struct pooling_fwd_pd_t : public pooling_pd_t { - typedef pooling_fwd_pd_t base_class; - typedef pooling_fwd_pd_t hint_class; + using base_class = pooling_fwd_pd_t; + using hint_class = pooling_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; if (arg == DNNL_ARG_DST) return arg_usage_t::output; - if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md()))) - return arg_usage_t::output; + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::output + : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -215,7 +216,7 @@ struct pooling_fwd_pd_t : public pooling_pd_t { : &glob_zero_md; } - int n_inputs() const override { return 1 + n_binary_po_inputs(); } + int n_inputs() const override { return 1 + n_binary_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs(); } int n_outputs() const override { return 1 + (!types::is_zero_md(workspace_md())); } @@ -229,7 +230,7 @@ struct pooling_fwd_pd_t : public pooling_pd_t { memory_desc_t src_md_; memory_desc_t dst_md_; - pooling_fwd_pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr, + pooling_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const pooling_fwd_pd_t *hint_fwd_pd) : pooling_pd_t(adesc, attr, hint_fwd_pd) , src_md_(desc_.src_desc) @@ -245,18 +246,21 @@ struct pooling_fwd_pd_t : public pooling_pd_t { dst_md_, src_md_.format_desc.blocking); } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct pooling_bwd_pd_t : public pooling_pd_t { - typedef pooling_bwd_pd_t base_class; - typedef pooling_fwd_pd_t hint_class; + using base_class = pooling_bwd_pd_t; + using hint_class = pooling_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_DIFF_DST) return arg_usage_t::input; if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; - if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md()))) - return arg_usage_t::input; + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::input + : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -302,7 +306,7 @@ struct pooling_bwd_pd_t : public pooling_pd_t { memory_desc_t diff_src_md_; memory_desc_t diff_dst_md_; - pooling_bwd_pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr, + pooling_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const pooling_fwd_pd_t *hint_fwd_pd) : pooling_pd_t(adesc, attr, hint_fwd_pd) , diff_src_md_(desc_.diff_src_desc) @@ -338,6 +342,7 @@ struct pooling_bwd_pd_t : public pooling_pd_t { private: std::vector hint_mds_; }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/prelu_pd.hpp b/src/common/prelu_pd.hpp index 2b3d96ff9d5..de5305d5b4f 100644 --- a/src/common/prelu_pd.hpp +++ b/src/common/prelu_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,18 +78,19 @@ struct prelu_pd_t : public primitive_desc_t { memory_desc_t src_md_; memory_desc_t weights_md_; - prelu_pd_t(const prelu_desc_t *adesc, const primitive_attr_t *attr, + prelu_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const prelu_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) , src_md_(desc_.src_desc) , weights_md_(desc_.weights_desc) {} }; +// NOLINTBEGIN(google-default-arguments) struct prelu_fwd_pd_t : public prelu_pd_t { - typedef prelu_fwd_pd_t base_class; - typedef prelu_fwd_pd_t hint_class; + using base_class = prelu_fwd_pd_t; + using hint_class = prelu_fwd_pd_t; primitive_desc_t::arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; @@ -133,7 +134,7 @@ struct prelu_fwd_pd_t : public prelu_pd_t { protected: memory_desc_t dst_md_; - prelu_fwd_pd_t(const prelu_desc_t *adesc, const primitive_attr_t *attr, + prelu_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const prelu_fwd_pd_t *hint_fwd_pd) : prelu_pd_t(adesc, attr, hint_fwd_pd), dst_md_(desc_.dst_desc) {} @@ -148,10 +149,12 @@ struct prelu_fwd_pd_t : public prelu_pd_t { == status::success); } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct prelu_bwd_pd_t : public prelu_pd_t { - typedef prelu_bwd_pd_t base_class; - typedef prelu_fwd_pd_t hint_class; + using base_class = prelu_bwd_pd_t; + using hint_class = prelu_fwd_pd_t; primitive_desc_t::arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; @@ -216,7 +219,7 @@ struct prelu_bwd_pd_t : public prelu_pd_t { memory_desc_t diff_weights_md_; memory_desc_t diff_dst_md_; - prelu_bwd_pd_t(const prelu_desc_t *adesc, const primitive_attr_t *attr, + prelu_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const prelu_fwd_pd_t *hint_fwd_pd) : prelu_pd_t(adesc, attr, hint_fwd_pd) , diff_src_md_(desc_.diff_src_desc) @@ -242,6 +245,7 @@ struct prelu_bwd_pd_t : public prelu_pd_t { == status::success); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/primitive.hpp b/src/common/primitive.hpp index ba217b521d7..1cc39c86eeb 100644 --- a/src/common/primitive.hpp +++ b/src/common/primitive.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -150,25 +150,25 @@ struct nested_scratchpad_t { } // namespace impl } // namespace dnnl -#define ARG_TYPE(t) \ - typename std::remove_cv::type>::type +#define ARG_PTR_TYPE(t) \ + typename std::remove_cv::type>::type * // Returns destination memory which has been zero pad initialized. This macro // may result in a failure returned via the `status` input since zero pad // may fail. #define CTX_OUT_CLEAN_MEM(type, arg, status) \ - static_cast(ctx.host_ptr(arg, true, &status)) + static_cast(ctx.host_ptr(arg, true, &(status))) // Returns destination memory which may not have been zero pad initialized. #define CTX_OUT_MEM_COMMON(type, arg, index) \ - static_cast(ctx.host_ptr(arg, false, nullptr, index)) + static_cast(ctx.host_ptr(arg, false, nullptr, index)) #define CTX_OUT_MEm(type, arg) CTX_OUT_MEM_COMMON(type, arg, 0) #define CTX_OUT_MEm0(type, arg) CTX_OUT_MEM_COMMON(type, arg, 0) #define CTX_OUT_MEm1(type, arg) CTX_OUT_MEM_COMMON(type, arg, 1) #define CTX_OUT_MEm2(type, arg) CTX_OUT_MEM_COMMON(type, arg, 2) #define CTX_IN_MEM_COMMON(type, arg, index) \ - static_cast( \ + static_cast( \ ctx.host_ptr(arg, false, nullptr, index)) #define CTX_IN_MEm(type, arg) CTX_IN_MEM_COMMON(type, arg, 0) #define CTX_IN_MEm0(type, arg) CTX_IN_MEM_COMMON(type, arg, 0) diff --git a/src/common/primitive_attr.cpp b/src/common/primitive_attr.cpp index 09007dd968a..d088682adca 100644 --- a/src/common/primitive_attr.cpp +++ b/src/common/primitive_attr.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,12 +35,7 @@ const primitive_attr_t &default_attr() { return default_attr_instance; } -const runtime_scales_t &default_runtime_scale() { - static const runtime_scales_t default_runtime_scale_instance; - return default_runtime_scale_instance; -} - -void scales_t::set_single_scale(float scale) { +void rnn_create_time_scales_t::set_single_scale(float scale) { count_ = 1; mask_ = 0; scales_ = scales_buf_; @@ -51,7 +46,8 @@ void scales_t::set_single_scale(float scale) { } } -status_t scales_t::set(dim_t count, int mask, const float *scales) { +status_t rnn_create_time_scales_t::set( + dim_t count, int mask, const float *scales) { cleanup(); count_ = count; @@ -73,39 +69,25 @@ status_t scales_t::set(dim_t count, int mask, const float *scales) { return status::success; } -status_t zero_points_t::get(int arg, int *mask, data_type_t *dt) const { - if (mask) *mask = get_mask(arg); - if (dt) *dt = get_data_type(arg); - return status::success; -} +template +status_t shifts_t::set(int count, int mask, const T *shifts) { + cleanup(); -int zero_points_t::get(int arg) const { - return get_mask(arg); -} - -status_t zero_points_t::set(int arg, int mask, int ndims, const dims_t groups, - data_type_t data_type) { - const bool supported_arg - = utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST); - if (!supported_arg) return status::unimplemented; - - switch (arg) { - case DNNL_ARG_SRC: - is_set_src = true; - mask_src = mask; - break; - case DNNL_ARG_WEIGHTS: - is_set_wei = true; - mask_wei = mask; - data_type_wei = data_type; - group_ndims_wei = ndims; - utils::array_copy(group_dims_wei, groups, group_ndims_wei); - break; - case DNNL_ARG_DST: - is_set_dst = true; - mask_dst = mask; - break; + count_ = count; + mask_ = mask; + + if (count_ == 1) { + shifts_ = shifts_buf_; + utils::array_set(shifts_, shifts[0], shifts_buf_size); + } else { + shifts_ = (T *)impl::malloc(count_ * sizeof(*shifts_), 64); + if (shifts_ == nullptr) + return status::out_of_memory; + + for (int c = 0; c < count_; ++c) + shifts_[c] = shifts[c]; } + return status::success; } @@ -128,37 +110,31 @@ status_t dropout_t::set_default_formats(const memory_desc_t *dst_md) { bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask, dnnl::impl::data_type_t dst_dt) const { using smask_t = skip_mask_t; - // prepare mask for runtime-parameters check - smask_t defined_mask = smask_t::none; - if ((mask & smask_t::oscale_runtime) == smask_t::oscale_runtime) - defined_mask |= smask_t::oscale; - if ((mask & smask_t::scales_runtime) == smask_t::scales_runtime) - defined_mask |= smask_t::scales; - if ((mask & smask_t::zero_points_runtime) == smask_t::zero_points_runtime) - defined_mask |= smask_t::zero_points; bool ok = true; #define CHECK_ARG(x) ok = ok && (x) #define CHECK_MASK(mask_name, mask_field) \ CHECK_ARG(IMPLICATION( \ (bool)(~mask & (mask_name)), (mask_field).has_default_values())) - CHECK_MASK(smask_t::oscale_runtime, output_scales_); CHECK_MASK(smask_t::scales, scales_); - CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::scales_runtime_groups), + CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::scales_groups), scales_.has_default_groups())); - CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::scales_runtime_data_type), + CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::scales_data_type), scales_.has_default_data_type())); CHECK_MASK(smask_t::zero_points, zero_points_); - CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::zero_points_runtime_groups), + CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::zero_points_groups), zero_points_.has_default_groups())); - CHECK_ARG( - IMPLICATION((bool)(~mask & smask_t::zero_points_runtime_data_type), - zero_points_.has_default_data_type())); + CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::zero_points_data_type), + zero_points_.has_default_data_type())); + CHECK_MASK(smask_t::input_zero_points, input_zero_points_); + CHECK_MASK(smask_t::weights_zero_points, weights_zero_points_); + CHECK_MASK(smask_t::output_compensations, output_compensations_); CHECK_MASK(smask_t::post_ops, post_ops_); CHECK_MASK(smask_t::rnn_data_qparams, rnn_data_qparams_); CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_); CHECK_MASK(smask_t::rnn_weights_projection_qparams, rnn_weights_projection_qparams_); + CHECK_MASK(smask_t::src_dyn_quant_params, src_dyn_quant_params_); CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::sum_dt), post_ops_.sum_with_default_dt(dst_dt))); bool gpu_attr_ok = IMPLICATION((bool)(~mask & smask_t::gpu_attr), @@ -172,7 +148,7 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask, (bool)(~mask & smask_t::dropout), dropout_.has_default_values())); CHECK_ARG(IMPLICATION((bool)(~mask & smask_t::rounding_mode), rounding_mode_.has_default_values())); - CHECK_ARG(this->defined(defined_mask)); + CHECK_ARG(this->defined(smask_t::none)); bool fpmath_mode_ok = IMPLICATION( (bool)(~mask & smask_t::fpmath_mode) && fpmath_.apply_to_int_, fpmath_.mode_ == fpmath_mode::strict); @@ -188,14 +164,11 @@ bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const { #define CHECK_ARG(x) ok = ok && (x) #define CHECK_MASK(mask_name, mask_field) \ CHECK_ARG(IMPLICATION((bool)(~mask & (mask_name)), (mask_field).defined())) - CHECK_MASK(smask_t::oscale, output_scales_); - CHECK_MASK(smask_t::scales, scales_); - CHECK_MASK(smask_t::zero_points, zero_points_); - CHECK_MASK(smask_t::post_ops, post_ops_); CHECK_MASK(smask_t::rnn_data_qparams, rnn_data_qparams_); CHECK_MASK(smask_t::rnn_weights_qparams, rnn_weights_qparams_); CHECK_MASK(smask_t::rnn_weights_projection_qparams, rnn_weights_projection_qparams_); + CHECK_MASK(smask_t::src_dyn_quant_params, src_dyn_quant_params_); return ok; #undef CHECK_MASK #undef CHECK_ARG @@ -203,6 +176,8 @@ bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const { status_t post_ops_t::append_sum( float scale, int32_t zero_point, data_type_t dt) { + if (is_runtime_value(scale)) return invalid_arguments; + entry_.emplace_back(); auto &e = entry_.back(); e.kind = primitive_kind::sum; @@ -216,6 +191,9 @@ status_t post_ops_t::append_eltwise( float scale, alg_kind_t alg, float alpha, float beta) { if (!math::is_eltwise_ok(data_type::f32, alg, alpha, beta)) return invalid_arguments; + if (is_runtime_value(scale)) return invalid_arguments; + if (is_runtime_value(alpha)) return invalid_arguments; + if (is_runtime_value(beta)) return invalid_arguments; entry_.emplace_back(); auto &e = entry_.back(); @@ -262,7 +240,7 @@ status_t post_ops_t::validate_binary( using namespace alg_kind; bool alg_ok = one_of(alg, binary_add, binary_mul, binary_max, binary_min, binary_div, binary_sub, binary_ge, binary_gt, binary_le, binary_lt, - binary_eq, binary_ne); + binary_eq, binary_ne, binary_prelu); if (!alg_ok) return invalid_arguments; if (!memory_desc_sanity_check(*user_src1_desc)) return invalid_arguments; @@ -313,25 +291,77 @@ status_t post_ops_t::append_prelu(int mask) { return success; } -bool post_ops_t::defined() const { - for (int idx = 0; idx < len(); ++idx) { - auto kind = entry_[idx].kind; - if (kind == primitive_kind::sum) { - if (is_runtime_value(entry_[idx].sum.scale)) return false; - } else if (kind == primitive_kind::eltwise) { - const auto &e = entry_[idx].eltwise; - if (is_runtime_value(e.scale) || is_runtime_value(e.alpha) - || is_runtime_value(e.beta)) - return false; - } else if (utils::one_of(kind, primitive_kind::binary, - primitive_kind::prelu, - primitive_kind::convolution)) { - // binary is always defined - } else { - assert(!"unreachable"); - } - } - return true; +status_t post_ops_t::append_depthwise(alg_kind_t alg, size_t offset_size, const size_t* offset) { + using namespace dnnl::impl::alg_kind; + if (len() == post_ops_limit) return out_of_memory; + bool known_alg = one_of(alg, depthwise_scale_shift, depthwise_prelu); + if (!known_alg) + return invalid_arguments; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::depthwise; + e.depthwise.alg = alg; + array_copy(e.depthwise.offset, offset, offset_size); + + return success; +} + +status_t post_ops_t::append_quantization(alg_kind_t alg, + size_t per_channel_size, const bool* per_channel, + size_t all_default_size, const bool* all_default, + size_t offset_size, const size_t* offset) { + using namespace dnnl::impl::alg_kind; + if (len() == post_ops_limit) return out_of_memory; + bool known_alg = one_of(alg, quantization_quantize_dequantize, quantization_quantize); + if (!known_alg) + return invalid_arguments; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::quantization; + e.quantization.alg = alg; + + array_copy(e.quantization.per_channel, per_channel, per_channel_size); + array_copy(e.quantization.all_default, all_default, all_default_size); + array_copy(e.quantization.offset, offset, offset_size); + + return success; +} + +status_t post_ops_t::append_binarization(alg_kind_t alg, const float* weights_data, const float* output_mask_data) { + using namespace dnnl::impl::alg_kind; + if (len() == post_ops_limit) return out_of_memory; + bool known_alg = one_of(alg, binarization_depthwise); + if (!known_alg) + return invalid_arguments; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::binarization; + e.binarization.alg = alg; + e.binarization.weights_data = weights_data; + e.binarization.output_mask_data = output_mask_data; + + return success; +} + +status_t post_ops_t::append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, + dnnl::impl::data_type_t in_dt) { + if (len() == post_ops_limit) return out_of_memory; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::convolution; + e.depthwise_conv_old.in_h = in_h; + e.depthwise_conv_old.in_w = in_w; + e.depthwise_conv_old.ker_h = ker_h; + e.depthwise_conv_old.ker_w = ker_w; + e.depthwise_conv_old.str_h = str_h; + e.depthwise_conv_old.str_w = str_w; + e.depthwise_conv_old.in_dt = in_dt; + + return success; } status_t post_ops_t::set_default_formats(const memory_desc_t *dst_md) { @@ -398,6 +428,26 @@ bool post_ops_t::check_sum_consistency(const data_type_t dst_dt, && check_sum_consistent_quantization(dst_dt, is_int8); } +status_t post_ops_t::entry_t::validate_binary_with_dst_consistency( + const memory_desc_t *dst_md) const { + if (!is_binary()) return status::success; + + VCHECK_ATTR(dst_md->ndims == binary.user_src1_desc.ndims, + VERBOSE_INCONSISTENT_NDIMS_WITH_VALS, "dst", "bin_po", + dst_md->ndims, binary.user_src1_desc.ndims); + + return status::success; +} + +status_t post_ops_t::validate_binary_with_dst_consistency( + const memory_desc_t *dst_md) const { + for (const auto &e : entry_) { + CHECK(e.validate_binary_with_dst_consistency(dst_md)); + } + + return status::success; +} + status_t primitive_attr_t::set_dropout(const memory_desc_t *user_dropout_desc) { if (any_null(user_dropout_desc)) return invalid_arguments; dropout_.user_dropout_desc_ = *user_dropout_desc; @@ -429,8 +479,10 @@ status_t primitive_attr_t::set_accumulation_mode(accumulation_mode_t am) { status_t primitive_attr_t::set_scratchpad_mode( scratchpad_mode_t scratchpad_mode) { - const bool ok = one_of( - scratchpad_mode, scratchpad_mode::library, scratchpad_mode::user); + /* workaround for the name conflict with system struct 'user' in llvm-android toolchain */ + using namespace dnnl::impl::scratchpad_mode; + + const bool ok = one_of(scratchpad_mode, scratchpad_mode::library, scratchpad_mode::user); if (!ok) return invalid_arguments; scratchpad_mode_ = scratchpad_mode; @@ -562,10 +614,18 @@ status_t dnnl_primitive_attr_set_scratchpad_mode( status_t dnnl_primitive_attr_set_scales_mask( primitive_attr_t *attr, int arg, int mask) { - bool ok = attr && mask >= 0 && arg >= 0; - if (!ok) return invalid_arguments; + VCHECK_ATTR(attr, VERBOSE_NULL_ARG); + VCHECK_ATTR(mask >= 0, VERBOSE_BAD_PARAM, "mask"); + VCHECK_ATTR(arg >= 0, VERBOSE_BAD_PARAM, "arg"); return attr->scales_.set(arg, mask); } +status_t dnnl_primitive_attr_set_scales_dims( + primitive_attr_t *attr, int arg, const dims_t dims, int ndims, data_type_t data_type) { + bool ok = attr && arg >= 0 && ndims > 0 + && attr->scales_.has_default_values(); + if (!ok) return invalid_arguments; + return attr->scales_.set_scales(arg, dims, ndims, data_type); +} status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg, int mask, int ndims, const dims_t group_dims, data_type_t data_type) { @@ -574,39 +634,47 @@ status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg, VCHECK_ATTR(mask >= 0, VERBOSE_BAD_PARAM, "mask"); VCHECK_ATTR(arg >= 0, VERBOSE_BAD_PARAM, "arg"); VCHECK_ATTR(ndims >= 0, VERBOSE_BAD_PARAM, "ndims"); - VCHECK_ATTR(utils::one_of(data_type, f32, bf16, f16, e8m0), - VERBOSE_INVALID_DATATYPE, "scales"); - VCHECK_ATTR(IMPLICATION(!utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS), - data_type == f32 && ndims == 0) - || IMPLICATION(arg == DNNL_ARG_DST, - utils::one_of(data_type, f32, e8m0)), + VCHECK_ATTR( + utils::one_of(data_type, f32, bf16, f16, e8m0, f8_e5m2, f8_e4m3), VERBOSE_INVALID_DATATYPE, "scales"); VCHECK_ATTR(IMPLICATION(ndims, validate_dims(ndims, group_dims)), VERBOSE_BAD_PARAM, "group_dims"); - return attr->scales_.set(arg, mask, ndims, group_dims, data_type); + return attr->scales_.set(arg, mask, data_type, ndims, group_dims); } status_t dnnl_primitive_attr_set_zero_points_mask( primitive_attr_t *attr, int arg, int mask) { - bool ok = attr && mask >= 0; + VCHECK_ATTR(attr, VERBOSE_NULL_ARG); + VCHECK_ATTR(mask >= 0, VERBOSE_BAD_PARAM, "mask"); + return attr->zero_points_.set(arg, mask); +} +status_t dnnl_primitive_attr_set_zero_points_dims( + primitive_attr_t *attr, int arg, const dims_t dims, int ndims, dnnl_data_type_t data_type) { + bool ok = attr && ndims > 0; if (!ok) return invalid_arguments; - return attr->zero_points_.set(arg, mask); + return attr->zero_points_.set_zero_points(arg, dims, ndims, data_type); } -dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points( - dnnl_primitive_attr_t attr, int arg, int mask, int ndims, - const dnnl_dims_t group_dims, dnnl_data_type_t data_type) { +status_t dnnl_primitive_attr_set_zero_points(dnnl_primitive_attr_t attr, + int arg, int mask, int ndims, const dnnl_dims_t group_dims, + dnnl_data_type_t data_type) { using namespace data_type; - bool ok = attr && arg >= 0 && mask >= 0 && ndims >= 0 - && utils::one_of(data_type, s32, s8, u8, s4, u4) - && IMPLICATION( - arg != DNNL_ARG_WEIGHTS, data_type == s32 && ndims == 0) - && IMPLICATION(utils::one_of(data_type, s4, u4), mask > 0) - && IMPLICATION(ndims, validate_dims(ndims, group_dims)); - if (!ok) return invalid_arguments; + VCHECK_ATTR(attr, VERBOSE_NULL_ARG); + VCHECK_ATTR(mask >= 0, VERBOSE_BAD_PARAM, "mask"); + VCHECK_ATTR(arg >= 0, VERBOSE_BAD_PARAM, "arg"); + VCHECK_ATTR(ndims >= 0, VERBOSE_BAD_PARAM, "ndims"); + VCHECK_ATTR(utils::one_of(data_type, s32, s8, u8, s4, u4), + VERBOSE_INVALID_DATATYPE, "zero points"); + VCHECK_ATTR(IMPLICATION(utils::one_of(data_type, s4, u4), mask > 0), + VERBOSE_BAD_PARAM, "mask with int4 data type"); + VCHECK_ATTR(IMPLICATION(!utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS), + data_type == s32 && ndims == 0), + VERBOSE_INVALID_DATATYPE, "zero points"); + VCHECK_ATTR(IMPLICATION(ndims, validate_dims(ndims, group_dims)), + VERBOSE_BAD_PARAM, "group_dims"); - return attr->zero_points_.set(arg, mask, ndims, group_dims, data_type); + return attr->zero_points_.set(arg, mask, data_type, ndims, group_dims); } status_t dnnl_primitive_attr_get_rounding( @@ -622,6 +690,33 @@ status_t dnnl_primitive_attr_set_rounding( return attr->rounding_mode_.set(arg, mode); } +status_t dnnl_primitive_attr_set_output_compensations(primitive_attr_t *attr, + int count, int mask) { + bool ok = !any_null(attr) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->output_compensations_.set(count, mask); +} + +status_t dnnl_primitive_attr_set_input_zero_points(primitive_attr_t *attr, + int count, int mask) { + bool ok = !any_null(attr) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->input_zero_points_.set(count, mask); +} + +status_t dnnl_primitive_attr_set_weights_zero_points(primitive_attr_t *attr, + int count, int mask) { + bool ok = !any_null(attr) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->weights_zero_points_.set(count, mask); +} + status_t dnnl_primitive_attr_get_post_ops( const primitive_attr_t *attr, const post_ops_t **post_ops) { if (any_null(attr, post_ops)) return invalid_arguments; @@ -681,19 +776,20 @@ status_t dnnl_post_ops_append_sum( } namespace { -bool simple_get_params_check( +status_t simple_get_params_check( const post_ops_t *post_ops, int index, primitive_kind_t kind) { - bool ok = true && post_ops != nullptr && 0 <= index - && index < post_ops->len() && post_ops->entry_[index].kind == kind; - return ok; + VCHECK_ATTR(post_ops, VERBOSE_NULL_ARG); + VCHECK_ATTR(index >= 0, VERBOSE_BAD_PARAM, "index"); + VCHECK_ATTR(index < post_ops->len(), VERBOSE_BAD_PARAM, "index"); + VCHECK_ATTR( + post_ops->entry_[index].kind == kind, VERBOSE_BAD_PARAM, "kind"); + return status::success; } } // namespace status_t dnnl_post_ops_get_params_sum(const post_ops_t *post_ops, int index, float *scale, int32_t *zero_point, data_type_t *dt) { - bool ok = true - && simple_get_params_check(post_ops, index, primitive_kind::sum); - if (!ok) return invalid_arguments; + CHECK(simple_get_params_check(post_ops, index, primitive_kind::sum)); if (scale) *scale = post_ops->entry_[index].sum.scale; if (zero_point) *zero_point = post_ops->entry_[index].sum.zero_point; @@ -711,15 +807,12 @@ status_t dnnl_post_ops_append_eltwise( status_t dnnl_post_ops_get_params_eltwise(const post_ops_t *post_ops, int index, alg_kind_t *alg, float *alpha, float *beta) { - bool ok = true - && simple_get_params_check(post_ops, index, primitive_kind::eltwise) - && !any_null(alpha, beta); - if (!ok) return invalid_arguments; + CHECK(simple_get_params_check(post_ops, index, primitive_kind::eltwise)); const auto &e = post_ops->entry_[index].eltwise; - *alg = e.alg; - *alpha = e.alpha; - *beta = e.beta; + if (alg) *alg = e.alg; + if (alpha) *alpha = e.alpha; + if (beta) *beta = e.beta; return success; } @@ -736,9 +829,8 @@ status_t dnnl_post_ops_append_dw(post_ops_t *post_ops, data_type_t wei_dt, status_t dnnl_post_ops_get_params_dw(const post_ops_t *post_ops, int index, data_type_t *wei_dt, data_type_t *bias_dt, data_type_t *dst_dt, dim_t *kernel, dim_t *stride, dim_t *padding) { - - if (!simple_get_params_check(post_ops, index, primitive_kind::convolution)) - return invalid_arguments; + CHECK(simple_get_params_check( + post_ops, index, primitive_kind::convolution)); const auto &d = post_ops->entry_[index].depthwise_conv; if (wei_dt) *wei_dt = d.wei_dt; @@ -760,8 +852,7 @@ status_t dnnl_post_ops_append_binary(post_ops_t *post_ops, alg_kind_t alg_kind, status_t dnnl_post_ops_get_params_binary(const post_ops_t *post_ops, int index, alg_kind_t *alg_kind, const memory_desc_t **user_src1_desc) { - if (!simple_get_params_check(post_ops, index, primitive_kind::binary)) - return invalid_arguments; + CHECK(simple_get_params_check(post_ops, index, primitive_kind::binary)); const auto &b = post_ops->entry_[index].binary; if (alg_kind) *alg_kind = b.alg; @@ -787,6 +878,45 @@ status_t dnnl_post_ops_get_params_prelu( return success; } +status_t dnnl_post_ops_append_depthwise(dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, size_t offset_size, const size_t* offset) { + if (post_ops == nullptr || offset == nullptr) return invalid_arguments; + + if (offset_size != 2) + return invalid_arguments; + + return post_ops->append_depthwise(alg, offset_size, offset); +} + +status_t dnnl_post_ops_append_quantization(post_ops_t *post_ops, alg_kind_t kind, + size_t per_channel_size, const bool* per_channel, + size_t all_default_size, const bool* all_default, + size_t offset_size, const size_t* offset) { + if (post_ops == nullptr || per_channel == nullptr || all_default == nullptr || offset == nullptr) + return invalid_arguments; + + if (per_channel_size != all_default_size || all_default_size != offset_size || offset_size != 6) + return invalid_arguments; + + return post_ops->append_quantization(kind, per_channel_size, per_channel, all_default_size, all_default, offset_size, offset); +} + +status_t dnnl_post_ops_append_binarization(post_ops_t *post_ops, alg_kind_t kind, const float* weights_data, + const float* output_mask_data) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_binarization(kind, weights_data, output_mask_data); +} + +status_t dnnl_post_ops_append_dw_conv(post_ops_t *post_ops, + int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, + dnnl::impl::data_type_t in_dt) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_dw_conv(in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt); +} + status_t dnnl_primitive_attr_set_rnn_data_qparams( primitive_attr_t *attr, const float scale, const float shift) { if (attr == nullptr) return invalid_arguments; @@ -854,3 +984,22 @@ status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams( return attr->rnn_tparams_.set(mode, ngates, scales, cscale); } + +status_t dnnl_primitive_attr_set_src_dyn_quant_params( + primitive_attr_t *attr, const uint64_t group_size) { + if (attr == nullptr) return invalid_arguments; + + return attr->src_dyn_quant_params_.set(group_size); +} + +status_t dnnl_primitive_attr_get_src_dyn_quant_params( + primitive_attr_t *attr, uint64_t* group_size) { + if (attr == nullptr) return invalid_arguments; + + if (group_size) *group_size = attr->src_dyn_quant_params_.get(); + return success; +} + +template struct dnnl::impl::shifts_t; +template struct dnnl::impl::shifts_t; +template struct dnnl::impl::shifts_t; diff --git a/src/common/primitive_attr.hpp b/src/common/primitive_attr.hpp index 5e1496978ed..c961d884455 100644 --- a/src/common/primitive_attr.hpp +++ b/src/common/primitive_attr.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ #include "c_types_map.hpp" #include "nstl.hpp" +#include "primitive_attr_quant.hpp" #include "type_helpers.hpp" #include "utils.hpp" @@ -36,11 +37,9 @@ namespace dnnl { namespace impl { const primitive_attr_t &default_attr(); -struct runtime_scales_t; -const runtime_scales_t &default_runtime_scale(); struct rnn_data_qparams_t : public c_compatible { - rnn_data_qparams_t() : scale_(1.), shift_(0.) {} + rnn_data_qparams_t() : scale_(1.f), shift_(0.f) {} bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); } bool defined() const { return !is_runtime_value(scale_) && !is_runtime_value(shift_); @@ -129,14 +128,14 @@ struct rnn_tparams_t : public c_compatible { }; // Note: keep for RNN quantization -struct scales_t : public c_compatible { - scales_t() : count_(1), mask_(0), scales_(scales_buf_) { - set_single_scale(1.); +struct rnn_create_time_scales_t : public c_compatible { + rnn_create_time_scales_t() : count_(1), mask_(0), scales_(scales_buf_) { + set_single_scale(1.f); } - ~scales_t() { cleanup(); } + ~rnn_create_time_scales_t() { cleanup(); } - bool operator==(const scales_t &rhs) const { + bool operator==(const rnn_create_time_scales_t &rhs) const { bool ret = count_ == rhs.count_ && mask_ == rhs.mask_ && !utils::any_null(scales_, rhs.scales_) && defined() == rhs.defined() @@ -162,7 +161,7 @@ struct scales_t : public c_compatible { return status::success; } - status_t copy_from(const scales_t &other) { + status_t copy_from(const rnn_create_time_scales_t &other) { return set(other.count_, other.mask_, other.scales_); } @@ -182,296 +181,59 @@ struct scales_t : public c_compatible { scales_ = scales_buf_; } - DNNL_DISALLOW_COPY_AND_ASSIGN(scales_t); + DNNL_DISALLOW_COPY_AND_ASSIGN(rnn_create_time_scales_t); }; -struct runtime_scales_t : public c_compatible { - // Clang-3.8.1 raises an error for a default initialization of a const - // object. Const runtime_scales_t object is used as default_scales. - // runtime_scales_t() = default; - runtime_scales_t() {} +template +struct shifts_t: public c_compatible { + shifts_t(): count_(1), mask_(0), shifts_(shifts_buf_) + { set(0); } - runtime_scales_t &operator=(const runtime_scales_t &rhs) { - mask_ = rhs.mask_; - is_set_ = rhs.is_set_; - ndims_ = rhs.ndims_; - if (ndims_ > 0) utils::array_copy(group_dims_, rhs.group_dims_, ndims_); - data_type_ = rhs.data_type_; - return *this; - } - - status_t set(int mask) { return set(0, mask, {}, data_type::f32); } - - status_t set(int ndims, int mask, const dims_t group_dims, - data_type_t data_type = data_type::f32) { - mask_ = mask; - is_set_ = true; - ndims_ = ndims; - if (ndims > 0) utils::array_copy(group_dims_, group_dims, ndims); - data_type_ = data_type; - return status::success; - } - - bool operator==(const runtime_scales_t &rhs) const { - return mask_ == rhs.mask_ && is_set_ == rhs.is_set_ - && ndims_ == rhs.ndims_ - && IMPLICATION(ndims_ > 0, - utils::array_cmp(group_dims_, rhs.group_dims_, ndims_)) - && data_type_ == rhs.data_type_; - } - - bool has_default_values() const { return *this == default_runtime_scale(); } - - bool has_default_groups() const { return 0 == ndims_; } - bool has_default_data_type() const { return data_type_ == data_type::f32; } - - bool defined() const { return has_default_values(); } - - void reset() { *this = default_runtime_scale(); } - - // TODO: replace with `-1` to remove `is_set_`. - // Hide `mask_` under `private:` to force interface usage. - int mask_ = 0; - bool is_set_ = false; - int ndims_ = 0; - dims_t group_dims_ = {}; - data_type_t data_type_ = data_type::f32; -}; - -struct arg_scales_t : public c_compatible { - arg_scales_t() = default; - - const runtime_scales_t &get(int arg) const { - static const runtime_scales_t default_scales; - const auto it = scales_.find(arg); - if (it == scales_.end()) return default_scales; - return it->second; - } - - status_t set(int arg, const runtime_scales_t &scale) { - if (!check_arg(arg)) return status::invalid_arguments; - scales_[arg] = scale; - return status::success; - } - - bool operator==(const arg_scales_t &rhs) const { - return scales_ == rhs.scales_; - } - - bool has_default_values(const std::vector &skip_args = {}) const { - auto predicate = [](const runtime_scales_t &s) { - return s.has_default_values(); - }; - return has_default_property(skip_args, predicate); - } - - bool has_default_data_type(const std::vector &skip_args = {}) const { - auto predicate = [](const runtime_scales_t &s) { - return s.has_default_data_type(); - }; - return has_default_property(skip_args, predicate); - } - - bool has_default_groups(const std::vector &skip_args = {}) const { - auto predicate = [](const runtime_scales_t &s) { - return s.has_default_groups(); - }; - return has_default_property(skip_args, predicate); - } - - status_t set(int arg, int mask) { - return set(arg, mask, 0, {}, data_type::f32); - } - - status_t set(int arg, int mask, int ndims, const dims_t group_dims, - data_type_t data_type) { - if (!check_arg(arg)) return status::invalid_arguments; - return scales_[arg].set(ndims, mask, group_dims, data_type); - } + ~shifts_t() { cleanup(); } - // TODO: move to `private` and keep a single interface per entry. - status_t get(int arg, int *mask, bool *is_set, int *ndims = nullptr, - dims_t group_dims = nullptr, - data_type_t *data_type = nullptr) const { - if (!check_arg(arg)) return status::invalid_arguments; - const auto &s = get(arg); - if (mask) *mask = s.mask_; - if (is_set) *is_set = s.is_set_; - if (ndims) *ndims = s.ndims_; - if (group_dims && s.ndims_ > 0) - utils::array_copy(group_dims, s.group_dims_, s.ndims_); - if (data_type) *data_type = s.data_type_; - return status::success; - } - - data_type_t get_data_type(int arg) const { - data_type_t data_type; - auto st = get(arg, nullptr, nullptr, nullptr, nullptr, &data_type); - if (st != status::success) return data_type::undef; - return data_type; - } - - status_t reset(int arg) { - if (!check_arg(arg)) return status::invalid_arguments; - const auto it = scales_.find(arg); - if (it != scales_.end()) scales_.erase(it); - return status::success; - } - - bool defined() const { return has_default_values(); } - - status_t copy_from(const arg_scales_t &other) { - for (auto it = other.scales_.begin(); it != other.scales_.end(); ++it) { - // Find an entry that can match the arguments without constructing a - // new object. - if (scales_.count(it->first) == 1) { - auto &entry = scales_[it->first]; - if (entry == it->second) continue; - } - - CHECK(set(it->first, it->second)); - } - return status::success; + bool operator==(const shifts_t &rhs) const { + bool ret = count_ == rhs.count_ && mask_ == rhs.mask_ + && !utils::any_null(shifts_, rhs.shifts_) + && defined() == rhs.defined() + && IMPLICATION(defined(), + utils::array_cmp(shifts_, rhs.shifts_, count_)); + return ret; } - std::map scales_; - -private: - bool check_arg(int arg) const { - // binary - for (const auto &sa : {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}) { - if (arg == sa) return true; - } - // concat - if (arg & DNNL_ARG_MULTIPLE_SRC) return true; - // convolution - for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - if (arg == sa) return true; - } - // depth-wise convolution post op - for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - if (arg == (DNNL_ARG_ATTR_POST_OP_DW | sa)) return true; - } - return false; - } - - bool has_default_property(const std::vector &skip_args, - bool (*predicate)(const runtime_scales_t &)) const { - for (const auto &s : scales_) { - if (!predicate(s.second)) { - bool skip = false; - for (const auto &skip_a : skip_args) - if (s.first == skip_a) { - skip = true; - break; - } - if (skip) continue; - return false; - } + bool has_default_values() const { + for (int c = 0; c < count_; ++c) { + if(shifts_[c] != 0) return false; } return true; } -}; - -struct zero_points_t : public c_compatible { - bool operator==(const zero_points_t &rhs) const { - return mask_src == rhs.mask_src && mask_wei == rhs.mask_wei - && mask_dst == rhs.mask_dst && is_set_src == rhs.is_set_src - && is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst - && data_type_wei == rhs.data_type_wei - && group_ndims_wei == rhs.group_ndims_wei - && IMPLICATION(group_ndims_wei > 0, - utils::array_cmp(group_dims_wei, rhs.group_dims_wei, - group_ndims_wei)); - } - - // arg-specific checks - bool common(int arg) const { return get_mask(arg) == 0; } - bool defined(int arg) const { return has_default_values(arg); } - bool has_default_values(int arg) const { - return is_set(arg) == false && has_default_data_type(arg); - } - bool has_default_groups(int arg) const { - return IMPLICATION(arg == DNNL_ARG_WEIGHTS, group_ndims_wei == 0); - } - bool has_default_data_type(int arg) const { - return get_data_type(arg) == data_type::s32; - } - // same checks but for all supported arguments at once - bool common() const { return check_all(&zero_points_t::common); } - bool defined() const { return has_default_values(); } - bool has_default_values() const { - return check_all(&zero_points_t::has_default_values); - } - bool has_default_groups() const { - return check_all(&zero_points_t::has_default_groups); - } - bool has_default_data_type() const { - return check_all(&zero_points_t::has_default_data_type); - } - - status_t get(int arg, int *mask, data_type_t *dt = nullptr) const; - - int get(int arg) const; // Returns 0 if dimension is unset - - data_type_t get_data_type(int arg) const { - if (arg == DNNL_ARG_WEIGHTS) return data_type_wei; - return data_type::s32; - } - const dim_t *get_groups(int arg) const { - if (arg == DNNL_ARG_WEIGHTS) return group_dims_wei; - return nullptr; - } + bool defined() const { return !is_runtime_value(shifts_[0]); } - int get_groups_ndims(int arg) const { - if (arg == DNNL_ARG_WEIGHTS) return group_ndims_wei; - return 0; - } + status_t set(int count, int mask, const T *zero_points); + status_t set(T single_zero_point) { return this->set(1, 0, &single_zero_point); } - status_t set(int arg, int mask, int ndims, const dims_t group_dims, - data_type_t data_type); - - status_t set(int arg, int mask) { - return set(arg, mask, 0, nullptr, data_type::s32); + status_t copy_from(const shifts_t &other) { + return set(other.count_, other.mask_, other.shifts_); } - status_t set(int arg) { return set(arg, 0); } + dim_t count_; + int mask_; + T *shifts_; private: - bool is_set_src = false, is_set_wei = false, is_set_dst = false; - int mask_src = 0, mask_wei = 0, mask_dst = 0; - data_type_t data_type_wei = data_type::s32; - int group_ndims_wei = 0; - dims_t group_dims_wei {}; - - int get_mask(int arg) const { - int mask = 0; - switch (arg) { - case DNNL_ARG_SRC: mask = mask_src; break; - case DNNL_ARG_WEIGHTS: mask = mask_wei; break; - case DNNL_ARG_DST: mask = mask_dst; break; - default: mask = 0; - } - return mask; - } + enum { shifts_buf_size = 16 }; + T shifts_buf_[shifts_buf_size]; - bool is_set(int arg) const { - bool arg_is_set = false; - switch (arg) { - case DNNL_ARG_SRC: arg_is_set = is_set_src; break; - case DNNL_ARG_WEIGHTS: arg_is_set = is_set_wei; break; - case DNNL_ARG_DST: arg_is_set = is_set_dst; break; - default: arg_is_set = 0; - } - return arg_is_set; - } + void cleanup() { + if (shifts_ != shifts_buf_ && shifts_ != nullptr) + impl::free(shifts_); - bool check_all(bool (zero_points_t::*f)(int) const) const { - for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) - if (!(this->*f)(arg)) return false; - return true; + count_ = 1; + mask_ = 0; + shifts_ = shifts_buf_; } + + DNNL_DISALLOW_COPY_AND_ASSIGN(shifts_t); }; struct dropout_t : public c_compatible { @@ -558,6 +320,26 @@ struct fpmath_t : public c_compatible { bool apply_to_int_; }; +struct legacy_zero_points_t : public c_compatible { + bool operator==(const legacy_zero_points_t &rhs) const { + return count_ == rhs.count_ && mask_ == rhs.mask_; + } + + bool has_default_values() const { + return count_ == 0 && mask_ == 0; + } + + status_t set(dim_t count, int mask) { + count_ = count; + mask_ = mask; + + return status::success; + } + + dim_t count_ = 0; + int mask_ = 0; +}; + } // namespace impl } // namespace dnnl @@ -609,14 +391,64 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { int mask; }; + struct depthwise_t { + enum depthwise_fields { + scales, + shifts, + + fields_count + }; + + dnnl::impl::alg_kind_t alg; + size_t offset[fields_count]; + }; + + struct quantization_t { + enum quantization_fields { + crop_low, + crop_high, + inp_scale, + inp_shift, + output_scale, + output_shift, + + fields_count + }; + + dnnl::impl::alg_kind_t alg; + bool per_channel[fields_count]; + bool all_default[fields_count]; + size_t offset[fields_count]; + }; + + struct binarization_t { + dnnl::impl::alg_kind_t alg; + const float* weights_data; + const float* output_mask_data; + }; + + struct depthwise_conv_old_t { + int in_h; + int in_w; + int ker_h; + int ker_w; + int str_h; + int str_w; + dnnl::impl::data_type_t in_dt; + }; + dnnl::impl::primitive_kind_t kind = dnnl::impl::primitive_kind::undefined; union { sum_t sum; eltwise_t eltwise; depthwise_conv_t depthwise_conv; + depthwise_conv_old_t depthwise_conv_old; binary_t binary; prelu_t prelu; + depthwise_t depthwise; + quantization_t quantization; + binarization_t binarization; }; bool is_eltwise(bool require_scale_one = false) const { @@ -655,8 +487,23 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { } bool is_like_binary() const { return is_binary() || is_prelu(); } + bool is_depthwise() const { + using namespace dnnl::impl; + return kind == primitive_kind::depthwise; + } - dnnl::impl::status_t set_depthwise_scales(const float *scales); + bool is_quantization() const { + using namespace dnnl::impl; + return kind == primitive_kind::quantization; + } + + bool is_binarization() const { + using namespace dnnl::impl; + return kind == primitive_kind::binarization; + } + + dnnl::impl::status_t validate_binary_with_dst_consistency( + const dnnl::impl::memory_desc_t *dst_desc) const; bool operator==(const entry_t &rhs) const { using namespace dnnl::impl; @@ -676,18 +523,26 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { && sum.dt == rhs.sum.dt; break; case primitive_kind::convolution: - // Depthwise Only - ret = depthwise_conv.kernel == rhs.depthwise_conv.kernel - && depthwise_conv.stride - == rhs.depthwise_conv.stride - && depthwise_conv.padding - == rhs.depthwise_conv.padding - && depthwise_conv.wei_dt - == rhs.depthwise_conv.wei_dt - && depthwise_conv.bias_dt - == rhs.depthwise_conv.bias_dt - && depthwise_conv.dst_dt - == rhs.depthwise_conv.dst_dt; + // todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported + // // Depthwise Only + // ret = depthwise_conv.kernel == rhs.depthwise_conv.kernel + // && depthwise_conv.stride + // == rhs.depthwise_conv.stride + // && depthwise_conv.padding + // == rhs.depthwise_conv.padding + // && depthwise_conv.wei_dt + // == rhs.depthwise_conv.wei_dt + // && depthwise_conv.bias_dt + // == rhs.depthwise_conv.bias_dt + // && depthwise_conv.dst_dt + // == rhs.depthwise_conv.dst_dt; + ret = depthwise_conv_old.in_h == rhs.depthwise_conv_old.in_h + && depthwise_conv_old.in_w == rhs.depthwise_conv_old.in_w + && depthwise_conv_old.ker_h == rhs.depthwise_conv_old.ker_h + && depthwise_conv_old.ker_w == rhs.depthwise_conv_old.ker_w + && depthwise_conv_old.str_h == rhs.depthwise_conv_old.str_h + && depthwise_conv_old.str_w == rhs.depthwise_conv_old.str_w + && depthwise_conv_old.in_dt == rhs.depthwise_conv_old.in_dt; break; case primitive_kind::binary: ret = binary.alg == rhs.binary.alg @@ -697,6 +552,21 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { case primitive_kind::prelu: ret = prelu.mask == rhs.prelu.mask; break; + case primitive_kind::depthwise: + ret = depthwise.alg == rhs.depthwise.alg + && array_cmp(depthwise.offset, rhs.depthwise.offset, depthwise.fields_count); + break; + case primitive_kind::quantization: + ret = quantization.alg == rhs.quantization.alg + && array_cmp(quantization.per_channel, rhs.quantization.per_channel, quantization.fields_count) + && array_cmp(quantization.all_default, rhs.quantization.all_default, quantization.fields_count) + && array_cmp(quantization.offset, rhs.quantization.offset, quantization.fields_count); + break; + case primitive_kind::binarization: + ret = depthwise.alg == rhs.depthwise.alg + && binarization.weights_data == rhs.binarization.weights_data + && binarization.output_mask_data == rhs.binarization.output_mask_data; + break; default: assert(!"unsupported post_op"); } return ret; @@ -707,7 +577,7 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { } }; - dnnl_post_ops() : entry_() {} + dnnl_post_ops() = default; ~dnnl_post_ops() = default; dnnl::impl::status_t append_sum(float scale, int32_t zero_point = 0, @@ -721,6 +591,15 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { dnnl::impl::status_t append_binary(dnnl::impl::alg_kind_t alg, const dnnl::impl::memory_desc_t *user_src1_desc); dnnl::impl::status_t append_prelu(int mask); + dnnl::impl::status_t append_depthwise(dnnl::impl::alg_kind_t alg, size_t offset_size, const size_t* offset); + dnnl::impl::status_t append_quantization(dnnl::impl::alg_kind_t alg, + size_t per_channel_size, const bool* per_channel, + size_t all_default_size, const bool* all_default, + size_t offset_size, const size_t* offset); + dnnl::impl::status_t append_binarization(dnnl::impl::alg_kind_t alg, const float* weights_data, + const float* output_mask_data); + dnnl::impl::status_t append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, + dnnl::impl::data_type_t in_dt); dnnl::impl::status_t prepend_binary(dnnl::impl::alg_kind_t alg, const dnnl::impl::memory_desc_t *user_src1_desc); @@ -743,7 +622,16 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { return dst_dt; } - bool defined() const; + int count(dnnl::impl::primitive_kind_t kind, int start = 0, + int stop = -1) const { + if (stop == -1) stop = len(); + stop = dnnl::impl::nstl::min(stop, len()); + int cnt = 0; + for (int idx = start; idx < stop; ++idx) + if (entry_[idx].kind == kind) cnt++; + return cnt; + } + int len() const { return (int)entry_.size(); } bool has_default_values( const std::vector &skip_pk @@ -777,6 +665,9 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { || entry_[sum_ind].sum.dt == dst_dt; } + dnnl::impl::status_t validate_binary_with_dst_consistency( + const dnnl::impl::memory_desc_t *dst_desc) const; + bool contain(dnnl::impl::primitive_kind_t kind, int index) const { return find(kind, index, index + 1) == index; } @@ -820,7 +711,8 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { return new dnnl_primitive_attr(*this); } - dnnl_primitive_attr(const dnnl_primitive_attr &other) { + dnnl_primitive_attr(const dnnl_primitive_attr &other) + : c_compatible(other) { if (copy_from(other) != dnnl::impl::status::success) is_initialized_ = false; } @@ -828,7 +720,6 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { dnnl::impl::status_t copy_from(const dnnl_primitive_attr &other) { using namespace dnnl::impl; - output_scales_ = other.output_scales_; scales_ = other.scales_; zero_points_ = other.zero_points_; rounding_mode_ = other.rounding_mode_; @@ -844,6 +735,10 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { CHECK(rnn_tparams_.copy_from(other.rnn_tparams_)); if (other.gpu_attr_) gpu_attr_ = other.gpu_attr_->clone(); dropout_ = other.dropout_; + input_zero_points_ = (other.input_zero_points_); + weights_zero_points_ = (other.weights_zero_points_); + output_compensations_ = (other.output_compensations_); + src_dyn_quant_params_ = other.src_dyn_quant_params_; return status::success; } @@ -852,28 +747,27 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { enum class skip_mask_t : unsigned { none = 0, - oscale = 1u << 0, - oscale_runtime = 1u << 1, - scales = 1u << 2, - scales_runtime = (unsigned)scales | (1u << 3), + scales = 1u << 1, + scales_groups = (unsigned)scales | (1u << 2), + scales_data_type = (unsigned)scales | (1u << 3), zero_points = 1u << 4, - zero_points_runtime = (unsigned)zero_points | (1u << 5), - post_ops = 1u << 6, - rnn_data_qparams = 1u << 7, - rnn_weights_qparams = 1u << 8, - rnn_tparams = 1u << 9, - sum_dt = 1u << 10, - rnn_weights_projection_qparams = 1u << 11, - gpu_attr = 1u << 12, - accumulation_mode = 1u << 13, - fpmath_mode = 1u << 14, - scales_runtime_groups = (unsigned)scales_runtime | (1u << 15), - scales_runtime_data_type = (unsigned)scales_runtime | (1u << 16), - zero_points_runtime_groups = (unsigned)zero_points_runtime | (1u << 17), - zero_points_runtime_data_type - = (unsigned)zero_points_runtime | (1u << 18), - dropout = 1u << 19, - rounding_mode = 1u << 20, + zero_points_groups = (unsigned)zero_points | (1u << 5), + zero_points_data_type = (unsigned)zero_points | (1u << 6), + post_ops = 1u << 7, + sum_dt = 1u << 8, + rnn_data_qparams = 1u << 9, + rnn_weights_qparams = 1u << 10, + rnn_tparams = 1u << 11, + rnn_weights_projection_qparams = 1u << 12, + gpu_attr = 1u << 13, + accumulation_mode = 1u << 14, + fpmath_mode = 1u << 15, + dropout = 1u << 16, + rounding_mode = 1u << 17, + input_zero_points = 1u << 18, + weights_zero_points = 1u << 19, + output_compensations = 1u << 20, + src_dyn_quant_params = 1u << 21, }; /** Returns true if the attributes have default values. @@ -889,7 +783,6 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { bool ret = scratchpad_mode_ == rhs.scratchpad_mode_ && fpmath_ == rhs.fpmath_ && acc_mode_ == rhs.acc_mode_ && deterministic_ == rhs.deterministic_ - && output_scales_ == rhs.output_scales_ && scales_ == rhs.scales_ && zero_points_ == rhs.zero_points_ && post_ops_ == rhs.post_ops_ && rnn_data_qparams_ == rhs.rnn_data_qparams_ @@ -901,7 +794,11 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { && gpu_attr_->is_equal(*rhs.gpu_attr_)) || (!gpu_attr_ && !rhs.gpu_attr_)) && dropout_ == rhs.dropout_ - && rounding_mode_ == rhs.rounding_mode_; + && rounding_mode_ == rhs.rounding_mode_ + && input_zero_points_ == rhs.input_zero_points_ + && weights_zero_points_ == rhs.weights_zero_points_ + && output_compensations_ == rhs.output_compensations_ + && src_dyn_quant_params_ == rhs.src_dyn_quant_params_; return ret; } @@ -964,8 +861,7 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { } // NOTE: make sure that the types below have overloaded comparison operator - dnnl::impl::runtime_scales_t output_scales_; - dnnl::impl::arg_scales_t scales_; + dnnl::impl::scales_t scales_; dnnl::impl::zero_points_t zero_points_; dnnl::impl::scratchpad_mode_t scratchpad_mode_; dnnl::impl::fpmath_t fpmath_; @@ -973,14 +869,20 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { bool deterministic_; dnnl::impl::post_ops_t post_ops_; dnnl::impl::rnn_data_qparams_t rnn_data_qparams_; - dnnl::impl::scales_t rnn_weights_qparams_; - dnnl::impl::scales_t rnn_weights_projection_qparams_; + dnnl::impl::rnn_create_time_scales_t rnn_weights_qparams_; + dnnl::impl::rnn_create_time_scales_t rnn_weights_projection_qparams_; dnnl::impl::rnn_tparams_t rnn_tparams_; dnnl::impl::dropout_t dropout_; dnnl::impl::rnd_mode_t rounding_mode_; std::unique_ptr gpu_attr_; + dnnl::impl::legacy_zero_points_t input_zero_points_; + dnnl::impl::legacy_zero_points_t weights_zero_points_; + dnnl::impl::legacy_zero_points_t output_compensations_; + + dnnl::impl::src_dyn_quant_params_t src_dyn_quant_params_; + dnnl_primitive_attr &operator=(const dnnl_primitive_attr &other) = delete; }; diff --git a/src/common/primitive_attr_quant.cpp b/src/common/primitive_attr_quant.cpp new file mode 100644 index 00000000000..2c954f8d9e6 --- /dev/null +++ b/src/common/primitive_attr_quant.cpp @@ -0,0 +1,289 @@ +/******************************************************************************* +* Copyright 2024-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/primitive_attr_quant.hpp" +#include "common/primitive_hashing.hpp" +#include "common/verbose.hpp" + +namespace dnnl { +namespace impl { + +const quant_entry_t &default_quant_entry() { + static const quant_entry_t default_quant_entry; + return default_quant_entry; +} + +size_t quant_entry_t::get_hash() const { + size_t seed = 0; + seed = hash_combine(seed, mask_); + seed = hash_combine(seed, static_cast(data_type_)); + seed = hash_combine(seed, group_ndims_); + if (group_ndims_ > 0) + seed = primitive_hashing::get_array_hash( + seed, group_dims_, group_ndims_); + return seed; +} + +void quant_entry_t::serialize(serialization_stream_t &sstream) const { + sstream.append(mask_); + sstream.append(data_type_); + sstream.append_array(group_ndims_, group_dims_); +} + +quant_entry_t quant_entry_t::deserialize(deserializer_t &d) { + quant_entry_t e; + d.pop(e.mask_); + d.pop(e.data_type_); + size_t group_ndims; + d.pop_array(group_ndims, e.group_dims_); + e.group_ndims_ = static_cast(group_ndims); + return e; +} + +std::string quant_entry_t::get_verbose() const { + std::string s; + s.append(std::to_string(get_mask())); + s.append(":").append(dnnl_dt2str(get_data_type())); + s.append(":").append(std::to_string(type_)); + s.append(":"); + if (group_ndims_ > 0) { + s.append(std::to_string(group_dims_[0])) + .append("x") + .append(std::to_string(group_dims_[1])); + } + s.append(":"); + if (get_ndims() > 0) { + s.append(std::to_string(get_dims()[0])) + .append("x") + .append(std::to_string(get_dims()[1])); + } + + return s; +} + +std::ostream &operator<<(std::ostream &ss, const quant_entry_t &e) { + ss << e.get_verbose(); + return ss; +} + +size_t quant_entries_t::get_hash() const { + size_t seed = 0; + // Go through scales for all arguments. + for (const auto &e : entries_) { + seed = hash_combine(seed, e.first); + seed = hash_combine(seed, e.second.get_hash()); + } + return seed; +} + +void quant_entries_t::serialize(serialization_stream_t &sstream) const { + sstream.append(entries_.size()); + for (const auto &e : entries_) { + sstream.append(e.first); + sstream.append(e.second); + } +} + +template +T deserialize_entries(deserializer_t &d) { + T entries; + size_t size = d.pop(); + for (size_t i = 0; i < size; i++) { + int arg = d.pop(); + entries.set(arg, d.pop()); + } + return entries; +} + +std::string quant_entries_t::get_verbose() const { + std::string s; + std::string empty_delim, attr_delim = "+"; + std::string delim = empty_delim; + for (const auto &scale : entries_) { + const auto &q = scale.second; + if (q.has_default_values()) continue; + + int arg = scale.first; + s.append(delim) + .append(arg2str(arg)) + .append(":") + .append(q.get_verbose()); + delim = attr_delim; + } + return s; +} + +scales_t scales_t::deserialize(deserializer_t &d) { + return deserialize_entries(d); +} + +zero_points_t zero_points_t::deserialize(deserializer_t &d) { + return deserialize_entries(d); +} + +status_t quant_entry_t::set(int mask, data_type_t data_type, int group_ndims, + const dims_t group_dims) { + type_ = type_ | DNNL; + is_set_ = true; + mask_ = mask; + data_type_ = data_type; + group_ndims_ = group_ndims; + if (group_ndims_ > 0) { + utils::array_copy(group_dims_, group_dims, group_ndims_); + } + return status::success; +} + +status_t quant_entry_t::set_scales(const dims_t dims, int ndims, data_type_t data_type, int mask) { + type_ = type_ | OV_SCALES; + is_set_scale = true; + ndims_scale = ndims; + mask_scale = mask; + data_type_scale = data_type; + if (ndims_scale > 0) { + utils::array_copy(dims_scale, dims, ndims_scale); + } + return status::success; +} + +status_t quant_entry_t::set_zero_points(const dims_t dims, int ndims, data_type_t data_type, int mask) { + type_ = type_ | DNNL; + is_set_wei = true; + ndims_wei = ndims; + mask_wei = mask; + if (ndims_wei > 0) { + utils::array_copy(dims_wei, dims, ndims_wei); + group_ndims_ = ndims; + utils::array_copy(group_dims_, dims, group_ndims_); + } + data_type_wei = data_type; + return status::success; +} + +status_t quant_entry_t::set_zero_points(const dims_t dims, int ndims, data_type_t data_type) { + type_ = type_ | OV_ZERO_POINTS; + is_set_wei = true; + ndims_wei = ndims; + mask_wei = 1; + if (ndims_wei > 0) { + utils::array_copy(dims_wei, dims, ndims_wei); + } + data_type_wei = data_type; + return status::success; +} + +status_t quant_entry_t::set(const quant_entry_t &other) { + type_ = other.type_; + is_set_ = other.is_set_; + mask_ = other.mask_; + data_type_ = other.data_type_; + group_ndims_ = other.group_ndims_; + if(group_ndims_ > 0) + utils::array_copy(group_dims_, other.group_dims_, group_ndims_); + is_set_scale = other.is_set_scale; + mask_scale = other.mask_scale; + data_type_scale = other.data_type_scale; + ndims_scale = other.ndims_scale; + if (ndims_scale > 0) + utils::array_cmp(dims_scale, other.dims_scale, ndims_scale); + is_set_wei = other.is_set_wei; + mask_wei = other.mask_wei; + data_type_wei = other.data_type_wei; + ndims_wei = other.ndims_wei; + if(ndims_wei > 0) + utils::array_cmp(dims_wei, other.dims_wei, ndims_wei); + return status::success; +} +int quant_entry_t::get_mask() const { + if (is_set_wei) return mask_wei; + if (is_set_) return mask_; + if (is_set_scale) return mask_scale; + return INT_MIN; +} +data_type_t quant_entry_t::get_data_type() const { + if (is_set_wei) return data_type_wei; + if (is_set_) return data_type_; + if (is_set_scale) return data_type_scale; + return data_type::undef; +} +const dims_t& quant_entry_t::get_dims() const { + if (is_set_wei) return dims_wei; + if (is_set_) return group_dims_; + if (is_set_scale) return dims_scale; + static const dims_t result = {}; + return result; +} + +int quant_entry_t::get_ndims() const { + if (is_set_wei) return ndims_wei; + if (is_set_) return group_ndims_; + if (is_set_scale) return ndims_scale; + return 0; +} +// Note: keep the definition here to satisfy the +// `gtests/internals/test_comparison_operators` linking requirements which +// mandates bodies to be in the header file. +bool quant_entry_t::operator==(const quant_entry_t &rhs) const { + bool result = (type_ == rhs.type_ && is_set_ == rhs.is_set_ + && mask_ == rhs.mask_ + && data_type_ == rhs.data_type_ + && group_ndims_ == rhs.group_ndims_ + && IMPLICATION(group_ndims_ > 0, + utils::array_cmp( + group_dims_, rhs.group_dims_, group_ndims_))); + + if (!result) return false; + result = (is_set_scale == rhs.is_set_scale + && mask_scale == rhs.mask_scale + && data_type_scale == rhs.data_type_scale + && ndims_scale == rhs.ndims_scale + && IMPLICATION(ndims_scale > 0, + utils::array_cmp( + dims_scale, rhs.dims_scale, ndims_scale))); + + if (!result) return false; + result = (is_set_wei == rhs.is_set_wei + && mask_wei == rhs.mask_wei + && data_type_wei == rhs.data_type_wei + && ndims_wei == rhs.ndims_wei + && IMPLICATION(ndims_wei > 0, + utils::array_cmp( + dims_wei, rhs.dims_wei, ndims_wei))); + return result; +} +status_t quant_entries_t::set_scales(int arg, const dims_t dims, int ndims, data_type_t data_type) { + if (!check_arg(arg)) return status::invalid_arguments; + CHECK(entries_[arg].set_scales(dims, ndims, data_type)); + return status::success; +} +status_t quant_entries_t::set_zero_points(int arg, const dims_t dims, int ndims, data_type_t data_type) { + if (arg != DNNL_ARG_WEIGHTS) return status::unimplemented; + CHECK(entries_[arg].set_zero_points(dims, ndims, data_type)); + return status::success; +} +status_t zero_points_t::set(int arg, int mask, data_type_t data_type, int group_ndims, + const dims_t group_dims) { + if (!check_arg(arg)) return status::invalid_arguments; + if (arg == DNNL_ARG_WEIGHTS) { + CHECK(entries_[arg].set_zero_points(group_dims, group_ndims, data_type, mask)); + } else { + CHECK(entries_[arg].set(mask, data_type, group_ndims, group_dims)); + } + return status::success; +} + +} // namespace impl +} // namespace dnnl diff --git a/src/common/primitive_attr_quant.hpp b/src/common/primitive_attr_quant.hpp new file mode 100644 index 00000000000..de541f2a909 --- /dev/null +++ b/src/common/primitive_attr_quant.hpp @@ -0,0 +1,378 @@ +/******************************************************************************* +* Copyright 2024-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_PRIMITIVE_ATTR_QUANT_HPP +#define COMMON_PRIMITIVE_ATTR_QUANT_HPP + +// NOTE: Objects declared in this header are moved out from primitive_attr.hpp due +// to micro_sdpa primitive. Int8 support requires at least two primitive_attr +// objects to be used inside sdpa_desc_t object which triggers a deleted +// copy-ctor of primitive_attr_t, which is there because of RNN scales still +// rely on static scales and manage dynamically-allocated memory. +// +// As a result, micro_sdpa uses scales and zero-points objects directly and +// requires a dedicated header for that, otherwise, it's going to be a circular +// dependency between headers when it comes to inclusion of opdesc.hpp which +// sdpa_desc_t is a part of. + +#include "common/serialization.hpp" +#include "common/utils.hpp" + +#include +#include +#include +#include + +namespace dnnl { +namespace impl { + +struct quant_entry_t; +const quant_entry_t &default_quant_entry(); + +struct quant_entry_t : public c_compatible { + quant_entry_t() = default; + + // `set(...)` approach is taken over constructors as the usage model assumes + // the change of state of this object but it doesn't require its destruction + // which would come with some performance price which prevails in this case. + status_t set(int mask, data_type_t data_type) { + return set(mask, data_type, 0, {}); + } + status_t set(int mask, data_type_t data_type, int group_ndims, + const dims_t group_dims); + status_t set_scales(const dims_t dims, int ndims, data_type_t data_type = data_type::f32, int mask = 1); + status_t set_zero_points(const dims_t dims, int ndims, data_type_t data_type); + status_t set_zero_points(const dims_t dims, int ndims, data_type_t data_type, int mask); + status_t set(const quant_entry_t &other); + quant_entry_t &operator=(const quant_entry_t &rhs) { + auto st = this->set(rhs); + assert(st == status::success); + UNUSED(st); + return *this; + } + bool has_default_values() const { return *this == default_quant_entry(); } + bool has_default_groups() const { + return this->group_ndims_ == default_quant_entry().group_ndims_; + } + int get_mask() const; + data_type_t get_data_type() const; + const dims_t& get_dims() const; + int get_ndims() const; + dim_t get_group(int d) const { + // If groups were not requested, return `1` for convenience. + if (group_ndims_ == default_quant_entry().group_ndims_) return 1; + // But if they were, any out of bound access would return `0` and likely + // lead to a division by zero which is fast to catch. + if (d >= group_ndims_) return 0; + return group_dims_[d]; + } + + // Note: keep the definition here to satisfy the + // `gtests/internals/test_comparison_operators` linking requirements which + // mandates bodies to be in the header file. + bool operator==(const quant_entry_t &rhs) const; + size_t get_hash() const; + + void serialize(serialization_stream_t &sstream) const; + + static quant_entry_t deserialize(deserializer_t &d); + + std::string get_verbose() const; + +private: + data_type_t data_type_ = data_type::undef; + int group_ndims_ = 0; + dims_t group_dims_ {}; + // Note: INT_MIN is used on purpose to avoid potential issues when + // `(mask & bit)` expression will return `true`. `INT_MIN` is represented + // as `10...0` in bits and will avoid such situations. + int mask_ = INT_MIN; + bool is_set_ = false; + // openvino extension + enum entry_type { + NONE = 0, + DNNL = 1, + OV_SCALES = 2, + OV_ZERO_POINTS = 4 + }; + int type_ = NONE; + // scale + bool is_set_scale = false; + int ndims_scale = 0; + int mask_scale = INT_MIN; + dims_t dims_scale {}; + data_type_t data_type_scale = data_type::undef; + // zero_point + bool is_set_wei = false; + int ndims_wei = 0; + int mask_wei = INT_MIN; + dims_t dims_wei {}; + data_type_t data_type_wei = data_type::s32; +}; + +std::ostream &operator<<(std::ostream &ss, const quant_entry_t &e); + +struct quant_entries_t : public c_compatible { + quant_entries_t(data_type_t default_data_type) + : default_data_type_(default_data_type) {} + + const quant_entry_t &get(int arg) const { + const auto it = entries_.find(arg); + if (it == entries_.end()) return default_quant_entry(); + return it->second; + } + + // See `set(...)` comment for `quant_entry_t` for a design choice + // explanation. + virtual status_t set(int arg, int mask) { + return set(arg, mask, default_data_type_, 0, {}); + } + const dims_t & get_dims(int arg) const { + return get(arg).get_dims(); + } + int get_ndims(int arg) const { + return get(arg).get_ndims(); + } + virtual status_t set(int arg, int mask, data_type_t data_type, int group_ndims, + const dims_t group_dims) { + if (!check_arg(arg)) return status::invalid_arguments; + CHECK(entries_[arg].set(mask, data_type, group_ndims, group_dims)); + return status::success; + } + status_t set_scales(int arg, const dims_t dims, int ndims, data_type_t data_type = data_type::f32); + status_t set_zero_points(int arg, const dims_t dims, int ndims, data_type_t data_type); + + // Use this interface with `default_quant_entry` when need to remove a + // specific entry. + virtual status_t set(int arg, const quant_entry_t &other) { + return entries_[arg].set(other); + } + + // This interface is different from the one below and is just a shortcut. + bool has_default_values(int arg) const { + return get(arg).has_default_values(); + } + + // This interface is used to make sure that other than `supported_args` have + // default values. It's to make sure that non-allowed arguments were not + // passed to the library. + bool has_default_values(const std::vector &supported_args = {}) const { + auto predicate + = [](const quant_entry_t &s) { return s.has_default_values(); }; + return has_default_property(supported_args, predicate); + } + + // This interface checks specific argument. It exists because quant_entry_t + // doesn't have a notion of default data_type, only this object does. + // Note: can be removed once the library unconditionally supports data type + // for scales/zero-points for every implementation, then this call can be + // removed as to make a proper load, the data type must be queried. + bool has_default_data_type(int arg) const { + // Note: `data_type::undef` represents `default_quant_entry`. + return utils::one_of( + get(arg).get_data_type(), default_data_type_, data_type::undef); + } + + // This interface is different from the one below and is just a shortcut. + bool has_default_groups(int arg) const { + return get(arg).has_default_groups(); + } + + // This interface is used to make sure that other than `supported_args` have + // default values. It's to make sure that non-allowed arguments were not + // passed to the library. + bool has_default_groups(const std::vector &supported_args = {}) const { + auto predicate + = [](const quant_entry_t &s) { return s.has_default_groups(); }; + return has_default_property(supported_args, predicate); + } + + int get_mask(int arg) const { return get(arg).get_mask(); } + data_type_t get_data_type(int arg) const { + return get(arg).get_data_type(); + } + dim_t get_group(int arg, int d) const { return get(arg).get_group(d); } + + bool operator==(const quant_entries_t &rhs) const { + return entries_ == rhs.entries_; + } + + size_t get_hash() const; + + void serialize(serialization_stream_t &sstream) const; + + std::string get_verbose() const; + +protected: + // Sorted property of `std::map` is used for hashing. + std::map entries_; + // Value is different depending on the inheritor. + data_type_t default_data_type_ = data_type::undef; + + virtual bool check_arg(int arg) const = 0; + + // The function makes sure that if any argument was specified by user, that + // only `supported_args` have their value customized, rest unsupported + // values were not updated. + bool has_default_property(const std::vector &supported_args, + bool (*predicate)(const quant_entry_t &)) const { + for (const auto &s : entries_) { + // Arg passed the condition, check the next one. + if (predicate(s.second)) continue; + + bool allow_non_default = false; + for (const auto &supported_arg : supported_args) + if (s.first == supported_arg) { + allow_non_default = true; + break; + } + if (allow_non_default) continue; + return false; + } + return true; + } +}; + +struct scales_t : public quant_entries_t { + scales_t() : quant_entries_t(default_data_type_) {}; + + // This interface checks the content of all entries, and allows to ignore + // certain arguments. + // Note: can't be put in `quant_entries_t` because `default_data_type_` is + // not a static member, but `has_default_property` requires `predicate` + // to have it this way. + bool has_default_data_type( + const std::vector &supported_args = {}) const { + auto predicate = [](const quant_entry_t &s) { + // Note: `data_type::undef` represents `default_quant_entry`. + return utils::one_of( + s.get_data_type(), default_data_type_, data_type::undef); + }; + return has_default_property(supported_args, predicate); + } + // Note: must present as compiler doesn't see an overloaded version inside a + // base class. + bool has_default_data_type(int arg) const { + return quant_entries_t::has_default_data_type(arg); + } + + static scales_t deserialize(deserializer_t &d); + +private: + static constexpr data_type_t default_data_type_ = data_type::f32; + + bool check_arg(int arg) const override { + // regular + for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { + if (arg == sa) return true; + } + // binary + for (const auto &sa : {DNNL_ARG_SRC_1}) { + if (arg == sa) return true; + } + // concat + if (arg & DNNL_ARG_MULTIPLE_SRC) return true; + // depth-wise convolution post op + for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | sa)) return true; + } + // sdpa + if (arg == DNNL_ARG_SRC_2) return true; + return false; + } +}; + +struct zero_points_t : public quant_entries_t { + zero_points_t() : quant_entries_t(default_data_type_) {}; + + // This interface checks the content of all entries, and allows to ignore + // certain arguments. + // Note: can't be put in `quant_entries_t` because `default_data_type_` is + // not a static member, but `has_default_property` requires `predicate` + // to have it this way. + bool has_default_data_type( + const std::vector &supported_args = {}) const { + auto predicate = [](const quant_entry_t &s) { + // Note: `data_type::undef` represents `default_quant_entry`. + return utils::one_of( + s.get_data_type(), default_data_type_, data_type::undef); + }; + return has_default_property(supported_args, predicate); + } + // Note: must present as compiler doesn't see an overloaded version inside a + // base class. + bool has_default_data_type(int arg) const { + return quant_entries_t::has_default_data_type(arg); + } + + static zero_points_t deserialize(deserializer_t &d); + status_t set(int arg, int mask) override { + return quant_entries_t::set(arg, mask, default_data_type_, 0, {}); + } + status_t set(int arg, int mask, data_type_t data_type, int group_ndims, + const dims_t group_dims) override; + + status_t set(int arg, const quant_entry_t &other) override { + return quant_entries_t::set(arg, other); + } +private: + static constexpr data_type_t default_data_type_ = data_type::s32; + + bool check_arg(int arg) const override { + // regular + // gemm internal primitive would use DNNL_ARG_A, DNNL_ARG_B, DNNL_ARG_C, + // which match to DNNL_ARG_WEIGHTS, DNNL_ARG_SRC, DNNL_ARG_DST. They + // are defined in gpu internals, thus, not spelled here. + for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { + if (arg == sa) return true; + } + // sdpa + if (arg == DNNL_ARG_SRC_2) return true; + return false; + } +}; + +struct src_dyn_quant_params_t : public c_compatible { + src_dyn_quant_params_t() : group_size_(0) {} + bool has_default_values() const { + return (group_size_ == 0); + } + bool defined() const { + return true; + } + + status_t set(uint64_t group_size) { + group_size_ = group_size; + return status::success; + } + + uint64_t get() const { + return group_size_; + } + + bool operator==(const src_dyn_quant_params_t &rhs) const { + using namespace utils; + return group_size_ == rhs.group_size_; + } + +private: + uint64_t group_size_; +}; + +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/common/primitive_cache.cpp b/src/common/primitive_cache.cpp index 31f506c814c..5264b12a0d8 100644 --- a/src/common/primitive_cache.cpp +++ b/src/common/primitive_cache.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -133,6 +133,18 @@ primitive_cache_iface_t::result_t primitive_cache_iface_t::get_or_create( return {std::move(r.value), r.status}; } +status_t set_primitive_cache_capacity( + int primitive_capacity, int kernel_capacity) { + if (primitive_capacity < 0 || kernel_capacity < 0) + return status::invalid_arguments; +#ifndef DNNL_DISABLE_PRIMITIVE_CACHE + auto status = global_primitive_cache().set_capacity(primitive_capacity); + CHECK(status); + return kernel_cache::get().set_capacity(kernel_capacity); +#endif + return status::success; +} + } // namespace impl } // namespace dnnl @@ -148,11 +160,5 @@ dnnl::impl::status_t dnnl_get_primitive_cache_capacity(int *capacity) { } dnnl::impl::status_t dnnl_set_primitive_cache_capacity(int capacity) { - if (capacity < 0) return dnnl::impl::status::invalid_arguments; -#ifndef DNNL_DISABLE_PRIMITIVE_CACHE - auto status = dnnl::impl::global_primitive_cache().set_capacity(capacity); - if (status != dnnl::impl::status::success) return status; - return dnnl::impl::kernel_cache::get().set_capacity(capacity); -#endif - return dnnl::impl::status::success; + return dnnl::impl::set_primitive_cache_capacity(capacity, capacity); } diff --git a/src/common/primitive_cache.hpp b/src/common/primitive_cache.hpp index bab4a8155dd..25f90e8fbf2 100644 --- a/src/common/primitive_cache.hpp +++ b/src/common/primitive_cache.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,8 @@ struct primitive_cache_iface_t { }; primitive_cache_iface_t primitive_cache(); +status_t set_primitive_cache_capacity( + int primitive_capacity, int kernel_capacity); // Undocumented API for testing. status_t DNNL_API get_primitive_cache_size(int *size); diff --git a/src/common/primitive_desc.hpp b/src/common/primitive_desc.hpp index 0939b28a6b2..69b6e52487b 100644 --- a/src/common/primitive_desc.hpp +++ b/src/common/primitive_desc.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ #include "cache_blob.hpp" #include "cache_blob_id.hpp" #include "cache_hit_types.hpp" +#include "dnnl_sel_build.hpp" #include "memory_tracking.hpp" #include "nstl.hpp" #include "opdesc.hpp" @@ -47,6 +48,7 @@ static int po_inputs(const post_ops_t &post_ops, const primitive_kind_t kind) { struct impl_list_item_t; struct primitive_t; // Primitive descriptor implementation +// NOLINTBEGIN(google-default-arguments) struct primitive_desc_t : public c_compatible { primitive_desc_t(const primitive_attr_t *attr, primitive_kind_t kind) : attr_(*attr), kind_(kind), pd_iterator_offset_(0), skip_idx_(-1) { @@ -80,7 +82,7 @@ struct primitive_desc_t : public c_compatible { // doesn't require any special handling since `get_verbose` is `false`. std::string info_with_runtime_dims(engine_t *engine, const memory_desc_t *src_md, const memory_desc_t *wei_md, - const memory_desc_t *bia_md, const memory_desc_t *dst_md) { + const memory_desc_t *bia_md, const memory_desc_t *dst_md) const { std::string info_str = info(engine); // Matmul and reorder are the only primitives supporting runtime dims. @@ -150,45 +152,60 @@ struct primitive_desc_t : public c_compatible { enum class arg_usage_t { unused, input, output }; virtual arg_usage_t arg_usage(int arg) const { using types::is_zero_md; - if (arg == DNNL_ARG_ATTR_OUTPUT_SCALES - && !attr()->output_scales_.defined()) + + if ((arg & (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC)) + && !attr()->input_zero_points_.has_default_values()) + return arg_usage_t::input; + if ((arg & (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS)) + && !attr()->weights_zero_points_.has_default_values()) + return arg_usage_t::input; + if ((arg & (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST)) + && !attr()->output_compensations_.has_default_values() + && arg != DNNL_ARG_SCRATCHPAD) return arg_usage_t::input; + if (arg & DNNL_ARG_ATTR_ZERO_POINTS) { int zp_arg = arg & ~DNNL_ARG_ATTR_ZERO_POINTS; - if (!attr()->zero_points_.defined(zp_arg)) - return arg_usage_t::input; + return !attr()->zero_points_.has_default_values(zp_arg) + ? arg_usage_t::input + : arg_usage_t::unused; } if (arg & DNNL_ARG_ATTR_SCALES) { int scale_arg = arg & ~DNNL_ARG_ATTR_SCALES; - if (!attr()->scales_.get(scale_arg).defined()) - return arg_usage_t::input; + return !attr()->scales_.has_default_values(scale_arg) + ? arg_usage_t::input + : arg_usage_t::unused; } - if ((arg == (DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0)) - && !attr()->scales_.get(DNNL_ARG_SRC_0).defined()) - return arg_usage_t::input; - if ((arg == (DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1)) - && !attr()->scales_.get(DNNL_ARG_SRC_1).defined()) - return arg_usage_t::input; - if (arg == DNNL_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) - return arg_usage_t::output; - if (arg == DNNL_ARG_ATTR_DROPOUT_MASK - && !attr()->dropout_.has_default_values()) - return arg_usage_t::output; - if ((arg == DNNL_ARG_ATTR_DROPOUT_PROBABILITY - || arg == DNNL_ARG_ATTR_DROPOUT_SEED) - && !attr()->dropout_.has_default_values()) - return arg_usage_t::input; - if ((arg == DNNL_ARG_ATTR_ROUNDING_SEED) - && !attr()->rounding_mode_.has_default_values()) - return arg_usage_t::input; + + if (arg == DNNL_ARG_SCRATCHPAD) + return !is_zero_md(scratchpad_md()) ? arg_usage_t::output + : arg_usage_t::unused; + if (arg == DNNL_ARG_ATTR_DROPOUT_MASK) + return !attr()->dropout_.has_default_values() ? arg_usage_t::output + : arg_usage_t::unused; + if (arg == DNNL_ARG_ATTR_DROPOUT_PROBABILITY) + return !attr()->dropout_.has_default_values() ? arg_usage_t::input + : arg_usage_t::unused; + if (arg == DNNL_ARG_ATTR_DROPOUT_SEED) + return !attr()->dropout_.has_default_values() ? arg_usage_t::input + : arg_usage_t::unused; + if (arg == DNNL_ARG_ATTR_ROUNDING_SEED) + return !attr()->rounding_mode_.has_default_values() + ? arg_usage_t::input + : arg_usage_t::unused; + for (int idx = 0; idx < attr()->post_ops_.len(); ++idx) { using namespace primitive_kind; - if (post_op_has_proper_input( - attr(), binary, idx, arg, DNNL_ARG_SRC_1) - || post_op_has_proper_input( - attr(), prelu, idx, arg, DNNL_ARG_WEIGHTS)) + if (post_op_has_proper_input(attr(), binary, idx, arg, DNNL_ARG_SRC_1) || + post_op_has_proper_input(attr(), depthwise, idx, arg, DNNL_ARG_SRC_1) || + post_op_has_proper_input(attr(), quantization, idx, arg, DNNL_ARG_SRC_1) || + post_op_has_proper_input(attr(), prelu, idx, arg, DNNL_ARG_WEIGHTS)) return arg_usage_t::input; } + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) + return arg_usage_t::input; + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)) + return arg_usage_t::input; return arg_usage_t::unused; } @@ -354,6 +371,15 @@ struct primitive_desc_t : public c_compatible { int n_prelu_po_inputs() const { return po_inputs(attr()->post_ops_, primitive_kind::prelu); } + + int n_depthwise_po_inputs() const { + return po_inputs(attr()->post_ops_, primitive_kind::depthwise); + } + + int n_quantization_po_inputs() const { + return po_inputs(attr()->post_ops_, primitive_kind::quantization); + } + // The `hint_mds(bool is_hint)` returns a vector of memory descriptors // that might affect the equality of primitive descriptors for backward pass. // @@ -438,7 +464,6 @@ struct primitive_desc_t : public c_compatible { memory_tracking::registry_t scratchpad_registry_; -protected: void init_pd_iterator_offset(int offset) { pd_iterator_offset_ = offset; } void init_skip_idx(int skip_idx) { skip_idx_ = skip_idx; } @@ -460,11 +485,11 @@ struct primitive_desc_t : public c_compatible { /** the only reason why this class is here is the inability of * utils::make_unique() to operate on protected parent classes * of the derivative pd_t's; compilers should optimize it out */ - class pd_t_compat : public pd_t { + class pd_compat_t : public pd_t { public: - pd_t_compat(Args &&...args) : pd_t(std::forward(args)...) {} + pd_compat_t(Args &&...args) : pd_t(std::forward(args)...) {} }; - return utils::make_unique(std::forward(args)...); + return utils::make_unique(std::forward(args)...); } template @@ -472,13 +497,11 @@ struct primitive_desc_t : public c_compatible { const primitive_attr_t *attr, engine_t *engine, const primitive_desc_t *hint_fwd) { using namespace dnnl::impl::status; - using pd_op_desc_t = typename pkind_traits::desc_type; - if (adesc->kind != pd_t::base_pkind) return invalid_arguments; + if (adesc->primitive_kind != pd_t::base_pkind) return invalid_arguments; assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); auto hint = reinterpret_cast(hint_fwd); - auto _pd - = make_unique_pd((const pd_op_desc_t *)adesc, attr, hint); + auto _pd = make_unique_pd(adesc, attr, hint); if (_pd == nullptr) return out_of_memory; if (!_pd->is_initialized()) return out_of_memory; CHECK(_pd->init(engine)); @@ -488,6 +511,7 @@ struct primitive_desc_t : public c_compatible { friend struct dnnl::impl::impl_list_item_t; }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl @@ -503,6 +527,7 @@ struct primitive_desc_t : public c_compatible { &primitive, \ dnnl::impl::engine_t *engine, const cache_blob_t &cache_blob) \ const override { \ + DNNL_PRIMITIVE_CREATE(pd_t) \ return primitive_t::create_primitive_common( \ primitive, this, engine, use_global_scratchpad, cache_blob); \ } \ diff --git a/src/common/primitive_desc_iface.cpp b/src/common/primitive_desc_iface.cpp index f359a72e6c3..263f29bff66 100644 --- a/src/common/primitive_desc_iface.cpp +++ b/src/common/primitive_desc_iface.cpp @@ -37,7 +37,7 @@ status_t primitive_desc_create(primitive_desc_iface_t **primitive_desc_iface, if (!primitive_desc_iface) return invalid_arguments; - const bool known_primitive_kind = utils::one_of(op_desc->kind, + const bool known_primitive_kind = utils::one_of(op_desc->primitive_kind, batch_normalization, binary, convolution, deconvolution, eltwise, gemm, group_normalization, inner_product, layer_normalization, lrn, matmul, pooling, prelu, reduction, resampling, rnn, sdpa, shuffle, diff --git a/src/common/primitive_desc_iterator.hpp b/src/common/primitive_desc_iterator.hpp index 39fe8f51838..096d20642a6 100644 --- a/src/common/primitive_desc_iterator.hpp +++ b/src/common/primitive_desc_iterator.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,26 +38,19 @@ struct primitive_desc_iterator_t : public c_compatible { int skip_idx = -1) : idx_(-1) , engine_(engine) - , op_desc_(nullptr) + , op_desc_(op_desc->clone()) , attr_(attr ? *attr : primitive_attr_t()) , hint_fwd_pd_(hint_fwd_pd) - , impl_list_(nullptr) + , impl_list_(engine_->get_implementation_list(op_desc_.get())) , last_idx_(0) , skip_idx_(skip_idx) , offset_(-1) { - op_desc_ = (op_desc_t *)std::malloc(sizeof(op_desc_t)); - copy_c_op_desc(op_desc_, op_desc); - - impl_list_ = engine_->get_implementation_list(op_desc_); - while (impl_list_[last_idx_]) ++last_idx_; is_initialized_ = is_initialized_ && attr_.is_initialized(); } - ~primitive_desc_iterator_t() { std::free(op_desc_); } - engine_t *engine() const { return engine_; } bool operator==(const primitive_desc_iterator_t &rhs) const { @@ -82,7 +75,7 @@ struct primitive_desc_iterator_t : public c_compatible { std::vector hint_mds; if (hint_fwd_pd_) hint_mds = hint_fwd_pd_->hint_mds(true /* is_hint */); primitive_hashing::key_t key( - engine_, op_desc_, &attr_, offset_, hint_mds, skip_idx_); + engine_, op_desc_.get(), &attr_, offset_, hint_mds, skip_idx_); pd_ = primitive_cache().get_pd(key); if (pd_) { return *this; } @@ -90,8 +83,8 @@ struct primitive_desc_iterator_t : public c_compatible { while (++idx_ != last_idx_) { if (idx_ == skip_idx_) continue; primitive_desc_t *candidate_pd = nullptr; - auto s = impl_list_[idx_](&candidate_pd, op_desc_, &attr_, engine_, - hint_fwd_pd_, offset_, skip_idx_); + auto s = impl_list_[idx_](&candidate_pd, op_desc_.get(), &attr_, + engine_, hint_fwd_pd_, offset_, skip_idx_); if (s == status::success) { pd_.reset(candidate_pd); break; @@ -110,7 +103,7 @@ struct primitive_desc_iterator_t : public c_compatible { int idx_; engine_t *engine_; std::shared_ptr pd_; - op_desc_t *op_desc_; + std::unique_ptr op_desc_; const primitive_attr_t attr_; const primitive_desc_t *hint_fwd_pd_; const impl_list_item_t *impl_list_; @@ -122,7 +115,6 @@ struct primitive_desc_iterator_t : public c_compatible { primitive_desc_iterator_t(engine_t *engine, int last_idx) : idx_(last_idx) , engine_(engine) - , op_desc_(nullptr) , hint_fwd_pd_(nullptr) , impl_list_(nullptr) , last_idx_(last_idx) @@ -133,7 +125,7 @@ struct primitive_desc_iterator_t : public c_compatible { : idx_(other.idx_) , engine_(other.engine_) , pd_(std::move(other.pd_)) - , op_desc_(other.op_desc_) + , op_desc_(std::move(other.op_desc_)) , attr_(other.attr_) , hint_fwd_pd_(other.hint_fwd_pd_) , impl_list_(other.impl_list_) diff --git a/src/common/primitive_exec_types.cpp b/src/common/primitive_exec_types.cpp index db9d582ebaa..4f1c3e6295b 100644 --- a/src/common/primitive_exec_types.cpp +++ b/src/common/primitive_exec_types.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,8 +49,7 @@ status_t cvt_primitive_args(const primitive_desc_t *pd, int nargs, case primitive_desc_t::arg_usage_t::input: args[arg] = {mem, true}; n_inputs++; - extra_inputs += (arg == DNNL_ARG_ATTR_OUTPUT_SCALES) - || (arg & DNNL_ARG_ATTR_ZERO_POINTS) + extra_inputs += (arg & DNNL_ARG_ATTR_ZERO_POINTS) || (arg & DNNL_ARG_ATTR_SCALES) // 1x1 + dw conv fusion || (arg @@ -136,7 +135,7 @@ void *exec_ctx_t::host_ptr( if (do_zeropad) status = mem->zero_pad(*this); if (status_) *status_ = status; - auto *mem_storage = mem->memory_storage(index); + auto *mem_storage = mem->memory_storage(); return host_ptr(mem_storage); } diff --git a/src/common/primitive_exec_types.hpp b/src/common/primitive_exec_types.hpp index f52e3399ae2..c7c9c2ac75f 100644 --- a/src/common/primitive_exec_types.hpp +++ b/src/common/primitive_exec_types.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,9 +25,22 @@ #include "memory.hpp" #include "memory_storage.hpp" -#define CTX_IN_STORAGE(arg) \ +// __VA_ARGS__here is an index of the buffer. It is empty unless the memory +// argument is sparse. +#define CTX_IN_STORAGE(arg, ...) CTX_IN_STORAGe##__VA_ARGS__(arg) + +#define CTX_IN_STORAGe(arg) \ (ctx.input(arg) ? *(ctx.input(arg)->memory_storage()) \ : dnnl::impl::memory_storage_t::empty_storage()) +#define CTX_IN_STORAGe0(arg) \ + (ctx.input(arg) ? *ctx.input(arg)->memory_storage(0) \ + : dnnl::impl::memory_storage_t::empty_storage()) +#define CTX_IN_STORAGe1(arg) \ + (ctx.input(arg) ? *ctx.input(arg)->memory_storage(1) \ + : dnnl::impl::memory_storage_t::empty_storage()) +#define CTX_IN_STORAGe2(arg) \ + (ctx.input(arg) ? *ctx.input(arg)->memory_storage(2) \ + : dnnl::impl::memory_storage_t::empty_storage()) // Returns destination memory which may not have been zero pad initialized. #define CTX_OUT_STORAGE(arg) \ diff --git a/src/common/primitive_hashing.cpp b/src/common/primitive_hashing.cpp index 97b657607ec..99a6239ad9c 100644 --- a/src/common/primitive_hashing.cpp +++ b/src/common/primitive_hashing.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ * limitations under the License. *******************************************************************************/ +#include +#include "primitive_attr.hpp" #include "primitive_desc.hpp" #include "type_helpers.hpp" #include "utils.hpp" @@ -29,7 +31,7 @@ namespace primitive_hashing { key_t::key_t(const engine_t *engine, const op_desc_t *op_desc, const primitive_attr_t *attr, int pd_iterator_offset, const std::vector &hint_mds, int skip_idx) - : primitive_kind_(op_desc->kind) + : primitive_kind_(op_desc->primitive_kind) , op_desc_(op_desc) , attr_(attr) , pd_iterator_offset_(pd_iterator_offset) @@ -53,22 +55,32 @@ bool key_t::operator==(const key_t &rhs) const { && hint_mds_.size() == rhs.hint_mds_.size() && pd_iterator_offset_ == rhs.pd_iterator_offset_ && impl_nthr_ == rhs.impl_nthr_ - && skip_idx_ == rhs.skip_idx_ - && (*attr_) == (*rhs.attr_); - - if (!ret) return false; + && skip_idx_ == rhs.skip_idx_ + && (*attr_) == (*rhs.attr_) + && std::equal( + hint_mds_.begin(), hint_mds_.end(), rhs.hint_mds_.begin()); + + if (!ret) { + // ANCHOR: HASHING_DEBUGINFO_16. + VDEBUGINFO(16, primitive, hashing, "operator==,ret=%d", ret); + return ret; + } #define CASE(pkind) \ case primitive_kind::pkind: \ - ret = cast_to_desc(op_desc_) \ - == cast_to_desc(rhs.op_desc_); \ + ret = *op_desc_t::to_desc(op_desc_) \ + == *op_desc_t::to_desc(rhs.op_desc_); \ break; switch ((int)primitive_kind_) { CASE(batch_normalization) CASE(binary) CASE(concat) - CASE(convolution) + // Use a custom comparison function that ignores alg_kind. + case primitive_kind::convolution: + ret = compare_conv_opdesc(*op_desc_t::to_desc(op_desc_), + *op_desc_t::to_desc(rhs.op_desc_)); + break; CASE(deconvolution) CASE(eltwise) CASE(gemm) @@ -93,231 +105,9 @@ bool key_t::operator==(const key_t &rhs) const { #undef CASE // clang-format on - if (!ret) return false; - - for (size_t i = 0; i < hint_mds_.size(); ++i) - if (hint_mds_[i] != rhs.hint_mds_[i]) return false; - - return true; -} - -// Combine hash of each memory_desc_t data member -size_t get_md_hash(const memory_desc_t &md) { - size_t seed = 0; - seed = get_array_hash(seed, md.dims, md.ndims); - seed = hash_combine(seed, static_cast(md.data_type)); - seed = get_array_hash(seed, md.padded_dims, md.ndims); - seed = get_array_hash(seed, md.padded_offsets, md.ndims); - seed = hash_combine(seed, md.offset0); - seed = hash_combine(seed, static_cast(md.format_kind)); - // format desc - switch ((int)md.format_kind) { - case format_kind::undef: - case format_kind::any: break; - case format_kind::blocked: - for (int i = 0; i < md.ndims; i++) { - if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue; - seed = hash_combine(seed, md.format_desc.blocking.strides[i]); - } - seed = hash_combine(seed, md.format_desc.blocking.inner_nblks); - seed = get_array_hash(seed, md.format_desc.blocking.inner_blks, - md.format_desc.blocking.inner_nblks); - seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs, - md.format_desc.blocking.inner_nblks); - break; - case format_kind::wino: - seed = hash_combine(seed, - static_cast(md.format_desc.wino_desc.wino_format)); - seed = hash_combine(seed, md.format_desc.wino_desc.r); - seed = hash_combine(seed, md.format_desc.wino_desc.alpha); - seed = hash_combine(seed, md.format_desc.wino_desc.ic); - seed = hash_combine(seed, md.format_desc.wino_desc.oc); - seed = hash_combine(seed, md.format_desc.wino_desc.ic_block); - seed = hash_combine(seed, md.format_desc.wino_desc.oc_block); - seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block); - seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block); - seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale); - seed = hash_combine(seed, md.format_desc.wino_desc.size); - break; - case format_kind::rnn_packed: - seed = hash_combine(seed, - static_cast(md.format_desc.rnn_packed_desc.format)); - seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts); - seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n); - seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb); - { - int n_parts = md.format_desc.rnn_packed_desc.n_parts; - seed = get_array_hash( - seed, md.format_desc.rnn_packed_desc.parts, n_parts); - seed = get_array_hash(seed, - md.format_desc.rnn_packed_desc.part_pack_size, n_parts); - seed = get_array_hash(seed, - md.format_desc.rnn_packed_desc.pack_part, n_parts); - } - seed = hash_combine( - seed, md.format_desc.rnn_packed_desc.offset_compensation); - seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size); - break; -#ifdef DNNL_EXPERIMENTAL_SPARSE - case format_kind::sparse: - seed = hash_combine(seed, - static_cast(md.format_desc.sparse_desc.encoding)); - seed = hash_combine(seed, md.format_desc.sparse_desc.nnz); - seed = get_array_hash(seed, - md.format_desc.sparse_desc.metadata_types, - sparse_desc_t::max_metadata_types); - // User cannot initialize `packed_desc` therefore `packed_desc` - // is always zero initialized. - break; -#endif - default: assert(!"unknown format_kind"); - } - - if (md.extra.flags != dnnl_memory_extra_flag_none) { - seed = hash_combine(seed, md.extra.flags); - if ((md.extra.flags - & (dnnl_memory_extra_flag_compensation_conv_s8s8 - | dnnl_memory_extra_flag_rnn_u8s8_compensation)) - && !types::extra_flag_rnn_s8s8_compensation_is_set( - md.extra.flags)) { - seed = hash_combine(seed, md.extra.compensation_mask); - } - - if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) { - seed = hash_combine(seed, md.extra.scale_adjust); - } - - if (md.extra.flags - & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) { - seed = hash_combine(seed, md.extra.asymm_compensation_mask); - } - } - // Combined hash for a memory descriptor - return seed; -} - -// Combine hash of each primitive_attr_t data member -size_t get_attr_hash(const primitive_attr_t &attr) { - size_t seed = 0; - // scratchpad_mode - seed = hash_combine(seed, static_cast(attr.scratchpad_mode_)); - // fpmath_mode - seed = hash_combine(seed, static_cast(attr.fpmath_.mode_)); - seed = hash_combine(seed, static_cast(attr.fpmath_.apply_to_int_)); - // deterministic - seed = hash_combine(seed, static_cast(attr.deterministic_)); - // acc_mode - seed = hash_combine(seed, static_cast(attr.acc_mode_)); - // rounding_mode - if (!attr.rounding_mode_.has_default_values()) { - for (const auto &e : attr.rounding_mode_.rounding_modes_map_) { - seed = hash_combine(seed, e.first); - seed = hash_combine(seed, static_cast(e.second)); - } - } - - if (!attr.output_scales_.has_default_values()) { - // output_scales: mask - seed = hash_combine(seed, attr.output_scales_.mask_); - } else if (!attr.scales_.has_default_values()) { - // go through scales for all arguments - for (const auto &p : attr.scales_.scales_) { - // scales: arg - seed = hash_combine(seed, p.first); - // scales: mask - seed = hash_combine(seed, p.second.mask_); - // scales: groups - const int ndims = p.second.ndims_; - seed = hash_combine(seed, ndims); - if (ndims > 0) - seed = get_array_hash(seed, p.second.group_dims_, ndims); - // scales: data type - seed = hash_combine(seed, static_cast(p.second.data_type_)); - } - } - // zero_points - for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) - if (!attr.zero_points_.has_default_values(arg)) { - const auto &zps = attr.zero_points_; - // zero_points: arg - seed = hash_combine(seed, arg); - int mask = 0; - zps.get(arg, &mask); - // zero_points: mask - seed = hash_combine(seed, mask); - // zero points: groups - const int ndims = zps.get_groups_ndims(arg); - seed = hash_combine(seed, ndims); - if (ndims > 0) - seed = get_array_hash(seed, zps.get_groups(arg), ndims); - // zero points: data type - seed = hash_combine( - seed, static_cast(zps.get_data_type(arg))); - } - // post_ops: entry[:] - for (int i = 0; i < attr.post_ops_.len(); i++) { - const auto &entry = attr.post_ops_.entry_[i]; - switch (entry.kind) { - case primitive_kind::eltwise: - seed = hash_combine( - seed, static_cast(entry.eltwise.alg)); - seed = hash_combine(seed, entry.eltwise.scale); - seed = hash_combine(seed, entry.eltwise.alpha); - seed = hash_combine(seed, entry.eltwise.beta); - break; - case primitive_kind::sum: - seed = hash_combine(seed, entry.sum.scale); - seed = hash_combine(seed, entry.sum.zero_point); - seed = hash_combine(seed, static_cast(entry.sum.dt)); - break; - case primitive_kind::convolution: - seed = hash_combine( - seed, static_cast(entry.depthwise_conv.kernel)); - seed = hash_combine( - seed, static_cast(entry.depthwise_conv.stride)); - seed = hash_combine(seed, - static_cast(entry.depthwise_conv.padding)); - seed = hash_combine( - seed, static_cast(entry.depthwise_conv.wei_dt)); - seed = hash_combine(seed, - static_cast(entry.depthwise_conv.bias_dt)); - seed = hash_combine( - seed, static_cast(entry.depthwise_conv.dst_dt)); - break; - case primitive_kind::binary: - seed = hash_combine( - seed, static_cast(entry.binary.alg)); - seed = hash_combine( - seed, get_md_hash(entry.binary.user_src1_desc)); - break; - case primitive_kind::prelu: - seed = hash_combine( - seed, static_cast(entry.prelu.mask)); - break; - default: assert(!"unknown post_op"); - } - } - // rnn_data_qparams: scale, shift - seed = hash_combine(seed, attr.rnn_data_qparams_.scale_); - seed = hash_combine(seed, attr.rnn_data_qparams_.shift_); - if (!attr.rnn_weights_qparams_.has_default_values()) { - // rnn_weights_qparams: mask - seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_); - // rnn_weights_qparams: count - seed = hash_combine(seed, attr.rnn_weights_qparams_.count_); - // rnn_weights_qparams: scales[:] - seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_, - attr.rnn_weights_qparams_.count_); - } - if (attr.gpu_attr_) { - seed = hash_combine(seed, attr.gpu_attr_->get_hash()); - } - if (!attr.dropout_.has_default_values()) { - seed = hash_combine( - seed, get_md_hash(attr.dropout_.user_dropout_desc_)); - } - // Combined hash for attributes - return seed; + // ANCHOR: HASHING_DEBUGINFO_16. + VDEBUGINFO(16, primitive, hashing, "operator==,ret=%d", ret); + return ret; } // Functions that compute hash for different op_descs @@ -366,6 +156,8 @@ size_t get_desc_hash(const binary_desc_t &desc) { // Memory descriptors seed = hash_combine(seed, get_md_hash(desc.src_desc[0])); seed = hash_combine(seed, get_md_hash(desc.src_desc[1])); + if (desc.alg_kind == alg_kind::binary_select) + seed = hash_combine(seed, get_md_hash(desc.src_desc[2])); seed = hash_combine(seed, get_md_hash(desc.dst_desc)); // Combined hash for binary op desc return seed; @@ -377,7 +169,18 @@ size_t get_desc_hash(const convolution_desc_t &desc) { // Kinds seed = hash_combine(seed, static_cast(desc.primitive_kind)); seed = hash_combine(seed, static_cast(desc.prop_kind)); - seed = hash_combine(seed, static_cast(desc.alg_kind)); + + // Ignore `alg_kind` to keep hash value consistent for any algorithm. + // + // Background: when a convolution primitive descriptor is created for + // the algorithm `auto` we overwrite `alg_kind` field in `op_desc` when + // store it in the primitive descriptor. Because of that, the `op_desc` + // stored in the primitive descriptor is different from the one user + // passed to oneDNN API. Because of the difference the requested + // primitive descriptor cannot be found in the cache if we hash/compare + // `alg_kind`. + //seed = hash_combine(seed, static_cast(desc.alg_kind)); + // Memory descriptors seed = hash_combine(seed, get_md_hash(desc.src_desc)); seed = hash_combine(seed, get_md_hash(desc.diff_src_desc)); @@ -530,6 +333,9 @@ size_t get_desc_hash(const matmul_desc_t &desc) { seed = hash_combine(seed, get_md_hash(desc.weights_desc)); seed = hash_combine(seed, get_md_hash(desc.bias_desc)); seed = hash_combine(seed, get_md_hash(desc.dst_desc)); + seed = hash_combine(seed, get_md_hash(desc.reduce_desc)); + // Reduce kind. + seed = hash_combine(seed, static_cast(desc.reduce_kind)); // Accumulator type seed = hash_combine(seed, static_cast(desc.accum_data_type)); // Combined hash for matmul op desc @@ -727,11 +533,17 @@ size_t get_desc_hash(const sdpa_desc_t &desc) { seed = hash_combine(seed, get_md_hash(desc.q_desc)); seed = hash_combine(seed, get_md_hash(desc.k_desc)); seed = hash_combine(seed, get_md_hash(desc.v_desc)); + seed = hash_combine(seed, desc.kq_scales.get_hash()); + seed = hash_combine(seed, desc.kq_zero_points.get_hash()); + seed = hash_combine(seed, desc.vs_scales.get_hash()); + seed = hash_combine(seed, desc.vs_zero_points.get_hash()); seed = hash_combine(seed, get_md_hash(desc.dst_desc)); seed = hash_combine(seed, get_md_hash(desc.attn_mask_desc)); // Scale type seed = hash_combine(seed, static_cast(desc.scale_dt)); seed = hash_combine(seed, desc.invert_scale); + seed = hash_combine(seed, desc.kv_head_number); + seed = hash_combine(seed, static_cast(desc.mask_type)); // Combined hash for sdpa desc return seed; } diff --git a/src/common/primitive_hashing.hpp b/src/common/primitive_hashing.hpp index fa33f920e55..655ed95c93d 100644 --- a/src/common/primitive_hashing.hpp +++ b/src/common/primitive_hashing.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,11 +21,11 @@ #include #include -#include "c_types_map.hpp" -#include "engine_id.hpp" -#include "oneapi/dnnl/dnnl.h" -#include "primitive_attr.hpp" -#include "type_helpers.hpp" +#include "common/c_types_map.hpp" +#include "common/engine_id.hpp" +#include "common/type_helpers.hpp" +#include "common/verbose.hpp" +#include "common/primitive_hashing_utils.hpp" namespace dnnl { namespace impl { @@ -59,11 +59,6 @@ struct key_t { engine_id_t engine_id_; private: - template - static const desc_t &cast_to_desc(const void *p) { - return *(reinterpret_cast(p)); - } - static primitive_kind_t get_pkind(primitive_kind_t pkind); // Thread ID is not used as part of the key, it's only used to get @@ -72,8 +67,6 @@ struct key_t { std::thread::id thread_id_; }; -size_t get_md_hash(const memory_desc_t &md); -size_t get_attr_hash(const primitive_attr_t &attr); size_t get_desc_hash(const concat_desc_t &desc); size_t get_desc_hash(const batch_normalization_desc_t &desc); size_t get_desc_hash(const binary_desc_t &desc); @@ -97,39 +90,6 @@ size_t get_desc_hash(const softmax_desc_t &desc); size_t get_desc_hash(const sum_desc_t &desc); size_t get_desc_hash(const zero_pad_desc_t &desc); -template -size_t get_array_hash(size_t seed, const T *v, int size) { - for (int i = 0; i < size; i++) { - seed = hash_combine(seed, v[i]); - } - return seed; -} - -template <> -inline size_t get_array_hash( - size_t seed, const memory_desc_t *v, int size) { - for (int i = 0; i < size; i++) { - seed = hash_combine(seed, get_md_hash(v[i])); - } - return seed; -} - -inline size_t get_array_hash( - size_t seed, const std::vector &mds) { - for (const auto *md : mds) - seed = hash_combine(seed, get_md_hash(*md)); - return seed; -} - -template <> -inline size_t get_array_hash( - size_t seed, const data_type_t *v, int size) { - for (int i = 0; i < size; i++) { - seed = hash_combine(seed, static_cast(v[i])); - } - return seed; -} - } // namespace primitive_hashing } // namespace impl } // namespace dnnl @@ -153,11 +113,19 @@ struct hash { seed = hash_combine(seed, hash_combine(0, key.skip_idx_)); seed = hash_combine(seed, key.engine_id_.hash()); + + seed = get_array_hash( + seed, key.hint_mds_.data(), (int)key.hint_mds_.size()); + + const result_type verb_seed_before_desc = seed; + UNUSED(verb_seed_before_desc); + // Combine hash for op_desc with the computed hash #define CASE(pkind) \ case primitive_kind::pkind: \ - seed = hash_combine( \ - seed, get_desc_hash(*(pkind##_desc_t *)key.op_desc_)); \ + seed = hash_combine(seed, \ + get_desc_hash( \ + *op_desc_t::to_desc(key.op_desc_))); \ break; // clang-format off @@ -189,8 +157,13 @@ struct hash { } // clang-format on #undef CASE - seed = get_array_hash( - seed, key.hint_mds_.data(), (int)key.hint_mds_.size()); + + // Note: `16` is just a random number, as debuginfo hasn't received a + // single command center for levels across layers of the library. + // ANCHOR: HASHING_DEBUGINFO_16. + VDEBUGINFO(16, primitive, hashing, + "operator(),seed_before_desc=%zu seed_after_desc=%zu", + verb_seed_before_desc, seed); return seed; } diff --git a/src/common/primitive_hashing_utils.cpp b/src/common/primitive_hashing_utils.cpp new file mode 100644 index 00000000000..6a84ebe4dd3 --- /dev/null +++ b/src/common/primitive_hashing_utils.cpp @@ -0,0 +1,254 @@ +/******************************************************************************* +* Copyright 2019-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "primitive_attr.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "dnnl_thread.hpp" +#include "engine.hpp" +#include "primitive_hashing_utils.hpp" + +namespace dnnl { +namespace impl { +namespace primitive_hashing { + +// Combine hash of each memory_desc_t data member +size_t get_md_hash(const memory_desc_t &md) { + size_t seed = 0; + seed = get_array_hash(seed, md.dims, md.ndims); + seed = hash_combine(seed, static_cast(md.data_type)); + seed = get_array_hash(seed, md.padded_dims, md.ndims); + seed = get_array_hash(seed, md.padded_offsets, md.ndims); + seed = hash_combine(seed, md.offset0); + seed = hash_combine(seed, static_cast(md.format_kind)); + // format desc + switch ((int)md.format_kind) { + case format_kind::undef: + case format_kind::any: break; + case format_kind::blocked: + for (int i = 0; i < md.ndims; i++) { + if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue; + seed = hash_combine(seed, md.format_desc.blocking.strides[i]); + } + seed = hash_combine(seed, md.format_desc.blocking.inner_nblks); + seed = get_array_hash(seed, md.format_desc.blocking.inner_blks, + md.format_desc.blocking.inner_nblks); + seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs, + md.format_desc.blocking.inner_nblks); + break; + case format_kind::wino: + seed = hash_combine(seed, + static_cast(md.format_desc.wino_desc.wino_format)); + seed = hash_combine(seed, md.format_desc.wino_desc.r); + seed = hash_combine(seed, md.format_desc.wino_desc.alpha); + seed = hash_combine(seed, md.format_desc.wino_desc.ic); + seed = hash_combine(seed, md.format_desc.wino_desc.oc); + seed = hash_combine(seed, md.format_desc.wino_desc.ic_block); + seed = hash_combine(seed, md.format_desc.wino_desc.oc_block); + seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block); + seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block); + seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale); + seed = hash_combine(seed, md.format_desc.wino_desc.size); + break; + case format_kind::cublaslt_blocked: + seed = hash_combine(seed, + static_cast(md.format_desc.cublaslt_blocked_desc + .cublaslt_format)); + seed = hash_combine( + seed, (md.format_desc.cublaslt_blocked_desc.size)); + break; + case format_kind::rnn_packed: + seed = hash_combine(seed, + static_cast(md.format_desc.rnn_packed_desc.format)); + seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts); + seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n); + seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb); + { + int n_parts = md.format_desc.rnn_packed_desc.n_parts; + seed = get_array_hash( + seed, md.format_desc.rnn_packed_desc.parts, n_parts); + seed = get_array_hash(seed, + md.format_desc.rnn_packed_desc.part_pack_size, n_parts); + seed = get_array_hash(seed, + md.format_desc.rnn_packed_desc.pack_part, n_parts); + } + seed = hash_combine( + seed, md.format_desc.rnn_packed_desc.offset_compensation); + seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size); + break; +#ifdef DNNL_EXPERIMENTAL_SPARSE + case format_kind::sparse: + seed = hash_combine(seed, + static_cast(md.format_desc.sparse_desc.encoding)); + seed = hash_combine(seed, md.format_desc.sparse_desc.nnz); + seed = get_array_hash(seed, + md.format_desc.sparse_desc.metadata_types, + sparse_desc_t::max_metadata_types); + // User cannot initialize `packed_desc` therefore `packed_desc` + // is always zero initialized. + break; +#endif + case format_kind::sparse: + seed = hash_combine(seed, + static_cast(md.format_desc.sparse_desc.encoding)); + // User cannot initialize `packed_desc` therefore therefore + // at this point `packed_desc` is always zero initialized. + break; + default: assert(!"unknown format_kind"); + } + + if (md.extra.flags != dnnl_memory_extra_flag_none) { + seed = hash_combine(seed, md.extra.flags); + if (md.extra.flags + & (dnnl_memory_extra_flag_compensation_conv_s8s8 + | dnnl_memory_extra_flag_rnn_u8s8_compensation)) { + seed = hash_combine(seed, md.extra.compensation_mask); + } + + if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) { + seed = hash_combine(seed, md.extra.scale_adjust); + } + + if (md.extra.flags + & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) { + seed = hash_combine(seed, md.extra.asymm_compensation_mask); + } + + if (md.extra.flags + & dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src) { + seed = get_array_hash(seed, md.extra.idhw, 3); + seed = get_array_hash(seed, md.extra.odhw, 3); + seed = get_array_hash(seed, md.extra.pdhw, 3); + seed = get_array_hash(seed, md.extra.ddhw, 3); + seed = hash_combine(seed, md.extra.dst_size); + } + } + // Combined hash for a memory descriptor + return seed; +} + +// Combine hash of each primitive_attr_t data member +size_t get_attr_hash(const primitive_attr_t &attr) { + size_t seed = 0; + // scratchpad_mode + seed = hash_combine(seed, static_cast(attr.scratchpad_mode_)); + // fpmath_mode + seed = hash_combine(seed, static_cast(attr.fpmath_.mode_)); + seed = hash_combine(seed, static_cast(attr.fpmath_.apply_to_int_)); + // deterministic + seed = hash_combine(seed, static_cast(attr.deterministic_)); + // acc_mode + seed = hash_combine(seed, static_cast(attr.acc_mode_)); + // rounding_mode + if (!attr.rounding_mode_.has_default_values()) { + for (const auto &e : attr.rounding_mode_.rounding_modes_map_) { + seed = hash_combine(seed, e.first); + seed = hash_combine(seed, static_cast(e.second)); + } + } + + if (!attr.scales_.has_default_values()) { + seed = hash_combine(seed, attr.scales_.get_hash()); + } + + if (!attr.zero_points_.has_default_values()) { + seed = hash_combine(seed, attr.zero_points_.get_hash()); + } + + // post_ops: entry[:] + seed = get_post_op_hash(seed, attr.post_ops_); + // rnn_data_qparams: scale, shift + seed = hash_combine(seed, attr.rnn_data_qparams_.scale_); + seed = hash_combine(seed, attr.rnn_data_qparams_.shift_); + if (!attr.rnn_weights_qparams_.has_default_values()) { + // rnn_weights_qparams: mask + seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_); + // rnn_weights_qparams: count + seed = hash_combine(seed, attr.rnn_weights_qparams_.count_); + // rnn_weights_qparams: scales[:] + seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_, + attr.rnn_weights_qparams_.count_); + } + if (attr.gpu_attr_) { + seed = hash_combine(seed, attr.gpu_attr_->get_hash()); + } + if (!attr.dropout_.has_default_values()) { + seed = hash_combine( + seed, get_md_hash(attr.dropout_.user_dropout_desc_)); + } + seed = hash_combine(seed, attr.src_dyn_quant_params_.get()); + // Combined hash for attributes + return seed; +} + +// Combine hash of each post_ops::entry_ +size_t get_post_op_hash(size_t seed, const post_ops_t &post_ops) { + for (int i = 0; i < post_ops.len(); i++) { + const auto &entry = post_ops.entry_[i]; + switch (entry.kind) { + case primitive_kind::eltwise: + seed = hash_combine( + seed, static_cast(entry.eltwise.alg)); + seed = hash_combine(seed, entry.eltwise.scale); + seed = hash_combine(seed, entry.eltwise.alpha); + seed = hash_combine(seed, entry.eltwise.beta); + break; + case primitive_kind::sum: + seed = hash_combine(seed, entry.sum.scale); + seed = hash_combine(seed, static_cast(entry.sum.dt)); + break; + case primitive_kind::convolution: + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.in_h)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.in_w)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.ker_h)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.ker_w)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.str_h)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.str_w)); + seed = hash_combine(seed, static_cast(entry.depthwise_conv_old.in_dt)); + break; + case primitive_kind::binary: + seed = hash_combine( + seed, static_cast(entry.binary.alg)); + seed = hash_combine( + seed, get_md_hash(entry.binary.user_src1_desc)); + break; + case primitive_kind::prelu: + seed = hash_combine( + seed, static_cast(entry.prelu.mask)); + break; + case primitive_kind::depthwise: + seed = hash_combine(seed, static_cast(entry.depthwise.alg)); + seed = get_array_hash(seed, entry.depthwise.offset, entry.depthwise.fields_count); + break; + case primitive_kind::quantization: + seed = hash_combine(seed, static_cast(entry.quantization.alg)); + seed = get_array_hash(seed, entry.quantization.per_channel, entry.quantization.fields_count); + seed = get_array_hash(seed, entry.quantization.all_default, entry.quantization.fields_count); + seed = get_array_hash(seed, entry.quantization.offset, entry.quantization.fields_count); + break; + default: assert(!"unknown post_op"); + } + } + + return seed; +} + +} // namespace primitive_hashing +} // namespace impl +} // namespace dnnl diff --git a/src/common/primitive_hashing_utils.hpp b/src/common/primitive_hashing_utils.hpp new file mode 100644 index 00000000000..27a7490d093 --- /dev/null +++ b/src/common/primitive_hashing_utils.hpp @@ -0,0 +1,71 @@ +/******************************************************************************* +* Copyright 2019-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_PRIMITIVE_HASHING_UTILS_HPP +#define COMMON_PRIMITIVE_HASHING_UTILS_HPP + +#include +#include +#include + +#include "common/c_types_map.hpp" +#include "common/engine_id.hpp" +#include "common/type_helpers.hpp" +#include "common/verbose.hpp" + +namespace dnnl { +namespace impl { +struct primitive_desc_t; +namespace primitive_hashing { + +size_t get_md_hash(const memory_desc_t &md); +size_t get_attr_hash(const primitive_attr_t &attr); +size_t get_post_op_hash(size_t seed, const post_ops_t &post_ops); + +template +size_t get_array_hash(size_t seed, const T *v, int size) { + for (int i = 0; i < size; i++) { + seed = hash_combine(seed, v[i]); + } + return seed; +} + +template <> +inline size_t get_array_hash( + size_t seed, const memory_desc_t *v, int size) { + for (int i = 0; i < size; i++) { + seed = hash_combine(seed, get_md_hash(v[i])); + } + return seed; +} + +inline size_t get_array_hash( + size_t seed, const std::vector &mds) { + for (const auto *md : mds) + seed = hash_combine(seed, get_md_hash(*md)); + return seed; +} + +template +size_t get_vector_hash(size_t seed, const std::vector &vec) { + return get_array_hash(seed, vec.data(), vec.size()); +} + +} // namespace primitive_hashing +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/common/primitive_serialization.cpp b/src/common/primitive_serialization.cpp new file mode 100644 index 00000000000..b0132bc197b --- /dev/null +++ b/src/common/primitive_serialization.cpp @@ -0,0 +1,596 @@ +/******************************************************************************* +* Copyright 2021-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/primitive_serialization.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +namespace dnnl { +namespace impl { + +status_t serialize_desc( + serialization_stream_t &sstream, const op_desc_t *op_desc) { +#define CASE(pkind) \ + case primitive_kind::pkind: \ + serialize(sstream, *(const pkind##_desc_t *)op_desc); \ + break; + + switch ((int)op_desc->primitive_kind) { + CASE(batch_normalization) + CASE(binary) + CASE(concat) + CASE(convolution) + CASE(deconvolution) + CASE(eltwise) + CASE(gemm) + CASE(group_normalization) + CASE(inner_product) + CASE(layer_normalization) + CASE(lrn) + CASE(matmul) + CASE(pooling) + CASE(prelu) + CASE(reduction) + CASE(reorder) + CASE(resampling) + CASE(rnn) + CASE(sdpa) + CASE(shuffle) + CASE(softmax) + CASE(sum) + default: return status::invalid_arguments; + } +#undef CASE + return status::success; +} + +void serialize(serialization_stream_t &sstream, const memory_desc_t &md) { + sstream.append(md.ndims); + sstream.append_array(md.ndims, md.dims); + sstream.append(md.data_type); + sstream.append_array(md.ndims, md.padded_dims); + sstream.append_array(md.ndims, md.padded_offsets); + sstream.append(md.offset0); + sstream.append(md.format_kind); + // format desc + switch ((int)md.format_kind) { +#ifdef DNNL_EXPERIMENTAL_SPARSE + case format_kind::sparse: +#endif + case format_kind::undef: + case format_kind::any: break; + case format_kind::blocked: + sstream.append_array(md.ndims, md.format_desc.blocking.strides); + sstream.append(md.format_desc.blocking.inner_nblks); + sstream.append_array(md.format_desc.blocking.inner_nblks, + md.format_desc.blocking.inner_blks); + sstream.append_array(md.format_desc.blocking.inner_nblks, + md.format_desc.blocking.inner_idxs); + break; + case format_kind::wino: + sstream.append(md.format_desc.wino_desc.wino_format); + sstream.append(md.format_desc.wino_desc.r); + sstream.append(md.format_desc.wino_desc.alpha); + sstream.append(md.format_desc.wino_desc.ic); + sstream.append(md.format_desc.wino_desc.oc); + sstream.append(md.format_desc.wino_desc.ic_block); + sstream.append(md.format_desc.wino_desc.oc_block); + sstream.append(md.format_desc.wino_desc.ic2_block); + sstream.append(md.format_desc.wino_desc.oc2_block); + sstream.append(md.format_desc.wino_desc.adj_scale); + sstream.append(md.format_desc.wino_desc.size); + break; + case format_kind::cublaslt_blocked: + sstream.append( + md.format_desc.cublaslt_blocked_desc.cublaslt_format); + sstream.append(md.format_desc.cublaslt_blocked_desc.size); + break; + case format_kind::rnn_packed: + sstream.append(md.format_desc.rnn_packed_desc.format); + sstream.append(md.format_desc.rnn_packed_desc.n_parts); + sstream.append(md.format_desc.rnn_packed_desc.n); + sstream.append(md.format_desc.rnn_packed_desc.ldb); + { + int n_parts = md.format_desc.rnn_packed_desc.n_parts; + sstream.append_array( + n_parts, md.format_desc.rnn_packed_desc.parts); + sstream.append_array( + n_parts, md.format_desc.rnn_packed_desc.part_pack_size); + sstream.append_array( + n_parts, md.format_desc.rnn_packed_desc.pack_part); + } + sstream.append(md.format_desc.rnn_packed_desc.offset_compensation); + sstream.append(md.format_desc.rnn_packed_desc.size); + break; + default: assert(!"unknown format_kind"); + } + + if (md.extra.flags != dnnl_memory_extra_flag_none) { + sstream.append(md.extra.flags); + if (md.extra.flags + & (dnnl_memory_extra_flag_compensation_conv_s8s8 + | dnnl_memory_extra_flag_rnn_u8s8_compensation)) { + sstream.append(md.extra.compensation_mask); + } + if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) { + sstream.append(md.extra.scale_adjust); + } + if (md.extra.flags + & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) { + sstream.append(md.extra.asymm_compensation_mask); + } + if (md.extra.flags + & dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src) { + sstream.append_array(3, md.extra.idhw); + sstream.append_array(3, md.extra.odhw); + sstream.append_array(3, md.extra.pdhw); + sstream.append_array(3, md.extra.ddhw); + sstream.append(md.extra.dst_size); + } + } +} + +void serialize(serialization_stream_t &sstream, const post_ops_t &post_ops) { + // post_ops: entry[:] + for (int i = 0; i < post_ops.len(); i++) { + const auto &entry = post_ops.entry_[i]; + switch (entry.kind) { + case primitive_kind::eltwise: + sstream.append(entry.eltwise.alg); + sstream.append(entry.eltwise.scale); + sstream.append(entry.eltwise.alpha); + sstream.append(entry.eltwise.beta); + break; + case primitive_kind::sum: + sstream.append(entry.sum.scale); + sstream.append(entry.sum.zero_point); + sstream.append(entry.sum.dt); + break; + case primitive_kind::convolution: + sstream.append(entry.depthwise_conv.kernel); + sstream.append(entry.depthwise_conv.stride); + sstream.append(entry.depthwise_conv.padding); + sstream.append(entry.depthwise_conv.wei_dt); + sstream.append(entry.depthwise_conv.bias_dt); + sstream.append(entry.depthwise_conv.dst_dt); + break; + case primitive_kind::binary: + sstream.append(entry.binary.alg); + serialize(sstream, entry.binary.user_src1_desc); + break; + case primitive_kind::prelu: sstream.append(entry.prelu.mask); break; + default: assert(!"unknown post_op"); + } + } +} + +void serialize(serialization_stream_t &sstream, const primitive_attr_t &attr) { + // scratchpad_mode + sstream.append(attr.scratchpad_mode_); + // fpmath_mode + sstream.append(attr.fpmath_.mode_); + sstream.append(attr.fpmath_.apply_to_int_); + // deterministic + sstream.append(attr.deterministic_); + // acc_mode + sstream.append(attr.acc_mode_); + + if (!attr.scales_.has_default_values()) { + sstream.append('s'); + attr.scales_.serialize(sstream); + } + // zero_points + if (!attr.zero_points_.has_default_values()) { + sstream.append('z'); + attr.zero_points_.serialize(sstream); + } + + // Rounding modes + if (!attr.rounding_mode_.has_default_values()) sstream.append('r'); + for (const auto &e : attr.rounding_mode_.rounding_modes_map_) { + if (!attr.rounding_mode_.has_default_values(e.first)) { + sstream.append(e.first); + sstream.append(e.second); + } + } + + if (!attr.dropout_.has_default_values()) { + sstream.append('d'); + serialize(sstream, attr.dropout_.user_dropout_desc_); + } + + serialize(sstream, attr.post_ops_); + + // rnn_data_qparams: scale, shift + sstream.append(attr.rnn_data_qparams_.scale_); + sstream.append(attr.rnn_data_qparams_.shift_); + if (!attr.rnn_weights_qparams_.has_default_values()) { + // rnn_weights_qparams: mask + sstream.append(attr.rnn_weights_qparams_.mask_); + // rnn_weights_qparams: count + sstream.append(attr.rnn_weights_qparams_.count_); + // rnn_weights_qparams: scales[:] + sstream.append_array(attr.rnn_weights_qparams_.count_, + attr.rnn_weights_qparams_.scales_); + } + if (attr.gpu_attr_) { + attr.gpu_attr_->serialize(sstream); + } else { + int zero = 0; + sstream.append(zero); + } + sstream.append(attr.src_dyn_quant_params_.get()); +} + +void serialize(serialization_stream_t &sstream, const concat_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + // Memory descriptors + serialize(sstream, *desc.dst_md); + // N + sstream.append(desc.n); + // Concat dimension + sstream.append(desc.concat_dimension); + // Array of mds + for (int i = 0; i < desc.n; i++) + serialize(sstream, *desc.src_mds[i]); +} + +void serialize(serialization_stream_t &sstream, + const batch_normalization_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.diff_dst_desc); + serialize(sstream, desc.scaleshift_desc); + serialize(sstream, desc.diff_scaleshift_desc); + serialize(sstream, desc.stat_desc); + // Epsilon + sstream.append(desc.batch_norm_epsilon); + // Flags + sstream.append(desc.flags); +} + +void serialize(serialization_stream_t &sstream, const binary_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.alg_kind); + // Memory descriptors + serialize(sstream, desc.src_desc[0]); + serialize(sstream, desc.src_desc[1]); + serialize(sstream, desc.src_desc[2]); + serialize(sstream, desc.dst_desc); +} + +// (De-)Convolution +void serialize( + serialization_stream_t &sstream, const convolution_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + sstream.append(desc.alg_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.weights_desc); + serialize(sstream, desc.diff_weights_desc); + serialize(sstream, desc.bias_desc); + serialize(sstream, desc.diff_bias_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_dst_desc); + // Strides, dilates, padding + sstream.append_array(DNNL_MAX_NDIMS, desc.strides); + sstream.append_array(DNNL_MAX_NDIMS, desc.dilates); + sstream.append_array(DNNL_MAX_NDIMS, desc.padding[0]); + sstream.append_array(DNNL_MAX_NDIMS, desc.padding[1]); + // Accumulator type + sstream.append(desc.accum_data_type); + // Internal member + sstream.append(desc.use_inversion); +} + +// Eltwise +void serialize(serialization_stream_t &sstream, const eltwise_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + sstream.append(desc.alg_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.diff_dst_desc); + // Alpha, beta + sstream.append(desc.alpha); + sstream.append(desc.beta); +} + +void serialize(serialization_stream_t &sstream, const gemm_desc_t &desc) { + // Kind + sstream.append(desc.primitive_kind); + serialize(sstream, desc.a_desc); + serialize(sstream, desc.b_desc); + serialize(sstream, desc.c_desc); + serialize(sstream, desc.bias_desc); + // Accumulator type + sstream.append(desc.acc_type); + sstream.append(desc.sum_ab); + sstream.append(desc.sum_ab_type); +} + +void serialize(serialization_stream_t &sstream, + const group_normalization_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.diff_dst_desc); + serialize(sstream, desc.scaleshift_desc); + serialize(sstream, desc.diff_scaleshift_desc); + serialize(sstream, desc.stat_desc); + // Groups + sstream.append(desc.groups); + // Epsilon + sstream.append(desc.group_norm_epsilon); + // Flags + sstream.append(desc.flags); +} + +void serialize( + serialization_stream_t &sstream, const inner_product_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.weights_desc); + serialize(sstream, desc.diff_weights_desc); + serialize(sstream, desc.bias_desc); + serialize(sstream, desc.diff_bias_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_dst_desc); + // Accumulator type + sstream.append(desc.accum_data_type); +} + +void serialize(serialization_stream_t &sstream, + const layer_normalization_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.data_scaleshift_desc); + serialize(sstream, desc.diff_data_scaleshift_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_dst_desc); + serialize(sstream, desc.stat_desc); + // Epsilon + sstream.append(desc.layer_norm_epsilon); + // Flags + sstream.append(desc.flags); +} + +void serialize(serialization_stream_t &sstream, const lrn_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + sstream.append(desc.alg_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.diff_dst_desc); + // Local size + sstream.append(desc.local_size); + // Alpha, beta + sstream.append(desc.lrn_alpha); + sstream.append(desc.lrn_beta); + // k + sstream.append(desc.lrn_k); +} + +void serialize(serialization_stream_t &sstream, const matmul_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.weights_desc); + serialize(sstream, desc.bias_desc); + serialize(sstream, desc.dst_desc); + // Accumulator type + sstream.append(desc.accum_data_type); +} + +void serialize(serialization_stream_t &sstream, const pooling_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + sstream.append(desc.alg_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_dst_desc); + // Strides, dilates, padding + sstream.append_array(DNNL_MAX_NDIMS, desc.strides); + sstream.append_array(DNNL_MAX_NDIMS, desc.kernel); + sstream.append_array(DNNL_MAX_NDIMS, desc.padding[0]); + sstream.append_array(DNNL_MAX_NDIMS, desc.padding[1]); + sstream.append_array(DNNL_MAX_NDIMS, desc.dilation); + // Accumulator type + sstream.append(desc.accum_data_type); +} + +void serialize(serialization_stream_t &sstream, const prelu_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.weights_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.diff_weights_desc); + serialize(sstream, desc.diff_dst_desc); +} + +void serialize(serialization_stream_t &sstream, const reduction_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.alg_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.dst_desc); + // P, eps + sstream.append(desc.p); + sstream.append(desc.eps); +} + +void serialize(serialization_stream_t &sstream, const reorder_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + // Memory descriptors + serialize(sstream, *desc.src_md); + serialize(sstream, *desc.dst_md); + // Kinds of source and destination engines + sstream.append(desc.src_engine_kind); + sstream.append(desc.dst_engine_kind); + sstream.append(desc.is_cross_engine); +} + +void serialize(serialization_stream_t &sstream, const resampling_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.alg_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_dst_desc); + // Factors + sstream.append_array(DNNL_MAX_NDIMS, desc.factors); +} + +void serialize(serialization_stream_t &sstream, const rnn_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + sstream.append(desc.cell_kind); + sstream.append(desc.direction); + // Memory descriptors + serialize(sstream, desc.src_layer_desc); + serialize(sstream, desc.src_iter_desc); + serialize(sstream, desc.src_iter_c_desc); + serialize(sstream, desc.weights_layer_desc); + serialize(sstream, desc.weights_iter_desc); + serialize(sstream, desc.bias_desc); + serialize(sstream, desc.dst_layer_desc); + serialize(sstream, desc.dst_iter_desc); + serialize(sstream, desc.dst_iter_c_desc); + serialize(sstream, desc.weights_peephole_desc); + serialize(sstream, desc.weights_projection_desc); + serialize(sstream, desc.diff_src_layer_desc); + serialize(sstream, desc.diff_src_iter_desc); + serialize(sstream, desc.diff_src_iter_c_desc); + serialize(sstream, desc.diff_weights_layer_desc); + serialize(sstream, desc.diff_weights_iter_desc); + serialize(sstream, desc.diff_bias_desc); + serialize(sstream, desc.diff_dst_layer_desc); + serialize(sstream, desc.diff_dst_iter_desc); + serialize(sstream, desc.diff_dst_iter_c_desc); + serialize(sstream, desc.diff_weights_peephole_desc); + serialize(sstream, desc.diff_weights_projection_desc); + // Flags + sstream.append(desc.flags); + // Activation kind + sstream.append(desc.activation_kind); + // Alpha, beta + sstream.append(desc.alpha); + sstream.append(desc.beta); +} + +// Shuffle +void serialize(serialization_stream_t &sstream, const shuffle_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.dst_desc); + // Axis + sstream.append(desc.axis); + // Groupe size + sstream.append(desc.group_size); +} + +void serialize(serialization_stream_t &sstream, const softmax_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); + sstream.append(desc.alg_kind); + // Memory descriptors + serialize(sstream, desc.src_desc); + serialize(sstream, desc.diff_src_desc); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_dst_desc); + // Axis + sstream.append(desc.softmax_axis); +} + +void serialize(serialization_stream_t &sstream, const sum_desc_t &desc) { + // Kinds + sstream.append(desc.primitive_kind); + // Memory descriptors + serialize(sstream, *desc.dst_md); + // N + sstream.append(desc.n); + // Scales + sstream.append_array(desc.n, desc.scales); + // Array of mds + for (int i = 0; i < desc.n; i++) + serialize(sstream, *desc.src_mds[i]); +} + +void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) { + // Kind + sstream.append(desc.primitive_kind); + serialize(sstream, desc.q_desc); + serialize(sstream, desc.k_desc); + serialize(sstream, desc.v_desc); + desc.kq_scales.serialize(sstream); + desc.kq_zero_points.serialize(sstream); + desc.vs_scales.serialize(sstream); + desc.vs_zero_points.serialize(sstream); + serialize(sstream, desc.dst_desc); + serialize(sstream, desc.attn_mask_desc); + sstream.append(desc.scale_dt); + sstream.append(desc.invert_scale); + sstream.append(desc.kv_head_number); + sstream.append(desc.mask_type); +} + +} // namespace impl +} // namespace dnnl diff --git a/src/common/primitive_serialization.hpp b/src/common/primitive_serialization.hpp new file mode 100644 index 00000000000..50d87f46be3 --- /dev/null +++ b/src/common/primitive_serialization.hpp @@ -0,0 +1,63 @@ +/******************************************************************************* +* Copyright 2021-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_PRIMITIVE_SERIALIZATION_HPP +#define COMMON_PRIMITIVE_SERIALIZATION_HPP + +#include "common/c_types_map.hpp" +#include "common/primitive_attr.hpp" +#include "common/serialization.hpp" +#include "common/type_helpers.hpp" + +namespace dnnl { +namespace impl { + +void serialize(serialization_stream_t &sstream, const post_ops_t &post_ops); +void serialize(serialization_stream_t &sstream, const primitive_attr_t &attr); +void serialize(serialization_stream_t &sstream, const memory_desc_t &md); +void serialize(serialization_stream_t &sstream, const concat_desc_t &desc); +void serialize(serialization_stream_t &sstream, + const batch_normalization_desc_t &desc); +void serialize(serialization_stream_t &sstream, const binary_desc_t &desc); +void serialize(serialization_stream_t &sstream, const convolution_desc_t &desc); +void serialize(serialization_stream_t &sstream, const eltwise_desc_t &desc); +void serialize(serialization_stream_t &sstream, const gemm_desc_t &desc); +void serialize(serialization_stream_t &sstream, + const group_normalization_desc_t &desc); +void serialize( + serialization_stream_t &sstream, const inner_product_desc_t &desc); +void serialize(serialization_stream_t &sstream, + const layer_normalization_desc_t &desc); +void serialize(serialization_stream_t &sstream, const lrn_desc_t &desc); +void serialize(serialization_stream_t &sstream, const matmul_desc_t &desc); +void serialize(serialization_stream_t &sstream, const pooling_desc_t &desc); +void serialize(serialization_stream_t &sstream, const prelu_desc_t &desc); +void serialize(serialization_stream_t &sstream, const reduction_desc_t &desc); +void serialize(serialization_stream_t &sstream, const reorder_desc_t &desc); +void serialize(serialization_stream_t &sstream, const resampling_desc_t &desc); +void serialize(serialization_stream_t &sstream, const rnn_desc_t &desc); +void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc); +void serialize(serialization_stream_t &sstream, const shuffle_desc_t &desc); +void serialize(serialization_stream_t &sstream, const softmax_desc_t &desc); +void serialize(serialization_stream_t &sstream, const sum_desc_t &desc); + +status_t serialize_desc( + serialization_stream_t &sstream, const op_desc_t *op_desc); + +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/common/profiler.hpp b/src/common/profiler.hpp index 8bedb0a8e52..35b5b3a90b8 100644 --- a/src/common/profiler.hpp +++ b/src/common/profiler.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2023 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,8 +90,7 @@ static double get_msec() { // names are copied into long term storage. struct profiler_t { - profiler_t(const std::string &profile_name) - : _profile_name(profile_name), _run_data(), _data() { + profiler_t(const std::string &profile_name) : _profile_name(profile_name) { // Reserve data on construction to reduce chance of recording // reallocation _run_data.reserve(128); @@ -109,14 +108,14 @@ struct profiler_t { // Recording data void stamp(const char *name) { optimization_barrier(); - _run_data.emplace_back(record_t(name, get_msec())); + _run_data.emplace_back(name, get_msec()); assert(_state == RUNNING); optimization_barrier(); } void stop(const char *name) { optimization_barrier(); - _run_data.emplace_back(record_t(name, get_msec())); + _run_data.emplace_back(name, get_msec()); stop(); } @@ -172,7 +171,7 @@ struct profiler_t { T name; prof_time_t time; record_t(T name, prof_time_t time) : name(name), time(time) {} - record_t(std::pair record) + record_t(const std::pair &record) : name(record.first), time(record.second) {} // Reversed time ordering bool operator<(const record_t &b) const { return this->time > b.time; } diff --git a/src/common/reduction.cpp b/src/common/reduction.cpp index 6c59114e2c9..1f06a111dfc 100644 --- a/src/common/reduction.cpp +++ b/src/common/reduction.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -124,6 +124,9 @@ status_t reduction_attr_check(const reduction_desc_t &desc, // Check sum VCHECK_RED_UNIMPL(po.check_sum_consistency(dst_dt, false, true), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } return status::success; diff --git a/src/common/reduction_pd.hpp b/src/common/reduction_pd.hpp index e6a5b448609..211b89fd00a 100644 --- a/src/common/reduction_pd.hpp +++ b/src/common/reduction_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,10 +37,11 @@ status_t reduction_desc_init(reduction_desc_t *reduction_desc, alg_kind_t alg_kind, const memory_desc_t *src_desc, const memory_desc_t *dst_desc, float p, float eps); +// NOLINTBEGIN(google-default-arguments) struct reduction_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::reduction; - typedef reduction_pd_t hint_class; + using hint_class = reduction_pd_t; const reduction_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { @@ -131,16 +132,20 @@ struct reduction_pd_t : public primitive_desc_t { } } + bool has_zero_dim_memory() const { + return memory_desc_wrapper(src_md()).has_zero_dim(); + } + protected: reduction_desc_t desc_; memory_desc_t src_md_; memory_desc_t dst_md_; - reduction_pd_t(const reduction_desc_t *adesc, const primitive_attr_t *attr, + reduction_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const hint_class *hint_fwd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , src_md_(desc_.src_desc) , dst_md_(desc_.dst_desc) {} @@ -161,6 +166,7 @@ struct reduction_pd_t : public primitive_desc_t { return status::success; } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/reorder.cpp b/src/common/reorder.cpp index cedd98c7eb6..c21fe526dfb 100644 --- a/src/common/reorder.cpp +++ b/src/common/reorder.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,10 @@ namespace impl { VCONDCHECK(primitive, create, check, reorder, (cond), \ status::invalid_arguments, msg, ##__VA_ARGS__); +#define VCHECK_REORDER_UNIMPL(cond, msg, ...) \ + VCONDCHECK(primitive, create, check, reorder, (cond), \ + status::unimplemented, msg, ##__VA_ARGS__); + namespace { engine_t *get_reorder_engine(engine_t *src_engine, engine_t *dst_engine) { auto s_ek = src_engine->kind(); @@ -98,6 +102,48 @@ status_t reorder_primitive_desc_create(std::shared_ptr &pd, zero_points.has_default_values(DNNL_ARG_DST)), VERBOSE_UNSUPPORTED_ZP_CFG); + // Check scales + if (!attr->scales_.has_default_values()) { + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_DST}; + VCHECK_REORDER_UNIMPL(attr->scales_.has_default_values(supported_args), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + const auto &sc = attr->scales_; + const auto &sc_src = sc.get(DNNL_ARG_SRC); + const int mask_src = sc.get_mask(DNNL_ARG_SRC); + + VCHECK_REORDER(IMPLICATION(utils::one_of(src_md->data_type, + data_type::s4, data_type::u4), + mask_src > 0), + VERBOSE_INVALID_DATATYPE, "mask for int4 source"); + + if (!sc_src.has_default_groups()) { + const int src_ndims = s_mdw.ndims(); + const bool group_dims_are_consistent + = IMPLICATION(sc_src.get_group(0) > 1, + src_md->dims[src_ndims - 2] % sc_src.get_group(0) + == 0) + && IMPLICATION(sc_src.get_group(1) > 1, + src_md->dims[src_ndims - 1] % sc_src.get_group(1) + == 0); + VCHECK_REORDER(group_dims_are_consistent, + "groups dimensions are not consistent with reorder " + "dimensions"); + + // Groups are always applied to last two dimensions. Check that + // input scale mask is consistent with this limitation. + const bool mask_applies_to_last_two_dims + = (mask_src & (1 << (src_ndims - 1))) + && (mask_src & (1 << (src_ndims - 2))); + VCHECK_REORDER(mask_applies_to_last_two_dims, + "mask is not consistent with groups"); + } + + VCHECK_REORDER(sc.get(DNNL_ARG_DST).has_default_groups(), + VERBOSE_UNSUPPORTED_SCALES_CFG); + } + bool is_cross_engine = src_engine != dst_engine && utils::one_of( engine_kind::gpu, src_engine->kind(), dst_engine->kind()); diff --git a/src/common/reorder.hpp b/src/common/reorder.hpp index c254afba76f..c831d2c79c6 100644 --- a/src/common/reorder.hpp +++ b/src/common/reorder.hpp @@ -29,6 +29,10 @@ status_t reorder_primitive_desc_create(std::shared_ptr &pd, engine_t *engine, const memory_desc_t *src_md, const memory_desc_t *dst_md, const primitive_attr_t *attr = nullptr); +status_t reorder_primitive_desc_create(std::shared_ptr &pd, + engine_t *engine, const memory_desc_t *src_md, engine_t *src_engine, + const memory_desc_t *dst_md, engine_t *dst_engine, + const primitive_attr_t *attr = nullptr); } // namespace impl } // namespace dnnl diff --git a/src/common/reorder_pd.hpp b/src/common/reorder_pd.hpp index aea7c6e99de..2eea2c0c246 100644 --- a/src/common/reorder_pd.hpp +++ b/src/common/reorder_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -102,6 +102,7 @@ struct reorder_primitive_desc_iface_t : public dnnl_primitive_desc { dnnl::impl::engine_t *scratchpad_engine_; }; +// NOLINTBEGIN(google-default-arguments) struct reorder_pd_t : public primitive_desc_t { const reorder_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { @@ -159,10 +160,10 @@ struct reorder_pd_t : public primitive_desc_t { init_desc(src_engine_kind, dst_engine_kind, false); } - reorder_pd_t(const reorder_pd_t &other) : primitive_desc_t(other) { - src_md_ = other.src_md_; - dst_md_ = other.dst_md_; - + reorder_pd_t(const reorder_pd_t &other) + : primitive_desc_t(other) + , src_md_(other.src_md_) + , dst_md_(other.dst_md_) { init_desc(other.desc_.src_engine_kind, other.desc_.dst_engine_kind, other.desc_.is_cross_engine); } @@ -177,7 +178,6 @@ struct reorder_pd_t : public primitive_desc_t { return *this; } -protected: void init_desc(engine_kind_t src_engine_kind, engine_kind_t dst_engine_kind, bool is_cross_engine) { desc_ = reorder_desc_t(); @@ -189,6 +189,7 @@ struct reorder_pd_t : public primitive_desc_t { desc_.is_cross_engine = is_cross_engine; } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/resampling.cpp b/src/common/resampling.cpp index 98f91292929..cb5b151629c 100644 --- a/src/common/resampling.cpp +++ b/src/common/resampling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -131,6 +131,9 @@ status_t resampling_attr_check(const resampling_desc_t &desc, // Check sum VCHECK_RS_UNIMPL(po.check_sum_consistency(dst_dt, false, true), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } } else { VCHECK_RS_UNIMPL(false, VERBOSE_UNSUPPORTED_ATTR); diff --git a/src/common/resampling_pd.hpp b/src/common/resampling_pd.hpp index 8946d2297e5..f5d3cab4fff 100644 --- a/src/common/resampling_pd.hpp +++ b/src/common/resampling_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,13 +42,6 @@ struct resampling_fwd_pd_t; struct resampling_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::resampling; - resampling_pd_t(const resampling_desc_t *adesc, - const primitive_attr_t *attr, - const resampling_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) {} - const resampling_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { return reinterpret_cast(this->desc()); @@ -103,6 +96,12 @@ struct resampling_pd_t : public primitive_desc_t { resampling_desc_t desc_; const resampling_fwd_pd_t *hint_fwd_pd_; + resampling_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const resampling_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(attr, base_pkind) + , desc_(*op_desc_t::to_desc(adesc)) + , hint_fwd_pd_(hint_fwd_pd) {} + private: const memory_desc_t &src_desc() const { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; @@ -112,16 +111,10 @@ struct resampling_pd_t : public primitive_desc_t { } }; +// NOLINTBEGIN(google-default-arguments) struct resampling_fwd_pd_t : public resampling_pd_t { - typedef resampling_fwd_pd_t base_class; - typedef resampling_fwd_pd_t hint_class; - - resampling_fwd_pd_t(const resampling_desc_t *adesc, - const primitive_attr_t *attr, - const resampling_fwd_pd_t *hint_fwd_pd) - : resampling_pd_t(adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , dst_md_(desc_.dst_desc) {} + using base_class = resampling_fwd_pd_t; + using hint_class = resampling_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; @@ -155,6 +148,12 @@ struct resampling_fwd_pd_t : public resampling_pd_t { memory_desc_t src_md_; memory_desc_t dst_md_; + resampling_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const resampling_fwd_pd_t *hint_fwd_pd) + : resampling_pd_t(adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , dst_md_(desc_.dst_desc) {} + virtual status_t set_default_params( format_tag_t src_tag_hint = format_tag::undef) { if (dst_md()->format_kind != format_kind::any) return status::success; @@ -170,17 +169,12 @@ struct resampling_fwd_pd_t : public resampling_pd_t { } } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct resampling_bwd_pd_t : public resampling_pd_t { - typedef resampling_bwd_pd_t base_class; - typedef resampling_fwd_pd_t hint_class; - - resampling_bwd_pd_t(const resampling_desc_t *adesc, - const primitive_attr_t *attr, - const resampling_fwd_pd_t *hint_fwd_pd) - : resampling_pd_t(adesc, attr, hint_fwd_pd) - , diff_src_md_(desc_.diff_src_desc) - , diff_dst_md_(desc_.diff_dst_desc) {} + using base_class = resampling_bwd_pd_t; + using hint_class = resampling_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_DIFF_DST) return arg_usage_t::input; @@ -216,6 +210,12 @@ struct resampling_bwd_pd_t : public resampling_pd_t { memory_desc_t diff_src_md_; memory_desc_t diff_dst_md_; + resampling_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const resampling_fwd_pd_t *hint_fwd_pd) + : resampling_pd_t(adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , diff_dst_md_(desc_.diff_dst_desc) {} + virtual status_t set_default_params() { if (diff_dst_md()->format_kind == format_kind::any && hint_fwd_pd_) { status_t status = memory_desc_init_by_md_and_dt(diff_dst_md_, @@ -232,6 +232,7 @@ struct resampling_bwd_pd_t : public resampling_pd_t { diff_src_md_, diff_dst_md_.format_desc.blocking); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/rnn_pd.hpp b/src/common/rnn_pd.hpp index 857ad2e572e..f18e5aaf7de 100644 --- a/src/common/rnn_pd.hpp +++ b/src/common/rnn_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,6 +39,7 @@ namespace impl { struct rnn_fwd_pd_t; +// NOLINTBEGIN(google-default-arguments) struct rnn_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::rnn; @@ -230,10 +231,10 @@ struct rnn_pd_t : public primitive_desc_t { memory_desc_t ws_md_; - rnn_pd_t(const rnn_desc_t *adesc, const primitive_attr_t *attr, + rnn_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) , src_layer_md_(desc_.src_layer_desc) , src_iter_md_(desc_.src_iter_desc) @@ -245,47 +246,53 @@ struct rnn_pd_t : public primitive_desc_t { , bias_md_(desc_.bias_desc) , dst_layer_md_(desc_.dst_layer_desc) , dst_iter_md_(desc_.dst_iter_desc) - , dst_iter_c_md_(desc_.dst_iter_c_desc) - , ws_md_() {} + , dst_iter_c_md_(desc_.dst_iter_c_desc) {} }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct rnn_fwd_pd_t : public rnn_pd_t { - typedef rnn_fwd_pd_t base_class; - typedef rnn_fwd_pd_t hint_class; + using base_class = rnn_fwd_pd_t; + using hint_class = rnn_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC_LAYER) return arg_usage_t::input; - if (arg == DNNL_ARG_AUGRU_ATTENTION && with_augru_attention()) - return arg_usage_t::input; + if (arg == DNNL_ARG_AUGRU_ATTENTION) + return with_augru_attention() ? arg_usage_t::input + : arg_usage_t::unused; - if (arg == DNNL_ARG_SRC_ITER && with_src_iter()) - return arg_usage_t::input; + if (arg == DNNL_ARG_SRC_ITER) + return with_src_iter() ? arg_usage_t::input : arg_usage_t::unused; - if (arg == DNNL_ARG_SRC_ITER_C && with_src_iter_c()) - return arg_usage_t::input; + if (arg == DNNL_ARG_SRC_ITER_C) + return with_src_iter_c() ? arg_usage_t::input : arg_usage_t::unused; if (utils::one_of(arg, DNNL_ARG_WEIGHTS_LAYER, DNNL_ARG_WEIGHTS_ITER)) return arg_usage_t::input; - if (arg == DNNL_ARG_WEIGHTS_PEEPHOLE && is_lstm_peephole()) - return arg_usage_t::input; + if (arg == DNNL_ARG_WEIGHTS_PEEPHOLE) + return is_lstm_peephole() ? arg_usage_t::input + : arg_usage_t::unused; - if (arg == DNNL_ARG_WEIGHTS_PROJECTION && is_lstm_projection()) - return arg_usage_t::input; + if (arg == DNNL_ARG_WEIGHTS_PROJECTION) + return is_lstm_projection() ? arg_usage_t::input + : arg_usage_t::unused; - if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input; + if (arg == DNNL_ARG_BIAS) + return with_bias() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_DST_LAYER) return arg_usage_t::output; - if (arg == DNNL_ARG_DST_ITER && with_dst_iter()) - return arg_usage_t::output; + if (arg == DNNL_ARG_DST_ITER) + return with_dst_iter() ? arg_usage_t::output : arg_usage_t::unused; - if (arg == DNNL_ARG_DST_ITER_C && with_dst_iter() && is_lstm()) - return arg_usage_t::output; + if (arg == DNNL_ARG_DST_ITER_C) + return with_dst_iter_c() ? arg_usage_t::output + : arg_usage_t::unused; - if (arg == DNNL_ARG_WORKSPACE && is_training()) - return arg_usage_t::output; + if (arg == DNNL_ARG_WORKSPACE) + return is_training() ? arg_usage_t::output : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -323,14 +330,16 @@ struct rnn_fwd_pd_t : public rnn_pd_t { } protected: - rnn_fwd_pd_t(const rnn_desc_t *adesc, const primitive_attr_t *attr, + rnn_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd) : rnn_pd_t(adesc, attr, hint_fwd_pd) {} }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct rnn_bwd_pd_t : public rnn_pd_t { - typedef rnn_bwd_pd_t base_class; - typedef rnn_fwd_pd_t hint_class; + using base_class = rnn_bwd_pd_t; + using hint_class = rnn_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_SRC_LAYER, DNNL_ARG_DST_LAYER, @@ -342,53 +351,52 @@ struct rnn_bwd_pd_t : public rnn_pd_t { DNNL_ARG_DIFF_WEIGHTS_LAYER, DNNL_ARG_DIFF_WEIGHTS_ITER)) return arg_usage_t::output; - if (with_augru_attention()) { - if (arg == DNNL_ARG_AUGRU_ATTENTION) return arg_usage_t::input; - if (arg == DNNL_ARG_DIFF_AUGRU_ATTENTION) - return arg_usage_t::output; - } - - if (is_lstm_peephole()) { - if (arg == DNNL_ARG_WEIGHTS_PEEPHOLE) return arg_usage_t::input; - - if (arg == DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE) - return arg_usage_t::output; - } - - if (is_lstm_projection()) { - if (arg == DNNL_ARG_WEIGHTS_PROJECTION) return arg_usage_t::input; - - if (arg == DNNL_ARG_DIFF_WEIGHTS_PROJECTION) - return arg_usage_t::output; - } - - if (with_bias()) { - if (arg == DNNL_ARG_BIAS) return arg_usage_t::input; - - if (arg == DNNL_ARG_DIFF_BIAS) return arg_usage_t::output; - } - - if (with_src_iter()) { - if (arg == DNNL_ARG_SRC_ITER) return arg_usage_t::input; - - if (arg == DNNL_ARG_DIFF_SRC_ITER) return arg_usage_t::output; - } - - if (with_src_iter_c()) { - if (arg == DNNL_ARG_SRC_ITER_C) return arg_usage_t::input; - - if (arg == DNNL_ARG_DIFF_SRC_ITER_C) return arg_usage_t::output; - } - - if (with_dst_iter() - && utils::one_of( - arg, DNNL_ARG_DST_ITER, DNNL_ARG_DIFF_DST_ITER)) - return arg_usage_t::input; - - if (with_dst_iter_c() - && utils::one_of( - arg, DNNL_ARG_DST_ITER_C, DNNL_ARG_DIFF_DST_ITER_C)) - return arg_usage_t::input; + if (arg == DNNL_ARG_AUGRU_ATTENTION) + return with_augru_attention() ? arg_usage_t::input + : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_AUGRU_ATTENTION) + return with_augru_attention() ? arg_usage_t::output + : arg_usage_t::unused; + + if (arg == DNNL_ARG_WEIGHTS_PEEPHOLE) + return is_lstm_peephole() ? arg_usage_t::input + : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE) + return is_lstm_peephole() ? arg_usage_t::output + : arg_usage_t::unused; + + if (arg == DNNL_ARG_WEIGHTS_PROJECTION) + return is_lstm_projection() ? arg_usage_t::input + : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_WEIGHTS_PROJECTION) + return is_lstm_projection() ? arg_usage_t::output + : arg_usage_t::unused; + + if (arg == DNNL_ARG_BIAS) + return with_bias() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_BIAS) + return with_bias() ? arg_usage_t::output : arg_usage_t::unused; + + if (arg == DNNL_ARG_SRC_ITER) + return with_src_iter() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_SRC_ITER) + return with_src_iter() ? arg_usage_t::output : arg_usage_t::unused; + + if (arg == DNNL_ARG_SRC_ITER_C) + return with_src_iter_c() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_SRC_ITER_C) + return with_src_iter_c() ? arg_usage_t::output + : arg_usage_t::unused; + + if (arg == DNNL_ARG_DST_ITER) + return with_dst_iter() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_DST_ITER) + return with_dst_iter() ? arg_usage_t::input : arg_usage_t::unused; + + if (arg == DNNL_ARG_DST_ITER_C) + return with_dst_iter_c() ? arg_usage_t::input : arg_usage_t::unused; + if (arg == DNNL_ARG_DIFF_DST_ITER_C) + return with_dst_iter_c() ? arg_usage_t::input : arg_usage_t::unused; if (arg == DNNL_ARG_WORKSPACE) return arg_usage_t::input; @@ -521,7 +529,7 @@ struct rnn_bwd_pd_t : public rnn_pd_t { memory_desc_t diff_dst_iter_md_; memory_desc_t diff_dst_iter_c_md_; - rnn_bwd_pd_t(const rnn_desc_t *adesc, const primitive_attr_t *attr, + rnn_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd) : rnn_pd_t(adesc, attr, hint_fwd_pd) , diff_src_layer_md_(desc_.diff_src_layer_desc) @@ -536,6 +544,7 @@ struct rnn_bwd_pd_t : public rnn_pd_t { , diff_dst_iter_md_(desc_.diff_dst_iter_desc) , diff_dst_iter_c_md_(desc_.diff_dst_iter_c_desc) {} }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/scratchpad.hpp b/src/common/scratchpad.hpp index f837b75a28c..133b1ee34e8 100644 --- a/src/common/scratchpad.hpp +++ b/src/common/scratchpad.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2020 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ namespace dnnl { namespace impl { struct scratchpad_t { - virtual ~scratchpad_t() {} + virtual ~scratchpad_t() = default; virtual const memory_storage_t *get_memory_storage() const = 0; virtual size_t size() const = 0; }; diff --git a/src/common/sdpa_pd.hpp b/src/common/sdpa_pd.hpp index 9d95612cfd7..39a686abbeb 100644 --- a/src/common/sdpa_pd.hpp +++ b/src/common/sdpa_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,11 +27,6 @@ namespace dnnl { namespace impl { -#define DNNL_ARG_QUERIES DNNL_ARG_SRC_0 -#define DNNL_ARG_KEYS DNNL_ARG_SRC_1 -#define DNNL_ARG_VALUES DNNL_ARG_SRC_2 -#define DNNL_ARG_ATTN_MASK DNNL_ARG_SHIFT - #define VDISPATCH_SDPA(cond, msg, ...) \ VCONDCHECK(primitive, create, dispatch, sdpa, (cond), \ status::unimplemented, "%s," msg, this->info(engine), \ @@ -41,11 +36,12 @@ namespace impl { VCHECK(primitive, create, dispatch, sdpa, (f), "%s," msg, \ this->info(engine), ##__VA_ARGS__) +// NOLINTBEGIN(google-default-arguments) struct sdpa_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::sdpa; - typedef sdpa_pd_t base_class; - typedef sdpa_pd_t hint_class; + using base_class = sdpa_pd_t; + using hint_class = sdpa_pd_t; const sdpa_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { @@ -53,8 +49,15 @@ struct sdpa_pd_t : public primitive_desc_t { } arg_usage_t arg_usage(int arg) const override { + // TODO: this is broken for cases when the user passes quantization + // memories unconditionally but the primitive desc is not set up for + // quantization. if (utils::one_of(arg, DNNL_ARG_QUERIES, DNNL_ARG_KEYS, DNNL_ARG_VALUES, - DNNL_ARG_ATTN_MASK, DNNL_ARG_SCALE)) + DNNL_ARG_ATTN_MASK, DNNL_ARG_SCALE, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES)) return arg_usage_t::input; if (arg == DNNL_ARG_DST) return arg_usage_t::output; @@ -94,7 +97,9 @@ struct sdpa_pd_t : public primitive_desc_t { const memory_desc_t *val_md() const { return &desc_.v_desc; } const memory_desc_t *attn_mask_md() const { return &desc_.attn_mask_desc; } - int n_inputs() const override { return 3 + int(with_attn_mask()); } + int n_inputs() const override { + return 3 + int(with_attn_mask()) + int(with_attn_scale()); + } int n_outputs() const override { return 1; } bool with_attn_scale() const { @@ -105,12 +110,81 @@ struct sdpa_pd_t : public primitive_desc_t { return (attn_mask_md()->data_type != data_type::undef); } + /// If true, the attention mask is a causal mask + bool with_causal_mask() const { + return desc_.mask_type == attn_mask_type::top_left + || desc_.mask_type == attn_mask_type::bottom_right; + } + + /// If true, dequantize the K tensor using scaling in the KQ matmul + bool with_key_scales() const { + return (!desc()->kq_scales.has_default_values()); + } + + /// If true, dequantize the V tensor using scaling in the VS matmul + bool with_value_scales() const { + return (!desc()->vs_scales.has_default_values()); + } + + /// If true, dequantize the K tensor with zero points in the KQ matmul + bool with_key_zp() const { + return (!desc()->kq_zero_points.has_default_values(DNNL_ARG_WEIGHTS)); + } + + /// If true, dequantize the V tensor with zero points in the VS matmul + bool with_value_zp() const { + return (!desc()->vs_zero_points.has_default_values(DNNL_ARG_WEIGHTS)); + } + + /// Returns the data type of the scales tensor for the KQ matmul + data_type_t key_scales_dt() const { + return desc()->kq_scales.get_data_type(); + } + + /// Returns the data type of the zero points tensor for the KQ matmul + data_type_t key_zp_dt() const { + return desc()->kq_zero_points.get_data_type(DNNL_ARG_WEIGHTS); + } + + /// Returns the data type of the scales tensor for the VS matmul + data_type_t value_scales_dt() const { + return desc()->vs_scales.get_data_type(); + } + + /// Returns the data type of the zero points tensor for the VS matmul + data_type_t value_zp_dt() const { + return desc()->vs_zero_points.get_data_type(DNNL_ARG_WEIGHTS); + } + + // Returns the group size for the quantization parameters for the KQ matmul + int key_group_size() const { + int out = 0; + if (with_key_scales()) + out = scale_group_size(desc()->kq_scales, *key_md()); + else if (with_key_zp()) { + out = zp_group_size(desc()->kq_zero_points, *key_md()); + } + return out; + } + + // Returns the group size for the quantization parameters for the VS matmul + int value_group_size() const { + int out = 0; + if (with_value_scales()) + out = scale_group_size(desc()->vs_scales, *val_md()); + else if (with_value_zp()) { + out = zp_group_size(desc()->vs_zero_points, *val_md()); + } + return out; + } + protected: sdpa_desc_t desc_; - sdpa_pd_t(const sdpa_desc_t *adesc, const primitive_attr_t *attr, + sdpa_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const hint_class *hint_fwd_pd) - : primitive_desc_t(attr, base_pkind), desc_(*adesc) {} + : primitive_desc_t(attr, base_pkind) + , desc_(*op_desc_t::to_desc(adesc)) {} bool set_default_format(memory_desc_t *md) { memory_desc_wrapper mdw(md); @@ -132,7 +206,49 @@ struct sdpa_pd_t : public primitive_desc_t { return ok; } + +private: + static int scale_group_size( + const quant_entry_t &scales, const memory_desc_t &desc) { + dim_t out = utils::array_product(desc.dims, desc.ndims); + const auto mask = scales.get_mask(); + if (scales.has_default_groups()) { + for (int idx : mask_iterator(mask)) { + out /= desc.dims[idx]; + } + } else { + for (int idx : mask_iterator(mask)) { + if (idx < 2) { + out /= desc.dims[idx]; + } else { + out /= (desc.dims[idx] / scales.get_group(idx - 2)); + } + } + } + return static_cast(out); + } + + static int zp_group_size( + const zero_points_t &zp, const memory_desc_t &desc) { + dim_t out = utils::array_product(desc.dims, desc.ndims); + if (zp.get(DNNL_ARG_WEIGHTS).has_default_groups()) { + for (int idx : mask_iterator(zp.get_mask(DNNL_ARG_WEIGHTS))) { + out /= desc.dims[idx]; + } + } else { + for (int idx : mask_iterator(zp.get_mask(DNNL_ARG_WEIGHTS))) { + if (idx < 2) { + out /= desc.dims[idx]; + } else { + out /= (desc.dims[idx] + / zp.get_group(DNNL_ARG_WEIGHTS, idx - 2)); + } + } + } + return static_cast(out); + } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/sdpa_test_iface.cpp b/src/common/sdpa_test_iface.cpp new file mode 100644 index 00000000000..a7834210f51 --- /dev/null +++ b/src/common/sdpa_test_iface.cpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/primitive_desc_iface.hpp" +#include "common/sdpa_pd.hpp" +#include "common/sdpa_types.hpp" +#include "common/sdpa_utils.hpp" +#include "opdesc.hpp" + +using dnnl::impl::status_t; +using namespace dnnl::impl; + +dnnl_status_t DNNL_API sdpa_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, + const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, + const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t mask_desc, dnnl_data_type_t scale_dt, + bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, + const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, + const_dnnl_primitive_attr_t vs_attr) { + CHECK(sdpa_desc_check(query_desc, key_desc, value_desc, dst_desc, mask_desc, + engine, attr, kq_attr, vs_attr)); + CHECK(sdpa_attr_check( + query_desc, key_desc, value_desc, engine, attr, kq_attr, vs_attr)); + + dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc, + key_desc, value_desc, dst_desc, mask_desc, + (dnnl::impl::data_type_t)scale_dt, invert_scale, kv_head_number, + static_cast(attn_mask_type), kq_attr, vs_attr); + return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine, + (const dnnl::impl::op_desc_t *)&sdpa_desc, nullptr, attr); +} diff --git a/src/common/sdpa_types.hpp b/src/common/sdpa_types.hpp index 03fc9f67aaa..8b203a9e956 100644 --- a/src/common/sdpa_types.hpp +++ b/src/common/sdpa_types.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,28 +17,75 @@ #ifndef COMMON_SDPA_TYPES_HPP #define COMMON_SDPA_TYPES_HPP -#include +#include "oneapi/dnnl/dnnl_types.h" + #include "common/c_types_map.hpp" #include "common/memory_desc.hpp" +#include "common/primitive_attr_quant.hpp" + +#include namespace dnnl { namespace impl { +#define DNNL_ARG_QUERIES DNNL_ARG_SRC_0 +#define DNNL_ARG_KEYS DNNL_ARG_SRC_1 +#define DNNL_ARG_VALUES DNNL_ARG_SRC_2 +#define DNNL_ARG_ATTN_MASK DNNL_ARG_SHIFT + +// NOLINTBEGIN(modernize-use-using) +/// Types of attention mask +typedef enum { + dnnl_attn_mask_undef = 0, + /// explicit attention masks defined in a buffer + dnnl_attn_mask_buffer = 1, + + /// causal mask with the diagonal starting from the top left hand side of + /// the mask tensor + dnnl_attn_mask_top_left = 2, + + /// causal mask with the diagonal starting from the bottom right hand side + /// of the mask tensor + dnnl_attn_mask_bottom_right = 3, +} dnnl_attn_mask_type_t; +// NOLINTEND(modernize-use-using) + +using attn_mask_type_t = dnnl_attn_mask_type_t; +namespace attn_mask_type { +const attn_mask_type_t undef = dnnl_attn_mask_undef; +const attn_mask_type_t buffer = dnnl_attn_mask_buffer; +const attn_mask_type_t top_left = dnnl_attn_mask_top_left; +const attn_mask_type_t bottom_right = dnnl_attn_mask_bottom_right; +} // namespace attn_mask_type + // A descriptor for a scaled dot product attention (SDPA) operation. -struct sdpa_desc_t { - // The kind of primitive. Used for self identifying the primitive - // descriptor. Must be sdpa. - dnnl_primitive_kind_t primitive_kind; +struct sdpa_desc_t : public op_desc_t { + sdpa_desc_t() : op_desc_t(primitive_kind::sdpa) {} + + std::unique_ptr clone() const override { + return utils::make_unique(*this); + } + memory_desc_t q_desc; /* queries */ memory_desc_t k_desc; /* keys */ memory_desc_t v_desc; /* values */ + + // primitive_attr_t can't be used because of deleted copy-ctor, but desc_t + // must be copyable. + quant_entry_t kq_scales; + zero_points_t kq_zero_points; + quant_entry_t vs_scales; + zero_points_t vs_zero_points; + memory_desc_t dst_desc; memory_desc_t attn_mask_desc; - data_type_t scale_dt; + data_type_t scale_dt {}; // invert_scale = false: multiply by scale // invert_scale = true: divide by scale - bool invert_scale; - dim_t kv_head_number; + bool invert_scale {}; + dim_t kv_head_number {}; + + attn_mask_type_t mask_type = attn_mask_type::undef; // Number of queries. dnnl_dim_t queries() const { return q_desc.dims[q_desc.ndims - 2]; } diff --git a/src/common/sdpa_utils.hpp b/src/common/sdpa_utils.hpp index ccba17a2081..c72d6f0a0f4 100644 --- a/src/common/sdpa_utils.hpp +++ b/src/common/sdpa_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,20 +28,135 @@ namespace dnnl { namespace impl { +#define VCHECK_SDPA(f, msg, ...) \ + VCHECK(primitive, create, check, sdpa, (f), msg, ##__VA_ARGS__); + +#define VCHECK_SDPA_COND(cond, msg, ...) \ + VCONDCHECK(primitive, create, check, sdpa, (cond), \ + status::invalid_arguments, msg, ##__VA_ARGS__); + +#define VCHECK_SDPA_ATTR_TYPE( \ + variable_check, variable, attribute_member_name, expected_types) \ + VCONDCHECK(primitive, create, check, sdpa, (variable_check), \ + status::invalid_arguments, VERBOSE_INVALID_DATATYPE, \ + format_verbose_string(#variable attribute_member_name \ + "(%s). must be " expected_types, \ + attr2str(variable).c_str()) \ + .c_str()) + +#define VCHECK_SDPA_UNIMPL(cond, msg, ...) \ + VCONDCHECK(primitive, create, check, sdpa, (cond), status::unimplemented, \ + msg, ##__VA_ARGS__); + +static inline status_t sdpa_desc_check(const memory_desc_t *q_desc, + const memory_desc_t *k_desc, const memory_desc_t *v_desc, + const memory_desc_t *dst_desc, const memory_desc_t *attn_mask_md, + const engine_t *engine, const primitive_attr_t *attr, + const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) { + int ndims = dst_desc->ndims; + int r = ndims - 2, c = ndims - 1; + VCHECK_SDPA_COND(utils::everyone_is(ndims, q_desc->ndims, k_desc->ndims, + v_desc->ndims), + "number of dimensions have to match. expected: %d q: %d k: %d v: " + "%d", + ndims, q_desc->ndims, k_desc->ndims, v_desc->ndims); + + VCHECK_SDPA_COND(q_desc->dims[c] == k_desc->dims[r], + "q_desc->dims[%d](%s) must match k_desc->dims[%d](%s)", c, + md2dim_str(q_desc).c_str(), r, md2dim_str(k_desc).c_str()); + VCHECK_SDPA_COND(k_desc->dims[c] == v_desc->dims[r], + "k_desc->dims[%d](%s) must match v_desc->dims[%d](%s)", c, + md2dim_str(k_desc).c_str(), r, md2dim_str(v_desc).c_str()); + VCHECK_SDPA_COND(dst_desc->dims[r] == q_desc->dims[r], + "dst_desc->dims[%d](%s) == q_desc->dims[%d](%s)", r, + md2dim_str(dst_desc).c_str(), r, md2dim_str(q_desc).c_str()); + VCHECK_SDPA_COND(dst_desc->dims[c] == v_desc->dims[c], + "dst_desc->dims[%d](%s) == v_desc->dims[%d](%s)", c, + md2dim_str(dst_desc).c_str(), c, md2dim_str(v_desc).c_str()); + + return status::success; +} + +static inline status_t sdpa_attr_check(const memory_desc_t *q_desc, + const memory_desc_t *k_desc, const memory_desc_t *v_desc, + const engine_t *engine, const primitive_attr_t *attr, + const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) { + using smask_t = primitive_attr_t::skip_mask_t; + + if (utils::everyone_is(nullptr, attr, kq_attr, vs_attr)) + return status::success; + if (attr && attr->has_default_values() && kq_attr + && kq_attr->has_default_values() && vs_attr + && vs_attr->has_default_values()) { + return status::success; + } + + using namespace dnnl::impl::data_type; + if (kq_attr && !kq_attr->has_default_values()) { + const auto &sc = kq_attr->scales_; + const auto &zp = kq_attr->zero_points_; + if (!sc.has_default_values()) { + const auto &scale_dt = sc.get_data_type(DNNL_ARG_WEIGHTS); + VCHECK_SDPA_ATTR_TYPE(utils::one_of(scale_dt, f16, f32), kq_attr, + "scales", "f16 or f32"); + } + if (!zp.has_default_values()) { + const auto &zp_dt = zp.get_data_type(DNNL_ARG_WEIGHTS); + VCHECK_SDPA_ATTR_TYPE(utils::one_of(zp_dt, s4, u4, u8, s8, s32), + kq_attr, "zero_points", "u4, s4, u8, s8, or s32"); + } + } + + if (vs_attr && !vs_attr->has_default_values()) { + const auto &sc = vs_attr->scales_; + const auto &zp = vs_attr->zero_points_; + + if (!sc.has_default_values()) { + const auto &scale_dt = sc.get_data_type(DNNL_ARG_WEIGHTS); + VCHECK_SDPA_ATTR_TYPE(utils::one_of(scale_dt, f16, f32), vs_attr, + "scales", "f16 or f32"); + } + if (!zp.has_default_values()) { + const auto &zp_dt = zp.get_data_type(DNNL_ARG_WEIGHTS); + VCHECK_SDPA_ATTR_TYPE(utils::one_of(zp_dt, s4, u4, u8, s8, s32), + vs_attr, "zero_points", "u4, s4, u8, s8, or s32"); + } + } + + if (attr) { + smask_t attr_mask = smask_t::none; + VCHECK_SDPA_UNIMPL( + attr->has_default_values(attr_mask), VERBOSE_UNSUPPORTED_ATTR); + } + + return status::success; +} + static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, const memory_desc_t *k_md, const memory_desc_t *v_md, const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, - data_type_t scale_dt, dim_t kv_head_number, bool invert_scale = false) { + data_type_t scale_dt, bool invert_scale, dim_t kv_head_number, + attn_mask_type_t attn_mask_type, const primitive_attr_t *kq_attr, + const primitive_attr_t *vs_attr) { auto sdpa_desc = sdpa_desc_t(); sdpa_desc.primitive_kind = primitive_kind::sdpa; sdpa_desc.q_desc = *q_md; sdpa_desc.k_desc = *k_md; + if (kq_attr) { + sdpa_desc.kq_scales = kq_attr->scales_.get(DNNL_ARG_WEIGHTS); + sdpa_desc.kq_zero_points = kq_attr->zero_points_; + } + if (vs_attr) { + sdpa_desc.vs_scales = vs_attr->scales_.get(DNNL_ARG_WEIGHTS); + sdpa_desc.vs_zero_points = vs_attr->zero_points_; + } sdpa_desc.v_desc = *v_md; sdpa_desc.dst_desc = *dst_md; if (attn_mask_md) sdpa_desc.attn_mask_desc = *attn_mask_md; sdpa_desc.scale_dt = scale_dt; sdpa_desc.invert_scale = invert_scale; sdpa_desc.kv_head_number = kv_head_number; + sdpa_desc.mask_type = attn_mask_type; return sdpa_desc; } @@ -50,26 +165,25 @@ static inline status_t create_sdpa_pd( const memory_desc_t *q_md, const memory_desc_t *k_md, const memory_desc_t *v_md, const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, data_type_t scale_dt, - bool invert_scale, const primitive_attr_t *attr, dim_t kv_head_number) { - auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, - scale_dt, kv_head_number, invert_scale); + bool invert_scale, dim_t kv_head_number, + attn_mask_type_t attn_mask_type, const primitive_attr_t *attr, + const primitive_attr_t *kq_attr = nullptr, + const primitive_attr_t *vs_attr = nullptr) { + CHECK(sdpa_attr_check(q_md, k_md, v_md, engine, attr, kq_attr, vs_attr)); + CHECK(sdpa_desc_check(q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr, + kq_attr, vs_attr)); - int ndims = dst_md->ndims; - int r = ndims - 2, c = ndims - 1; - if (!utils::everyone_is(ndims, q_md->ndims, k_md->ndims, v_md->ndims)) - return status::invalid_arguments; - if (q_md->dims[c] != k_md->dims[r]) return status::invalid_arguments; - if (k_md->dims[c] != v_md->dims[r]) return status::invalid_arguments; - if (dst_md->dims[r] != q_md->dims[r] || dst_md->dims[c] != v_md->dims[c]) - return status::invalid_arguments; + auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, + scale_dt, invert_scale, kv_head_number, attn_mask_type, kq_attr, + vs_attr); - primitive_attr_t sdpa_attr = *attr; + primitive_attr_t sdpa_attr = attr ? *attr : default_attr(); primitive_desc_iterator_t it( engine, (op_desc_t *)&sdpa_desc, &sdpa_attr, nullptr); sdpa_pd_ = *(++it); - if (!sdpa_pd_) return status::unimplemented; + VCHECK_SDPA_COND(sdpa_pd_, "failed to create the SDPA primitive"); return status::success; } diff --git a/src/common/serialization.cpp b/src/common/serialization.cpp deleted file mode 100644 index 035733db406..00000000000 --- a/src/common/serialization.cpp +++ /dev/null @@ -1,618 +0,0 @@ -/******************************************************************************* -* Copyright 2021-2024 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "common/serialization.hpp" -#include "common/type_helpers.hpp" -#include "common/utils.hpp" - -namespace dnnl { -namespace impl { -namespace serialization { - -status_t serialize_desc( - serialization_stream_t &sstream, const op_desc_t *op_desc) { -#define CASE(pkind) \ - case primitive_kind::pkind: \ - serialize_desc(sstream, *(const pkind##_desc_t *)op_desc); \ - break; - - switch ((int)op_desc->kind) { - CASE(batch_normalization) - CASE(binary) - CASE(concat) - CASE(convolution) - CASE(deconvolution) - CASE(eltwise) - CASE(gemm) - CASE(group_normalization) - CASE(inner_product) - CASE(layer_normalization) - CASE(lrn) - CASE(matmul) - CASE(pooling) - CASE(prelu) - CASE(reduction) - CASE(reorder) - CASE(resampling) - CASE(rnn) - CASE(sdpa) - CASE(shuffle) - CASE(softmax) - CASE(sum) - default: return status::invalid_arguments; - } -#undef CASE - return status::success; -} - -void serialize_md(serialization_stream_t &sstream, const memory_desc_t &md) { - sstream.write(&md.ndims); - sstream.write(md.dims, md.ndims); - sstream.write(&md.data_type); - sstream.write(md.padded_dims, md.ndims); - sstream.write(md.padded_offsets, md.ndims); - sstream.write(&md.offset0); - sstream.write(&md.format_kind); - // format desc - switch ((int)md.format_kind) { - case format_kind::undef: - case format_kind::any: break; - case format_kind::blocked: - sstream.write(md.format_desc.blocking.strides, md.ndims); - sstream.write(&md.format_desc.blocking.inner_nblks); - sstream.write(md.format_desc.blocking.inner_blks, - md.format_desc.blocking.inner_nblks); - sstream.write(md.format_desc.blocking.inner_idxs, - md.format_desc.blocking.inner_nblks); - break; - case format_kind::wino: - sstream.write(&md.format_desc.wino_desc.wino_format); - sstream.write(&md.format_desc.wino_desc.r); - sstream.write(&md.format_desc.wino_desc.alpha); - sstream.write(&md.format_desc.wino_desc.ic); - sstream.write(&md.format_desc.wino_desc.oc); - sstream.write(&md.format_desc.wino_desc.ic_block); - sstream.write(&md.format_desc.wino_desc.oc_block); - sstream.write(&md.format_desc.wino_desc.ic2_block); - sstream.write(&md.format_desc.wino_desc.oc2_block); - sstream.write(&md.format_desc.wino_desc.adj_scale); - sstream.write(&md.format_desc.wino_desc.size); - break; - case format_kind::rnn_packed: - sstream.write(&md.format_desc.rnn_packed_desc.format); - sstream.write(&md.format_desc.rnn_packed_desc.n_parts); - sstream.write(&md.format_desc.rnn_packed_desc.n); - sstream.write(&md.format_desc.rnn_packed_desc.ldb); - { - int n_parts = md.format_desc.rnn_packed_desc.n_parts; - sstream.write(md.format_desc.rnn_packed_desc.parts, n_parts); - sstream.write( - md.format_desc.rnn_packed_desc.part_pack_size, n_parts); - sstream.write( - md.format_desc.rnn_packed_desc.pack_part, n_parts); - } - sstream.write(&md.format_desc.rnn_packed_desc.offset_compensation); - sstream.write(&md.format_desc.rnn_packed_desc.size); - break; - default: assert(!"unknown format_kind"); - } - - if (md.extra.flags != dnnl_memory_extra_flag_none) { - sstream.write(&md.extra.flags); - if ((md.extra.flags - & (dnnl_memory_extra_flag_compensation_conv_s8s8 - | dnnl_memory_extra_flag_rnn_u8s8_compensation)) - && !types::extra_flag_rnn_s8s8_compensation_is_set( - md.extra.flags)) { - sstream.write(&md.extra.compensation_mask); - } - - if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) { - sstream.write(&md.extra.scale_adjust); - } - - if (md.extra.flags - & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) { - sstream.write(&md.extra.asymm_compensation_mask); - } - } -} - -void serialize_post_ops( - serialization_stream_t &sstream, const post_ops_t &post_ops) { - // post_ops: entry[:] - for (int i = 0; i < post_ops.len(); i++) { - const auto &entry = post_ops.entry_[i]; - switch (entry.kind) { - case primitive_kind::eltwise: - sstream.write(&entry.eltwise.alg); - sstream.write(&entry.eltwise.scale); - sstream.write(&entry.eltwise.alpha); - sstream.write(&entry.eltwise.beta); - break; - case primitive_kind::sum: - sstream.write(&entry.sum.scale); - sstream.write(&entry.sum.zero_point); - sstream.write(&entry.sum.dt); - break; - case primitive_kind::convolution: - sstream.write(&entry.depthwise_conv.kernel); - sstream.write(&entry.depthwise_conv.stride); - sstream.write(&entry.depthwise_conv.padding); - sstream.write(&entry.depthwise_conv.wei_dt); - sstream.write(&entry.depthwise_conv.bias_dt); - sstream.write(&entry.depthwise_conv.dst_dt); - break; - case primitive_kind::binary: - sstream.write(&entry.binary.alg); - serialize_md(sstream, entry.binary.user_src1_desc); - break; - case primitive_kind::prelu: sstream.write(&entry.prelu.mask); break; - default: assert(!"unknown post_op"); - } - } -} - -void serialize_attr( - serialization_stream_t &sstream, const primitive_attr_t &attr) { - // scratchpad_mode - sstream.write(&attr.scratchpad_mode_); - // fpmath_mode - sstream.write(&attr.fpmath_.mode_); - sstream.write(&attr.fpmath_.apply_to_int_); - // deterministic - sstream.write(&attr.deterministic_); - // acc_mode - sstream.write(&attr.acc_mode_); - - if (!attr.output_scales_.has_default_values()) { - // output_scales: mask - sstream.write(&attr.output_scales_.mask_); - } else if (!attr.scales_.has_default_values()) { - sstream.write("scale:"); - // go through scales for all arguments - for (const auto &p : attr.scales_.scales_) { - // scales: arg - sstream.write(&p.first); - // scales: mask - sstream.write(&p.second.mask_); - // scales: groups - const int ndims = p.second.ndims_; - sstream.write(&ndims); - if (ndims > 0) sstream.write(p.second.group_dims_, ndims); - // scales: data type - sstream.write(&p.second.data_type_); - } - } - // zero_points - if (!attr.zero_points_.has_default_values()) sstream.write("zp:"); - for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) - if (!attr.zero_points_.has_default_values(arg)) { - const auto &zps = attr.zero_points_; - // zero_points: arg - sstream.write(&arg); - int mask = 0; - data_type_t dt = data_type::s32; - zps.get(arg, &mask, &dt); - // zero_points: mask - sstream.write(&mask); - // zero points: groups - const int ndims = zps.get_groups_ndims(arg); - sstream.write(&ndims); - if (ndims > 0) sstream.write(zps.get_groups(arg), ndims); - // zero_points: data type - sstream.write(&dt); - } - - // Rounding modes - if (!attr.rounding_mode_.has_default_values()) sstream.write("rm:"); - for (const auto &e : attr.rounding_mode_.rounding_modes_map_) { - if (!attr.rounding_mode_.has_default_values(e.first)) { - sstream.write(&e.first); - sstream.write(&e.second); - } - } - - if (!attr.dropout_.has_default_values()) { - sstream.write("dropout:"); - serialize_md(sstream, attr.dropout_.user_dropout_desc_); - } - - serialize_post_ops(sstream, attr.post_ops_); - - // rnn_data_qparams: scale, shift - sstream.write(&attr.rnn_data_qparams_.scale_); - sstream.write(&attr.rnn_data_qparams_.shift_); - if (!attr.rnn_weights_qparams_.has_default_values()) { - // rnn_weights_qparams: mask - sstream.write(&attr.rnn_weights_qparams_.mask_); - // rnn_weights_qparams: count - sstream.write(&attr.rnn_weights_qparams_.count_); - // rnn_weights_qparams: scales[:] - sstream.write(attr.rnn_weights_qparams_.scales_, - attr.rnn_weights_qparams_.count_); - } - if (attr.gpu_attr_) { - attr.gpu_attr_->serialize(sstream); - } else { - int zero = 0; - sstream.write(&zero); - } -} - -void serialize_desc( - serialization_stream_t &sstream, const concat_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - // Memory descriptors - serialize_md(sstream, *desc.dst_md); - // N - sstream.write(&desc.n); - // Concat dimension - sstream.write(&desc.concat_dimension); - // Array of mds - for (int i = 0; i < desc.n; i++) - serialize_md(sstream, *desc.src_mds[i]); -} - -void serialize_desc(serialization_stream_t &sstream, - const batch_normalization_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.diff_dst_desc); - serialize_md(sstream, desc.scaleshift_desc); - serialize_md(sstream, desc.diff_scaleshift_desc); - serialize_md(sstream, desc.stat_desc); - // Epsilon - sstream.write(&desc.batch_norm_epsilon); - // Flags - sstream.write(&desc.flags); -} - -void serialize_desc( - serialization_stream_t &sstream, const binary_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.alg_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc[0]); - serialize_md(sstream, desc.src_desc[1]); - serialize_md(sstream, desc.dst_desc); -} - -// (De-)Convolution -void serialize_desc( - serialization_stream_t &sstream, const convolution_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - sstream.write(&desc.alg_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.weights_desc); - serialize_md(sstream, desc.diff_weights_desc); - serialize_md(sstream, desc.bias_desc); - serialize_md(sstream, desc.diff_bias_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_dst_desc); - // Strides, dilates, padding - sstream.write(desc.strides, DNNL_MAX_NDIMS); - sstream.write(desc.dilates, DNNL_MAX_NDIMS); - sstream.write(desc.padding[0], DNNL_MAX_NDIMS); - sstream.write(desc.padding[1], DNNL_MAX_NDIMS); - // Accumulator type - sstream.write(&desc.accum_data_type); - // Internal member - sstream.write(&desc.use_inversion); -} - -// Eltwise -void serialize_desc( - serialization_stream_t &sstream, const eltwise_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - sstream.write(&desc.alg_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.diff_dst_desc); - // Alpha, beta - sstream.write(&desc.alpha); - sstream.write(&desc.beta); -} - -void serialize_desc(serialization_stream_t &sstream, const gemm_desc_t &desc) { - // Kind - sstream.write(&desc.primitive_kind); - serialize_md(sstream, desc.a_desc); - serialize_md(sstream, desc.b_desc); - serialize_md(sstream, desc.c_desc); - serialize_md(sstream, desc.bias_desc); - // Accumulator type - sstream.write(&desc.acc_type); - sstream.write(&desc.sum_ab); - sstream.write(&desc.sum_ab_type); -} - -void serialize_desc(serialization_stream_t &sstream, - const group_normalization_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.diff_dst_desc); - serialize_md(sstream, desc.scaleshift_desc); - serialize_md(sstream, desc.diff_scaleshift_desc); - serialize_md(sstream, desc.stat_desc); - // Groups - sstream.write(&desc.groups); - // Epsilon - sstream.write(&desc.group_norm_epsilon); - // Flags - sstream.write(&desc.flags); -} - -void serialize_desc( - serialization_stream_t &sstream, const inner_product_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.weights_desc); - serialize_md(sstream, desc.diff_weights_desc); - serialize_md(sstream, desc.bias_desc); - serialize_md(sstream, desc.diff_bias_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_dst_desc); - // Accumulator type - sstream.write(&desc.accum_data_type); -} - -void serialize_desc(serialization_stream_t &sstream, - const layer_normalization_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.data_scaleshift_desc); - serialize_md(sstream, desc.diff_data_scaleshift_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_dst_desc); - serialize_md(sstream, desc.stat_desc); - // Epsilon - sstream.write(&desc.layer_norm_epsilon); - // Flags - sstream.write(&desc.flags); -} - -void serialize_desc(serialization_stream_t &sstream, const lrn_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - sstream.write(&desc.alg_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.diff_dst_desc); - // Local size - sstream.write(&desc.local_size); - // Alpha, beta - sstream.write(&desc.lrn_alpha); - sstream.write(&desc.lrn_beta); - // k - sstream.write(&desc.lrn_k); -} - -void serialize_desc( - serialization_stream_t &sstream, const matmul_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.weights_desc); - serialize_md(sstream, desc.bias_desc); - serialize_md(sstream, desc.dst_desc); - // Accumulator type - sstream.write(&desc.accum_data_type); -} - -void serialize_desc( - serialization_stream_t &sstream, const pooling_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - sstream.write(&desc.alg_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_dst_desc); - // Strides, dilates, padding - sstream.write(desc.strides, DNNL_MAX_NDIMS); - sstream.write(desc.kernel, DNNL_MAX_NDIMS); - sstream.write(desc.padding[0], DNNL_MAX_NDIMS); - sstream.write(desc.padding[1], DNNL_MAX_NDIMS); - sstream.write(desc.dilation, DNNL_MAX_NDIMS); - // Accumulator type - sstream.write(&desc.accum_data_type); -} - -void serialize_desc(serialization_stream_t &sstream, const prelu_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.weights_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.diff_weights_desc); - serialize_md(sstream, desc.diff_dst_desc); -} - -void serialize_desc( - serialization_stream_t &sstream, const reduction_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.alg_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.dst_desc); - // P, eps - sstream.write(&desc.p); - sstream.write(&desc.eps); -} - -void serialize_desc( - serialization_stream_t &sstream, const reorder_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - // Memory descriptors - serialize_md(sstream, *desc.src_md); - serialize_md(sstream, *desc.dst_md); - // Kinds of source and destination engines - sstream.write(&desc.src_engine_kind); - sstream.write(&desc.dst_engine_kind); - sstream.write(&desc.is_cross_engine); -} - -void serialize_desc( - serialization_stream_t &sstream, const resampling_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.alg_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_dst_desc); - // Factors - sstream.write(desc.factors, DNNL_MAX_NDIMS); -} - -void serialize_desc(serialization_stream_t &sstream, const rnn_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - sstream.write(&desc.cell_kind); - sstream.write(&desc.direction); - // Memory descriptors - serialize_md(sstream, desc.src_layer_desc); - serialize_md(sstream, desc.src_iter_desc); - serialize_md(sstream, desc.src_iter_c_desc); - serialize_md(sstream, desc.weights_layer_desc); - serialize_md(sstream, desc.weights_iter_desc); - serialize_md(sstream, desc.bias_desc); - serialize_md(sstream, desc.dst_layer_desc); - serialize_md(sstream, desc.dst_iter_desc); - serialize_md(sstream, desc.dst_iter_c_desc); - serialize_md(sstream, desc.weights_peephole_desc); - serialize_md(sstream, desc.weights_projection_desc); - serialize_md(sstream, desc.diff_src_layer_desc); - serialize_md(sstream, desc.diff_src_iter_desc); - serialize_md(sstream, desc.diff_src_iter_c_desc); - serialize_md(sstream, desc.diff_weights_layer_desc); - serialize_md(sstream, desc.diff_weights_iter_desc); - serialize_md(sstream, desc.diff_bias_desc); - serialize_md(sstream, desc.diff_dst_layer_desc); - serialize_md(sstream, desc.diff_dst_iter_desc); - serialize_md(sstream, desc.diff_dst_iter_c_desc); - serialize_md(sstream, desc.diff_weights_peephole_desc); - serialize_md(sstream, desc.diff_weights_projection_desc); - // Flags - sstream.write(&desc.flags); - // Activation kind - sstream.write(&desc.activation_kind); - // Alpha, beta - sstream.write(&desc.alpha); - sstream.write(&desc.beta); -} - -// Shuffle -void serialize_desc( - serialization_stream_t &sstream, const shuffle_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.dst_desc); - // Axis - sstream.write(&desc.axis); - // Groupe size - sstream.write(&desc.group_size); -} - -void serialize_desc( - serialization_stream_t &sstream, const softmax_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - sstream.write(&desc.prop_kind); - sstream.write(&desc.alg_kind); - // Memory descriptors - serialize_md(sstream, desc.src_desc); - serialize_md(sstream, desc.diff_src_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.diff_dst_desc); - // Axis - sstream.write(&desc.softmax_axis); -} - -void serialize_desc(serialization_stream_t &sstream, const sum_desc_t &desc) { - // Kinds - sstream.write(&desc.primitive_kind); - // Memory descriptors - serialize_md(sstream, *desc.dst_md); - // N - sstream.write(&desc.n); - // Scales - sstream.write(desc.scales, desc.n); - // Array of mds - for (int i = 0; i < desc.n; i++) - serialize_md(sstream, *desc.src_mds[i]); -} - -void serialize_desc(serialization_stream_t &sstream, const sdpa_desc_t &desc) { - // Kind - sstream.write(&desc.primitive_kind); - serialize_md(sstream, desc.q_desc); - serialize_md(sstream, desc.k_desc); - serialize_md(sstream, desc.v_desc); - serialize_md(sstream, desc.dst_desc); - serialize_md(sstream, desc.attn_mask_desc); - sstream.write(&desc.scale_dt); - sstream.write(&desc.invert_scale); -} - -} // namespace serialization -} // namespace impl -} // namespace dnnl diff --git a/src/common/serialization.hpp b/src/common/serialization.hpp index afd4ffba136..f575b140979 100644 --- a/src/common/serialization.hpp +++ b/src/common/serialization.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,59 +17,272 @@ #ifndef COMMON_SERIALIZATION_HPP #define COMMON_SERIALIZATION_HPP -#include "common/c_types_map.hpp" -#include "common/primitive_attr.hpp" -#include "common/serialization_stream.hpp" -#include "common/type_helpers.hpp" -#include "oneapi/dnnl/dnnl.h" +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.hpp" namespace dnnl { namespace impl { -namespace serialization { - -void serialize_post_ops( - serialization_stream_t &sstream, const post_ops_t &post_ops); -void serialize_attr( - serialization_stream_t &sstream, const primitive_attr_t &attr); -void serialize_md(serialization_stream_t &sstream, const memory_desc_t &md); -void serialize_desc(serialization_stream_t &sstream, const concat_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, - const batch_normalization_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, const binary_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const convolution_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const eltwise_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, const gemm_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, - const group_normalization_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const inner_product_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, - const layer_normalization_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, const lrn_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, const matmul_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const pooling_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, const prelu_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const reduction_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const reorder_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const resampling_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, const rnn_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, const sdpa_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const shuffle_desc_t &desc); -void serialize_desc( - serialization_stream_t &sstream, const softmax_desc_t &desc); -void serialize_desc(serialization_stream_t &sstream, const sum_desc_t &desc); - -status_t serialize_desc( - serialization_stream_t &sstream, const op_desc_t *op_desc); - -} // namespace serialization + +#define DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(cls) \ + static_assert(serialization_stream_t::is_trivially_serialized::value, \ + #cls " must be trivially serializable.") + +struct serialization_stream_t { + serialization_stream_t() = default; + + template + serialization_stream_t(const Arg1 &a1, const Args &...args) { + append(a1, args...); + } + + static serialization_stream_t from_data(std::vector data) { + serialization_stream_t s; + s.data_ = std::move(data); + return s; + } + + bool operator==(const serialization_stream_t &other) const { + return data_ == other.data_; + } + +#if defined(__cpp_lib_has_unique_object_representations) \ + && __cpp_lib_has_unique_object_representations >= 201606L + template + struct is_trivially_serialized { + static const bool value + = (std::has_unique_object_representations::value + || std::is_floating_point::value) + && !(std::is_pointer::value); + }; + +#else + // Fallback for backward compatibility. As the structure layout should not + // change between c++ versions, compiling with c++17 will already verify the + // structures are valid for this use case. + template + struct is_trivially_serialized { + static const bool value = std::is_trivially_copyable::value + && !(std::is_pointer::value); + }; +#endif + + template + struct has_serialize_t { + static const bool value = false; + }; + + template + struct has_serialize_t().serialize( + std::declval()))> { + static const bool value = true; + }; + + // Append helper function for structures with the member function + // void serialize(serialization_stream_t &) const + template ::value, bool> = true> + void append(const T &t) { + t.serialize(*this); + } + + // Append helper function for trivially serialized objects + template ::value + && !has_serialize_t::value, + bool> = true> + void append(const T &t) { + std::array type_data; + std::memcpy(type_data.data(), &t, sizeof(T)); + data_.insert(data_.end(), type_data.begin(), type_data.end()); + } + + template ::value, bool> = true> + void append(const T &v) { + append(v.size()); + for (const typename T::value_type &d : v) + append(d); + } + + template + void append(const Arg1 &a1, const Arg2 &a2, const Args &...args) { + append(a1); + append(a2, args...); + } + + template ::value, bool> = true> + void append_array(size_t size, const T *ptr) { + append(size); + const auto *p = reinterpret_cast(ptr); + data_.insert(data_.end(), p, p + sizeof(T) * size); + } + + template ::value, bool> = true> + T get(size_t idx) const { + T t {}; + if (data_.size() < idx + sizeof(T)) { + assert(!"unexpected"); + return t; + } + std::memcpy(&t, &data_[idx], sizeof(T)); + return t; + } + + void get(size_t idx, size_t size, uint8_t *ptr) const { + if (data_.size() < idx + size) { + assert(!"unexpected"); + return; + } + std::memcpy(ptr, &data_[idx], size); + } + + size_t get_hash() const { return hash_range(data_.data(), data_.size()); } + + template + static size_t get_hash(const T &t) { + return serialization_stream_t(t).get_hash(); + } + + std::string str() { + std::ostringstream oss; + oss << std::hex << std::setfill('0'); + for (auto c : data_) { + oss << std::setw(2) << static_cast(c); + } + return oss.str(); + } + + bool empty() const { return data_.empty(); } + + const std::vector &get_data() const { return data_; } + +private: + static size_t hash_range(const uint8_t *v, size_t size) { + size_t seed = 0; + const uint8_t *end = v + size; + for (; v < end; v += sizeof(seed)) { + size_t value = 0; + std::memcpy(&value, v, + std::min(static_cast(end - v), sizeof(seed))); + seed = hash_combine(seed, value); + } + + return seed; + } + + std::vector data_; +}; + +struct deserializer_t { + deserializer_t(const serialization_stream_t &sstream) + : idx_(0), sstream_(sstream) {} + + template + struct has_deserialize_t { + using yes_t = uint8_t; + using no_t = uint16_t; + + template + static yes_t test( + utils::enable_if_t::value, + bool>); + template + static no_t test(...); + + static const bool value = (sizeof(test(0)) == sizeof(yes_t)); + }; + + // Helper function for structures with the static member function + // void deserialize(deserializer_t&) + template ::value, bool> = true> + void pop(T &t) { + t = T::deserialize(*this); + } + template ::value, bool> = true> + T pop() { + return T::deserialize(*this); + } + + template ::value + && !has_deserialize_t::value, + bool> = true> + void pop(T &t) { + t = sstream_.get(idx_); + idx_ += sizeof(T); + } + + template ::value + && !has_deserialize_t::value, + bool> = true> + T pop() { + auto idx_start = idx_; + idx_ += sizeof(T); + return sstream_.get(idx_start); + } + + // Helper for vector types + template ::value, bool> = true> + void pop(T &v) { + size_t size; + pop(size); + v.clear(); + v.reserve(size); + for (size_t i = 0; i < size; i++) { + typename T::value_type t = {}; + pop(t); + v.emplace_back(t); + } + } + + template ::value, + bool> = true> + void pop_array(size_t &size, T *ptr) { + pop(size); + sstream_.get(idx_, sizeof(T) * size, reinterpret_cast(ptr)); + idx_ += sizeof(T) * size; + } + + bool empty() const { return idx_ >= sstream_.get_data().size(); } + +private: + size_t idx_ = 0; + const serialization_stream_t &sstream_; +}; + +template +struct trivially_serializable_t { + static constexpr bool is_trivially_validatable = true; + + serialization_stream_t serialize() const { + DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(T); + return serialization_stream_t(*static_cast(this)); + } + + static T deserialize(const serialization_stream_t &s) { + return deserializer_t(s).pop(); + } +}; + } // namespace impl } // namespace dnnl diff --git a/src/common/serialization_stream.hpp b/src/common/serialization_stream.hpp deleted file mode 100644 index 28eb32aad61..00000000000 --- a/src/common/serialization_stream.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************* -* Copyright 2021-2023 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef COMMON_SERIALIZATION_STREAM_HPP -#define COMMON_SERIALIZATION_STREAM_HPP - -#include -#include -#include - -namespace dnnl { -namespace impl { - -struct serialization_stream_t { - serialization_stream_t() = default; - - template - void write(const T ptr, size_t nelems = 1) { - using non_pointer_type = typename std::remove_pointer::type; - - static_assert(std::is_pointer::value, - "T is expected to be a pointer type."); - static_assert(!std::is_pointer::value, - "T cannot be a pointer to pointer."); - static_assert(!std::is_class::value, - "non-pointer type is expected to be a trivial type to avoid " - "padding issues."); - static_assert(!std::is_array::value, - "non-pointer type cannot be an array."); - - write_impl((const void *)ptr, sizeof(non_pointer_type) * nelems); - } - - bool empty() const { return data_.empty(); } - - const std::vector &get_data() const { return data_; } - -private: - void write_impl(const void *ptr, size_t size) { - const auto *p = reinterpret_cast(ptr); - data_.insert(data_.end(), p, p + size); - } - - std::vector data_; -}; - -} // namespace impl -} // namespace dnnl - -#endif diff --git a/src/common/shuffle_pd.hpp b/src/common/shuffle_pd.hpp index dec26b107f2..5a2886ee210 100644 --- a/src/common/shuffle_pd.hpp +++ b/src/common/shuffle_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,11 +34,12 @@ namespace dnnl { namespace impl { +// NOLINTBEGIN(google-default-arguments) struct shuffle_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::shuffle; - typedef shuffle_pd_t base_class; - typedef shuffle_pd_t hint_class; + using base_class = shuffle_pd_t; + using hint_class = shuffle_pd_t; const shuffle_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { @@ -145,10 +146,10 @@ struct shuffle_pd_t : public primitive_desc_t { memory_desc_t src_md_; memory_desc_t dst_md_; - shuffle_pd_t(const shuffle_desc_t *adesc, const primitive_attr_t *attr, + shuffle_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const shuffle_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) , src_md_(desc_.src_desc) , dst_md_(desc_.dst_desc) { @@ -179,6 +180,7 @@ struct shuffle_pd_t : public primitive_desc_t { return is_fwd() ? src_md() : diff_src_md(); } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/softmax.cpp b/src/common/softmax.cpp index 94e6e9c4ca5..77abe54034b 100644 --- a/src/common/softmax.cpp +++ b/src/common/softmax.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -107,18 +107,26 @@ status_t softmax_attr_check(const softmax_desc_t &desc, const engine_t *engine, const bool is_int8 = utils::one_of(src_dt, data_type::s8, data_type::u8) || utils::one_of(dst_dt, data_type::s8, data_type::u8); - if (is_int8) fwd_attr_mask |= smask_t::scales_runtime; + if (is_int8) fwd_attr_mask |= smask_t::scales; VCHECK_SOFTMAX_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); + // Check scales if (!attr->scales_.has_default_values()) { - const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - VCHECK_SOFTMAX_UNIMPL(utils::everyone_is(0, mask_src, mask_dst), + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_DST}; + VCHECK_SOFTMAX_UNIMPL( + attr->scales_.has_default_values(supported_args), VERBOSE_UNSUPPORTED_SCALES_CFG); + + for (int arg : supported_args) { + if (attr->scales_.has_default_values(arg)) continue; + + const int mask = attr->scales_.get_mask(arg); + VCHECK_SOFTMAX_UNIMPL( + mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } } // Check post-ops @@ -127,6 +135,9 @@ status_t softmax_attr_check(const softmax_desc_t &desc, const engine_t *engine, using namespace primitive_kind; VCHECK_SOFTMAX_UNIMPL(po.has_default_values({binary, eltwise}), VERBOSE_UNSUPPORTED_POSTOP); + + // Note: verbose support is inside the call. + CHECK(po.validate_binary_with_dst_consistency(&desc.dst_desc)); } } else { VCHECK_SOFTMAX_UNIMPL(false, VERBOSE_UNSUPPORTED_ATTR); diff --git a/src/common/softmax_pd.hpp b/src/common/softmax_pd.hpp index e42ae59e1e2..623772a2c61 100644 --- a/src/common/softmax_pd.hpp +++ b/src/common/softmax_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -108,10 +108,10 @@ struct softmax_pd_t : public primitive_desc_t { memory_desc_t dst_md_; - softmax_pd_t(const softmax_desc_t *adesc, const primitive_attr_t *attr, + softmax_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const softmax_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*adesc) + , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) , dst_md_(desc_.dst_desc) {} @@ -119,17 +119,19 @@ struct softmax_pd_t : public primitive_desc_t { const memory_desc_t &dst_desc() const { return dst_md_; } }; +// NOLINTBEGIN(google-default-arguments) struct softmax_fwd_pd_t : public softmax_pd_t { - typedef softmax_fwd_pd_t base_class; - typedef softmax_fwd_pd_t hint_class; + using base_class = softmax_fwd_pd_t; + using hint_class = softmax_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (arg == DNNL_ARG_SRC) return arg_usage_t::input; if (arg == DNNL_ARG_DST) return arg_usage_t::output; - if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md()))) - return arg_usage_t::output; + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::output + : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -162,7 +164,7 @@ struct softmax_fwd_pd_t : public softmax_pd_t { protected: memory_desc_t src_md_; - softmax_fwd_pd_t(const softmax_desc_t *adesc, const primitive_attr_t *attr, + softmax_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const softmax_fwd_pd_t *hint_fwd_pd) : softmax_pd_t(adesc, attr, hint_fwd_pd), src_md_(desc_.src_desc) {} @@ -176,19 +178,28 @@ struct softmax_fwd_pd_t : public softmax_pd_t { dst_md_, src_md_.format_desc.blocking); } - bool attr_scales_ok() const { + bool attr_scales_ok(const std::vector &supported_args + = {DNNL_ARG_SRC, DNNL_ARG_DST}) const { const auto &scales = attr()->scales_; - bool ok = true; - for (const auto &e : scales.scales_) { - ok = ok && e.second.mask_ == 0; + bool ok = scales.has_default_values(supported_args); + + for (const auto &arg : supported_args) { + if (scales.has_default_values(arg)) continue; + + // TODO: disallow non-int8 scales? + // const data_type_t dt = arg_md(arg)->data_type; + // ok = ok && utils::one_of(dt, s8, u8); + ok = ok && scales.get_mask(arg) == 0; } return ok; } }; +// NOLINTEND(google-default-arguments) +// NOLINTBEGIN(google-default-arguments) struct softmax_bwd_pd_t : public softmax_pd_t { - typedef softmax_bwd_pd_t base_class; - typedef softmax_fwd_pd_t hint_class; + using base_class = softmax_bwd_pd_t; + using hint_class = softmax_fwd_pd_t; arg_usage_t arg_usage(int arg) const override { if (utils::one_of(arg, DNNL_ARG_DST, DNNL_ARG_DIFF_DST)) @@ -196,8 +207,9 @@ struct softmax_bwd_pd_t : public softmax_pd_t { if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; - if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md()))) - return arg_usage_t::input; + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::input + : arg_usage_t::unused; return primitive_desc_t::arg_usage(arg); } @@ -239,7 +251,7 @@ struct softmax_bwd_pd_t : public softmax_pd_t { memory_desc_t diff_src_md_; memory_desc_t diff_dst_md_; - softmax_bwd_pd_t(const softmax_desc_t *adesc, const primitive_attr_t *attr, + softmax_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const softmax_fwd_pd_t *hint_fwd_pd) : softmax_pd_t(adesc, attr, hint_fwd_pd) , diff_src_md_(desc_.diff_src_desc) @@ -260,6 +272,7 @@ struct softmax_bwd_pd_t : public softmax_pd_t { return status::success; } }; +// NOLINTEND(google-default-arguments) } // namespace impl } // namespace dnnl diff --git a/src/common/spdlog/common-inl.h b/src/common/spdlog/common-inl.h deleted file mode 100755 index 19817b2a702..00000000000 --- a/src/common/spdlog/common-inl.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include - -namespace spdlog { -namespace level { - -#if __cplusplus >= 201703L -constexpr -#endif - static string_view_t level_string_views[] SPDLOG_LEVEL_NAMES; - -static const char *short_level_names[] SPDLOG_SHORT_LEVEL_NAMES; - -SPDLOG_INLINE const string_view_t &to_string_view( - spdlog::level::level_enum l) SPDLOG_NOEXCEPT { - return level_string_views[l]; -} - -SPDLOG_INLINE const char *to_short_c_str( - spdlog::level::level_enum l) SPDLOG_NOEXCEPT { - return short_level_names[l]; -} - -SPDLOG_INLINE spdlog::level::level_enum from_str( - const std::string &name) SPDLOG_NOEXCEPT { - auto it = std::find( - std::begin(level_string_views), std::end(level_string_views), name); - if (it != std::end(level_string_views)) - return static_cast( - std::distance(std::begin(level_string_views), it)); - - // check also for "warn" and "err" before giving up.. - if (name == "warn") { return level::warn; } - if (name == "err") { return level::err; } - return level::off; -} -} // namespace level - -SPDLOG_INLINE spdlog_ex::spdlog_ex(std::string msg) : msg_(std::move(msg)) {} - -SPDLOG_INLINE spdlog_ex::spdlog_ex(const std::string &msg, int last_errno) { -#ifdef SPDLOG_USE_STD_FORMAT - msg_ = std::system_error( - std::error_code(last_errno, std::generic_category()), msg) - .what(); -#else - memory_buf_t outbuf; - fmt::format_system_error(outbuf, last_errno, msg.c_str()); - msg_ = fmt::to_string(outbuf); -#endif -} - -SPDLOG_INLINE const char *spdlog_ex::what() const SPDLOG_NOEXCEPT { - return msg_.c_str(); -} - -SPDLOG_INLINE void throw_spdlog_ex(const std::string &msg, int last_errno) { - SPDLOG_THROW(spdlog_ex(msg, last_errno)); -} - -SPDLOG_INLINE void throw_spdlog_ex(std::string msg) { - SPDLOG_THROW(spdlog_ex(std::move(msg))); -} - -} // namespace spdlog diff --git a/src/common/spdlog/common.h b/src/common/spdlog/common.h deleted file mode 100755 index 69f4289b540..00000000000 --- a/src/common/spdlog/common.h +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef SPDLOG_USE_STD_FORMAT -#include -#if __cpp_lib_format >= 202207L -#include -#else -#include -#endif -#endif - -#ifdef SPDLOG_COMPILED_LIB -#undef SPDLOG_HEADER_ONLY -#if defined(SPDLOG_SHARED_LIB) -#if defined(_WIN32) -#ifdef spdlog_EXPORTS -#define SPDLOG_API __declspec(dllexport) -#else // !spdlog_EXPORTS -#define SPDLOG_API __declspec(dllimport) -#endif -#else // !defined(_WIN32) -#define SPDLOG_API __attribute__((visibility("default"))) -#endif -#else // !defined(SPDLOG_SHARED_LIB) -#define SPDLOG_API -#endif -#define SPDLOG_INLINE -#else // !defined(SPDLOG_COMPILED_LIB) -#define SPDLOG_API -#define SPDLOG_HEADER_ONLY -#define SPDLOG_INLINE inline -#endif // #ifdef SPDLOG_COMPILED_LIB - -#include - -#if !defined(SPDLOG_USE_STD_FORMAT) \ - && FMT_VERSION \ - >= 80000 // backward compatibility with fmt versions older than 8 -#define SPDLOG_FMT_RUNTIME(format_string) fmt::runtime(format_string) -#define SPDLOG_FMT_STRING(format_string) FMT_STRING(format_string) -#if defined(SPDLOG_WCHAR_FILENAMES) || defined(SPDLOG_WCHAR_TO_UTF8_SUPPORT) -#include -#endif -#else -#define SPDLOG_FMT_RUNTIME(format_string) format_string -#define SPDLOG_FMT_STRING(format_string) format_string -#endif - -// visual studio up to 2013 does not support noexcept nor constexpr -#if defined(_MSC_VER) && (_MSC_VER < 1900) -#define SPDLOG_NOEXCEPT _NOEXCEPT -#define SPDLOG_CONSTEXPR -#else -#define SPDLOG_NOEXCEPT noexcept -#define SPDLOG_CONSTEXPR constexpr -#endif - -// If building with std::format, can just use constexpr, otherwise if building with fmt -// SPDLOG_CONSTEXPR_FUNC needs to be set the same as FMT_CONSTEXPR to avoid situations where -// a constexpr function in spdlog could end up calling a non-constexpr function in fmt -// depending on the compiler -// If fmt determines it can't use constexpr, we should inline the function instead -#ifdef SPDLOG_USE_STD_FORMAT -#define SPDLOG_CONSTEXPR_FUNC constexpr -#else // Being built with fmt -#if FMT_USE_CONSTEXPR -#define SPDLOG_CONSTEXPR_FUNC FMT_CONSTEXPR -#else -#define SPDLOG_CONSTEXPR_FUNC inline -#endif -#endif - -#if defined(__GNUC__) || defined(__clang__) -#define SPDLOG_DEPRECATED __attribute__((deprecated)) -#elif defined(_MSC_VER) -#define SPDLOG_DEPRECATED __declspec(deprecated) -#else -#define SPDLOG_DEPRECATED -#endif - -// disable thread local on msvc 2013 -#ifndef SPDLOG_NO_TLS -#if (defined(_MSC_VER) && (_MSC_VER < 1900)) || defined(__cplusplus_winrt) -#define SPDLOG_NO_TLS 1 -#endif -#endif - -#ifndef SPDLOG_FUNCTION -#define SPDLOG_FUNCTION static_cast(__FUNCTION__) -#endif - -#ifdef SPDLOG_NO_EXCEPTIONS -#define SPDLOG_TRY -#define SPDLOG_THROW(ex) \ - do { \ - printf("spdlog fatal error: %s\n", ex.what()); \ - std::abort(); \ - } while (0) -#define SPDLOG_CATCH_STD -#else -#define SPDLOG_TRY try -#define SPDLOG_THROW(ex) throw(ex) -#define SPDLOG_CATCH_STD \ - catch (const std::exception &) { \ - } -#endif - -namespace spdlog { - -class formatter; - -namespace sinks { -class sink; -} - -#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) -using filename_t = std::wstring; -// allow macro expansion to occur in SPDLOG_FILENAME_T -#define SPDLOG_FILENAME_T_INNER(s) L##s -#define SPDLOG_FILENAME_T(s) SPDLOG_FILENAME_T_INNER(s) -#else -using filename_t = std::string; -#define SPDLOG_FILENAME_T(s) s -#endif - -using log_clock = std::chrono::system_clock; -using sink_ptr = std::shared_ptr; -using sinks_init_list = std::initializer_list; -using err_handler = std::function; -#ifdef SPDLOG_USE_STD_FORMAT -namespace fmt_lib = std; - -using string_view_t = std::string_view; -using memory_buf_t = std::string; - -template -#if __cpp_lib_format >= 202207L -using format_string_t = std::format_string; -#else -using format_string_t = std::string_view; -#endif - -template -struct is_convertible_to_basic_format_string - : std::integral_constant>::value> {}; - -#if defined(SPDLOG_WCHAR_FILENAMES) || defined(SPDLOG_WCHAR_TO_UTF8_SUPPORT) -using wstring_view_t = std::wstring_view; -using wmemory_buf_t = std::wstring; - -template -#if __cpp_lib_format >= 202207L -using wformat_string_t = std::wformat_string; -#else -using wformat_string_t = std::wstring_view; -#endif -#endif -#define SPDLOG_BUF_TO_STRING(x) x -#else // use fmt lib instead of std::format -namespace fmt_lib = fmt; - -using string_view_t = fmt::basic_string_view; -using memory_buf_t = fmt::basic_memory_buffer; - -template -using format_string_t = fmt::format_string; - -template -using remove_cvref_t = - typename std::remove_cv::type>::type; - -template -#if FMT_VERSION >= 90101 -using fmt_runtime_string = fmt::runtime_format_string; -#else -using fmt_runtime_string = fmt::basic_runtime; -#endif - -// clang doesn't like SFINAE disabled constructor in std::is_convertible<> so have to repeat the -// condition from basic_format_string here, in addition, fmt::basic_runtime is only -// convertible to basic_format_string but not basic_string_view -template -struct is_convertible_to_basic_format_string - : std::integral_constant>::value - || std::is_same, - fmt_runtime_string>::value> {}; - -#if defined(SPDLOG_WCHAR_FILENAMES) || defined(SPDLOG_WCHAR_TO_UTF8_SUPPORT) -using wstring_view_t = fmt::basic_string_view; -using wmemory_buf_t = fmt::basic_memory_buffer; - -template -using wformat_string_t = fmt::wformat_string; -#endif -#define SPDLOG_BUF_TO_STRING(x) fmt::to_string(x) -#endif - -#ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT -#ifndef _WIN32 -#error SPDLOG_WCHAR_TO_UTF8_SUPPORT only supported on windows -#endif // _WIN32 -#endif // SPDLOG_WCHAR_TO_UTF8_SUPPORT - -template -struct is_convertible_to_any_format_string - : std::integral_constant::value - || is_convertible_to_basic_format_string::value> {}; - -#if defined(SPDLOG_NO_ATOMIC_LEVELS) -using level_t = details::null_atomic_int; -#else -using level_t = std::atomic; -#endif - -#define SPDLOG_LEVEL_TRACE 0 -#define SPDLOG_LEVEL_DEBUG 1 -#define SPDLOG_LEVEL_INFO 2 -#define SPDLOG_LEVEL_WARN 3 -#define SPDLOG_LEVEL_ERROR 4 -#define SPDLOG_LEVEL_CRITICAL 5 -#define SPDLOG_LEVEL_OFF 6 - -#if !defined(SPDLOG_ACTIVE_LEVEL) -#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO -#endif - -// Log level enum -namespace level { -enum level_enum : int { - trace = SPDLOG_LEVEL_TRACE, - debug = SPDLOG_LEVEL_DEBUG, - info = SPDLOG_LEVEL_INFO, - warn = SPDLOG_LEVEL_WARN, - err = SPDLOG_LEVEL_ERROR, - critical = SPDLOG_LEVEL_CRITICAL, - off = SPDLOG_LEVEL_OFF, - n_levels -}; - -#define SPDLOG_LEVEL_NAME_TRACE spdlog::string_view_t("trace", 5) -#define SPDLOG_LEVEL_NAME_DEBUG spdlog::string_view_t("debug", 5) -#define SPDLOG_LEVEL_NAME_INFO spdlog::string_view_t("info", 4) -#define SPDLOG_LEVEL_NAME_WARNING spdlog::string_view_t("warning", 7) -#define SPDLOG_LEVEL_NAME_ERROR spdlog::string_view_t("error", 5) -#define SPDLOG_LEVEL_NAME_CRITICAL spdlog::string_view_t("critical", 8) -#define SPDLOG_LEVEL_NAME_OFF spdlog::string_view_t("off", 3) - -#if !defined(SPDLOG_LEVEL_NAMES) -#define SPDLOG_LEVEL_NAMES \ - { \ - SPDLOG_LEVEL_NAME_TRACE, SPDLOG_LEVEL_NAME_DEBUG, \ - SPDLOG_LEVEL_NAME_INFO, SPDLOG_LEVEL_NAME_WARNING, \ - SPDLOG_LEVEL_NAME_ERROR, SPDLOG_LEVEL_NAME_CRITICAL, \ - SPDLOG_LEVEL_NAME_OFF \ - } -#endif - -#if !defined(SPDLOG_SHORT_LEVEL_NAMES) - -#define SPDLOG_SHORT_LEVEL_NAMES \ - { "T", "D", "I", "W", "E", "C", "O" } -#endif - -SPDLOG_API const string_view_t &to_string_view( - spdlog::level::level_enum l) SPDLOG_NOEXCEPT; -SPDLOG_API const char *to_short_c_str( - spdlog::level::level_enum l) SPDLOG_NOEXCEPT; -SPDLOG_API spdlog::level::level_enum from_str( - const std::string &name) SPDLOG_NOEXCEPT; - -} // namespace level - -// -// Color mode used by sinks with color support. -// -enum class color_mode { always, automatic, never }; - -// -// Pattern time - specific time getting to use for pattern_formatter. -// local time by default -// -enum class pattern_time_type { - local, // log localtime - utc // log utc -}; - -// -// Log exception -// -class SPDLOG_API spdlog_ex : public std::exception { -public: - explicit spdlog_ex(std::string msg); - spdlog_ex(const std::string &msg, int last_errno); - const char *what() const SPDLOG_NOEXCEPT override; - -private: - std::string msg_; -}; - -[[noreturn]] SPDLOG_API void throw_spdlog_ex( - const std::string &msg, int last_errno); -[[noreturn]] SPDLOG_API void throw_spdlog_ex(std::string msg); - -struct source_loc { - SPDLOG_CONSTEXPR source_loc() = default; - SPDLOG_CONSTEXPR source_loc( - const char *filename_in, int line_in, const char *funcname_in) - : filename {filename_in}, line {line_in}, funcname {funcname_in} {} - - SPDLOG_CONSTEXPR bool empty() const SPDLOG_NOEXCEPT { return line <= 0; } - const char *filename {nullptr}; - int line {0}; - const char *funcname {nullptr}; -}; - -struct file_event_handlers { - file_event_handlers() - : before_open(nullptr) - , after_open(nullptr) - , before_close(nullptr) - , after_close(nullptr) {} - - std::function before_open; - std::function - after_open; - std::function - before_close; - std::function after_close; -}; - -namespace details { - -// to_string_view - -SPDLOG_CONSTEXPR_FUNC spdlog::string_view_t to_string_view( - const memory_buf_t &buf) SPDLOG_NOEXCEPT { - return spdlog::string_view_t {buf.data(), buf.size()}; -} - -SPDLOG_CONSTEXPR_FUNC spdlog::string_view_t to_string_view( - spdlog::string_view_t str) SPDLOG_NOEXCEPT { - return str; -} - -#if defined(SPDLOG_WCHAR_FILENAMES) || defined(SPDLOG_WCHAR_TO_UTF8_SUPPORT) -SPDLOG_CONSTEXPR_FUNC spdlog::wstring_view_t to_string_view( - const wmemory_buf_t &buf) SPDLOG_NOEXCEPT { - return spdlog::wstring_view_t {buf.data(), buf.size()}; -} - -SPDLOG_CONSTEXPR_FUNC spdlog::wstring_view_t to_string_view( - spdlog::wstring_view_t str) SPDLOG_NOEXCEPT { - return str; -} -#endif - -#ifndef SPDLOG_USE_STD_FORMAT -template -inline fmt::basic_string_view to_string_view( - fmt::basic_format_string fmt) { - return fmt; -} -#elif __cpp_lib_format >= 202207L -template -SPDLOG_CONSTEXPR_FUNC std::basic_string_view to_string_view( - std::basic_format_string fmt) SPDLOG_NOEXCEPT { - return fmt.get(); -} -#endif - -// make_unique support for pre c++14 -#if __cplusplus >= 201402L // C++14 and beyond -using std::enable_if_t; -using std::make_unique; -#else -template -using enable_if_t = typename std::enable_if::type; - -template -std::unique_ptr make_unique(Args &&...args) { - static_assert(!std::is_array::value, "arrays not supported"); - return std::unique_ptr(new T(std::forward(args)...)); -} -#endif - -// to avoid useless casts (see https://github.com/nlohmann/json/issues/2893#issuecomment-889152324) -template ::value, int> = 0> -constexpr T conditional_static_cast(U value) { - return static_cast(value); -} - -template ::value, int> = 0> -constexpr T conditional_static_cast(U value) { - return value; -} - -} // namespace details -} // namespace spdlog - -#ifdef SPDLOG_HEADER_ONLY -#include "common-inl.h" -#endif diff --git a/src/common/spdlog/details/backtracer-inl.h b/src/common/spdlog/details/backtracer-inl.h deleted file mode 100755 index 14448d74c41..00000000000 --- a/src/common/spdlog/details/backtracer-inl.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif -namespace spdlog { -namespace details { -SPDLOG_INLINE backtracer::backtracer(const backtracer &other) { - std::lock_guard lock(other.mutex_); - enabled_ = other.enabled(); - messages_ = other.messages_; -} - -SPDLOG_INLINE backtracer::backtracer(backtracer &&other) SPDLOG_NOEXCEPT { - std::lock_guard lock(other.mutex_); - enabled_ = other.enabled(); - messages_ = std::move(other.messages_); -} - -SPDLOG_INLINE backtracer &backtracer::operator=(backtracer other) { - std::lock_guard lock(mutex_); - enabled_ = other.enabled(); - messages_ = std::move(other.messages_); - return *this; -} - -SPDLOG_INLINE void backtracer::enable(size_t size) { - std::lock_guard lock {mutex_}; - enabled_.store(true, std::memory_order_relaxed); - messages_ = circular_q {size}; -} - -SPDLOG_INLINE void backtracer::disable() { - std::lock_guard lock {mutex_}; - enabled_.store(false, std::memory_order_relaxed); -} - -SPDLOG_INLINE bool backtracer::enabled() const { - return enabled_.load(std::memory_order_relaxed); -} - -SPDLOG_INLINE void backtracer::push_back(const log_msg &msg) { - std::lock_guard lock {mutex_}; - messages_.push_back(log_msg_buffer {msg}); -} - -SPDLOG_INLINE bool backtracer::empty() const { - std::lock_guard lock {mutex_}; - return messages_.empty(); -} - -// pop all items in the q and apply the given fun on each of them. -SPDLOG_INLINE void backtracer::foreach_pop( - std::function fun) { - std::lock_guard lock {mutex_}; - while (!messages_.empty()) { - auto &front_msg = messages_.front(); - fun(front_msg); - messages_.pop_front(); - } -} -} // namespace details -} // namespace spdlog diff --git a/src/common/spdlog/details/file_helper-inl.h b/src/common/spdlog/details/file_helper-inl.h deleted file mode 100755 index 6d24b0ce257..00000000000 --- a/src/common/spdlog/details/file_helper-inl.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace spdlog { -namespace details { - -SPDLOG_INLINE file_helper::file_helper( - const file_event_handlers &event_handlers) - : event_handlers_(event_handlers) {} - -SPDLOG_INLINE file_helper::~file_helper() { - close(); -} - -SPDLOG_INLINE void file_helper::open(const filename_t &fname, bool truncate) { - close(); - filename_ = fname; - - auto *mode = SPDLOG_FILENAME_T("ab"); - auto *trunc_mode = SPDLOG_FILENAME_T("wb"); - - if (event_handlers_.before_open) { event_handlers_.before_open(filename_); } - for (int tries = 0; tries < open_tries_; ++tries) { - // create containing folder if not exists already. - os::create_dir(os::dir_name(fname)); - if (truncate) { - // Truncate by opening-and-closing a tmp file in "wb" mode, always - // opening the actual log-we-write-to in "ab" mode, since that - // interacts more politely with eternal processes that might - // rotate/truncate the file underneath us. - std::FILE *tmp; - if (os::fopen_s(&tmp, fname, trunc_mode)) { continue; } - std::fclose(tmp); - } - if (!os::fopen_s(&fd_, fname, mode)) { - if (event_handlers_.after_open) { - event_handlers_.after_open(filename_, fd_); - } - return; - } - - details::os::sleep_for_millis(open_interval_); - } - - throw_spdlog_ex("Failed opening file " + os::filename_to_str(filename_) - + " for writing", - errno); -} - -SPDLOG_INLINE void file_helper::reopen(bool truncate) { - if (filename_.empty()) { - throw_spdlog_ex("Failed re opening file - was not opened before"); - } - this->open(filename_, truncate); -} - -SPDLOG_INLINE void file_helper::flush() { - if (std::fflush(fd_) != 0) { - throw_spdlog_ex( - "Failed flush to file " + os::filename_to_str(filename_), - errno); - } -} - -SPDLOG_INLINE void file_helper::sync() { - if (!os::fsync(fd_)) { - throw_spdlog_ex( - "Failed to fsync file " + os::filename_to_str(filename_), - errno); - } -} - -SPDLOG_INLINE void file_helper::close() { - if (fd_ != nullptr) { - if (event_handlers_.before_close) { - event_handlers_.before_close(filename_, fd_); - } - - std::fclose(fd_); - fd_ = nullptr; - - if (event_handlers_.after_close) { - event_handlers_.after_close(filename_); - } - } -} - -SPDLOG_INLINE void file_helper::write(const memory_buf_t &buf) { - if (fd_ == nullptr) return; - size_t msg_size = buf.size(); - auto data = buf.data(); - if (std::fwrite(data, 1, msg_size, fd_) != msg_size) { - throw_spdlog_ex( - "Failed writing to file " + os::filename_to_str(filename_), - errno); - } -} - -SPDLOG_INLINE size_t file_helper::size() const { - if (fd_ == nullptr) { - throw_spdlog_ex("Cannot use size() on closed file " - + os::filename_to_str(filename_)); - } - return os::filesize(fd_); -} - -SPDLOG_INLINE const filename_t &file_helper::filename() const { - return filename_; -} - -// -// return file path and its extension: -// -// "mylog.txt" => ("mylog", ".txt") -// "mylog" => ("mylog", "") -// "mylog." => ("mylog.", "") -// "/dir1/dir2/mylog.txt" => ("/dir1/dir2/mylog", ".txt") -// -// the starting dot in filenames is ignored (hidden files): -// -// ".mylog" => (".mylog". "") -// "my_folder/.mylog" => ("my_folder/.mylog", "") -// "my_folder/.mylog.txt" => ("my_folder/.mylog", ".txt") -SPDLOG_INLINE std::tuple -file_helper::split_by_extension(const filename_t &fname) { - auto ext_index = fname.rfind('.'); - - // no valid extension found - return whole path and empty string as - // extension - if (ext_index == filename_t::npos || ext_index == 0 - || ext_index == fname.size() - 1) { - return std::make_tuple(fname, filename_t()); - } - - // treat cases like "/etc/rc.d/somelogfile or "/abc/.hiddenfile" - auto folder_index = fname.find_last_of(details::os::folder_seps_filename); - if (folder_index != filename_t::npos && folder_index >= ext_index - 1) { - return std::make_tuple(fname, filename_t()); - } - - // finally - return a valid base and extension tuple - return std::make_tuple(fname.substr(0, ext_index), fname.substr(ext_index)); -} - -} // namespace details -} // namespace spdlog diff --git a/src/common/spdlog/details/log_msg-inl.h b/src/common/spdlog/details/log_msg-inl.h deleted file mode 100755 index 027ce5102c0..00000000000 --- a/src/common/spdlog/details/log_msg-inl.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include - -namespace spdlog { -namespace details { - -SPDLOG_INLINE log_msg::log_msg(spdlog::log_clock::time_point log_time, - spdlog::source_loc loc, string_view_t a_logger_name, - spdlog::level::level_enum lvl, spdlog::string_view_t msg) - : logger_name(a_logger_name) - , level(lvl) - , time(log_time) -#ifndef SPDLOG_NO_THREAD_ID - , thread_id(os::thread_id()) -#endif - , source(loc) - , payload(msg) { -} - -SPDLOG_INLINE log_msg::log_msg(spdlog::source_loc loc, - string_view_t a_logger_name, spdlog::level::level_enum lvl, - spdlog::string_view_t msg) - : log_msg(os::now(), loc, a_logger_name, lvl, msg) {} - -SPDLOG_INLINE log_msg::log_msg(string_view_t a_logger_name, - spdlog::level::level_enum lvl, spdlog::string_view_t msg) - : log_msg(os::now(), source_loc {}, a_logger_name, lvl, msg) {} - -} // namespace details -} // namespace spdlog diff --git a/src/common/spdlog/details/log_msg.h b/src/common/spdlog/details/log_msg.h deleted file mode 100755 index c11aaf257b1..00000000000 --- a/src/common/spdlog/details/log_msg.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include - -namespace spdlog { -namespace details { -struct SPDLOG_API log_msg { - log_msg() = default; - log_msg(log_clock::time_point log_time, source_loc loc, - string_view_t logger_name, level::level_enum lvl, - string_view_t msg); - log_msg(source_loc loc, string_view_t logger_name, level::level_enum lvl, - string_view_t msg); - log_msg(string_view_t logger_name, level::level_enum lvl, - string_view_t msg); - log_msg(const log_msg &other) = default; - log_msg &operator=(const log_msg &other) = default; - - string_view_t logger_name; - level::level_enum level {level::off}; - log_clock::time_point time; - size_t thread_id {0}; - - // wrapping the formatted text with color (updated by pattern_formatter). - mutable size_t color_range_start {0}; - mutable size_t color_range_end {0}; - - source_loc source; - string_view_t payload; -}; -} // namespace details -} // namespace spdlog - -#ifdef SPDLOG_HEADER_ONLY -#include "log_msg-inl.h" -#endif diff --git a/src/common/spdlog/details/log_msg_buffer-inl.h b/src/common/spdlog/details/log_msg_buffer-inl.h deleted file mode 100755 index f3ef28f3708..00000000000 --- a/src/common/spdlog/details/log_msg_buffer-inl.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -namespace spdlog { -namespace details { - -SPDLOG_INLINE log_msg_buffer::log_msg_buffer(const log_msg &orig_msg) - : log_msg {orig_msg} { - buffer.append(logger_name.begin(), logger_name.end()); - buffer.append(payload.begin(), payload.end()); - update_string_views(); -} - -SPDLOG_INLINE log_msg_buffer::log_msg_buffer(const log_msg_buffer &other) - : log_msg {other} { - buffer.append(logger_name.begin(), logger_name.end()); - buffer.append(payload.begin(), payload.end()); - update_string_views(); -} - -SPDLOG_INLINE log_msg_buffer::log_msg_buffer( - log_msg_buffer &&other) SPDLOG_NOEXCEPT - : log_msg {other}, - buffer {std::move(other.buffer)} { - update_string_views(); -} - -SPDLOG_INLINE log_msg_buffer &log_msg_buffer::operator=( - const log_msg_buffer &other) { - log_msg::operator=(other); - buffer.clear(); - buffer.append( - other.buffer.data(), other.buffer.data() + other.buffer.size()); - update_string_views(); - return *this; -} - -SPDLOG_INLINE log_msg_buffer &log_msg_buffer::operator=( - log_msg_buffer &&other) SPDLOG_NOEXCEPT { - log_msg::operator=(other); - buffer = std::move(other.buffer); - update_string_views(); - return *this; -} - -SPDLOG_INLINE void log_msg_buffer::update_string_views() { - logger_name = string_view_t {buffer.data(), logger_name.size()}; - payload = string_view_t { - buffer.data() + logger_name.size(), payload.size()}; -} - -} // namespace details -} // namespace spdlog diff --git a/src/common/spdlog/details/null_mutex.h b/src/common/spdlog/details/null_mutex.h deleted file mode 100755 index 1aa188fe027..00000000000 --- a/src/common/spdlog/details/null_mutex.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include -// null, no cost dummy "mutex" and dummy "atomic" int - -namespace spdlog { -namespace details { -struct null_mutex { - void lock() const {} - void unlock() const {} -}; - -struct null_atomic_int { - int value; - null_atomic_int() = default; - - explicit null_atomic_int(int new_value) : value(new_value) {} - - int load(std::memory_order = std::memory_order_relaxed) const { - return value; - } - - void store(int new_value, std::memory_order = std::memory_order_relaxed) { - value = new_value; - } - - int exchange(int new_value, std::memory_order = std::memory_order_relaxed) { - std::swap(new_value, value); - return new_value; // return value before the call - } -}; - -} // namespace details -} // namespace spdlog diff --git a/src/common/spdlog/details/os-inl.h b/src/common/spdlog/details/os-inl.h deleted file mode 100755 index 3cf5ad7d6d7..00000000000 --- a/src/common/spdlog/details/os-inl.h +++ /dev/null @@ -1,589 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef _WIN32 -#include // for FlushFileBuffers -#include // for _get_osfhandle, _isatty, _fileno -#include // for _get_pid -#include - -#ifdef __MINGW32__ -#include -#endif - -#if defined(SPDLOG_WCHAR_TO_UTF8_SUPPORT) || defined(SPDLOG_WCHAR_FILENAMES) -#include -#include -#endif - -#include // for _mkdir/_wmkdir - -#else // unix - -#include -#include - -#ifdef __linux__ -#include //Use gettid() syscall under linux to get thread id - -#elif defined(_AIX) -#include // for pthread_getthrds_np - -#elif defined(__DragonFly__) || defined(__FreeBSD__) -#include // for pthread_getthreadid_np - -#elif defined(__NetBSD__) -#include // for _lwp_self - -#elif defined(__sun) -#include // for thr_self -#endif - -#endif // unix - -#if defined __APPLE__ -#include -#endif - -#ifndef __has_feature // Clang - feature checking macros. -#define __has_feature(x) 0 // Compatibility with non-clang compilers. -#endif - -namespace spdlog { -namespace details { -namespace os { - -SPDLOG_INLINE spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT { -#if defined __linux__ && defined SPDLOG_CLOCK_COARSE - timespec ts; - ::clock_gettime(CLOCK_REALTIME_COARSE, &ts); - return std::chrono::time_point( - std::chrono::duration_cast( - std::chrono::seconds(ts.tv_sec) - + std::chrono::nanoseconds(ts.tv_nsec))); - -#else - return log_clock::now(); -#endif -} -SPDLOG_INLINE std::tm localtime(const std::time_t &time_tt) SPDLOG_NOEXCEPT { -#ifdef _WIN32 - std::tm tm; - ::localtime_s(&tm, &time_tt); -#else - std::tm tm; - ::localtime_r(&time_tt, &tm); -#endif - return tm; -} - -SPDLOG_INLINE std::tm localtime() SPDLOG_NOEXCEPT { - std::time_t now_t = ::time(nullptr); - return localtime(now_t); -} - -SPDLOG_INLINE std::tm gmtime(const std::time_t &time_tt) SPDLOG_NOEXCEPT { -#ifdef _WIN32 - std::tm tm; - ::gmtime_s(&tm, &time_tt); -#else - std::tm tm; - ::gmtime_r(&time_tt, &tm); -#endif - return tm; -} - -SPDLOG_INLINE std::tm gmtime() SPDLOG_NOEXCEPT { - std::time_t now_t = ::time(nullptr); - return gmtime(now_t); -} - -// fopen_s on non windows for writing -SPDLOG_INLINE bool fopen_s( - FILE **fp, const filename_t &filename, const filename_t &mode) { -#ifdef _WIN32 -#ifdef SPDLOG_WCHAR_FILENAMES - *fp = ::_wfsopen((filename.c_str()), mode.c_str(), _SH_DENYNO); -#else - *fp = ::_fsopen((filename.c_str()), mode.c_str(), _SH_DENYNO); -#endif -#if defined(SPDLOG_PREVENT_CHILD_FD) - if (*fp != nullptr) { - auto file_handle - = reinterpret_cast(_get_osfhandle(::_fileno(*fp))); - if (!::SetHandleInformation(file_handle, HANDLE_FLAG_INHERIT, 0)) { - ::fclose(*fp); - *fp = nullptr; - } - } -#endif -#else // unix -#if defined(SPDLOG_PREVENT_CHILD_FD) - const int mode_flag = mode == SPDLOG_FILENAME_T("ab") ? O_APPEND : O_TRUNC; - const int fd = ::open((filename.c_str()), - O_CREAT | O_WRONLY | O_CLOEXEC | mode_flag, mode_t(0644)); - if (fd == -1) { return true; } - *fp = ::fdopen(fd, mode.c_str()); - if (*fp == nullptr) { ::close(fd); } -#else - *fp = ::fopen((filename.c_str()), mode.c_str()); -#endif -#endif - - return *fp == nullptr; -} - -SPDLOG_INLINE int remove(const filename_t &filename) SPDLOG_NOEXCEPT { -#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) - return ::_wremove(filename.c_str()); -#else - return std::remove(filename.c_str()); -#endif -} - -SPDLOG_INLINE int remove_if_exists(const filename_t &filename) SPDLOG_NOEXCEPT { - return path_exists(filename) ? remove(filename) : 0; -} - -SPDLOG_INLINE int rename(const filename_t &filename1, - const filename_t &filename2) SPDLOG_NOEXCEPT { -#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) - return ::_wrename(filename1.c_str(), filename2.c_str()); -#else - return std::rename(filename1.c_str(), filename2.c_str()); -#endif -} - -// Return true if path exists (file or directory) -SPDLOG_INLINE bool path_exists(const filename_t &filename) SPDLOG_NOEXCEPT { -#ifdef _WIN32 - struct _stat buffer; -#ifdef SPDLOG_WCHAR_FILENAMES - return (::_wstat(filename.c_str(), &buffer) == 0); -#else - return (::_stat(filename.c_str(), &buffer) == 0); -#endif -#else // common linux/unix all have the stat system call - struct stat buffer; - return (::stat(filename.c_str(), &buffer) == 0); -#endif -} - -#ifdef _MSC_VER -// avoid warning about unreachable statement at the end of filesize() -#pragma warning(push) -#pragma warning(disable : 4702) -#endif - -// Return file size according to open FILE* object -SPDLOG_INLINE size_t filesize(FILE *f) { - if (f == nullptr) { - throw_spdlog_ex("Failed getting file size. fd is null"); - } -#if defined(_WIN32) && !defined(__CYGWIN__) - int fd = ::_fileno(f); -#if defined(_WIN64) // 64 bits - __int64 ret = ::_filelengthi64(fd); - if (ret >= 0) { return static_cast(ret); } - -#else // windows 32 bits - long ret = ::_filelength(fd); - if (ret >= 0) { return static_cast(ret); } -#endif - -#else // unix -// OpenBSD and AIX doesn't compile with :: before the fileno(..) -#if defined(__OpenBSD__) || defined(_AIX) - int fd = fileno(f); -#else - int fd = ::fileno(f); -#endif -// 64 bits(but not in osx, linux/musl or cygwin, where fstat64 is deprecated) -#if ((defined(__linux__) && defined(__GLIBC__)) || defined(__sun) \ - || defined(_AIX)) \ - && (defined(__LP64__) || defined(_LP64)) - struct stat64 st; - if (::fstat64(fd, &st) == 0) { return static_cast(st.st_size); } -#else // other unix or linux 32 bits or cygwin - struct stat st; - if (::fstat(fd, &st) == 0) { return static_cast(st.st_size); } -#endif -#endif - throw_spdlog_ex("Failed getting file size from fd", errno); - return 0; // will not be reached. -} - -#ifdef _MSC_VER -#pragma warning(pop) -#endif - -// Return utc offset in minutes or throw spdlog_ex on failure -SPDLOG_INLINE int utc_minutes_offset(const std::tm &tm) { -#ifdef _WIN32 -#if _WIN32_WINNT < _WIN32_WINNT_WS08 - TIME_ZONE_INFORMATION tzinfo; - auto rv = ::GetTimeZoneInformation(&tzinfo); -#else - DYNAMIC_TIME_ZONE_INFORMATION tzinfo; - auto rv = ::GetDynamicTimeZoneInformation(&tzinfo); -#endif - if (rv == TIME_ZONE_ID_INVALID) - throw_spdlog_ex("Failed getting timezone info. ", errno); - - int offset = -tzinfo.Bias; - if (tm.tm_isdst) { - offset -= tzinfo.DaylightBias; - } else { - offset -= tzinfo.StandardBias; - } - return offset; -#else - -#if defined(sun) || defined(__sun) || defined(_AIX) \ - || (defined(__NEWLIB__) && !defined(__TM_GMTOFF)) \ - || (!defined(_BSD_SOURCE) && !defined(_GNU_SOURCE)) - // 'tm_gmtoff' field is BSD extension and it's missing on SunOS/Solaris - struct helper { - static long int calculate_gmt_offset( - const std::tm &localtm = details::os::localtime(), - const std::tm &gmtm = details::os::gmtime()) { - int local_year = localtm.tm_year + (1900 - 1); - int gmt_year = gmtm.tm_year + (1900 - 1); - - long int days = ( - // difference in day of year - localtm.tm_yday - - gmtm.tm_yday - - // + intervening leap days - + ((local_year >> 2) - (gmt_year >> 2)) - - (local_year / 100 - gmt_year / 100) - + ((local_year / 100 >> 2) - (gmt_year / 100 >> 2)) - - // + difference in years * 365 */ - + static_cast(local_year - gmt_year) * 365); - - long int hours = (24 * days) + (localtm.tm_hour - gmtm.tm_hour); - long int mins = (60 * hours) + (localtm.tm_min - gmtm.tm_min); - long int secs = (60 * mins) + (localtm.tm_sec - gmtm.tm_sec); - - return secs; - } - }; - - auto offset_seconds = helper::calculate_gmt_offset(tm); -#else - auto offset_seconds = tm.tm_gmtoff; -#endif - - return static_cast(offset_seconds / 60); -#endif -} - -// Return current thread id as size_t -// It exists because the std::this_thread::get_id() is much slower(especially -// under VS 2013) -SPDLOG_INLINE size_t _thread_id() SPDLOG_NOEXCEPT { -#ifdef _WIN32 - return static_cast(::GetCurrentThreadId()); -#elif defined(__linux__) -#if defined(__ANDROID__) && defined(__ANDROID_API__) && (__ANDROID_API__ < 21) -#define SYS_gettid __NR_gettid -#endif - return static_cast(::syscall(SYS_gettid)); -#elif defined(_AIX) - struct __pthrdsinfo buf; - int reg_size = 0; - pthread_t pt = pthread_self(); - int retval = pthread_getthrds_np( - &pt, PTHRDSINFO_QUERY_TID, &buf, sizeof(buf), NULL, ®_size); - int tid = (!retval) ? buf.__pi_tid : 0; - return static_cast(tid); -#elif defined(__DragonFly__) || defined(__FreeBSD__) - return static_cast(::pthread_getthreadid_np()); -#elif defined(__NetBSD__) - return static_cast(::_lwp_self()); -#elif defined(__OpenBSD__) - return static_cast(::getthrid()); -#elif defined(__sun) - return static_cast(::thr_self()); -#elif __APPLE__ - uint64_t tid; -// There is no pthread_threadid_np prior to Mac OS X 10.6, and it is not supported on any PPC, -// including 10.6.8 Rosetta. __POWERPC__ is Apple-specific define encompassing ppc and ppc64. -#ifdef MAC_OS_X_VERSION_MAX_ALLOWED - { -#if (MAC_OS_X_VERSION_MAX_ALLOWED < 1060) || defined(__POWERPC__) - tid = pthread_mach_thread_np(pthread_self()); -#elif MAC_OS_X_VERSION_MIN_REQUIRED < 1060 - if (&pthread_threadid_np) { - pthread_threadid_np(nullptr, &tid); - } else { - tid = pthread_mach_thread_np(pthread_self()); - } -#else - pthread_threadid_np(nullptr, &tid); -#endif - } -#else - pthread_threadid_np(nullptr, &tid); -#endif - return static_cast(tid); -#else // Default to standard C++11 (other Unix) - return static_cast( - std::hash()(std::this_thread::get_id())); -#endif -} - -// Return current thread id as size_t (from thread local storage) -SPDLOG_INLINE size_t thread_id() SPDLOG_NOEXCEPT { -#if defined(SPDLOG_NO_TLS) - return _thread_id(); -#else // cache thread id in tls - static thread_local const size_t tid = _thread_id(); - return tid; -#endif -} - -// This is avoid msvc issue in sleep_for that happens if the clock changes. -// See https://github.com/gabime/spdlog/issues/609 -SPDLOG_INLINE void sleep_for_millis(unsigned int milliseconds) SPDLOG_NOEXCEPT { -#if defined(_WIN32) - ::Sleep(milliseconds); -#else - std::this_thread::sleep_for(std::chrono::milliseconds(milliseconds)); -#endif -} - -// wchar support for windows file names (SPDLOG_WCHAR_FILENAMES must be defined) -#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) -SPDLOG_INLINE std::string filename_to_str(const filename_t &filename) { - memory_buf_t buf; - wstr_to_utf8buf(filename, buf); - return SPDLOG_BUF_TO_STRING(buf); -} -#else -SPDLOG_INLINE std::string filename_to_str(const filename_t &filename) { - return filename; -} -#endif - -SPDLOG_INLINE int pid() SPDLOG_NOEXCEPT { -#ifdef _WIN32 - return conditional_static_cast(::GetCurrentProcessId()); -#else - return conditional_static_cast(::getpid()); -#endif -} - -// Determine if the terminal supports colors -// Based on: https://github.com/agauniyal/rang/ -SPDLOG_INLINE bool is_color_terminal() SPDLOG_NOEXCEPT { -#ifdef _WIN32 - return true; -#else - - static const bool result = []() { - const char *env_colorterm_p = std::getenv("COLORTERM"); - if (env_colorterm_p != nullptr) { return true; } - - static constexpr std::array terms - = {{"ansi", "color", "console", "cygwin", "gnome", "konsole", - "kterm", "linux", "msys", "putty", "rxvt", "screen", - "vt100", "xterm", "alacritty", "vt102"}}; - - const char *env_term_p = std::getenv("TERM"); - if (env_term_p == nullptr) { return false; } - - return std::any_of(terms.begin(), terms.end(), [&](const char *term) { - return std::strstr(env_term_p, term) != nullptr; - }); - }(); - - return result; -#endif -} - -// Determine if the terminal attached -// Source: https://github.com/agauniyal/rang/ -SPDLOG_INLINE bool in_terminal(FILE *file) SPDLOG_NOEXCEPT { -#ifdef _WIN32 - return ::_isatty(_fileno(file)) != 0; -#else - return ::isatty(fileno(file)) != 0; -#endif -} - -#if (defined(SPDLOG_WCHAR_TO_UTF8_SUPPORT) || defined(SPDLOG_WCHAR_FILENAMES)) \ - && defined(_WIN32) -SPDLOG_INLINE void wstr_to_utf8buf(wstring_view_t wstr, memory_buf_t &target) { - if (wstr.size() - > static_cast((std::numeric_limits::max)()) / 4 - 1) { - throw_spdlog_ex("UTF-16 string is too big to be converted to UTF-8"); - } - - int wstr_size = static_cast(wstr.size()); - if (wstr_size == 0) { - target.resize(0); - return; - } - - int result_size = static_cast(target.capacity()); - if ((wstr_size + 1) * 4 > result_size) { - result_size = ::WideCharToMultiByte( - CP_UTF8, 0, wstr.data(), wstr_size, NULL, 0, NULL, NULL); - } - - if (result_size > 0) { - target.resize(result_size); - result_size = ::WideCharToMultiByte(CP_UTF8, 0, wstr.data(), wstr_size, - target.data(), result_size, NULL, NULL); - - if (result_size > 0) { - target.resize(result_size); - return; - } - } - - throw_spdlog_ex(fmt_lib::format( - "WideCharToMultiByte failed. Last error: {}", ::GetLastError())); -} - -SPDLOG_INLINE void utf8_to_wstrbuf(string_view_t str, wmemory_buf_t &target) { - if (str.size() - > static_cast((std::numeric_limits::max)()) - 1) { - throw_spdlog_ex("UTF-8 string is too big to be converted to UTF-16"); - } - - int str_size = static_cast(str.size()); - if (str_size == 0) { - target.resize(0); - return; - } - - // find the size to allocate for the result buffer - int result_size = ::MultiByteToWideChar( - CP_UTF8, MB_ERR_INVALID_CHARS, str.data(), str_size, NULL, 0); - - if (result_size > 0) { - target.resize(result_size); - result_size = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, - str.data(), str_size, target.data(), result_size); - if (result_size > 0) { - assert(result_size == target.size()); - return; - } - } - - throw_spdlog_ex(fmt_lib::format( - "MultiByteToWideChar failed. Last error: {}", ::GetLastError())); -} -#endif - -// return true on success -static SPDLOG_INLINE bool mkdir_(const filename_t &path) { -#ifdef _WIN32 -#ifdef SPDLOG_WCHAR_FILENAMES - return ::_wmkdir(path.c_str()) == 0; -#else - return ::_mkdir(path.c_str()) == 0; -#endif -#else - return ::mkdir(path.c_str(), mode_t(0755)) == 0; -#endif -} - -// create the given directory - and all directories leading to it return true on success or if the directory already exists -SPDLOG_INLINE bool create_dir(const filename_t &path) { - if (path_exists(path)) { return true; } - - if (path.empty()) { return false; } - - size_t search_offset = 0; - do { - auto token_pos - = path.find_first_of(folder_seps_filename, search_offset); - // treat the entire path as a folder if no folder separator not found - if (token_pos == filename_t::npos) { token_pos = path.size(); } - - auto subdir = path.substr(0, token_pos); -#ifdef _WIN32 - // if subdir is just a drive letter, add a slash e.g. "c:"=>"c:\", - // otherwise path_exists(subdir) returns false (issue #3079) - const bool is_drive = subdir.length() == 2 && subdir[1] == ':'; - if (is_drive) { - subdir += '\\'; - token_pos++; - } -#endif - - if (!subdir.empty() && !path_exists(subdir) && !mkdir_(subdir)) { - return false; // return error if failed creating dir - } - search_offset = token_pos + 1; - } while (search_offset < path.size()); - - return true; -} - -// Return directory name from given path or empty string -// "abc/file" => "abc" -// "abc/" => "abc" -// "abc" => "" -// "abc///" => "abc//" -SPDLOG_INLINE filename_t dir_name(const filename_t &path) { - auto pos = path.find_last_of(folder_seps_filename); - return pos != filename_t::npos ? path.substr(0, pos) : filename_t {}; -} - -std::string SPDLOG_INLINE getenv(const char *field) { -#if defined(_MSC_VER) -#if defined(__cplusplus_winrt) - return std::string {}; // not supported under uwp -#else - size_t len = 0; - char buf[128]; - bool ok = ::getenv_s(&len, buf, sizeof(buf), field) == 0; - return ok ? buf : std::string {}; -#endif -#else // revert to getenv - char *buf = ::getenv(field); - return buf ? buf : std::string {}; -#endif -} - -// Do fsync by FILE handlerpointer -// Return true on success -SPDLOG_INLINE bool fsync(FILE *fp) { -#ifdef _WIN32 - return FlushFileBuffers( - reinterpret_cast(_get_osfhandle(_fileno(fp)))) - != 0; -#else - return ::fsync(fileno(fp)) == 0; -#endif -} - -} // namespace os -} // namespace details -} // namespace spdlog diff --git a/src/common/spdlog/details/registry-inl.h b/src/common/spdlog/details/registry-inl.h deleted file mode 100755 index b34c9cf07cf..00000000000 --- a/src/common/spdlog/details/registry-inl.h +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include -#include -#include - -#ifndef SPDLOG_DISABLE_DEFAULT_LOGGER -// support for the default stdout color logger -#ifdef _WIN32 -#include -#else -#include -#endif -#endif // SPDLOG_DISABLE_DEFAULT_LOGGER - -#include -#include -#include -#include -#include - -namespace spdlog { -namespace details { - -SPDLOG_INLINE registry::registry() : formatter_(new pattern_formatter()) { -#ifndef SPDLOG_DISABLE_DEFAULT_LOGGER -// create default logger (ansicolor_stdout_sink_mt or wincolor_stdout_sink_mt in windows). -#ifdef _WIN32 - auto color_sink = std::make_shared(); -#else - auto color_sink = std::make_shared(); -#endif - - const char *default_logger_name = ""; - default_logger_ = std::make_shared( - default_logger_name, std::move(color_sink)); - loggers_[default_logger_name] = default_logger_; - -#endif // SPDLOG_DISABLE_DEFAULT_LOGGER -} - -SPDLOG_INLINE registry::~registry() = default; - -SPDLOG_INLINE void registry::register_logger( - std::shared_ptr new_logger) { - std::lock_guard lock(logger_map_mutex_); - register_logger_(std::move(new_logger)); -} - -SPDLOG_INLINE void registry::initialize_logger( - std::shared_ptr new_logger) { - std::lock_guard lock(logger_map_mutex_); - new_logger->set_formatter(formatter_->clone()); - - if (err_handler_) { new_logger->set_error_handler(err_handler_); } - - // set new level according to previously configured level or default level - auto it = log_levels_.find(new_logger->name()); - auto new_level = it != log_levels_.end() ? it->second : global_log_level_; - new_logger->set_level(new_level); - - new_logger->flush_on(flush_level_); - - if (backtrace_n_messages_ > 0) { - new_logger->enable_backtrace(backtrace_n_messages_); - } - - if (automatic_registration_) { register_logger_(std::move(new_logger)); } -} - -SPDLOG_INLINE std::shared_ptr registry::get( - const std::string &logger_name) { - std::lock_guard lock(logger_map_mutex_); - auto found = loggers_.find(logger_name); - return found == loggers_.end() ? nullptr : found->second; -} - -SPDLOG_INLINE std::shared_ptr registry::default_logger() { - std::lock_guard lock(logger_map_mutex_); - return default_logger_; -} - -// Return raw ptr to the default logger. -// To be used directly by the spdlog default api (e.g. spdlog::info) -// This make the default API faster, but cannot be used concurrently with set_default_logger(). -// e.g do not call set_default_logger() from one thread while calling spdlog::info() from another. -SPDLOG_INLINE logger *registry::get_default_raw() { - return default_logger_.get(); -} - -// set default logger. -// default logger is stored in default_logger_ (for faster retrieval) and in the loggers_ map. -SPDLOG_INLINE void registry::set_default_logger( - std::shared_ptr new_default_logger) { - std::lock_guard lock(logger_map_mutex_); - if (new_default_logger != nullptr) { - loggers_[new_default_logger->name()] = new_default_logger; - } - default_logger_ = std::move(new_default_logger); -} - -SPDLOG_INLINE void registry::set_tp(std::shared_ptr tp) { - std::lock_guard lock(tp_mutex_); - tp_ = std::move(tp); -} - -SPDLOG_INLINE std::shared_ptr registry::get_tp() { - std::lock_guard lock(tp_mutex_); - return tp_; -} - -// Set global formatter. Each sink in each logger will get a clone of this object -SPDLOG_INLINE void registry::set_formatter( - std::unique_ptr formatter) { - std::lock_guard lock(logger_map_mutex_); - formatter_ = std::move(formatter); - for (auto &l : loggers_) { - l.second->set_formatter(formatter_->clone()); - } -} - -SPDLOG_INLINE void registry::enable_backtrace(size_t n_messages) { - std::lock_guard lock(logger_map_mutex_); - backtrace_n_messages_ = n_messages; - - for (auto &l : loggers_) { - l.second->enable_backtrace(n_messages); - } -} - -SPDLOG_INLINE void registry::disable_backtrace() { - std::lock_guard lock(logger_map_mutex_); - backtrace_n_messages_ = 0; - for (auto &l : loggers_) { - l.second->disable_backtrace(); - } -} - -SPDLOG_INLINE void registry::set_level(level::level_enum log_level) { - std::lock_guard lock(logger_map_mutex_); - for (auto &l : loggers_) { - l.second->set_level(log_level); - } - global_log_level_ = log_level; -} - -SPDLOG_INLINE void registry::flush_on(level::level_enum log_level) { - std::lock_guard lock(logger_map_mutex_); - for (auto &l : loggers_) { - l.second->flush_on(log_level); - } - flush_level_ = log_level; -} - -SPDLOG_INLINE void registry::set_error_handler(err_handler handler) { - std::lock_guard lock(logger_map_mutex_); - for (auto &l : loggers_) { - l.second->set_error_handler(handler); - } - err_handler_ = std::move(handler); -} - -SPDLOG_INLINE void registry::apply_all( - const std::function)> &fun) { - std::lock_guard lock(logger_map_mutex_); - for (auto &l : loggers_) { - fun(l.second); - } -} - -SPDLOG_INLINE void registry::flush_all() { - std::lock_guard lock(logger_map_mutex_); - for (auto &l : loggers_) { - l.second->flush(); - } -} - -SPDLOG_INLINE void registry::drop(const std::string &logger_name) { - std::lock_guard lock(logger_map_mutex_); - auto is_default_logger - = default_logger_ && default_logger_->name() == logger_name; - loggers_.erase(logger_name); - if (is_default_logger) { default_logger_.reset(); } -} - -SPDLOG_INLINE void registry::drop_all() { - std::lock_guard lock(logger_map_mutex_); - loggers_.clear(); - default_logger_.reset(); -} - -// clean all resources and threads started by the registry -SPDLOG_INLINE void registry::shutdown() { - { - std::lock_guard lock(flusher_mutex_); - periodic_flusher_.reset(); - } - - drop_all(); - - { - std::lock_guard lock(tp_mutex_); - tp_.reset(); - } -} - -SPDLOG_INLINE std::recursive_mutex ®istry::tp_mutex() { - return tp_mutex_; -} - -SPDLOG_INLINE void registry::set_automatic_registration( - bool automatic_registration) { - std::lock_guard lock(logger_map_mutex_); - automatic_registration_ = automatic_registration; -} - -SPDLOG_INLINE void registry::set_levels( - log_levels levels, level::level_enum *global_level) { - std::lock_guard lock(logger_map_mutex_); - log_levels_ = std::move(levels); - auto global_level_requested = global_level != nullptr; - global_log_level_ - = global_level_requested ? *global_level : global_log_level_; - - for (auto &logger : loggers_) { - auto logger_entry = log_levels_.find(logger.first); - if (logger_entry != log_levels_.end()) { - logger.second->set_level(logger_entry->second); - } else if (global_level_requested) { - logger.second->set_level(*global_level); - } - } -} - -SPDLOG_INLINE registry ®istry::instance() { - static registry s_instance; - return s_instance; -} - -SPDLOG_INLINE void registry::apply_logger_env_levels( - std::shared_ptr new_logger) { - std::lock_guard lock(logger_map_mutex_); - auto it = log_levels_.find(new_logger->name()); - auto new_level = it != log_levels_.end() ? it->second : global_log_level_; - new_logger->set_level(new_level); -} - -SPDLOG_INLINE void registry::throw_if_exists_(const std::string &logger_name) { - if (loggers_.find(logger_name) != loggers_.end()) { - throw_spdlog_ex( - "logger with name '" + logger_name + "' already exists"); - } -} - -SPDLOG_INLINE void registry::register_logger_( - std::shared_ptr new_logger) { - auto logger_name = new_logger->name(); - throw_if_exists_(logger_name); - loggers_[logger_name] = std::move(new_logger); -} - -} // namespace details -} // namespace spdlog diff --git a/src/common/spdlog/details/synchronous_factory.h b/src/common/spdlog/details/synchronous_factory.h deleted file mode 100755 index dbe67d72d45..00000000000 --- a/src/common/spdlog/details/synchronous_factory.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include "registry.h" - -namespace spdlog { - -// Default logger factory- creates synchronous loggers -class logger; - -struct synchronous_factory { - template - static std::shared_ptr create( - std::string logger_name, SinkArgs &&...args) { - auto sink = std::make_shared(std::forward(args)...); - auto new_logger = std::make_shared( - std::move(logger_name), std::move(sink)); - details::registry::instance().initialize_logger(new_logger); - return new_logger; - } -}; -} // namespace spdlog diff --git a/src/common/spdlog/details/windows_include.h b/src/common/spdlog/details/windows_include.h deleted file mode 100755 index 6a2f14f9c76..00000000000 --- a/src/common/spdlog/details/windows_include.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#ifndef NOMINMAX -#define NOMINMAX // prevent windows redefining min/max -#endif - -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif - -#include diff --git a/src/common/spdlog/fmt/bundled/core.h b/src/common/spdlog/fmt/bundled/core.h deleted file mode 100755 index 26686495b67..00000000000 --- a/src/common/spdlog/fmt/bundled/core.h +++ /dev/null @@ -1,3059 +0,0 @@ -// Formatting library for C++ - the core API for char/UTF-8 -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_CORE_H_ -#define FMT_CORE_H_ - -#include // std::byte -#include // std::FILE -#include // std::strlen -#include -#include -#include // std::addressof -#include -#include - -// The fmt library version in the form major * 10000 + minor * 100 + patch. -#define FMT_VERSION 100201 - -#if defined(__clang__) && !defined(__ibmxl__) -#define FMT_CLANG_VERSION (__clang_major__ * 100 + __clang_minor__) -#else -#define FMT_CLANG_VERSION 0 -#endif - -#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) \ - && !defined(__NVCOMPILER) -#define FMT_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) -#else -#define FMT_GCC_VERSION 0 -#endif - -#ifndef FMT_GCC_PRAGMA -// Workaround _Pragma bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59884. -#if FMT_GCC_VERSION >= 504 -#define FMT_GCC_PRAGMA(arg) _Pragma(arg) -#else -#define FMT_GCC_PRAGMA(arg) -#endif -#endif - -#ifdef __ICL -#define FMT_ICC_VERSION __ICL -#elif defined(__INTEL_COMPILER) -#define FMT_ICC_VERSION __INTEL_COMPILER -#else -#define FMT_ICC_VERSION 0 -#endif - -#ifdef _MSC_VER -#define FMT_MSC_VERSION _MSC_VER -#define FMT_MSC_WARNING(...) __pragma(warning(__VA_ARGS__)) -#else -#define FMT_MSC_VERSION 0 -#define FMT_MSC_WARNING(...) -#endif - -#ifdef _MSVC_LANG -#define FMT_CPLUSPLUS _MSVC_LANG -#else -#define FMT_CPLUSPLUS __cplusplus -#endif - -#ifdef __has_feature -#define FMT_HAS_FEATURE(x) __has_feature(x) -#else -#define FMT_HAS_FEATURE(x) 0 -#endif - -#if defined(__has_include) || FMT_ICC_VERSION >= 1600 || FMT_MSC_VERSION > 1900 -#define FMT_HAS_INCLUDE(x) __has_include(x) -#else -#define FMT_HAS_INCLUDE(x) 0 -#endif - -#ifdef __has_cpp_attribute -#define FMT_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) -#else -#define FMT_HAS_CPP_ATTRIBUTE(x) 0 -#endif - -#define FMT_HAS_CPP14_ATTRIBUTE(attribute) \ - (FMT_CPLUSPLUS >= 201402L && FMT_HAS_CPP_ATTRIBUTE(attribute)) - -#define FMT_HAS_CPP17_ATTRIBUTE(attribute) \ - (FMT_CPLUSPLUS >= 201703L && FMT_HAS_CPP_ATTRIBUTE(attribute)) - -// Check if relaxed C++14 constexpr is supported. -// GCC doesn't allow throw in constexpr until version 6 (bug 67371). -#ifndef FMT_USE_CONSTEXPR -#if (FMT_HAS_FEATURE(cxx_relaxed_constexpr) || FMT_MSC_VERSION >= 1912 \ - || (FMT_GCC_VERSION >= 600 && FMT_CPLUSPLUS >= 201402L)) \ - && !FMT_ICC_VERSION \ - && (!defined(__NVCC__) || FMT_CPLUSPLUS >= 202002L) -#define FMT_USE_CONSTEXPR 1 -#else -#define FMT_USE_CONSTEXPR 0 -#endif -#endif -#if FMT_USE_CONSTEXPR -#define FMT_CONSTEXPR constexpr -#else -#define FMT_CONSTEXPR -#endif - -#if (FMT_CPLUSPLUS >= 202002L \ - || (FMT_CPLUSPLUS >= 201709L && FMT_GCC_VERSION >= 1002)) \ - && ((!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 10) \ - && (!defined(_LIBCPP_VERSION) || _LIBCPP_VERSION >= 10000) \ - && (!FMT_MSC_VERSION || FMT_MSC_VERSION >= 1928)) \ - && defined(__cpp_lib_is_constant_evaluated) -#define FMT_CONSTEXPR20 constexpr -#else -#define FMT_CONSTEXPR20 -#endif - -// Check if constexpr std::char_traits<>::{compare,length} are supported. -#if defined(__GLIBCXX__) -#if FMT_CPLUSPLUS >= 201703L && defined(_GLIBCXX_RELEASE) \ - && _GLIBCXX_RELEASE >= 7 // GCC 7+ libstdc++ has _GLIBCXX_RELEASE. -#define FMT_CONSTEXPR_CHAR_TRAITS constexpr -#endif -#elif defined(_LIBCPP_VERSION) && FMT_CPLUSPLUS >= 201703L \ - && _LIBCPP_VERSION >= 4000 -#define FMT_CONSTEXPR_CHAR_TRAITS constexpr -#elif FMT_MSC_VERSION >= 1914 && FMT_CPLUSPLUS >= 201703L -#define FMT_CONSTEXPR_CHAR_TRAITS constexpr -#endif -#ifndef FMT_CONSTEXPR_CHAR_TRAITS -#define FMT_CONSTEXPR_CHAR_TRAITS -#endif - -// Check if exceptions are disabled. -#ifndef FMT_EXCEPTIONS -#if (defined(__GNUC__) && !defined(__EXCEPTIONS)) \ - || (FMT_MSC_VERSION && !_HAS_EXCEPTIONS) -#define FMT_EXCEPTIONS 0 -#else -#define FMT_EXCEPTIONS 1 -#endif -#endif - -// Disable [[noreturn]] on MSVC/NVCC because of bogus unreachable code warnings. -#if FMT_EXCEPTIONS && FMT_HAS_CPP_ATTRIBUTE(noreturn) && !FMT_MSC_VERSION \ - && !defined(__NVCC__) -#define FMT_NORETURN [[noreturn]] -#else -#define FMT_NORETURN -#endif - -#ifndef FMT_NODISCARD -#if FMT_HAS_CPP17_ATTRIBUTE(nodiscard) -#define FMT_NODISCARD [[nodiscard]] -#else -#define FMT_NODISCARD -#endif -#endif - -#ifndef FMT_INLINE -#if FMT_GCC_VERSION || FMT_CLANG_VERSION -#define FMT_INLINE inline __attribute__((always_inline)) -#else -#define FMT_INLINE inline -#endif -#endif - -#ifdef _MSC_VER -#define FMT_UNCHECKED_ITERATOR(It) \ - using _Unchecked_type = It // Mark iterator as checked. -#else -#define FMT_UNCHECKED_ITERATOR(It) using unchecked_type = It -#endif - -#ifndef FMT_BEGIN_NAMESPACE -#define FMT_BEGIN_NAMESPACE \ - namespace fmt { \ - inline namespace v10 { -#define FMT_END_NAMESPACE \ - } \ - } -#endif - -#ifndef FMT_EXPORT -#define FMT_EXPORT -#define FMT_BEGIN_EXPORT -#define FMT_END_EXPORT -#endif - -#if FMT_GCC_VERSION || FMT_CLANG_VERSION -#define FMT_VISIBILITY(value) __attribute__((visibility(value))) -#else -#define FMT_VISIBILITY(value) -#endif - -#if !defined(FMT_HEADER_ONLY) && defined(_WIN32) -#if defined(FMT_LIB_EXPORT) -#define FMT_API __declspec(dllexport) -#elif defined(FMT_SHARED) -#define FMT_API __declspec(dllimport) -#endif -#elif defined(FMT_LIB_EXPORT) || defined(FMT_SHARED) -#define FMT_API FMT_VISIBILITY("default") -#endif -#ifndef FMT_API -#define FMT_API -#endif - -// libc++ supports string_view in pre-c++17. -#if FMT_HAS_INCLUDE() \ - && (FMT_CPLUSPLUS >= 201703L || defined(_LIBCPP_VERSION)) -#include -#define FMT_USE_STRING_VIEW -#elif FMT_HAS_INCLUDE("experimental/string_view") && FMT_CPLUSPLUS >= 201402L -#include -#define FMT_USE_EXPERIMENTAL_STRING_VIEW -#endif - -#ifndef FMT_UNICODE -#define FMT_UNICODE !FMT_MSC_VERSION -#endif - -#ifndef FMT_CONSTEVAL -#if ((FMT_GCC_VERSION >= 1000 || FMT_CLANG_VERSION >= 1101) \ - && (!defined(__apple_build_version__) \ - || __apple_build_version__ >= 14000029L) \ - && FMT_CPLUSPLUS >= 202002L) \ - || (defined(__cpp_consteval) \ - && (!FMT_MSC_VERSION || FMT_MSC_VERSION >= 1929)) -// consteval is broken in MSVC before VS2019 version 16.10 and Apple clang -// before 14. -#define FMT_CONSTEVAL consteval -#define FMT_HAS_CONSTEVAL -#else -#define FMT_CONSTEVAL -#endif -#endif - -#ifndef FMT_USE_NONTYPE_TEMPLATE_ARGS -#if defined(__cpp_nontype_template_args) \ - && ((FMT_GCC_VERSION >= 903 && FMT_CPLUSPLUS >= 201709L) \ - || __cpp_nontype_template_args >= 201911L) \ - && !defined(__NVCOMPILER) && !defined(__LCC__) -#define FMT_USE_NONTYPE_TEMPLATE_ARGS 1 -#else -#define FMT_USE_NONTYPE_TEMPLATE_ARGS 0 -#endif -#endif - -// GCC < 5 requires this-> in decltype -#ifndef FMT_DECLTYPE_THIS -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 500 -#define FMT_DECLTYPE_THIS this-> -#else -#define FMT_DECLTYPE_THIS -#endif -#endif - -// Enable minimal optimizations for more compact code in debug mode. -FMT_GCC_PRAGMA("GCC push_options") -#if !defined(__OPTIMIZE__) && !defined(__NVCOMPILER) && !defined(__LCC__) \ - && !defined(__CUDACC__) -FMT_GCC_PRAGMA("GCC optimize(\"Og\")") -#endif - -FMT_BEGIN_NAMESPACE - -// Implementations of enable_if_t and other metafunctions for older systems. -template -using enable_if_t = typename std::enable_if::type; -template -using conditional_t = typename std::conditional::type; -template -using bool_constant = std::integral_constant; -template -using remove_reference_t = typename std::remove_reference::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_cvref_t = typename std::remove_cv>::type; -template -struct type_identity { - using type = T; -}; -template -using type_identity_t = typename type_identity::type; -template -using underlying_t = typename std::underlying_type::type; - -// Checks whether T is a container with contiguous storage. -template -struct is_contiguous : std::false_type {}; -template -struct is_contiguous> : std::true_type {}; - -struct monostate { - constexpr monostate() {} -}; - -// An enable_if helper to be used in template parameters which results in much -// shorter symbols: https://godbolt.org/z/sWw4vP. Extra parentheses are needed -// to workaround a bug in MSVC 2019 (see #1140 and #1186). -#ifdef FMT_DOC -#define FMT_ENABLE_IF(...) -#else -#define FMT_ENABLE_IF(...) fmt::enable_if_t<(__VA_ARGS__), int> = 0 -#endif - -// This is defined in core.h instead of format.h to avoid injecting in std. -// It is a template to avoid undesirable implicit conversions to std::byte. -#ifdef __cpp_lib_byte -template ::value)> -inline auto format_as(T b) -> unsigned char { - return static_cast(b); -} -#endif - -namespace detail { -// Suppresses "unused variable" warnings with the method described in -// https://herbsutter.com/2009/10/18/mailbag-shutting-up-compiler-warnings/. -// (void)var does not work on many Intel compilers. -template -FMT_CONSTEXPR void ignore_unused(const T &...) {} - -constexpr FMT_INLINE auto is_constant_evaluated( - bool default_value = false) noexcept -> bool { -// Workaround for incompatibility between libstdc++ consteval-based -// std::is_constant_evaluated() implementation and clang-14. -// https://github.com/fmtlib/fmt/issues/3247 -#if FMT_CPLUSPLUS >= 202002L && defined(_GLIBCXX_RELEASE) \ - && _GLIBCXX_RELEASE >= 12 \ - && (FMT_CLANG_VERSION >= 1400 && FMT_CLANG_VERSION < 1500) - ignore_unused(default_value); - return __builtin_is_constant_evaluated(); -#elif defined(__cpp_lib_is_constant_evaluated) - ignore_unused(default_value); - return std::is_constant_evaluated(); -#else - return default_value; -#endif -} - -// Suppresses "conditional expression is constant" warnings. -template -constexpr FMT_INLINE auto const_check(T value) -> T { - return value; -} - -FMT_NORETURN FMT_API void assert_fail( - const char *file, int line, const char *message); - -#ifndef FMT_ASSERT -#ifdef NDEBUG -// FMT_ASSERT is not empty to avoid -Wempty-body. -#define FMT_ASSERT(condition, message) \ - fmt::detail::ignore_unused((condition), (message)) -#else -#define FMT_ASSERT(condition, message) \ - ((condition) /* void() fails with -Winvalid-constexpr on clang 4.0.1 */ \ - ? (void)0 \ - : fmt::detail::assert_fail(__FILE__, __LINE__, (message))) -#endif -#endif - -#if defined(FMT_USE_STRING_VIEW) -template -using std_string_view = std::basic_string_view; -#elif defined(FMT_USE_EXPERIMENTAL_STRING_VIEW) -template -using std_string_view = std::experimental::basic_string_view; -#else -template -struct std_string_view {}; -#endif - -#ifdef FMT_USE_INT128 -// Do nothing. -#elif defined(__SIZEOF_INT128__) && !defined(__NVCC__) \ - && !(FMT_CLANG_VERSION && FMT_MSC_VERSION) -#define FMT_USE_INT128 1 -using int128_opt = __int128_t; // An optional native 128-bit integer. -using uint128_opt = __uint128_t; -template -inline auto convert_for_visit(T value) -> T { - return value; -} -#else -#define FMT_USE_INT128 0 -#endif -#if !FMT_USE_INT128 -enum class int128_opt {}; -enum class uint128_opt {}; -// Reduce template instantiations. -template -auto convert_for_visit(T) -> monostate { - return {}; -} -#endif - -// Casts a nonnegative integer to unsigned. -template -FMT_CONSTEXPR auto to_unsigned(Int value) -> - typename std::make_unsigned::type { - FMT_ASSERT(std::is_unsigned::value || value >= 0, "negative value"); - return static_cast::type>(value); -} - -FMT_CONSTEXPR inline auto is_utf8() -> bool { - FMT_MSC_WARNING(suppress : 4566) - constexpr unsigned char section[] = "\u00A7"; - - // Avoid buggy sign extensions in MSVC's constant evaluation mode (#2297). - using uchar = unsigned char; - return FMT_UNICODE - || (sizeof(section) == 3 && uchar(section[0]) == 0xC2 - && uchar(section[1]) == 0xA7); -} -} // namespace detail - -/** - An implementation of ``std::basic_string_view`` for pre-C++17. It provides a - subset of the API. ``fmt::basic_string_view`` is used for format strings even - if ``std::string_view`` is available to prevent issues when a library is - compiled with a different ``-std`` option than the client code (which is not - recommended). - */ -FMT_EXPORT -template -class basic_string_view { -private: - const Char *data_; - size_t size_; - -public: - using value_type = Char; - using iterator = const Char *; - - constexpr basic_string_view() noexcept : data_(nullptr), size_(0) {} - - /** Constructs a string reference object from a C string and a size. */ - constexpr basic_string_view(const Char *s, size_t count) noexcept - : data_(s), size_(count) {} - - /** - \rst - Constructs a string reference object from a C string computing - the size with ``std::char_traits::length``. - \endrst - */ - FMT_CONSTEXPR_CHAR_TRAITS - FMT_INLINE - basic_string_view(const Char *s) - : data_(s) - , size_(detail::const_check(std::is_same::value - && !detail::is_constant_evaluated(true)) - ? std::strlen(reinterpret_cast(s)) - : std::char_traits::length(s)) {} - - /** Constructs a string reference from a ``std::basic_string`` object. */ - template - FMT_CONSTEXPR basic_string_view( - const std::basic_string &s) noexcept - : data_(s.data()), size_(s.size()) {} - - template >::value)> - FMT_CONSTEXPR basic_string_view(S s) noexcept - : data_(s.data()), size_(s.size()) {} - - /** Returns a pointer to the string data. */ - constexpr auto data() const noexcept -> const Char * { return data_; } - - /** Returns the string size. */ - constexpr auto size() const noexcept -> size_t { return size_; } - - constexpr auto begin() const noexcept -> iterator { return data_; } - constexpr auto end() const noexcept -> iterator { return data_ + size_; } - - constexpr auto operator[](size_t pos) const noexcept -> const Char & { - return data_[pos]; - } - - FMT_CONSTEXPR void remove_prefix(size_t n) noexcept { - data_ += n; - size_ -= n; - } - - FMT_CONSTEXPR_CHAR_TRAITS auto starts_with( - basic_string_view sv) const noexcept -> bool { - return size_ >= sv.size_ - && std::char_traits::compare(data_, sv.data_, sv.size_) - == 0; - } - FMT_CONSTEXPR_CHAR_TRAITS auto starts_with(Char c) const noexcept -> bool { - return size_ >= 1 && std::char_traits::eq(*data_, c); - } - FMT_CONSTEXPR_CHAR_TRAITS auto starts_with(const Char *s) const -> bool { - return starts_with(basic_string_view(s)); - } - - // Lexicographically compare this string reference to other. - FMT_CONSTEXPR_CHAR_TRAITS auto compare(basic_string_view other) const - -> int { - size_t str_size = size_ < other.size_ ? size_ : other.size_; - int result - = std::char_traits::compare(data_, other.data_, str_size); - if (result == 0) - result = size_ == other.size_ ? 0 : (size_ < other.size_ ? -1 : 1); - return result; - } - - FMT_CONSTEXPR_CHAR_TRAITS friend auto operator==( - basic_string_view lhs, basic_string_view rhs) -> bool { - return lhs.compare(rhs) == 0; - } - friend auto operator!=(basic_string_view lhs, basic_string_view rhs) - -> bool { - return lhs.compare(rhs) != 0; - } - friend auto operator<(basic_string_view lhs, basic_string_view rhs) - -> bool { - return lhs.compare(rhs) < 0; - } - friend auto operator<=(basic_string_view lhs, basic_string_view rhs) - -> bool { - return lhs.compare(rhs) <= 0; - } - friend auto operator>(basic_string_view lhs, basic_string_view rhs) - -> bool { - return lhs.compare(rhs) > 0; - } - friend auto operator>=(basic_string_view lhs, basic_string_view rhs) - -> bool { - return lhs.compare(rhs) >= 0; - } -}; - -FMT_EXPORT -using string_view = basic_string_view; - -/** Specifies if ``T`` is a character type. Can be specialized by users. */ -FMT_EXPORT -template -struct is_char : std::false_type {}; -template <> -struct is_char : std::true_type {}; - -namespace detail { - -// A base class for compile-time strings. -struct compile_string {}; - -template -struct is_compile_string : std::is_base_of {}; - -template ::value)> -FMT_INLINE auto to_string_view(const Char *s) -> basic_string_view { - return s; -} -template -inline auto to_string_view(const std::basic_string &s) - -> basic_string_view { - return s; -} -template -constexpr auto to_string_view(basic_string_view s) - -> basic_string_view { - return s; -} -template >::value)> -inline auto to_string_view(std_string_view s) -> basic_string_view { - return s; -} -template ::value)> -constexpr auto to_string_view(const S &s) - -> basic_string_view { - return basic_string_view(s); -} -void to_string_view(...); - -// Specifies whether S is a string type convertible to fmt::basic_string_view. -// It should be a constexpr function but MSVC 2017 fails to compile it in -// enable_if and MSVC 2015 fails to compile it as an alias template. -// Arg Dep Lookup is intentionally disabled as to_string_view is not an -// extension point. -template -struct is_string - : std::is_class()))> {}; - -template -struct char_t_impl {}; -template -struct char_t_impl::value>> { - using result = decltype(to_string_view(std::declval())); - using type = typename result::value_type; -}; - -enum class type { - none_type, - // Integer types should go first, - int_type, - uint_type, - long_long_type, - ulong_long_type, - int128_type, - uint128_type, - bool_type, - char_type, - last_integer_type = char_type, - // followed by floating-point types. - float_type, - double_type, - long_double_type, - last_numeric_type = long_double_type, - cstring_type, - string_type, - pointer_type, - custom_type -}; - -// Maps core type T to the corresponding type enum constant. -template -struct type_constant : std::integral_constant {}; - -#define FMT_TYPE_CONSTANT(Type, constant) \ - template \ - struct type_constant \ - : std::integral_constant {} - -FMT_TYPE_CONSTANT(int, int_type); -FMT_TYPE_CONSTANT(unsigned, uint_type); -FMT_TYPE_CONSTANT(long long, long_long_type); -FMT_TYPE_CONSTANT(unsigned long long, ulong_long_type); -FMT_TYPE_CONSTANT(int128_opt, int128_type); -FMT_TYPE_CONSTANT(uint128_opt, uint128_type); -FMT_TYPE_CONSTANT(bool, bool_type); -FMT_TYPE_CONSTANT(Char, char_type); -FMT_TYPE_CONSTANT(float, float_type); -FMT_TYPE_CONSTANT(double, double_type); -FMT_TYPE_CONSTANT(long double, long_double_type); -FMT_TYPE_CONSTANT(const Char *, cstring_type); -FMT_TYPE_CONSTANT(basic_string_view, string_type); -FMT_TYPE_CONSTANT(const void *, pointer_type); - -constexpr auto is_integral_type(type t) -> bool { - return t > type::none_type && t <= type::last_integer_type; -} -constexpr auto is_arithmetic_type(type t) -> bool { - return t > type::none_type && t <= type::last_numeric_type; -} - -constexpr auto set(type rhs) -> int { - return 1 << static_cast(rhs); -} -constexpr auto in(type t, int set) -> bool { - return ((set >> static_cast(t)) & 1) != 0; -} - -// Bitsets of types. -enum { - sint_set - = set(type::int_type) | set(type::long_long_type) | set(type::int128_type), - uint_set = set(type::uint_type) | set(type::ulong_long_type) - | set(type::uint128_type), - bool_set = set(type::bool_type), - char_set = set(type::char_type), - float_set = set(type::float_type) | set(type::double_type) - | set(type::long_double_type), - string_set = set(type::string_type), - cstring_set = set(type::cstring_type), - pointer_set = set(type::pointer_type) -}; - -// DEPRECATED! -FMT_NORETURN FMT_API void throw_format_error(const char *message); - -struct error_handler { - constexpr error_handler() = default; - - // This function is intentionally not constexpr to give a compile-time error. - FMT_NORETURN void on_error(const char *message) { - throw_format_error(message); - } -}; -} // namespace detail - -/** Throws ``format_error`` with a given message. */ -using detail::throw_format_error; - -/** String's character type. */ -template -using char_t = typename detail::char_t_impl::type; - -/** - \rst - Parsing context consisting of a format string range being parsed and an - argument counter for automatic indexing. - You can use the ``format_parse_context`` type alias for ``char`` instead. - \endrst - */ -FMT_EXPORT -template -class basic_format_parse_context { -private: - basic_string_view format_str_; - int next_arg_id_; - - FMT_CONSTEXPR void do_check_arg_id(int id); - -public: - using char_type = Char; - using iterator = const Char *; - - explicit constexpr basic_format_parse_context( - basic_string_view format_str, int next_arg_id = 0) - : format_str_(format_str), next_arg_id_(next_arg_id) {} - - /** - Returns an iterator to the beginning of the format string range being - parsed. - */ - constexpr auto begin() const noexcept -> iterator { - return format_str_.begin(); - } - - /** - Returns an iterator past the end of the format string range being parsed. - */ - constexpr auto end() const noexcept -> iterator { - return format_str_.end(); - } - - /** Advances the begin iterator to ``it``. */ - FMT_CONSTEXPR void advance_to(iterator it) { - format_str_.remove_prefix(detail::to_unsigned(it - begin())); - } - - /** - Reports an error if using the manual argument indexing; otherwise returns - the next argument index and switches to the automatic indexing. - */ - FMT_CONSTEXPR auto next_arg_id() -> int { - if (next_arg_id_ < 0) { - detail::throw_format_error( - "cannot switch from manual to automatic argument indexing"); - return 0; - } - int id = next_arg_id_++; - do_check_arg_id(id); - return id; - } - - /** - Reports an error if using the automatic argument indexing; otherwise - switches to the manual indexing. - */ - FMT_CONSTEXPR void check_arg_id(int id) { - if (next_arg_id_ > 0) { - detail::throw_format_error( - "cannot switch from automatic to manual argument indexing"); - return; - } - next_arg_id_ = -1; - do_check_arg_id(id); - } - FMT_CONSTEXPR void check_arg_id(basic_string_view) {} - FMT_CONSTEXPR void check_dynamic_spec(int arg_id); -}; - -FMT_EXPORT -using format_parse_context = basic_format_parse_context; - -namespace detail { -// A parse context with extra data used only in compile-time checks. -template -class compile_parse_context : public basic_format_parse_context { -private: - int num_args_; - const type *types_; - using base = basic_format_parse_context; - -public: - explicit FMT_CONSTEXPR compile_parse_context( - basic_string_view format_str, int num_args, const type *types, - int next_arg_id = 0) - : base(format_str, next_arg_id), num_args_(num_args), types_(types) {} - - constexpr auto num_args() const -> int { return num_args_; } - constexpr auto arg_type(int id) const -> type { return types_[id]; } - - FMT_CONSTEXPR auto next_arg_id() -> int { - int id = base::next_arg_id(); - if (id >= num_args_) throw_format_error("argument not found"); - return id; - } - - FMT_CONSTEXPR void check_arg_id(int id) { - base::check_arg_id(id); - if (id >= num_args_) throw_format_error("argument not found"); - } - using base::check_arg_id; - - FMT_CONSTEXPR void check_dynamic_spec(int arg_id) { - detail::ignore_unused(arg_id); -#if !defined(__LCC__) - if (arg_id < num_args_ && types_ && !is_integral_type(types_[arg_id])) - throw_format_error("width/precision is not integer"); -#endif - } -}; - -// Extracts a reference to the container from back_insert_iterator. -template -inline auto get_container(std::back_insert_iterator it) - -> Container & { - using base = std::back_insert_iterator; - struct accessor : base { - accessor(base b) : base(b) {} - using base::container; - }; - return *accessor(it).container; -} - -template -FMT_CONSTEXPR auto copy_str(InputIt begin, InputIt end, OutputIt out) - -> OutputIt { - while (begin != end) - *out++ = static_cast(*begin++); - return out; -} - -template , U>::value &&is_char::value)> -FMT_CONSTEXPR auto copy_str(T *begin, T *end, U *out) -> U * { - if (is_constant_evaluated()) - return copy_str(begin, end, out); - auto size = to_unsigned(end - begin); - if (size > 0) memcpy(out, begin, size * sizeof(U)); - return out + size; -} - -/** - \rst - A contiguous memory buffer with an optional growing ability. It is an internal - class and shouldn't be used directly, only via `~fmt::basic_memory_buffer`. - \endrst - */ -template -class buffer { -private: - T *ptr_; - size_t size_; - size_t capacity_; - -protected: - // Don't initialize ptr_ since it is not accessed to save a few cycles. - FMT_MSC_WARNING(suppress : 26495) - FMT_CONSTEXPR buffer(size_t sz) noexcept : size_(sz), capacity_(sz) {} - - FMT_CONSTEXPR20 buffer( - T *p = nullptr, size_t sz = 0, size_t cap = 0) noexcept - : ptr_(p), size_(sz), capacity_(cap) {} - - FMT_CONSTEXPR20 ~buffer() = default; - buffer(buffer &&) = default; - - /** Sets the buffer data and capacity. */ - FMT_CONSTEXPR void set(T *buf_data, size_t buf_capacity) noexcept { - ptr_ = buf_data; - capacity_ = buf_capacity; - } - - /** Increases the buffer capacity to hold at least *capacity* elements. */ - // DEPRECATED! - virtual FMT_CONSTEXPR20 void grow(size_t capacity) = 0; - -public: - using value_type = T; - using const_reference = const T &; - - buffer(const buffer &) = delete; - void operator=(const buffer &) = delete; - - FMT_INLINE auto begin() noexcept -> T * { return ptr_; } - FMT_INLINE auto end() noexcept -> T * { return ptr_ + size_; } - - FMT_INLINE auto begin() const noexcept -> const T * { return ptr_; } - FMT_INLINE auto end() const noexcept -> const T * { return ptr_ + size_; } - - /** Returns the size of this buffer. */ - constexpr auto size() const noexcept -> size_t { return size_; } - - /** Returns the capacity of this buffer. */ - constexpr auto capacity() const noexcept -> size_t { return capacity_; } - - /** Returns a pointer to the buffer data (not null-terminated). */ - FMT_CONSTEXPR auto data() noexcept -> T * { return ptr_; } - FMT_CONSTEXPR auto data() const noexcept -> const T * { return ptr_; } - - /** Clears this buffer. */ - void clear() { size_ = 0; } - - // Tries resizing the buffer to contain *count* elements. If T is a POD type - // the new elements may not be initialized. - FMT_CONSTEXPR20 void try_resize(size_t count) { - try_reserve(count); - size_ = count <= capacity_ ? count : capacity_; - } - - // Tries increasing the buffer capacity to *new_capacity*. It can increase the - // capacity by a smaller amount than requested but guarantees there is space - // for at least one additional element either by increasing the capacity or by - // flushing the buffer if it is full. - FMT_CONSTEXPR20 void try_reserve(size_t new_capacity) { - if (new_capacity > capacity_) grow(new_capacity); - } - - FMT_CONSTEXPR20 void push_back(const T &value) { - try_reserve(size_ + 1); - ptr_[size_++] = value; - } - - /** Appends data to the end of the buffer. */ - template - void append(const U *begin, const U *end); - - template - FMT_CONSTEXPR auto operator[](Idx index) -> T & { - return ptr_[index]; - } - template - FMT_CONSTEXPR auto operator[](Idx index) const -> const T & { - return ptr_[index]; - } -}; - -struct buffer_traits { - explicit buffer_traits(size_t) {} - auto count() const -> size_t { return 0; } - auto limit(size_t size) -> size_t { return size; } -}; - -class fixed_buffer_traits { -private: - size_t count_ = 0; - size_t limit_; - -public: - explicit fixed_buffer_traits(size_t limit) : limit_(limit) {} - auto count() const -> size_t { return count_; } - auto limit(size_t size) -> size_t { - size_t n = limit_ > count_ ? limit_ - count_ : 0; - count_ += size; - return size < n ? size : n; - } -}; - -// A buffer that writes to an output iterator when flushed. -template -class iterator_buffer final : public Traits, public buffer { -private: - OutputIt out_; - enum { buffer_size = 256 }; - T data_[buffer_size]; - -protected: - FMT_CONSTEXPR20 void grow(size_t) override { - if (this->size() == buffer_size) flush(); - } - - void flush() { - auto size = this->size(); - this->clear(); - out_ = copy_str(data_, data_ + this->limit(size), out_); - } - -public: - explicit iterator_buffer(OutputIt out, size_t n = buffer_size) - : Traits(n), buffer(data_, 0, buffer_size), out_(out) {} - iterator_buffer(iterator_buffer &&other) - : Traits(other), buffer(data_, 0, buffer_size), out_(other.out_) {} - ~iterator_buffer() { flush(); } - - auto out() -> OutputIt { - flush(); - return out_; - } - auto count() const -> size_t { return Traits::count() + this->size(); } -}; - -template -class iterator_buffer final - : public fixed_buffer_traits, - public buffer { -private: - T *out_; - enum { buffer_size = 256 }; - T data_[buffer_size]; - -protected: - FMT_CONSTEXPR20 void grow(size_t) override { - if (this->size() == this->capacity()) flush(); - } - - void flush() { - size_t n = this->limit(this->size()); - if (this->data() == out_) { - out_ += n; - this->set(data_, buffer_size); - } - this->clear(); - } - -public: - explicit iterator_buffer(T *out, size_t n = buffer_size) - : fixed_buffer_traits(n), buffer(out, 0, n), out_(out) {} - iterator_buffer(iterator_buffer &&other) - : fixed_buffer_traits(other) - , buffer(std::move(other)) - , out_(other.out_) { - if (this->data() != out_) { - this->set(data_, buffer_size); - this->clear(); - } - } - ~iterator_buffer() { flush(); } - - auto out() -> T * { - flush(); - return out_; - } - auto count() const -> size_t { - return fixed_buffer_traits::count() + this->size(); - } -}; - -template -class iterator_buffer final : public buffer { -protected: - FMT_CONSTEXPR20 void grow(size_t) override {} - -public: - explicit iterator_buffer(T *out, size_t = 0) - : buffer(out, 0, ~size_t()) {} - - auto out() -> T * { return &*this->end(); } -}; - -// A buffer that writes to a container with the contiguous storage. -template -class iterator_buffer, - enable_if_t::value, - typename Container::value_type>> - final : public buffer { -private: - Container &container_; - -protected: - FMT_CONSTEXPR20 void grow(size_t capacity) override { - container_.resize(capacity); - this->set(&container_[0], capacity); - } - -public: - explicit iterator_buffer(Container &c) - : buffer(c.size()), container_(c) {} - explicit iterator_buffer( - std::back_insert_iterator out, size_t = 0) - : iterator_buffer(get_container(out)) {} - - auto out() -> std::back_insert_iterator { - return std::back_inserter(container_); - } -}; - -// A buffer that counts the number of code units written discarding the output. -template -class counting_buffer final : public buffer { -private: - enum { buffer_size = 256 }; - T data_[buffer_size]; - size_t count_ = 0; - -protected: - FMT_CONSTEXPR20 void grow(size_t) override { - if (this->size() != buffer_size) return; - count_ += this->size(); - this->clear(); - } - -public: - counting_buffer() : buffer(data_, 0, buffer_size) {} - - auto count() -> size_t { return count_ + this->size(); } -}; -} // namespace detail - -template -FMT_CONSTEXPR void basic_format_parse_context::do_check_arg_id(int id) { - // Argument id is only checked at compile-time during parsing because - // formatting has its own validation. - if (detail::is_constant_evaluated() - && (!FMT_GCC_VERSION || FMT_GCC_VERSION >= 1200)) { - using context = detail::compile_parse_context; - if (id >= static_cast(this)->num_args()) - detail::throw_format_error("argument not found"); - } -} - -template -FMT_CONSTEXPR void basic_format_parse_context::check_dynamic_spec( - int arg_id) { - if (detail::is_constant_evaluated() - && (!FMT_GCC_VERSION || FMT_GCC_VERSION >= 1200)) { - using context = detail::compile_parse_context; - static_cast(this)->check_dynamic_spec(arg_id); - } -} - -FMT_EXPORT template -class basic_format_arg; -FMT_EXPORT template -class basic_format_args; -FMT_EXPORT template -class dynamic_format_arg_store; - -// A formatter for objects of type T. -FMT_EXPORT -template -struct formatter { - // A deleted default constructor indicates a disabled formatter. - formatter() = delete; -}; - -// Specifies if T has an enabled formatter specialization. A type can be -// formattable even if it doesn't have a formatter e.g. via a conversion. -template -using has_formatter - = std::is_constructible>; - -// An output iterator that appends to a buffer. -// It is used to reduce symbol sizes for the common case. -class appender : public std::back_insert_iterator> { - using base = std::back_insert_iterator>; - -public: - using std::back_insert_iterator>::back_insert_iterator; - appender(base it) noexcept : base(it) {} - FMT_UNCHECKED_ITERATOR(appender); - - auto operator++() noexcept -> appender & { return *this; } - auto operator++(int) noexcept -> appender { return *this; } -}; - -namespace detail { - -template -constexpr auto has_const_formatter_impl(T *) -> decltype( - typename Context::template formatter_type().format( - std::declval(), std::declval()), - true) { - return true; -} -template -constexpr auto has_const_formatter_impl(...) -> bool { - return false; -} -template -constexpr auto has_const_formatter() -> bool { - return has_const_formatter_impl(static_cast(nullptr)); -} - -template -using buffer_appender = conditional_t::value, appender, - std::back_insert_iterator>>; - -// Maps an output iterator to a buffer. -template -auto get_buffer(OutputIt out) -> iterator_buffer { - return iterator_buffer(out); -} -template , Buf>::value)> -auto get_buffer(std::back_insert_iterator out) -> buffer & { - return get_container(out); -} - -template -FMT_INLINE auto get_iterator(Buf &buf, OutputIt) -> decltype(buf.out()) { - return buf.out(); -} -template -auto get_iterator(buffer &, OutputIt out) -> OutputIt { - return out; -} - -struct view {}; - -template -struct named_arg : view { - const Char *name; - const T &value; - named_arg(const Char *n, const T &v) : name(n), value(v) {} -}; - -template -struct named_arg_info { - const Char *name; - int id; -}; - -template -struct arg_data { - // args_[0].named_args points to named_args_ to avoid bloating format_args. - // +1 to workaround a bug in gcc 7.5 that causes duplicated-branches warning. - T args_[1 + (NUM_ARGS != 0 ? NUM_ARGS : +1)]; - named_arg_info named_args_[NUM_NAMED_ARGS]; - - template - arg_data(const U &...init) - : args_ {T(named_args_, NUM_NAMED_ARGS), init...} {} - arg_data(const arg_data &other) = delete; - auto args() const -> const T * { return args_ + 1; } - auto named_args() -> named_arg_info * { return named_args_; } -}; - -template -struct arg_data { - // +1 to workaround a bug in gcc 7.5 that causes duplicated-branches warning. - T args_[NUM_ARGS != 0 ? NUM_ARGS : +1]; - - template - FMT_CONSTEXPR FMT_INLINE arg_data(const U &...init) : args_ {init...} {} - FMT_CONSTEXPR FMT_INLINE auto args() const -> const T * { return args_; } - FMT_CONSTEXPR FMT_INLINE auto named_args() -> std::nullptr_t { - return nullptr; - } -}; - -template -inline void init_named_args(named_arg_info *, int, int) {} - -template -struct is_named_arg : std::false_type {}; -template -struct is_statically_named_arg : std::false_type {}; - -template -struct is_named_arg> : std::true_type {}; - -template ::value)> -void init_named_args(named_arg_info *named_args, int arg_count, - int named_arg_count, const T &, const Tail &...args) { - init_named_args(named_args, arg_count + 1, named_arg_count, args...); -} - -template ::value)> -void init_named_args(named_arg_info *named_args, int arg_count, - int named_arg_count, const T &arg, const Tail &...args) { - named_args[named_arg_count++] = {arg.name, arg_count}; - init_named_args(named_args, arg_count + 1, named_arg_count, args...); -} - -template -FMT_CONSTEXPR FMT_INLINE void init_named_args( - std::nullptr_t, int, int, const Args &...) {} - -template -constexpr auto count() -> size_t { - return B ? 1 : 0; -} -template -constexpr auto count() -> size_t { - return (B1 ? 1 : 0) + count(); -} - -template -constexpr auto count_named_args() -> size_t { - return count::value...>(); -} - -template -constexpr auto count_statically_named_args() -> size_t { - return count::value...>(); -} - -struct unformattable {}; -struct unformattable_char : unformattable {}; -struct unformattable_pointer : unformattable {}; - -template -struct string_value { - const Char *data; - size_t size; -}; - -template -struct named_arg_value { - const named_arg_info *data; - size_t size; -}; - -template -struct custom_value { - using parse_context = typename Context::parse_context_type; - void *value; - void (*format)(void *arg, parse_context &parse_ctx, Context &ctx); -}; - -// A formatting argument value. -template -class value { -public: - using char_type = typename Context::char_type; - - union { - monostate no_value; - int int_value; - unsigned uint_value; - long long long_long_value; - unsigned long long ulong_long_value; - int128_opt int128_value; - uint128_opt uint128_value; - bool bool_value; - char_type char_value; - float float_value; - double double_value; - long double long_double_value; - const void *pointer; - string_value string; - custom_value custom; - named_arg_value named_args; - }; - - constexpr FMT_INLINE value() : no_value() {} - constexpr FMT_INLINE value(int val) : int_value(val) {} - constexpr FMT_INLINE value(unsigned val) : uint_value(val) {} - constexpr FMT_INLINE value(long long val) : long_long_value(val) {} - constexpr FMT_INLINE value(unsigned long long val) - : ulong_long_value(val) {} - FMT_INLINE value(int128_opt val) : int128_value(val) {} - FMT_INLINE value(uint128_opt val) : uint128_value(val) {} - constexpr FMT_INLINE value(float val) : float_value(val) {} - constexpr FMT_INLINE value(double val) : double_value(val) {} - FMT_INLINE value(long double val) : long_double_value(val) {} - constexpr FMT_INLINE value(bool val) : bool_value(val) {} - constexpr FMT_INLINE value(char_type val) : char_value(val) {} - FMT_CONSTEXPR FMT_INLINE value(const char_type *val) { - string.data = val; - if (is_constant_evaluated()) string.size = {}; - } - FMT_CONSTEXPR FMT_INLINE value(basic_string_view val) { - string.data = val.data(); - string.size = val.size(); - } - FMT_INLINE value(const void *val) : pointer(val) {} - FMT_INLINE value(const named_arg_info *args, size_t size) - : named_args {args, size} {} - - template - FMT_CONSTEXPR20 FMT_INLINE value(T &val) { - using value_type = remove_const_t; - custom.value = const_cast(std::addressof(val)); - // Get the formatter type through the context to allow different contexts - // have different extension points, e.g. `formatter` for `format` and - // `printf_formatter` for `printf`. - custom.format = format_custom_arg>; - } - value(unformattable); - value(unformattable_char); - value(unformattable_pointer); - -private: - // Formats an argument of a custom type, such as a user-defined class. - template - static void format_custom_arg(void *arg, - typename Context::parse_context_type &parse_ctx, Context &ctx) { - auto f = Formatter(); - parse_ctx.advance_to(f.parse(parse_ctx)); - using qualified_type - = conditional_t(), const T, T>; - // Calling format through a mutable reference is deprecated. - ctx.advance_to(f.format(*static_cast(arg), ctx)); - } -}; - -// To minimize the number of types we need to deal with, long is translated -// either to int or to long long depending on its size. -enum { long_short = sizeof(long) == sizeof(int) }; -using long_type = conditional_t; -using ulong_type = conditional_t; - -template -struct format_as_result { - template ::value || std::is_class::value)> - static auto map(U *) - -> remove_cvref_t()))>; - static auto map(...) -> void; - - using type = decltype(map(static_cast(nullptr))); -}; -template -using format_as_t = typename format_as_result::type; - -template -struct has_format_as - : bool_constant, void>::value> {}; - -// Maps formatting arguments to core types. -// arg_mapper reports errors by returning unformattable instead of using -// static_assert because it's used in the is_formattable trait. -template -struct arg_mapper { - using char_type = typename Context::char_type; - - FMT_CONSTEXPR FMT_INLINE auto map(signed char val) -> int { return val; } - FMT_CONSTEXPR FMT_INLINE auto map(unsigned char val) -> unsigned { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(short val) -> int { return val; } - FMT_CONSTEXPR FMT_INLINE auto map(unsigned short val) -> unsigned { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(int val) -> int { return val; } - FMT_CONSTEXPR FMT_INLINE auto map(unsigned val) -> unsigned { return val; } - FMT_CONSTEXPR FMT_INLINE auto map(long val) -> long_type { return val; } - FMT_CONSTEXPR FMT_INLINE auto map(unsigned long val) -> ulong_type { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(long long val) -> long long { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(unsigned long long val) - -> unsigned long long { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(int128_opt val) -> int128_opt { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(uint128_opt val) -> uint128_opt { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(bool val) -> bool { return val; } - - template ::value - || std::is_same::value)> - FMT_CONSTEXPR FMT_INLINE auto map(T val) -> char_type { - return val; - } - template ::value || -#ifdef __cpp_char8_t - std::is_same::value || -#endif - std::is_same::value - || std::is_same::value) - && !std::is_same::value, - int> = 0> - FMT_CONSTEXPR FMT_INLINE auto map(T) -> unformattable_char { - return {}; - } - - FMT_CONSTEXPR FMT_INLINE auto map(float val) -> float { return val; } - FMT_CONSTEXPR FMT_INLINE auto map(double val) -> double { return val; } - FMT_CONSTEXPR FMT_INLINE auto map(long double val) -> long double { - return val; - } - - FMT_CONSTEXPR FMT_INLINE auto map(char_type *val) -> const char_type * { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(const char_type *val) - -> const char_type * { - return val; - } - template ::value && !std::is_pointer::value - && std::is_same>::value)> - FMT_CONSTEXPR FMT_INLINE auto map(const T &val) - -> basic_string_view { - return to_string_view(val); - } - template ::value && !std::is_pointer::value - && !std::is_same>::value)> - FMT_CONSTEXPR FMT_INLINE auto map(const T &) -> unformattable_char { - return {}; - } - - FMT_CONSTEXPR FMT_INLINE auto map(void *val) -> const void * { return val; } - FMT_CONSTEXPR FMT_INLINE auto map(const void *val) -> const void * { - return val; - } - FMT_CONSTEXPR FMT_INLINE auto map(std::nullptr_t val) -> const void * { - return val; - } - - // Use SFINAE instead of a const T* parameter to avoid a conflict with the - // array overload. - template ::value - || std::is_member_pointer::value - || std::is_function< - typename std::remove_pointer::type>::value - || (std::is_array::value - && !std::is_convertible::value))> - FMT_CONSTEXPR auto map(const T &) -> unformattable_pointer { - return {}; - } - - template ::value)> - FMT_CONSTEXPR FMT_INLINE auto map(const T (&values)[N]) -> const T (&)[N] { - return values; - } - - // Only map owning types because mapping views can be unsafe. - template , - FMT_ENABLE_IF(std::is_arithmetic::value)> - FMT_CONSTEXPR FMT_INLINE auto map(const T &val) - -> decltype(FMT_DECLTYPE_THIS map(U())) { - return map(format_as(val)); - } - - template > - struct formattable : bool_constant() - || (has_formatter::value - && !std::is_const::value)> {}; - - template ::value)> - FMT_CONSTEXPR FMT_INLINE auto do_map(T &val) -> T & { - return val; - } - template ::value)> - FMT_CONSTEXPR FMT_INLINE auto do_map(T &) -> unformattable { - return {}; - } - - template , - FMT_ENABLE_IF((std::is_class::value || std::is_enum::value - || std::is_union::value) - && !is_string::value && !is_char::value - && !is_named_arg::value - && !std::is_arithmetic>::value)> - FMT_CONSTEXPR FMT_INLINE auto map(T &val) - -> decltype(FMT_DECLTYPE_THIS do_map(val)) { - return do_map(val); - } - - template ::value)> - FMT_CONSTEXPR FMT_INLINE auto map(const T &named_arg) - -> decltype(FMT_DECLTYPE_THIS map(named_arg.value)) { - return map(named_arg.value); - } - - auto map(...) -> unformattable { return {}; } -}; - -// A type constant after applying arg_mapper. -template -using mapped_type_constant = type_constant().map( - std::declval())), - typename Context::char_type>; - -enum { packed_arg_bits = 4 }; -// Maximum number of arguments with packed types. -enum { max_packed_args = 62 / packed_arg_bits }; -enum : unsigned long long { is_unpacked_bit = 1ULL << 63 }; -enum : unsigned long long { has_named_args_bit = 1ULL << 62 }; - -template -auto copy_str(InputIt begin, InputIt end, appender out) -> appender { - get_container(out).append(begin, end); - return out; -} -template -auto copy_str( - InputIt begin, InputIt end, std::back_insert_iterator out) - -> std::back_insert_iterator { - get_container(out).append(begin, end); - return out; -} - -template -FMT_CONSTEXPR auto copy_str(R &&rng, OutputIt out) -> OutputIt { - return detail::copy_str(rng.begin(), rng.end(), out); -} - -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 500 -// A workaround for gcc 4.8 to make void_t work in a SFINAE context. -template -struct void_t_impl { - using type = void; -}; -template -using void_t = typename void_t_impl::type; -#else -template -using void_t = void; -#endif - -template -struct is_output_iterator : std::false_type {}; - -template -struct is_output_iterator::iterator_category, - decltype(*std::declval() = std::declval())>> - : std::true_type {}; - -template -struct is_back_insert_iterator : std::false_type {}; -template -struct is_back_insert_iterator> - : std::true_type {}; - -// A type-erased reference to an std::locale to avoid a heavy include. -class locale_ref { -private: - const void *locale_; // A type-erased pointer to std::locale. - -public: - constexpr FMT_INLINE locale_ref() : locale_(nullptr) {} - template - explicit locale_ref(const Locale &loc); - - explicit operator bool() const noexcept { return locale_ != nullptr; } - - template - auto get() const -> Locale; -}; - -template -constexpr auto encode_types() -> unsigned long long { - return 0; -} - -template -constexpr auto encode_types() -> unsigned long long { - return static_cast(mapped_type_constant::value) - | (encode_types() << packed_arg_bits); -} - -#if defined(__cpp_if_constexpr) -// This type is intentionally undefined, only used for errors -template -struct type_is_unformattable_for; -#endif - -template -FMT_CONSTEXPR FMT_INLINE auto make_arg(T &val) -> value { - using arg_type = remove_cvref_t().map(val))>; - - constexpr bool formattable_char - = !std::is_same::value; - static_assert(formattable_char, "Mixing character types is disallowed."); - - // Formatting of arbitrary pointers is disallowed. If you want to format a - // pointer cast it to `void*` or `const void*`. In particular, this forbids - // formatting of `[const] volatile char*` printed as bool by iostreams. - constexpr bool formattable_pointer - = !std::is_same::value; - static_assert(formattable_pointer, - "Formatting of non-void pointers is disallowed."); - - constexpr bool formattable = !std::is_same::value; -#if defined(__cpp_if_constexpr) - if constexpr (!formattable) { - type_is_unformattable_for _; - } -#endif - static_assert(formattable, - "Cannot format an argument. To make type T formattable provide a " - "formatter specialization: https://fmt.dev/latest/api.html#udt"); - return {arg_mapper().map(val)}; -} - -template -FMT_CONSTEXPR auto make_arg(T &val) -> basic_format_arg { - auto arg = basic_format_arg(); - arg.type_ = mapped_type_constant::value; - arg.value_ = make_arg(val); - return arg; -} - -template -FMT_CONSTEXPR inline auto make_arg(T &val) -> basic_format_arg { - return make_arg(val); -} -} // namespace detail -FMT_BEGIN_EXPORT - -// A formatting argument. Context is a template parameter for the compiled API -// where output can be unbuffered. -template -class basic_format_arg { -private: - detail::value value_; - detail::type type_; - - template - friend FMT_CONSTEXPR auto detail::make_arg(T &value) - -> basic_format_arg; - - template - friend FMT_CONSTEXPR auto visit_format_arg(Visitor &&vis, - const basic_format_arg &arg) -> decltype(vis(0)); - - friend class basic_format_args; - friend class dynamic_format_arg_store; - - using char_type = typename Context::char_type; - - template - friend struct detail::arg_data; - - basic_format_arg(const detail::named_arg_info *args, size_t size) - : value_(args, size) {} - -public: - class handle { - public: - explicit handle(detail::custom_value custom) - : custom_(custom) {} - - void format(typename Context::parse_context_type &parse_ctx, - Context &ctx) const { - custom_.format(custom_.value, parse_ctx, ctx); - } - - private: - detail::custom_value custom_; - }; - - constexpr basic_format_arg() : type_(detail::type::none_type) {} - - constexpr explicit operator bool() const noexcept { - return type_ != detail::type::none_type; - } - - auto type() const -> detail::type { return type_; } - - auto is_integral() const -> bool { return detail::is_integral_type(type_); } - auto is_arithmetic() const -> bool { - return detail::is_arithmetic_type(type_); - } - - FMT_INLINE auto format_custom(const char_type *parse_begin, - typename Context::parse_context_type &parse_ctx, Context &ctx) - -> bool { - if (type_ != detail::type::custom_type) return false; - parse_ctx.advance_to(parse_begin); - value_.custom.format(value_.custom.value, parse_ctx, ctx); - return true; - } -}; - -/** - \rst - Visits an argument dispatching to the appropriate visit method based on - the argument type. For example, if the argument type is ``double`` then - ``vis(value)`` will be called with the value of type ``double``. - \endrst - */ -// DEPRECATED! -template -FMT_CONSTEXPR FMT_INLINE auto visit_format_arg(Visitor &&vis, - const basic_format_arg &arg) -> decltype(vis(0)) { - switch (arg.type_) { - case detail::type::none_type: break; - case detail::type::int_type: return vis(arg.value_.int_value); - case detail::type::uint_type: return vis(arg.value_.uint_value); - case detail::type::long_long_type: - return vis(arg.value_.long_long_value); - case detail::type::ulong_long_type: - return vis(arg.value_.ulong_long_value); - case detail::type::int128_type: - return vis(detail::convert_for_visit(arg.value_.int128_value)); - case detail::type::uint128_type: - return vis(detail::convert_for_visit(arg.value_.uint128_value)); - case detail::type::bool_type: return vis(arg.value_.bool_value); - case detail::type::char_type: return vis(arg.value_.char_value); - case detail::type::float_type: return vis(arg.value_.float_value); - case detail::type::double_type: return vis(arg.value_.double_value); - case detail::type::long_double_type: - return vis(arg.value_.long_double_value); - case detail::type::cstring_type: return vis(arg.value_.string.data); - case detail::type::string_type: - using sv = basic_string_view; - return vis(sv(arg.value_.string.data, arg.value_.string.size)); - case detail::type::pointer_type: return vis(arg.value_.pointer); - case detail::type::custom_type: - return vis(typename basic_format_arg::handle( - arg.value_.custom)); - } - return vis(monostate()); -} - -// Formatting context. -template -class basic_format_context { -private: - OutputIt out_; - basic_format_args args_; - detail::locale_ref loc_; - -public: - using iterator = OutputIt; - using format_arg = basic_format_arg; - using format_args = basic_format_args; - using parse_context_type = basic_format_parse_context; - template - using formatter_type = formatter; - - /** The character type for the output. */ - using char_type = Char; - - basic_format_context(basic_format_context &&) = default; - basic_format_context(const basic_format_context &) = delete; - void operator=(const basic_format_context &) = delete; - /** - Constructs a ``basic_format_context`` object. References to the arguments - are stored in the object so make sure they have appropriate lifetimes. - */ - constexpr basic_format_context( - OutputIt out, format_args ctx_args, detail::locale_ref loc = {}) - : out_(out), args_(ctx_args), loc_(loc) {} - - constexpr auto arg(int id) const -> format_arg { return args_.get(id); } - FMT_CONSTEXPR auto arg(basic_string_view name) -> format_arg { - return args_.get(name); - } - FMT_CONSTEXPR auto arg_id(basic_string_view name) -> int { - return args_.get_id(name); - } - auto args() const -> const format_args & { return args_; } - - // DEPRECATED! - FMT_CONSTEXPR auto error_handler() -> detail::error_handler { return {}; } - void on_error(const char *message) { error_handler().on_error(message); } - - // Returns an iterator to the beginning of the output range. - FMT_CONSTEXPR auto out() -> iterator { return out_; } - - // Advances the begin iterator to ``it``. - void advance_to(iterator it) { - if (!detail::is_back_insert_iterator()) out_ = it; - } - - FMT_CONSTEXPR auto locale() -> detail::locale_ref { return loc_; } -}; - -template -using buffer_context - = basic_format_context, Char>; -using format_context = buffer_context; - -template -using is_formattable = bool_constant>().map( - std::declval()))>::value>; - -/** - \rst - An array of references to arguments. It can be implicitly converted into - `~fmt::basic_format_args` for passing into type-erased formatting functions - such as `~fmt::vformat`. - \endrst - */ -template -class format_arg_store -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 - // Workaround a GCC template argument substitution bug. - : public basic_format_args -#endif -{ -private: - static const size_t num_args = sizeof...(Args); - static constexpr size_t num_named_args - = detail::count_named_args(); - static const bool is_packed = num_args <= detail::max_packed_args; - - using value_type = conditional_t, - basic_format_arg>; - - detail::arg_data - data_; - - friend class basic_format_args; - - static constexpr unsigned long long desc - = (is_packed ? detail::encode_types() - : detail::is_unpacked_bit | num_args) - | (num_named_args != 0 ? static_cast( - detail::has_named_args_bit) - : 0); - -public: - template - FMT_CONSTEXPR FMT_INLINE format_arg_store(T &...args) - : -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 - basic_format_args(*this) - , -#endif - data_ {detail::make_arg(args)...} { - if (detail::const_check(num_named_args != 0)) - detail::init_named_args(data_.named_args(), 0, 0, args...); - } -}; - -/** - \rst - Constructs a `~fmt::format_arg_store` object that contains references to - arguments and can be implicitly converted to `~fmt::format_args`. `Context` - can be omitted in which case it defaults to `~fmt::format_context`. - See `~fmt::arg` for lifetime considerations. - \endrst - */ -// Arguments are taken by lvalue references to avoid some lifetime issues. -template -constexpr auto make_format_args(T &...args) - -> format_arg_store...> { - return {args...}; -} - -/** - \rst - Returns a named argument to be used in a formatting function. - It should only be used in a call to a formatting function or - `dynamic_format_arg_store::push_back`. - - **Example**:: - - fmt::print("Elapsed time: {s:.2f} seconds", fmt::arg("s", 1.23)); - \endrst - */ -template -inline auto arg(const Char *name, const T &arg) -> detail::named_arg { - static_assert(!detail::is_named_arg(), "nested named arguments"); - return {name, arg}; -} -FMT_END_EXPORT - -/** - \rst - A view of a collection of formatting arguments. To avoid lifetime issues it - should only be used as a parameter type in type-erased functions such as - ``vformat``:: - - void vlog(string_view format_str, format_args args); // OK - format_args args = make_format_args(); // Error: dangling reference - \endrst - */ -template -class basic_format_args { -public: - using size_type = int; - using format_arg = basic_format_arg; - -private: - // A descriptor that contains information about formatting arguments. - // If the number of arguments is less or equal to max_packed_args then - // argument types are passed in the descriptor. This reduces binary code size - // per formatting function call. - unsigned long long desc_; - union { - // If is_packed() returns true then argument values are stored in values_; - // otherwise they are stored in args_. This is done to improve cache - // locality and reduce compiled code size since storing larger objects - // may require more code (at least on x86-64) even if the same amount of - // data is actually copied to stack. It saves ~10% on the bloat test. - const detail::value *values_; - const format_arg *args_; - }; - - constexpr auto is_packed() const -> bool { - return (desc_ & detail::is_unpacked_bit) == 0; - } - auto has_named_args() const -> bool { - return (desc_ & detail::has_named_args_bit) != 0; - } - - FMT_CONSTEXPR auto type(int index) const -> detail::type { - int shift = index * detail::packed_arg_bits; - unsigned int mask = (1 << detail::packed_arg_bits) - 1; - return static_cast((desc_ >> shift) & mask); - } - - constexpr FMT_INLINE basic_format_args( - unsigned long long desc, const detail::value *values) - : desc_(desc), values_(values) {} - constexpr basic_format_args(unsigned long long desc, const format_arg *args) - : desc_(desc), args_(args) {} - -public: - constexpr basic_format_args() : desc_(0), args_(nullptr) {} - - /** - \rst - Constructs a `basic_format_args` object from `~fmt::format_arg_store`. - \endrst - */ - template - constexpr FMT_INLINE basic_format_args( - const format_arg_store &store) - : basic_format_args( - format_arg_store::desc, store.data_.args()) {} - - /** - \rst - Constructs a `basic_format_args` object from - `~fmt::dynamic_format_arg_store`. - \endrst - */ - constexpr FMT_INLINE basic_format_args( - const dynamic_format_arg_store &store) - : basic_format_args(store.get_types(), store.data()) {} - - /** - \rst - Constructs a `basic_format_args` object from a dynamic set of arguments. - \endrst - */ - constexpr basic_format_args(const format_arg *args, int count) - : basic_format_args( - detail::is_unpacked_bit | detail::to_unsigned(count), args) {} - - /** Returns the argument with the specified id. */ - FMT_CONSTEXPR auto get(int id) const -> format_arg { - format_arg arg; - if (!is_packed()) { - if (id < max_size()) arg = args_[id]; - return arg; - } - if (id >= detail::max_packed_args) return arg; - arg.type_ = type(id); - if (arg.type_ == detail::type::none_type) return arg; - arg.value_ = values_[id]; - return arg; - } - - template - auto get(basic_string_view name) const -> format_arg { - int id = get_id(name); - return id >= 0 ? get(id) : format_arg(); - } - - template - auto get_id(basic_string_view name) const -> int { - if (!has_named_args()) return -1; - const auto &named_args - = (is_packed() ? values_[-1] : args_[-1].value_).named_args; - for (size_t i = 0; i < named_args.size; ++i) { - if (named_args.data[i].name == name) return named_args.data[i].id; - } - return -1; - } - - auto max_size() const -> int { - unsigned long long max_packed = detail::max_packed_args; - return static_cast( - is_packed() ? max_packed : desc_ & ~detail::is_unpacked_bit); - } -}; - -/** An alias to ``basic_format_args``. */ -// A separate type would result in shorter symbols but break ABI compatibility -// between clang and gcc on ARM (#1919). -FMT_EXPORT using format_args = basic_format_args; - -// We cannot use enum classes as bit fields because of a gcc bug, so we put them -// in namespaces instead (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61414). -// Additionally, if an underlying type is specified, older gcc incorrectly warns -// that the type is too small. Both bugs are fixed in gcc 9.3. -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 903 -#define FMT_ENUM_UNDERLYING_TYPE(type) -#else -#define FMT_ENUM_UNDERLYING_TYPE(type) : type -#endif -namespace align { -enum type FMT_ENUM_UNDERLYING_TYPE(unsigned char) { - none, left, right, center, numeric}; -} -using align_t = align::type; -namespace sign { -enum type FMT_ENUM_UNDERLYING_TYPE(unsigned char) {none, minus, plus, space}; -} -using sign_t = sign::type; - -namespace detail { - -// Workaround an array initialization issue in gcc 4.8. -template -struct fill_t { -private: - enum { max_size = 4 }; - Char data_[max_size] = {Char(' '), Char(0), Char(0), Char(0)}; - unsigned char size_ = 1; - -public: - FMT_CONSTEXPR void operator=(basic_string_view s) { - auto size = s.size(); - FMT_ASSERT(size <= max_size, "invalid fill"); - for (size_t i = 0; i < size; ++i) - data_[i] = s[i]; - size_ = static_cast(size); - } - - constexpr auto size() const -> size_t { return size_; } - constexpr auto data() const -> const Char * { return data_; } - - FMT_CONSTEXPR auto operator[](size_t index) -> Char & { - return data_[index]; - } - FMT_CONSTEXPR auto operator[](size_t index) const -> const Char & { - return data_[index]; - } -}; -} // namespace detail - -enum class presentation_type : unsigned char { - none, - dec, // 'd' - oct, // 'o' - hex_lower, // 'x' - hex_upper, // 'X' - bin_lower, // 'b' - bin_upper, // 'B' - hexfloat_lower, // 'a' - hexfloat_upper, // 'A' - exp_lower, // 'e' - exp_upper, // 'E' - fixed_lower, // 'f' - fixed_upper, // 'F' - general_lower, // 'g' - general_upper, // 'G' - chr, // 'c' - string, // 's' - pointer, // 'p' - debug // '?' -}; - -// Format specifiers for built-in and string types. -template -struct format_specs { - int width; - int precision; - presentation_type type; - align_t align : 4; - sign_t sign : 3; - bool alt : 1; // Alternate form ('#'). - bool localized : 1; - detail::fill_t fill; - - constexpr format_specs() - : width(0) - , precision(-1) - , type(presentation_type::none) - , align(align::none) - , sign(sign::none) - , alt(false) - , localized(false) {} -}; - -namespace detail { - -enum class arg_id_kind { none, index, name }; - -// An argument reference. -template -struct arg_ref { - FMT_CONSTEXPR arg_ref() : kind(arg_id_kind::none), val() {} - - FMT_CONSTEXPR explicit arg_ref(int index) - : kind(arg_id_kind::index), val(index) {} - FMT_CONSTEXPR explicit arg_ref(basic_string_view name) - : kind(arg_id_kind::name), val(name) {} - - FMT_CONSTEXPR auto operator=(int idx) -> arg_ref & { - kind = arg_id_kind::index; - val.index = idx; - return *this; - } - - arg_id_kind kind; - union value { - FMT_CONSTEXPR value(int idx = 0) : index(idx) {} - FMT_CONSTEXPR value(basic_string_view n) : name(n) {} - - int index; - basic_string_view name; - } val; -}; - -// Format specifiers with width and precision resolved at formatting rather -// than parsing time to allow reusing the same parsed specifiers with -// different sets of arguments (precompilation of format strings). -template -struct dynamic_format_specs : format_specs { - arg_ref width_ref; - arg_ref precision_ref; -}; - -// Converts a character to ASCII. Returns '\0' on conversion failure. -template ::value)> -constexpr auto to_ascii(Char c) -> char { - return c <= 0xff ? static_cast(c) : '\0'; -} -template ::value)> -constexpr auto to_ascii(Char c) -> char { - return c <= 0xff ? static_cast(c) : '\0'; -} - -// Returns the number of code units in a code point or 1 on error. -template -FMT_CONSTEXPR auto code_point_length(const Char *begin) -> int { - if (const_check(sizeof(Char) != 1)) return 1; - auto c = static_cast(*begin); - return static_cast((0x3a55000000000000ull >> (2 * (c >> 3))) & 0x3) - + 1; -} - -// Return the result via the out param to workaround gcc bug 77539. -template -FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr &out) -> bool { - for (out = first; out != last; ++out) { - if (*out == value) return true; - } - return false; -} - -template <> -inline auto find(const char *first, const char *last, char value, - const char *&out) -> bool { - out = static_cast( - std::memchr(first, value, to_unsigned(last - first))); - return out != nullptr; -} - -// Parses the range [begin, end) as an unsigned integer. This function assumes -// that the range is non-empty and the first character is a digit. -template -FMT_CONSTEXPR auto parse_nonnegative_int( - const Char *&begin, const Char *end, int error_value) noexcept -> int { - FMT_ASSERT(begin != end && '0' <= *begin && *begin <= '9', ""); - unsigned value = 0, prev = 0; - auto p = begin; - do { - prev = value; - value = value * 10 + unsigned(*p - '0'); - ++p; - } while (p != end && '0' <= *p && *p <= '9'); - auto num_digits = p - begin; - begin = p; - if (num_digits <= std::numeric_limits::digits10) - return static_cast(value); - // Check for overflow. - const unsigned max = to_unsigned((std::numeric_limits::max)()); - return num_digits == std::numeric_limits::digits10 + 1 - && prev * 10ull + unsigned(p[-1] - '0') <= max - ? static_cast(value) - : error_value; -} - -FMT_CONSTEXPR inline auto parse_align(char c) -> align_t { - switch (c) { - case '<': return align::left; - case '>': return align::right; - case '^': return align::center; - } - return align::none; -} - -template -constexpr auto is_name_start(Char c) -> bool { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '_'; -} - -template -FMT_CONSTEXPR auto do_parse_arg_id( - const Char *begin, const Char *end, Handler &&handler) -> const Char * { - Char c = *begin; - if (c >= '0' && c <= '9') { - int index = 0; - constexpr int max = (std::numeric_limits::max)(); - if (c != '0') - index = parse_nonnegative_int(begin, end, max); - else - ++begin; - if (begin == end || (*begin != '}' && *begin != ':')) - throw_format_error("invalid format string"); - else - handler.on_index(index); - return begin; - } - if (!is_name_start(c)) { - throw_format_error("invalid format string"); - return begin; - } - auto it = begin; - do { - ++it; - } while (it != end && (is_name_start(*it) || ('0' <= *it && *it <= '9'))); - handler.on_name({begin, to_unsigned(it - begin)}); - return it; -} - -template -FMT_CONSTEXPR FMT_INLINE auto parse_arg_id( - const Char *begin, const Char *end, Handler &&handler) -> const Char * { - FMT_ASSERT(begin != end, ""); - Char c = *begin; - if (c != '}' && c != ':') return do_parse_arg_id(begin, end, handler); - handler.on_auto(); - return begin; -} - -template -struct dynamic_spec_id_handler { - basic_format_parse_context &ctx; - arg_ref &ref; - - FMT_CONSTEXPR void on_auto() { - int id = ctx.next_arg_id(); - ref = arg_ref(id); - ctx.check_dynamic_spec(id); - } - FMT_CONSTEXPR void on_index(int id) { - ref = arg_ref(id); - ctx.check_arg_id(id); - ctx.check_dynamic_spec(id); - } - FMT_CONSTEXPR void on_name(basic_string_view id) { - ref = arg_ref(id); - ctx.check_arg_id(id); - } -}; - -// Parses [integer | "{" [arg_id] "}"]. -template -FMT_CONSTEXPR auto parse_dynamic_spec(const Char *begin, const Char *end, - int &value, arg_ref &ref, basic_format_parse_context &ctx) - -> const Char * { - FMT_ASSERT(begin != end, ""); - if ('0' <= *begin && *begin <= '9') { - int val = parse_nonnegative_int(begin, end, -1); - if (val != -1) - value = val; - else - throw_format_error("number is too big"); - } else if (*begin == '{') { - ++begin; - auto handler = dynamic_spec_id_handler {ctx, ref}; - if (begin != end) begin = parse_arg_id(begin, end, handler); - if (begin != end && *begin == '}') return ++begin; - throw_format_error("invalid format string"); - } - return begin; -} - -template -FMT_CONSTEXPR auto parse_precision(const Char *begin, const Char *end, - int &value, arg_ref &ref, basic_format_parse_context &ctx) - -> const Char * { - ++begin; - if (begin == end || *begin == '}') { - throw_format_error("invalid precision"); - return begin; - } - return parse_dynamic_spec(begin, end, value, ref, ctx); -} - -enum class state { start, align, sign, hash, zero, width, precision, locale }; - -// Parses standard format specifiers. -template -FMT_CONSTEXPR FMT_INLINE auto parse_format_specs(const Char *begin, - const Char *end, dynamic_format_specs &specs, - basic_format_parse_context &ctx, type arg_type) -> const Char * { - auto c = '\0'; - if (end - begin > 1) { - auto next = to_ascii(begin[1]); - c = parse_align(next) == align::none ? to_ascii(*begin) : '\0'; - } else { - if (begin == end) return begin; - c = to_ascii(*begin); - } - - struct { - state current_state = state::start; - FMT_CONSTEXPR void operator()(state s, bool valid = true) { - if (current_state >= s || !valid) - throw_format_error("invalid format specifier"); - current_state = s; - } - } enter_state; - - using pres = presentation_type; - constexpr auto integral_set = sint_set | uint_set | bool_set | char_set; - struct { - const Char *&begin; - dynamic_format_specs &specs; - type arg_type; - - FMT_CONSTEXPR auto operator()(pres pres_type, int set) -> const Char * { - if (!in(arg_type, set)) { - if (arg_type == type::none_type) return begin; - throw_format_error("invalid format specifier"); - } - specs.type = pres_type; - return begin + 1; - } - } parse_presentation_type {begin, specs, arg_type}; - - for (;;) { - switch (c) { - case '<': - case '>': - case '^': - enter_state(state::align); - specs.align = parse_align(c); - ++begin; - break; - case '+': - case '-': - case ' ': - if (arg_type == type::none_type) return begin; - enter_state(state::sign, in(arg_type, sint_set | float_set)); - switch (c) { - case '+': specs.sign = sign::plus; break; - case '-': specs.sign = sign::minus; break; - case ' ': specs.sign = sign::space; break; - } - ++begin; - break; - case '#': - if (arg_type == type::none_type) return begin; - enter_state(state::hash, is_arithmetic_type(arg_type)); - specs.alt = true; - ++begin; - break; - case '0': - enter_state(state::zero); - if (!is_arithmetic_type(arg_type)) { - if (arg_type == type::none_type) return begin; - throw_format_error( - "format specifier requires numeric argument"); - } - if (specs.align == align::none) { - // Ignore 0 if align is specified for compatibility with std::format. - specs.align = align::numeric; - specs.fill[0] = Char('0'); - } - ++begin; - break; - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - case '{': - enter_state(state::width); - begin = parse_dynamic_spec( - begin, end, specs.width, specs.width_ref, ctx); - break; - case '.': - if (arg_type == type::none_type) return begin; - enter_state(state::precision, - in(arg_type, float_set | string_set | cstring_set)); - begin = parse_precision( - begin, end, specs.precision, specs.precision_ref, ctx); - break; - case 'L': - if (arg_type == type::none_type) return begin; - enter_state(state::locale, is_arithmetic_type(arg_type)); - specs.localized = true; - ++begin; - break; - case 'd': return parse_presentation_type(pres::dec, integral_set); - case 'o': return parse_presentation_type(pres::oct, integral_set); - case 'x': - return parse_presentation_type(pres::hex_lower, integral_set); - case 'X': - return parse_presentation_type(pres::hex_upper, integral_set); - case 'b': - return parse_presentation_type(pres::bin_lower, integral_set); - case 'B': - return parse_presentation_type(pres::bin_upper, integral_set); - case 'a': - return parse_presentation_type(pres::hexfloat_lower, float_set); - case 'A': - return parse_presentation_type(pres::hexfloat_upper, float_set); - case 'e': - return parse_presentation_type(pres::exp_lower, float_set); - case 'E': - return parse_presentation_type(pres::exp_upper, float_set); - case 'f': - return parse_presentation_type(pres::fixed_lower, float_set); - case 'F': - return parse_presentation_type(pres::fixed_upper, float_set); - case 'g': - return parse_presentation_type(pres::general_lower, float_set); - case 'G': - return parse_presentation_type(pres::general_upper, float_set); - case 'c': - if (arg_type == type::bool_type) - throw_format_error("invalid format specifier"); - return parse_presentation_type(pres::chr, integral_set); - case 's': - return parse_presentation_type( - pres::string, bool_set | string_set | cstring_set); - case 'p': - return parse_presentation_type( - pres::pointer, pointer_set | cstring_set); - case '?': - return parse_presentation_type( - pres::debug, char_set | string_set | cstring_set); - case '}': return begin; - default: { - if (*begin == '}') return begin; - // Parse fill and alignment. - auto fill_end = begin + code_point_length(begin); - if (end - fill_end <= 0) { - throw_format_error("invalid format specifier"); - return begin; - } - if (*begin == '{') { - throw_format_error("invalid fill character '{'"); - return begin; - } - auto align = parse_align(to_ascii(*fill_end)); - enter_state(state::align, align != align::none); - specs.fill = {begin, to_unsigned(fill_end - begin)}; - specs.align = align; - begin = fill_end + 1; - } - } - if (begin == end) return begin; - c = to_ascii(*begin); - } -} - -template -FMT_CONSTEXPR auto parse_replacement_field( - const Char *begin, const Char *end, Handler &&handler) -> const Char * { - struct id_adapter { - Handler &handler; - int arg_id; - - FMT_CONSTEXPR void on_auto() { arg_id = handler.on_arg_id(); } - FMT_CONSTEXPR void on_index(int id) { arg_id = handler.on_arg_id(id); } - FMT_CONSTEXPR void on_name(basic_string_view id) { - arg_id = handler.on_arg_id(id); - } - }; - - ++begin; - if (begin == end) return handler.on_error("invalid format string"), end; - if (*begin == '}') { - handler.on_replacement_field(handler.on_arg_id(), begin); - } else if (*begin == '{') { - handler.on_text(begin, begin + 1); - } else { - auto adapter = id_adapter {handler, 0}; - begin = parse_arg_id(begin, end, adapter); - Char c = begin != end ? *begin : Char(); - if (c == '}') { - handler.on_replacement_field(adapter.arg_id, begin); - } else if (c == ':') { - begin = handler.on_format_specs(adapter.arg_id, begin + 1, end); - if (begin == end || *begin != '}') - return handler.on_error("unknown format specifier"), end; - } else { - return handler.on_error("missing '}' in format string"), end; - } - } - return begin + 1; -} - -template -FMT_CONSTEXPR FMT_INLINE void parse_format_string( - basic_string_view format_str, Handler &&handler) { - auto begin = format_str.data(); - auto end = begin + format_str.size(); - if (end - begin < 32) { - // Use a simple loop instead of memchr for small strings. - const Char *p = begin; - while (p != end) { - auto c = *p++; - if (c == '{') { - handler.on_text(begin, p - 1); - begin = p = parse_replacement_field(p - 1, end, handler); - } else if (c == '}') { - if (p == end || *p != '}') - return handler.on_error("unmatched '}' in format string"); - handler.on_text(begin, p); - begin = ++p; - } - } - handler.on_text(begin, end); - return; - } - struct writer { - FMT_CONSTEXPR void operator()(const Char *from, const Char *to) { - if (from == to) return; - for (;;) { - const Char *p = nullptr; - if (!find(from, to, Char('}'), p)) - return handler_.on_text(from, to); - ++p; - if (p == to || *p != '}') - return handler_.on_error("unmatched '}' in format string"); - handler_.on_text(from, p); - from = p + 1; - } - } - Handler &handler_; - } write = {handler}; - while (begin != end) { - // Doing two passes with memchr (one for '{' and another for '}') is up to - // 2.5x faster than the naive one-pass implementation on big format strings. - const Char *p = begin; - if (*begin != '{' && !find(begin + 1, end, Char('{'), p)) - return write(begin, end); - write(begin, p); - begin = parse_replacement_field(p, end, handler); - } -} - -template ::value> -struct strip_named_arg { - using type = T; -}; -template -struct strip_named_arg { - using type = remove_cvref_t; -}; - -template -FMT_CONSTEXPR auto parse_format_specs(ParseContext &ctx) - -> decltype(ctx.begin()) { - using char_type = typename ParseContext::char_type; - using context = buffer_context; - using mapped_type = conditional_t::value - != type::custom_type, - decltype(arg_mapper().map(std::declval())), - typename strip_named_arg::type>; -#if defined(__cpp_if_constexpr) - if constexpr (std::is_default_constructible< - formatter>::value) { - return formatter().parse(ctx); - } else { - type_is_unformattable_for _; - return ctx.begin(); - } -#else - return formatter().parse(ctx); -#endif -} - -// Checks char specs and returns true iff the presentation type is char-like. -template -FMT_CONSTEXPR auto check_char_specs(const format_specs &specs) -> bool { - if (specs.type != presentation_type::none - && specs.type != presentation_type::chr - && specs.type != presentation_type::debug) { - return false; - } - if (specs.align == align::numeric || specs.sign != sign::none || specs.alt) - throw_format_error("invalid format specifier for char"); - return true; -} - -#if FMT_USE_NONTYPE_TEMPLATE_ARGS -template -constexpr auto get_arg_index_by_name(basic_string_view name) -> int { - if constexpr (is_statically_named_arg()) { - if (name == T::name) return N; - } - if constexpr (sizeof...(Args) > 0) - return get_arg_index_by_name(name); - (void)name; // Workaround an MSVC bug about "unused" parameter. - return -1; -} -#endif - -template -FMT_CONSTEXPR auto get_arg_index_by_name(basic_string_view name) -> int { -#if FMT_USE_NONTYPE_TEMPLATE_ARGS - if constexpr (sizeof...(Args) > 0) - return get_arg_index_by_name<0, Args...>(name); -#endif - (void)name; - return -1; -} - -template -class format_string_checker { -private: - using parse_context_type = compile_parse_context; - static constexpr int num_args = sizeof...(Args); - - // Format specifier parsing function. - // In the future basic_format_parse_context will replace compile_parse_context - // here and will use is_constant_evaluated and downcasting to access the data - // needed for compile-time checks: https://godbolt.org/z/GvWzcTjh1. - using parse_func = const Char *(*)(parse_context_type &); - - type types_[num_args > 0 ? static_cast(num_args) : 1]; - parse_context_type context_; - parse_func parse_funcs_[num_args > 0 ? static_cast(num_args) : 1]; - -public: - explicit FMT_CONSTEXPR format_string_checker(basic_string_view fmt) - : types_ {mapped_type_constant>::value...} - , context_(fmt, num_args, types_) - , parse_funcs_ {&parse_format_specs...} {} - - FMT_CONSTEXPR void on_text(const Char *, const Char *) {} - - FMT_CONSTEXPR auto on_arg_id() -> int { return context_.next_arg_id(); } - FMT_CONSTEXPR auto on_arg_id(int id) -> int { - return context_.check_arg_id(id), id; - } - FMT_CONSTEXPR auto on_arg_id(basic_string_view id) -> int { -#if FMT_USE_NONTYPE_TEMPLATE_ARGS - auto index = get_arg_index_by_name(id); - if (index < 0) on_error("named argument is not found"); - return index; -#else - (void)id; - on_error( - "compile-time checks for named arguments require C++20 " - "support"); - return 0; -#endif - } - - FMT_CONSTEXPR void on_replacement_field(int id, const Char *begin) { - on_format_specs(id, begin, begin); // Call parse() on empty specs. - } - - FMT_CONSTEXPR auto on_format_specs(int id, const Char *begin, const Char *) - -> const Char * { - context_.advance_to(begin); - // id >= 0 check is a workaround for gcc 10 bug (#2065). - return id >= 0 && id < num_args ? parse_funcs_[id](context_) : begin; - } - - FMT_CONSTEXPR void on_error(const char *message) { - throw_format_error(message); - } -}; - -// Reports a compile-time error if S is not a valid format string. -template ::value)> -FMT_INLINE void check_format_string(const S &) { -#ifdef FMT_ENFORCE_COMPILE_STRING - static_assert(is_compile_string::value, - "FMT_ENFORCE_COMPILE_STRING requires all format strings to use " - "FMT_STRING."); -#endif -} -template ::value)> -void check_format_string(S format_str) { - using char_t = typename S::char_type; - FMT_CONSTEXPR auto s = basic_string_view(format_str); - using checker = format_string_checker...>; - FMT_CONSTEXPR bool error = (parse_format_string(s, checker(s)), true); - ignore_unused(error); -} - -template -struct vformat_args { - using type = basic_format_args>, Char>>; -}; -template <> -struct vformat_args { - using type = format_args; -}; - -// Use vformat_args and avoid type_identity to keep symbols short. -template -void vformat_to(buffer &buf, basic_string_view fmt, - typename vformat_args::type args, locale_ref loc = {}); - -FMT_API void vprint_mojibake(std::FILE *, string_view, format_args); -#ifndef _WIN32 -inline void vprint_mojibake(std::FILE *, string_view, format_args) {} -#endif -} // namespace detail - -FMT_BEGIN_EXPORT - -// A formatter specialization for natively supported types. -template -struct formatter::value - != detail::type::custom_type>> { -private: - detail::dynamic_format_specs specs_; - -public: - template - FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const Char * { - auto type = detail::type_constant::value; - auto end = detail::parse_format_specs( - ctx.begin(), ctx.end(), specs_, ctx, type); - if (type == detail::type::char_type) detail::check_char_specs(specs_); - return end; - } - - template ::value, - FMT_ENABLE_IF(U == detail::type::string_type - || U == detail::type::cstring_type - || U == detail::type::char_type)> - FMT_CONSTEXPR void set_debug_format(bool set = true) { - specs_.type = set ? presentation_type::debug : presentation_type::none; - } - - template - FMT_CONSTEXPR auto format(const T &val, FormatContext &ctx) const - -> decltype(ctx.out()); -}; - -template -struct runtime_format_string { - basic_string_view str; -}; - -/** A compile-time format string. */ -template -class basic_format_string { -private: - basic_string_view str_; - -public: - template >::value)> - FMT_CONSTEVAL FMT_INLINE basic_format_string(const S &s) : str_(s) { - static_assert(detail::count<(std::is_base_of>::value - && std::is_reference::value)...>() - == 0, - "passing views as lvalues is disallowed"); -#ifdef FMT_HAS_CONSTEVAL - if constexpr (detail::count_named_args() - == detail::count_statically_named_args()) { - using checker = detail::format_string_checker...>; - detail::parse_format_string(str_, checker(s)); - } -#else - detail::check_format_string(s); -#endif - } - basic_format_string(runtime_format_string fmt) : str_(fmt.str) {} - - FMT_INLINE operator basic_string_view() const { return str_; } - FMT_INLINE auto get() const -> basic_string_view { return str_; } -}; - -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 -// Workaround broken conversion on older gcc. -template -using format_string = string_view; -inline auto runtime(string_view s) -> string_view { - return s; -} -#else -template -using format_string = basic_format_string...>; -/** - \rst - Creates a runtime format string. - - **Example**:: - - // Check format string at runtime instead of compile-time. - fmt::print(fmt::runtime("{:d}"), "I am not a number"); - \endrst - */ -inline auto runtime(string_view s) -> runtime_format_string<> { - return {{s}}; -} -#endif - -FMT_API auto vformat(string_view fmt, format_args args) -> std::string; - -/** - \rst - Formats ``args`` according to specifications in ``fmt`` and returns the result - as a string. - - **Example**:: - - #include - std::string message = fmt::format("The answer is {}.", 42); - \endrst -*/ -template -FMT_NODISCARD FMT_INLINE auto format(format_string fmt, T &&...args) - -> std::string { - return vformat(fmt, fmt::make_format_args(args...)); -} - -/** Formats a string and writes the output to ``out``. */ -template ::value)> -auto vformat_to(OutputIt out, string_view fmt, format_args args) -> OutputIt { - auto &&buf = detail::get_buffer(out); - detail::vformat_to(buf, fmt, args, {}); - return detail::get_iterator(buf, out); -} - -/** - \rst - Formats ``args`` according to specifications in ``fmt``, writes the result to - the output iterator ``out`` and returns the iterator past the end of the output - range. `format_to` does not append a terminating null character. - - **Example**:: - - auto out = std::vector(); - fmt::format_to(std::back_inserter(out), "{}", 42); - \endrst - */ -template ::value)> -FMT_INLINE auto format_to(OutputIt out, format_string fmt, T &&...args) - -> OutputIt { - return vformat_to(out, fmt, fmt::make_format_args(args...)); -} - -template -struct format_to_n_result { - /** Iterator past the end of the output range. */ - OutputIt out; - /** Total (not truncated) output size. */ - size_t size; -}; - -template ::value)> -auto vformat_to_n(OutputIt out, size_t n, string_view fmt, format_args args) - -> format_to_n_result { - using traits = detail::fixed_buffer_traits; - auto buf = detail::iterator_buffer(out, n); - detail::vformat_to(buf, fmt, args, {}); - return {buf.out(), buf.count()}; -} - -/** - \rst - Formats ``args`` according to specifications in ``fmt``, writes up to ``n`` - characters of the result to the output iterator ``out`` and returns the total - (not truncated) output size and the iterator past the end of the output range. - `format_to_n` does not append a terminating null character. - \endrst - */ -template ::value)> -FMT_INLINE auto format_to_n(OutputIt out, size_t n, format_string fmt, - T &&...args) -> format_to_n_result { - return vformat_to_n(out, n, fmt, fmt::make_format_args(args...)); -} - -/** Returns the number of chars in the output of ``format(fmt, args...)``. */ -template -FMT_NODISCARD FMT_INLINE auto formatted_size( - format_string fmt, T &&...args) -> size_t { - auto buf = detail::counting_buffer<>(); - detail::vformat_to(buf, fmt, fmt::make_format_args(args...), {}); - return buf.count(); -} - -FMT_API void vprint(string_view fmt, format_args args); -FMT_API void vprint(std::FILE *f, string_view fmt, format_args args); - -/** - \rst - Formats ``args`` according to specifications in ``fmt`` and writes the output - to ``stdout``. - - **Example**:: - - fmt::print("Elapsed time: {0:.2f} seconds", 1.23); - \endrst - */ -template -FMT_INLINE void print(format_string fmt, T &&...args) { - const auto &vargs = fmt::make_format_args(args...); - return detail::is_utf8() ? vprint(fmt, vargs) - : detail::vprint_mojibake(stdout, fmt, vargs); -} - -/** - \rst - Formats ``args`` according to specifications in ``fmt`` and writes the - output to the file ``f``. - - **Example**:: - - fmt::print(stderr, "Don't {}!", "panic"); - \endrst - */ -template -FMT_INLINE void print(std::FILE *f, format_string fmt, T &&...args) { - const auto &vargs = fmt::make_format_args(args...); - return detail::is_utf8() ? vprint(f, fmt, vargs) - : detail::vprint_mojibake(f, fmt, vargs); -} - -/** - Formats ``args`` according to specifications in ``fmt`` and writes the - output to the file ``f`` followed by a newline. - */ -template -FMT_INLINE void println(std::FILE *f, format_string fmt, T &&...args) { - return fmt::print(f, "{}\n", fmt::format(fmt, std::forward(args)...)); -} - -/** - Formats ``args`` according to specifications in ``fmt`` and writes the output - to ``stdout`` followed by a newline. - */ -template -FMT_INLINE void println(format_string fmt, T &&...args) { - return fmt::println(stdout, fmt, std::forward(args)...); -} - -FMT_END_EXPORT -FMT_GCC_PRAGMA("GCC pop_options") -FMT_END_NAMESPACE - -#ifdef FMT_HEADER_ONLY -#include "common/spdlog/fmt/bundled/format.h" -#endif -#endif // FMT_CORE_H_ diff --git a/src/common/spdlog/fmt/bundled/format-inl.h b/src/common/spdlog/fmt/bundled/format-inl.h deleted file mode 100755 index bc912667ca0..00000000000 --- a/src/common/spdlog/fmt/bundled/format-inl.h +++ /dev/null @@ -1,2859 +0,0 @@ -// Formatting library for C++ - implementation -// -// Copyright (c) 2012 - 2016, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_FORMAT_INL_H_ -#define FMT_FORMAT_INL_H_ - -#include -#include // errno -#include -#include -#include - -#ifndef FMT_STATIC_THOUSANDS_SEPARATOR -#include -#endif - -#if defined(_WIN32) && !defined(FMT_WINDOWS_NO_WCHAR) -#include // _isatty -#endif - -#include "format.h" - -FMT_BEGIN_NAMESPACE -namespace detail { - -FMT_FUNC void assert_fail(const char *file, int line, const char *message) { - // Use unchecked std::fprintf to avoid triggering another assertion when - // writing to stderr fails - std::fprintf(stderr, "%s:%d: assertion failed: %s", file, line, message); - // Chosen instead of std::abort to satisfy Clang in CUDA mode during device - // code pass. - std::terminate(); -} - -FMT_FUNC void throw_format_error(const char *message) { - FMT_THROW(format_error(message)); -} - -FMT_FUNC void format_error_code(detail::buffer &out, int error_code, - string_view message) noexcept { - // Report error code making sure that the output fits into - // inline_buffer_size to avoid dynamic memory allocation and potential - // bad_alloc. - out.try_resize(0); - static const char SEP[] = ": "; - static const char ERROR_STR[] = "error "; - // Subtract 2 to account for terminating null characters in SEP and ERROR_STR. - size_t error_code_size = sizeof(SEP) + sizeof(ERROR_STR) - 2; - auto abs_value = static_cast>(error_code); - if (detail::is_negative(error_code)) { - abs_value = 0 - abs_value; - ++error_code_size; - } - error_code_size += detail::to_unsigned(detail::count_digits(abs_value)); - auto it = buffer_appender(out); - if (message.size() <= inline_buffer_size - error_code_size) - fmt::format_to(it, FMT_STRING("{}{}"), message, SEP); - fmt::format_to(it, FMT_STRING("{}{}"), ERROR_STR, error_code); - FMT_ASSERT(out.size() <= inline_buffer_size, ""); -} - -FMT_FUNC void report_error( - format_func func, int error_code, const char *message) noexcept { - memory_buffer full_message; - func(full_message, error_code, message); - // Don't use fwrite_fully because the latter may throw. - if (std::fwrite(full_message.data(), full_message.size(), 1, stderr) > 0) - std::fputc('\n', stderr); -} - -// A wrapper around fwrite that throws on error. -inline void fwrite_fully(const void *ptr, size_t count, FILE *stream) { - size_t written = std::fwrite(ptr, 1, count, stream); - if (written < count) - FMT_THROW(system_error(errno, FMT_STRING("cannot write to file"))); -} - -#ifndef FMT_STATIC_THOUSANDS_SEPARATOR -template -locale_ref::locale_ref(const Locale &loc) : locale_(&loc) { - static_assert(std::is_same::value, ""); -} - -template -auto locale_ref::get() const -> Locale { - static_assert(std::is_same::value, ""); - return locale_ ? *static_cast(locale_) : std::locale(); -} - -template -FMT_FUNC auto thousands_sep_impl(locale_ref loc) -> thousands_sep_result { - auto &facet = std::use_facet>(loc.get()); - auto grouping = facet.grouping(); - auto thousands_sep = grouping.empty() ? Char() : facet.thousands_sep(); - return {std::move(grouping), thousands_sep}; -} -template -FMT_FUNC auto decimal_point_impl(locale_ref loc) -> Char { - return std::use_facet>(loc.get()) - .decimal_point(); -} -#else -template -FMT_FUNC auto thousands_sep_impl(locale_ref) -> thousands_sep_result { - return {"\03", FMT_STATIC_THOUSANDS_SEPARATOR}; -} -template -FMT_FUNC Char decimal_point_impl(locale_ref) { - return '.'; -} -#endif - -FMT_FUNC auto write_loc(appender out, loc_value value, - const format_specs<> &specs, locale_ref loc) -> bool { -#ifndef FMT_STATIC_THOUSANDS_SEPARATOR - auto locale = loc.get(); - // We cannot use the num_put facet because it may produce output in - // a wrong encoding. - using facet = format_facet; - if (std::has_facet(locale)) - return std::use_facet(locale).put(out, value, specs); - return facet(locale).put(out, value, specs); -#endif - return false; -} -} // namespace detail - -template -typename Locale::id format_facet::id; - -#ifndef FMT_STATIC_THOUSANDS_SEPARATOR -template -format_facet::format_facet(Locale &loc) { - auto &numpunct = std::use_facet>(loc); - grouping_ = numpunct.grouping(); - if (!grouping_.empty()) - separator_ = std::string(1, numpunct.thousands_sep()); -} - -template <> -FMT_API FMT_FUNC auto format_facet::do_put(appender out, - loc_value val, const format_specs<> &specs) const -> bool { - return val.visit(detail::loc_writer<> { - out, specs, separator_, grouping_, decimal_point_}); -} -#endif - -FMT_FUNC auto vsystem_error(int error_code, string_view fmt, format_args args) - -> std::system_error { - auto ec = std::error_code(error_code, std::generic_category()); - return std::system_error(ec, vformat(fmt, args)); -} - -namespace detail { - -template -inline auto operator==(basic_fp x, basic_fp y) -> bool { - return x.f == y.f && x.e == y.e; -} - -// Compilers should be able to optimize this into the ror instruction. -FMT_CONSTEXPR inline auto rotr(uint32_t n, uint32_t r) noexcept -> uint32_t { - r &= 31; - return (n >> r) | (n << (32 - r)); -} -FMT_CONSTEXPR inline auto rotr(uint64_t n, uint32_t r) noexcept -> uint64_t { - r &= 63; - return (n >> r) | (n << (64 - r)); -} - -// Implementation of Dragonbox algorithm: https://github.com/jk-jeon/dragonbox. -namespace dragonbox { -// Computes upper 64 bits of multiplication of a 32-bit unsigned integer and a -// 64-bit unsigned integer. -inline auto umul96_upper64(uint32_t x, uint64_t y) noexcept -> uint64_t { - return umul128_upper64(static_cast(x) << 32, y); -} - -// Computes lower 128 bits of multiplication of a 64-bit unsigned integer and a -// 128-bit unsigned integer. -inline auto umul192_lower128(uint64_t x, uint128_fallback y) noexcept - -> uint128_fallback { - uint64_t high = x * y.high(); - uint128_fallback high_low = umul128(x, y.low()); - return {high + high_low.high(), high_low.low()}; -} - -// Computes lower 64 bits of multiplication of a 32-bit unsigned integer and a -// 64-bit unsigned integer. -inline auto umul96_lower64(uint32_t x, uint64_t y) noexcept -> uint64_t { - return x * y; -} - -// Various fast log computations. -inline auto floor_log10_pow2_minus_log10_4_over_3(int e) noexcept -> int { - FMT_ASSERT(e <= 2936 && e >= -2985, "too large exponent"); - return (e * 631305 - 261663) >> 21; -} - -FMT_INLINE_VARIABLE constexpr struct { - uint32_t divisor; - int shift_amount; -} div_small_pow10_infos[] = {{10, 16}, {100, 16}}; - -// Replaces n by floor(n / pow(10, N)) returning true if and only if n is -// divisible by pow(10, N). -// Precondition: n <= pow(10, N + 1). -template -auto check_divisibility_and_divide_by_pow10(uint32_t &n) noexcept -> bool { - // The numbers below are chosen such that: - // 1. floor(n/d) = floor(nm / 2^k) where d=10 or d=100, - // 2. nm mod 2^k < m if and only if n is divisible by d, - // where m is magic_number, k is shift_amount - // and d is divisor. - // - // Item 1 is a common technique of replacing division by a constant with - // multiplication, see e.g. "Division by Invariant Integers Using - // Multiplication" by Granlund and Montgomery (1994). magic_number (m) is set - // to ceil(2^k/d) for large enough k. - // The idea for item 2 originates from Schubfach. - constexpr auto info = div_small_pow10_infos[N - 1]; - FMT_ASSERT(n <= info.divisor * 10, "n is too large"); - constexpr uint32_t magic_number - = (1u << info.shift_amount) / info.divisor + 1; - n *= magic_number; - const uint32_t comparison_mask = (1u << info.shift_amount) - 1; - bool result = (n & comparison_mask) < magic_number; - n >>= info.shift_amount; - return result; -} - -// Computes floor(n / pow(10, N)) for small n and N. -// Precondition: n <= pow(10, N + 1). -template -auto small_division_by_pow10(uint32_t n) noexcept -> uint32_t { - constexpr auto info = div_small_pow10_infos[N - 1]; - FMT_ASSERT(n <= info.divisor * 10, "n is too large"); - constexpr uint32_t magic_number - = (1u << info.shift_amount) / info.divisor + 1; - return (n * magic_number) >> info.shift_amount; -} - -// Computes floor(n / 10^(kappa + 1)) (float) -inline auto divide_by_10_to_kappa_plus_1(uint32_t n) noexcept -> uint32_t { - // 1374389535 = ceil(2^37/100) - return static_cast((static_cast(n) * 1374389535) >> 37); -} -// Computes floor(n / 10^(kappa + 1)) (double) -inline auto divide_by_10_to_kappa_plus_1(uint64_t n) noexcept -> uint64_t { - // 2361183241434822607 = ceil(2^(64+7)/1000) - return umul128_upper64(n, 2361183241434822607ull) >> 7; -} - -// Various subroutines using pow10 cache -template -struct cache_accessor; - -template <> -struct cache_accessor { - using carrier_uint = float_info::carrier_uint; - using cache_entry_type = uint64_t; - - static auto get_cached_power(int k) noexcept -> uint64_t { - FMT_ASSERT( - k >= float_info::min_k && k <= float_info::max_k, - "k is out of range"); - static constexpr const uint64_t pow10_significands[] = { - 0x81ceb32c4b43fcf5, 0xa2425ff75e14fc32, 0xcad2f7f5359a3b3f, - 0xfd87b5f28300ca0e, 0x9e74d1b791e07e49, 0xc612062576589ddb, - 0xf79687aed3eec552, 0x9abe14cd44753b53, 0xc16d9a0095928a28, - 0xf1c90080baf72cb2, 0x971da05074da7bef, 0xbce5086492111aeb, - 0xec1e4a7db69561a6, 0x9392ee8e921d5d08, 0xb877aa3236a4b44a, - 0xe69594bec44de15c, 0x901d7cf73ab0acda, 0xb424dc35095cd810, - 0xe12e13424bb40e14, 0x8cbccc096f5088cc, 0xafebff0bcb24aaff, - 0xdbe6fecebdedd5bf, 0x89705f4136b4a598, 0xabcc77118461cefd, - 0xd6bf94d5e57a42bd, 0x8637bd05af6c69b6, 0xa7c5ac471b478424, - 0xd1b71758e219652c, 0x83126e978d4fdf3c, 0xa3d70a3d70a3d70b, - 0xcccccccccccccccd, 0x8000000000000000, 0xa000000000000000, - 0xc800000000000000, 0xfa00000000000000, 0x9c40000000000000, - 0xc350000000000000, 0xf424000000000000, 0x9896800000000000, - 0xbebc200000000000, 0xee6b280000000000, 0x9502f90000000000, - 0xba43b74000000000, 0xe8d4a51000000000, 0x9184e72a00000000, - 0xb5e620f480000000, 0xe35fa931a0000000, 0x8e1bc9bf04000000, - 0xb1a2bc2ec5000000, 0xde0b6b3a76400000, 0x8ac7230489e80000, - 0xad78ebc5ac620000, 0xd8d726b7177a8000, 0x878678326eac9000, - 0xa968163f0a57b400, 0xd3c21bcecceda100, 0x84595161401484a0, - 0xa56fa5b99019a5c8, 0xcecb8f27f4200f3a, 0x813f3978f8940985, - 0xa18f07d736b90be6, 0xc9f2c9cd04674edf, 0xfc6f7c4045812297, - 0x9dc5ada82b70b59e, 0xc5371912364ce306, 0xf684df56c3e01bc7, - 0x9a130b963a6c115d, 0xc097ce7bc90715b4, 0xf0bdc21abb48db21, - 0x96769950b50d88f5, 0xbc143fa4e250eb32, 0xeb194f8e1ae525fe, - 0x92efd1b8d0cf37bf, 0xb7abc627050305ae, 0xe596b7b0c643c71a, - 0x8f7e32ce7bea5c70, 0xb35dbf821ae4f38c, 0xe0352f62a19e306f}; - return pow10_significands[k - float_info::min_k]; - } - - struct compute_mul_result { - carrier_uint result; - bool is_integer; - }; - struct compute_mul_parity_result { - bool parity; - bool is_integer; - }; - - static auto compute_mul(carrier_uint u, - const cache_entry_type &cache) noexcept -> compute_mul_result { - auto r = umul96_upper64(u, cache); - return {static_cast(r >> 32), - static_cast(r) == 0}; - } - - static auto compute_delta(const cache_entry_type &cache, int beta) noexcept - -> uint32_t { - return static_cast(cache >> (64 - 1 - beta)); - } - - static auto compute_mul_parity(carrier_uint two_f, - const cache_entry_type &cache, int beta) noexcept - -> compute_mul_parity_result { - FMT_ASSERT(beta >= 1, ""); - FMT_ASSERT(beta < 64, ""); - - auto r = umul96_lower64(two_f, cache); - return {((r >> (64 - beta)) & 1) != 0, - static_cast(r >> (32 - beta)) == 0}; - } - - static auto compute_left_endpoint_for_shorter_interval_case( - const cache_entry_type &cache, int beta) noexcept -> carrier_uint { - return static_cast( - (cache - (cache >> (num_significand_bits() + 2))) - >> (64 - num_significand_bits() - 1 - beta)); - } - - static auto compute_right_endpoint_for_shorter_interval_case( - const cache_entry_type &cache, int beta) noexcept -> carrier_uint { - return static_cast( - (cache + (cache >> (num_significand_bits() + 1))) - >> (64 - num_significand_bits() - 1 - beta)); - } - - static auto compute_round_up_for_shorter_interval_case( - const cache_entry_type &cache, int beta) noexcept -> carrier_uint { - return (static_cast(cache - >> (64 - num_significand_bits() - 2 - beta)) - + 1) - / 2; - } -}; - -template <> -struct cache_accessor { - using carrier_uint = float_info::carrier_uint; - using cache_entry_type = uint128_fallback; - - static auto get_cached_power(int k) noexcept -> uint128_fallback { - FMT_ASSERT(k >= float_info::min_k - && k <= float_info::max_k, - "k is out of range"); - - static constexpr const uint128_fallback pow10_significands[] = { -#if FMT_USE_FULL_CACHE_DRAGONBOX - {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b}, - {0x9faacf3df73609b1, 0x77b191618c54e9ad}, - {0xc795830d75038c1d, 0xd59df5b9ef6a2418}, - {0xf97ae3d0d2446f25, 0x4b0573286b44ad1e}, - {0x9becce62836ac577, 0x4ee367f9430aec33}, - {0xc2e801fb244576d5, 0x229c41f793cda740}, - {0xf3a20279ed56d48a, 0x6b43527578c11110}, - {0x9845418c345644d6, 0x830a13896b78aaaa}, - {0xbe5691ef416bd60c, 0x23cc986bc656d554}, - {0xedec366b11c6cb8f, 0x2cbfbe86b7ec8aa9}, - {0x94b3a202eb1c3f39, 0x7bf7d71432f3d6aa}, - {0xb9e08a83a5e34f07, 0xdaf5ccd93fb0cc54}, - {0xe858ad248f5c22c9, 0xd1b3400f8f9cff69}, - {0x91376c36d99995be, 0x23100809b9c21fa2}, - {0xb58547448ffffb2d, 0xabd40a0c2832a78b}, - {0xe2e69915b3fff9f9, 0x16c90c8f323f516d}, - {0x8dd01fad907ffc3b, 0xae3da7d97f6792e4}, - {0xb1442798f49ffb4a, 0x99cd11cfdf41779d}, - {0xdd95317f31c7fa1d, 0x40405643d711d584}, - {0x8a7d3eef7f1cfc52, 0x482835ea666b2573}, - {0xad1c8eab5ee43b66, 0xda3243650005eed0}, - {0xd863b256369d4a40, 0x90bed43e40076a83}, - {0x873e4f75e2224e68, 0x5a7744a6e804a292}, - {0xa90de3535aaae202, 0x711515d0a205cb37}, - {0xd3515c2831559a83, 0x0d5a5b44ca873e04}, - {0x8412d9991ed58091, 0xe858790afe9486c3}, - {0xa5178fff668ae0b6, 0x626e974dbe39a873}, - {0xce5d73ff402d98e3, 0xfb0a3d212dc81290}, - {0x80fa687f881c7f8e, 0x7ce66634bc9d0b9a}, - {0xa139029f6a239f72, 0x1c1fffc1ebc44e81}, - {0xc987434744ac874e, 0xa327ffb266b56221}, - {0xfbe9141915d7a922, 0x4bf1ff9f0062baa9}, - {0x9d71ac8fada6c9b5, 0x6f773fc3603db4aa}, - {0xc4ce17b399107c22, 0xcb550fb4384d21d4}, - {0xf6019da07f549b2b, 0x7e2a53a146606a49}, - {0x99c102844f94e0fb, 0x2eda7444cbfc426e}, - {0xc0314325637a1939, 0xfa911155fefb5309}, - {0xf03d93eebc589f88, 0x793555ab7eba27cb}, - {0x96267c7535b763b5, 0x4bc1558b2f3458df}, - {0xbbb01b9283253ca2, 0x9eb1aaedfb016f17}, - {0xea9c227723ee8bcb, 0x465e15a979c1cadd}, - {0x92a1958a7675175f, 0x0bfacd89ec191eca}, - {0xb749faed14125d36, 0xcef980ec671f667c}, - {0xe51c79a85916f484, 0x82b7e12780e7401b}, - {0x8f31cc0937ae58d2, 0xd1b2ecb8b0908811}, - {0xb2fe3f0b8599ef07, 0x861fa7e6dcb4aa16}, - {0xdfbdcece67006ac9, 0x67a791e093e1d49b}, - {0x8bd6a141006042bd, 0xe0c8bb2c5c6d24e1}, - {0xaecc49914078536d, 0x58fae9f773886e19}, - {0xda7f5bf590966848, 0xaf39a475506a899f}, - {0x888f99797a5e012d, 0x6d8406c952429604}, - {0xaab37fd7d8f58178, 0xc8e5087ba6d33b84}, - {0xd5605fcdcf32e1d6, 0xfb1e4a9a90880a65}, - {0x855c3be0a17fcd26, 0x5cf2eea09a550680}, - {0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481f}, - {0xd0601d8efc57b08b, 0xf13b94daf124da27}, - {0x823c12795db6ce57, 0x76c53d08d6b70859}, - {0xa2cb1717b52481ed, 0x54768c4b0c64ca6f}, - {0xcb7ddcdda26da268, 0xa9942f5dcf7dfd0a}, - {0xfe5d54150b090b02, 0xd3f93b35435d7c4d}, - {0x9efa548d26e5a6e1, 0xc47bc5014a1a6db0}, - {0xc6b8e9b0709f109a, 0x359ab6419ca1091c}, - {0xf867241c8cc6d4c0, 0xc30163d203c94b63}, - {0x9b407691d7fc44f8, 0x79e0de63425dcf1e}, - {0xc21094364dfb5636, 0x985915fc12f542e5}, - {0xf294b943e17a2bc4, 0x3e6f5b7b17b2939e}, - {0x979cf3ca6cec5b5a, 0xa705992ceecf9c43}, - {0xbd8430bd08277231, 0x50c6ff782a838354}, - {0xece53cec4a314ebd, 0xa4f8bf5635246429}, - {0x940f4613ae5ed136, 0x871b7795e136be9a}, - {0xb913179899f68584, 0x28e2557b59846e40}, - {0xe757dd7ec07426e5, 0x331aeada2fe589d0}, - {0x9096ea6f3848984f, 0x3ff0d2c85def7622}, - {0xb4bca50b065abe63, 0x0fed077a756b53aa}, - {0xe1ebce4dc7f16dfb, 0xd3e8495912c62895}, - {0x8d3360f09cf6e4bd, 0x64712dd7abbbd95d}, - {0xb080392cc4349dec, 0xbd8d794d96aacfb4}, - {0xdca04777f541c567, 0xecf0d7a0fc5583a1}, - {0x89e42caaf9491b60, 0xf41686c49db57245}, - {0xac5d37d5b79b6239, 0x311c2875c522ced6}, - {0xd77485cb25823ac7, 0x7d633293366b828c}, - {0x86a8d39ef77164bc, 0xae5dff9c02033198}, - {0xa8530886b54dbdeb, 0xd9f57f830283fdfd}, - {0xd267caa862a12d66, 0xd072df63c324fd7c}, - {0x8380dea93da4bc60, 0x4247cb9e59f71e6e}, - {0xa46116538d0deb78, 0x52d9be85f074e609}, - {0xcd795be870516656, 0x67902e276c921f8c}, - {0x806bd9714632dff6, 0x00ba1cd8a3db53b7}, - {0xa086cfcd97bf97f3, 0x80e8a40eccd228a5}, - {0xc8a883c0fdaf7df0, 0x6122cd128006b2ce}, - {0xfad2a4b13d1b5d6c, 0x796b805720085f82}, - {0x9cc3a6eec6311a63, 0xcbe3303674053bb1}, - {0xc3f490aa77bd60fc, 0xbedbfc4411068a9d}, - {0xf4f1b4d515acb93b, 0xee92fb5515482d45}, - {0x991711052d8bf3c5, 0x751bdd152d4d1c4b}, - {0xbf5cd54678eef0b6, 0xd262d45a78a0635e}, - {0xef340a98172aace4, 0x86fb897116c87c35}, - {0x9580869f0e7aac0e, 0xd45d35e6ae3d4da1}, - {0xbae0a846d2195712, 0x8974836059cca10a}, - {0xe998d258869facd7, 0x2bd1a438703fc94c}, - {0x91ff83775423cc06, 0x7b6306a34627ddd0}, - {0xb67f6455292cbf08, 0x1a3bc84c17b1d543}, - {0xe41f3d6a7377eeca, 0x20caba5f1d9e4a94}, - {0x8e938662882af53e, 0x547eb47b7282ee9d}, - {0xb23867fb2a35b28d, 0xe99e619a4f23aa44}, - {0xdec681f9f4c31f31, 0x6405fa00e2ec94d5}, - {0x8b3c113c38f9f37e, 0xde83bc408dd3dd05}, - {0xae0b158b4738705e, 0x9624ab50b148d446}, - {0xd98ddaee19068c76, 0x3badd624dd9b0958}, - {0x87f8a8d4cfa417c9, 0xe54ca5d70a80e5d7}, - {0xa9f6d30a038d1dbc, 0x5e9fcf4ccd211f4d}, - {0xd47487cc8470652b, 0x7647c32000696720}, - {0x84c8d4dfd2c63f3b, 0x29ecd9f40041e074}, - {0xa5fb0a17c777cf09, 0xf468107100525891}, - {0xcf79cc9db955c2cc, 0x7182148d4066eeb5}, - {0x81ac1fe293d599bf, 0xc6f14cd848405531}, - {0xa21727db38cb002f, 0xb8ada00e5a506a7d}, - {0xca9cf1d206fdc03b, 0xa6d90811f0e4851d}, - {0xfd442e4688bd304a, 0x908f4a166d1da664}, - {0x9e4a9cec15763e2e, 0x9a598e4e043287ff}, - {0xc5dd44271ad3cdba, 0x40eff1e1853f29fe}, - {0xf7549530e188c128, 0xd12bee59e68ef47d}, - {0x9a94dd3e8cf578b9, 0x82bb74f8301958cf}, - {0xc13a148e3032d6e7, 0xe36a52363c1faf02}, - {0xf18899b1bc3f8ca1, 0xdc44e6c3cb279ac2}, - {0x96f5600f15a7b7e5, 0x29ab103a5ef8c0ba}, - {0xbcb2b812db11a5de, 0x7415d448f6b6f0e8}, - {0xebdf661791d60f56, 0x111b495b3464ad22}, - {0x936b9fcebb25c995, 0xcab10dd900beec35}, - {0xb84687c269ef3bfb, 0x3d5d514f40eea743}, - {0xe65829b3046b0afa, 0x0cb4a5a3112a5113}, - {0x8ff71a0fe2c2e6dc, 0x47f0e785eaba72ac}, - {0xb3f4e093db73a093, 0x59ed216765690f57}, - {0xe0f218b8d25088b8, 0x306869c13ec3532d}, - {0x8c974f7383725573, 0x1e414218c73a13fc}, - {0xafbd2350644eeacf, 0xe5d1929ef90898fb}, - {0xdbac6c247d62a583, 0xdf45f746b74abf3a}, - {0x894bc396ce5da772, 0x6b8bba8c328eb784}, - {0xab9eb47c81f5114f, 0x066ea92f3f326565}, - {0xd686619ba27255a2, 0xc80a537b0efefebe}, - {0x8613fd0145877585, 0xbd06742ce95f5f37}, - {0xa798fc4196e952e7, 0x2c48113823b73705}, - {0xd17f3b51fca3a7a0, 0xf75a15862ca504c6}, - {0x82ef85133de648c4, 0x9a984d73dbe722fc}, - {0xa3ab66580d5fdaf5, 0xc13e60d0d2e0ebbb}, - {0xcc963fee10b7d1b3, 0x318df905079926a9}, - {0xffbbcfe994e5c61f, 0xfdf17746497f7053}, - {0x9fd561f1fd0f9bd3, 0xfeb6ea8bedefa634}, - {0xc7caba6e7c5382c8, 0xfe64a52ee96b8fc1}, - {0xf9bd690a1b68637b, 0x3dfdce7aa3c673b1}, - {0x9c1661a651213e2d, 0x06bea10ca65c084f}, - {0xc31bfa0fe5698db8, 0x486e494fcff30a63}, - {0xf3e2f893dec3f126, 0x5a89dba3c3efccfb}, - {0x986ddb5c6b3a76b7, 0xf89629465a75e01d}, - {0xbe89523386091465, 0xf6bbb397f1135824}, - {0xee2ba6c0678b597f, 0x746aa07ded582e2d}, - {0x94db483840b717ef, 0xa8c2a44eb4571cdd}, - {0xba121a4650e4ddeb, 0x92f34d62616ce414}, - {0xe896a0d7e51e1566, 0x77b020baf9c81d18}, - {0x915e2486ef32cd60, 0x0ace1474dc1d122f}, - {0xb5b5ada8aaff80b8, 0x0d819992132456bb}, - {0xe3231912d5bf60e6, 0x10e1fff697ed6c6a}, - {0x8df5efabc5979c8f, 0xca8d3ffa1ef463c2}, - {0xb1736b96b6fd83b3, 0xbd308ff8a6b17cb3}, - {0xddd0467c64bce4a0, 0xac7cb3f6d05ddbdf}, - {0x8aa22c0dbef60ee4, 0x6bcdf07a423aa96c}, - {0xad4ab7112eb3929d, 0x86c16c98d2c953c7}, - {0xd89d64d57a607744, 0xe871c7bf077ba8b8}, - {0x87625f056c7c4a8b, 0x11471cd764ad4973}, - {0xa93af6c6c79b5d2d, 0xd598e40d3dd89bd0}, - {0xd389b47879823479, 0x4aff1d108d4ec2c4}, - {0x843610cb4bf160cb, 0xcedf722a585139bb}, - {0xa54394fe1eedb8fe, 0xc2974eb4ee658829}, - {0xce947a3da6a9273e, 0x733d226229feea33}, - {0x811ccc668829b887, 0x0806357d5a3f5260}, - {0xa163ff802a3426a8, 0xca07c2dcb0cf26f8}, - {0xc9bcff6034c13052, 0xfc89b393dd02f0b6}, - {0xfc2c3f3841f17c67, 0xbbac2078d443ace3}, - {0x9d9ba7832936edc0, 0xd54b944b84aa4c0e}, - {0xc5029163f384a931, 0x0a9e795e65d4df12}, - {0xf64335bcf065d37d, 0x4d4617b5ff4a16d6}, - {0x99ea0196163fa42e, 0x504bced1bf8e4e46}, - {0xc06481fb9bcf8d39, 0xe45ec2862f71e1d7}, - {0xf07da27a82c37088, 0x5d767327bb4e5a4d}, - {0x964e858c91ba2655, 0x3a6a07f8d510f870}, - {0xbbe226efb628afea, 0x890489f70a55368c}, - {0xeadab0aba3b2dbe5, 0x2b45ac74ccea842f}, - {0x92c8ae6b464fc96f, 0x3b0b8bc90012929e}, - {0xb77ada0617e3bbcb, 0x09ce6ebb40173745}, - {0xe55990879ddcaabd, 0xcc420a6a101d0516}, - {0x8f57fa54c2a9eab6, 0x9fa946824a12232e}, - {0xb32df8e9f3546564, 0x47939822dc96abfa}, - {0xdff9772470297ebd, 0x59787e2b93bc56f8}, - {0x8bfbea76c619ef36, 0x57eb4edb3c55b65b}, - {0xaefae51477a06b03, 0xede622920b6b23f2}, - {0xdab99e59958885c4, 0xe95fab368e45ecee}, - {0x88b402f7fd75539b, 0x11dbcb0218ebb415}, - {0xaae103b5fcd2a881, 0xd652bdc29f26a11a}, - {0xd59944a37c0752a2, 0x4be76d3346f04960}, - {0x857fcae62d8493a5, 0x6f70a4400c562ddc}, - {0xa6dfbd9fb8e5b88e, 0xcb4ccd500f6bb953}, - {0xd097ad07a71f26b2, 0x7e2000a41346a7a8}, - {0x825ecc24c873782f, 0x8ed400668c0c28c9}, - {0xa2f67f2dfa90563b, 0x728900802f0f32fb}, - {0xcbb41ef979346bca, 0x4f2b40a03ad2ffba}, - {0xfea126b7d78186bc, 0xe2f610c84987bfa9}, - {0x9f24b832e6b0f436, 0x0dd9ca7d2df4d7ca}, - {0xc6ede63fa05d3143, 0x91503d1c79720dbc}, - {0xf8a95fcf88747d94, 0x75a44c6397ce912b}, - {0x9b69dbe1b548ce7c, 0xc986afbe3ee11abb}, - {0xc24452da229b021b, 0xfbe85badce996169}, - {0xf2d56790ab41c2a2, 0xfae27299423fb9c4}, - {0x97c560ba6b0919a5, 0xdccd879fc967d41b}, - {0xbdb6b8e905cb600f, 0x5400e987bbc1c921}, - {0xed246723473e3813, 0x290123e9aab23b69}, - {0x9436c0760c86e30b, 0xf9a0b6720aaf6522}, - {0xb94470938fa89bce, 0xf808e40e8d5b3e6a}, - {0xe7958cb87392c2c2, 0xb60b1d1230b20e05}, - {0x90bd77f3483bb9b9, 0xb1c6f22b5e6f48c3}, - {0xb4ecd5f01a4aa828, 0x1e38aeb6360b1af4}, - {0xe2280b6c20dd5232, 0x25c6da63c38de1b1}, - {0x8d590723948a535f, 0x579c487e5a38ad0f}, - {0xb0af48ec79ace837, 0x2d835a9df0c6d852}, - {0xdcdb1b2798182244, 0xf8e431456cf88e66}, - {0x8a08f0f8bf0f156b, 0x1b8e9ecb641b5900}, - {0xac8b2d36eed2dac5, 0xe272467e3d222f40}, - {0xd7adf884aa879177, 0x5b0ed81dcc6abb10}, - {0x86ccbb52ea94baea, 0x98e947129fc2b4ea}, - {0xa87fea27a539e9a5, 0x3f2398d747b36225}, - {0xd29fe4b18e88640e, 0x8eec7f0d19a03aae}, - {0x83a3eeeef9153e89, 0x1953cf68300424ad}, - {0xa48ceaaab75a8e2b, 0x5fa8c3423c052dd8}, - {0xcdb02555653131b6, 0x3792f412cb06794e}, - {0x808e17555f3ebf11, 0xe2bbd88bbee40bd1}, - {0xa0b19d2ab70e6ed6, 0x5b6aceaeae9d0ec5}, - {0xc8de047564d20a8b, 0xf245825a5a445276}, - {0xfb158592be068d2e, 0xeed6e2f0f0d56713}, - {0x9ced737bb6c4183d, 0x55464dd69685606c}, - {0xc428d05aa4751e4c, 0xaa97e14c3c26b887}, - {0xf53304714d9265df, 0xd53dd99f4b3066a9}, - {0x993fe2c6d07b7fab, 0xe546a8038efe402a}, - {0xbf8fdb78849a5f96, 0xde98520472bdd034}, - {0xef73d256a5c0f77c, 0x963e66858f6d4441}, - {0x95a8637627989aad, 0xdde7001379a44aa9}, - {0xbb127c53b17ec159, 0x5560c018580d5d53}, - {0xe9d71b689dde71af, 0xaab8f01e6e10b4a7}, - {0x9226712162ab070d, 0xcab3961304ca70e9}, - {0xb6b00d69bb55c8d1, 0x3d607b97c5fd0d23}, - {0xe45c10c42a2b3b05, 0x8cb89a7db77c506b}, - {0x8eb98a7a9a5b04e3, 0x77f3608e92adb243}, - {0xb267ed1940f1c61c, 0x55f038b237591ed4}, - {0xdf01e85f912e37a3, 0x6b6c46dec52f6689}, - {0x8b61313bbabce2c6, 0x2323ac4b3b3da016}, - {0xae397d8aa96c1b77, 0xabec975e0a0d081b}, - {0xd9c7dced53c72255, 0x96e7bd358c904a22}, - {0x881cea14545c7575, 0x7e50d64177da2e55}, - {0xaa242499697392d2, 0xdde50bd1d5d0b9ea}, - {0xd4ad2dbfc3d07787, 0x955e4ec64b44e865}, - {0x84ec3c97da624ab4, 0xbd5af13bef0b113f}, - {0xa6274bbdd0fadd61, 0xecb1ad8aeacdd58f}, - {0xcfb11ead453994ba, 0x67de18eda5814af3}, - {0x81ceb32c4b43fcf4, 0x80eacf948770ced8}, - {0xa2425ff75e14fc31, 0xa1258379a94d028e}, - {0xcad2f7f5359a3b3e, 0x096ee45813a04331}, - {0xfd87b5f28300ca0d, 0x8bca9d6e188853fd}, - {0x9e74d1b791e07e48, 0x775ea264cf55347e}, - {0xc612062576589dda, 0x95364afe032a819e}, - {0xf79687aed3eec551, 0x3a83ddbd83f52205}, - {0x9abe14cd44753b52, 0xc4926a9672793543}, - {0xc16d9a0095928a27, 0x75b7053c0f178294}, - {0xf1c90080baf72cb1, 0x5324c68b12dd6339}, - {0x971da05074da7bee, 0xd3f6fc16ebca5e04}, - {0xbce5086492111aea, 0x88f4bb1ca6bcf585}, - {0xec1e4a7db69561a5, 0x2b31e9e3d06c32e6}, - {0x9392ee8e921d5d07, 0x3aff322e62439fd0}, - {0xb877aa3236a4b449, 0x09befeb9fad487c3}, - {0xe69594bec44de15b, 0x4c2ebe687989a9b4}, - {0x901d7cf73ab0acd9, 0x0f9d37014bf60a11}, - {0xb424dc35095cd80f, 0x538484c19ef38c95}, - {0xe12e13424bb40e13, 0x2865a5f206b06fba}, - {0x8cbccc096f5088cb, 0xf93f87b7442e45d4}, - {0xafebff0bcb24aafe, 0xf78f69a51539d749}, - {0xdbe6fecebdedd5be, 0xb573440e5a884d1c}, - {0x89705f4136b4a597, 0x31680a88f8953031}, - {0xabcc77118461cefc, 0xfdc20d2b36ba7c3e}, - {0xd6bf94d5e57a42bc, 0x3d32907604691b4d}, - {0x8637bd05af6c69b5, 0xa63f9a49c2c1b110}, - {0xa7c5ac471b478423, 0x0fcf80dc33721d54}, - {0xd1b71758e219652b, 0xd3c36113404ea4a9}, - {0x83126e978d4fdf3b, 0x645a1cac083126ea}, - {0xa3d70a3d70a3d70a, 0x3d70a3d70a3d70a4}, - {0xcccccccccccccccc, 0xcccccccccccccccd}, - {0x8000000000000000, 0x0000000000000000}, - {0xa000000000000000, 0x0000000000000000}, - {0xc800000000000000, 0x0000000000000000}, - {0xfa00000000000000, 0x0000000000000000}, - {0x9c40000000000000, 0x0000000000000000}, - {0xc350000000000000, 0x0000000000000000}, - {0xf424000000000000, 0x0000000000000000}, - {0x9896800000000000, 0x0000000000000000}, - {0xbebc200000000000, 0x0000000000000000}, - {0xee6b280000000000, 0x0000000000000000}, - {0x9502f90000000000, 0x0000000000000000}, - {0xba43b74000000000, 0x0000000000000000}, - {0xe8d4a51000000000, 0x0000000000000000}, - {0x9184e72a00000000, 0x0000000000000000}, - {0xb5e620f480000000, 0x0000000000000000}, - {0xe35fa931a0000000, 0x0000000000000000}, - {0x8e1bc9bf04000000, 0x0000000000000000}, - {0xb1a2bc2ec5000000, 0x0000000000000000}, - {0xde0b6b3a76400000, 0x0000000000000000}, - {0x8ac7230489e80000, 0x0000000000000000}, - {0xad78ebc5ac620000, 0x0000000000000000}, - {0xd8d726b7177a8000, 0x0000000000000000}, - {0x878678326eac9000, 0x0000000000000000}, - {0xa968163f0a57b400, 0x0000000000000000}, - {0xd3c21bcecceda100, 0x0000000000000000}, - {0x84595161401484a0, 0x0000000000000000}, - {0xa56fa5b99019a5c8, 0x0000000000000000}, - {0xcecb8f27f4200f3a, 0x0000000000000000}, - {0x813f3978f8940984, 0x4000000000000000}, - {0xa18f07d736b90be5, 0x5000000000000000}, - {0xc9f2c9cd04674ede, 0xa400000000000000}, - {0xfc6f7c4045812296, 0x4d00000000000000}, - {0x9dc5ada82b70b59d, 0xf020000000000000}, - {0xc5371912364ce305, 0x6c28000000000000}, - {0xf684df56c3e01bc6, 0xc732000000000000}, - {0x9a130b963a6c115c, 0x3c7f400000000000}, - {0xc097ce7bc90715b3, 0x4b9f100000000000}, - {0xf0bdc21abb48db20, 0x1e86d40000000000}, - {0x96769950b50d88f4, 0x1314448000000000}, - {0xbc143fa4e250eb31, 0x17d955a000000000}, - {0xeb194f8e1ae525fd, 0x5dcfab0800000000}, - {0x92efd1b8d0cf37be, 0x5aa1cae500000000}, - {0xb7abc627050305ad, 0xf14a3d9e40000000}, - {0xe596b7b0c643c719, 0x6d9ccd05d0000000}, - {0x8f7e32ce7bea5c6f, 0xe4820023a2000000}, - {0xb35dbf821ae4f38b, 0xdda2802c8a800000}, - {0xe0352f62a19e306e, 0xd50b2037ad200000}, - {0x8c213d9da502de45, 0x4526f422cc340000}, - {0xaf298d050e4395d6, 0x9670b12b7f410000}, - {0xdaf3f04651d47b4c, 0x3c0cdd765f114000}, - {0x88d8762bf324cd0f, 0xa5880a69fb6ac800}, - {0xab0e93b6efee0053, 0x8eea0d047a457a00}, - {0xd5d238a4abe98068, 0x72a4904598d6d880}, - {0x85a36366eb71f041, 0x47a6da2b7f864750}, - {0xa70c3c40a64e6c51, 0x999090b65f67d924}, - {0xd0cf4b50cfe20765, 0xfff4b4e3f741cf6d}, - {0x82818f1281ed449f, 0xbff8f10e7a8921a5}, - {0xa321f2d7226895c7, 0xaff72d52192b6a0e}, - {0xcbea6f8ceb02bb39, 0x9bf4f8a69f764491}, - {0xfee50b7025c36a08, 0x02f236d04753d5b5}, - {0x9f4f2726179a2245, 0x01d762422c946591}, - {0xc722f0ef9d80aad6, 0x424d3ad2b7b97ef6}, - {0xf8ebad2b84e0d58b, 0xd2e0898765a7deb3}, - {0x9b934c3b330c8577, 0x63cc55f49f88eb30}, - {0xc2781f49ffcfa6d5, 0x3cbf6b71c76b25fc}, - {0xf316271c7fc3908a, 0x8bef464e3945ef7b}, - {0x97edd871cfda3a56, 0x97758bf0e3cbb5ad}, - {0xbde94e8e43d0c8ec, 0x3d52eeed1cbea318}, - {0xed63a231d4c4fb27, 0x4ca7aaa863ee4bde}, - {0x945e455f24fb1cf8, 0x8fe8caa93e74ef6b}, - {0xb975d6b6ee39e436, 0xb3e2fd538e122b45}, - {0xe7d34c64a9c85d44, 0x60dbbca87196b617}, - {0x90e40fbeea1d3a4a, 0xbc8955e946fe31ce}, - {0xb51d13aea4a488dd, 0x6babab6398bdbe42}, - {0xe264589a4dcdab14, 0xc696963c7eed2dd2}, - {0x8d7eb76070a08aec, 0xfc1e1de5cf543ca3}, - {0xb0de65388cc8ada8, 0x3b25a55f43294bcc}, - {0xdd15fe86affad912, 0x49ef0eb713f39ebf}, - {0x8a2dbf142dfcc7ab, 0x6e3569326c784338}, - {0xacb92ed9397bf996, 0x49c2c37f07965405}, - {0xd7e77a8f87daf7fb, 0xdc33745ec97be907}, - {0x86f0ac99b4e8dafd, 0x69a028bb3ded71a4}, - {0xa8acd7c0222311bc, 0xc40832ea0d68ce0d}, - {0xd2d80db02aabd62b, 0xf50a3fa490c30191}, - {0x83c7088e1aab65db, 0x792667c6da79e0fb}, - {0xa4b8cab1a1563f52, 0x577001b891185939}, - {0xcde6fd5e09abcf26, 0xed4c0226b55e6f87}, - {0x80b05e5ac60b6178, 0x544f8158315b05b5}, - {0xa0dc75f1778e39d6, 0x696361ae3db1c722}, - {0xc913936dd571c84c, 0x03bc3a19cd1e38ea}, - {0xfb5878494ace3a5f, 0x04ab48a04065c724}, - {0x9d174b2dcec0e47b, 0x62eb0d64283f9c77}, - {0xc45d1df942711d9a, 0x3ba5d0bd324f8395}, - {0xf5746577930d6500, 0xca8f44ec7ee3647a}, - {0x9968bf6abbe85f20, 0x7e998b13cf4e1ecc}, - {0xbfc2ef456ae276e8, 0x9e3fedd8c321a67f}, - {0xefb3ab16c59b14a2, 0xc5cfe94ef3ea101f}, - {0x95d04aee3b80ece5, 0xbba1f1d158724a13}, - {0xbb445da9ca61281f, 0x2a8a6e45ae8edc98}, - {0xea1575143cf97226, 0xf52d09d71a3293be}, - {0x924d692ca61be758, 0x593c2626705f9c57}, - {0xb6e0c377cfa2e12e, 0x6f8b2fb00c77836d}, - {0xe498f455c38b997a, 0x0b6dfb9c0f956448}, - {0x8edf98b59a373fec, 0x4724bd4189bd5ead}, - {0xb2977ee300c50fe7, 0x58edec91ec2cb658}, - {0xdf3d5e9bc0f653e1, 0x2f2967b66737e3ee}, - {0x8b865b215899f46c, 0xbd79e0d20082ee75}, - {0xae67f1e9aec07187, 0xecd8590680a3aa12}, - {0xda01ee641a708de9, 0xe80e6f4820cc9496}, - {0x884134fe908658b2, 0x3109058d147fdcde}, - {0xaa51823e34a7eede, 0xbd4b46f0599fd416}, - {0xd4e5e2cdc1d1ea96, 0x6c9e18ac7007c91b}, - {0x850fadc09923329e, 0x03e2cf6bc604ddb1}, - {0xa6539930bf6bff45, 0x84db8346b786151d}, - {0xcfe87f7cef46ff16, 0xe612641865679a64}, - {0x81f14fae158c5f6e, 0x4fcb7e8f3f60c07f}, - {0xa26da3999aef7749, 0xe3be5e330f38f09e}, - {0xcb090c8001ab551c, 0x5cadf5bfd3072cc6}, - {0xfdcb4fa002162a63, 0x73d9732fc7c8f7f7}, - {0x9e9f11c4014dda7e, 0x2867e7fddcdd9afb}, - {0xc646d63501a1511d, 0xb281e1fd541501b9}, - {0xf7d88bc24209a565, 0x1f225a7ca91a4227}, - {0x9ae757596946075f, 0x3375788de9b06959}, - {0xc1a12d2fc3978937, 0x0052d6b1641c83af}, - {0xf209787bb47d6b84, 0xc0678c5dbd23a49b}, - {0x9745eb4d50ce6332, 0xf840b7ba963646e1}, - {0xbd176620a501fbff, 0xb650e5a93bc3d899}, - {0xec5d3fa8ce427aff, 0xa3e51f138ab4cebf}, - {0x93ba47c980e98cdf, 0xc66f336c36b10138}, - {0xb8a8d9bbe123f017, 0xb80b0047445d4185}, - {0xe6d3102ad96cec1d, 0xa60dc059157491e6}, - {0x9043ea1ac7e41392, 0x87c89837ad68db30}, - {0xb454e4a179dd1877, 0x29babe4598c311fc}, - {0xe16a1dc9d8545e94, 0xf4296dd6fef3d67b}, - {0x8ce2529e2734bb1d, 0x1899e4a65f58660d}, - {0xb01ae745b101e9e4, 0x5ec05dcff72e7f90}, - {0xdc21a1171d42645d, 0x76707543f4fa1f74}, - {0x899504ae72497eba, 0x6a06494a791c53a9}, - {0xabfa45da0edbde69, 0x0487db9d17636893}, - {0xd6f8d7509292d603, 0x45a9d2845d3c42b7}, - {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3}, - {0xa7f26836f282b732, 0x8e6cac7768d7141f}, - {0xd1ef0244af2364ff, 0x3207d795430cd927}, - {0x8335616aed761f1f, 0x7f44e6bd49e807b9}, - {0xa402b9c5a8d3a6e7, 0x5f16206c9c6209a7}, - {0xcd036837130890a1, 0x36dba887c37a8c10}, - {0x802221226be55a64, 0xc2494954da2c978a}, - {0xa02aa96b06deb0fd, 0xf2db9baa10b7bd6d}, - {0xc83553c5c8965d3d, 0x6f92829494e5acc8}, - {0xfa42a8b73abbf48c, 0xcb772339ba1f17fa}, - {0x9c69a97284b578d7, 0xff2a760414536efc}, - {0xc38413cf25e2d70d, 0xfef5138519684abb}, - {0xf46518c2ef5b8cd1, 0x7eb258665fc25d6a}, - {0x98bf2f79d5993802, 0xef2f773ffbd97a62}, - {0xbeeefb584aff8603, 0xaafb550ffacfd8fb}, - {0xeeaaba2e5dbf6784, 0x95ba2a53f983cf39}, - {0x952ab45cfa97a0b2, 0xdd945a747bf26184}, - {0xba756174393d88df, 0x94f971119aeef9e5}, - {0xe912b9d1478ceb17, 0x7a37cd5601aab85e}, - {0x91abb422ccb812ee, 0xac62e055c10ab33b}, - {0xb616a12b7fe617aa, 0x577b986b314d600a}, - {0xe39c49765fdf9d94, 0xed5a7e85fda0b80c}, - {0x8e41ade9fbebc27d, 0x14588f13be847308}, - {0xb1d219647ae6b31c, 0x596eb2d8ae258fc9}, - {0xde469fbd99a05fe3, 0x6fca5f8ed9aef3bc}, - {0x8aec23d680043bee, 0x25de7bb9480d5855}, - {0xada72ccc20054ae9, 0xaf561aa79a10ae6b}, - {0xd910f7ff28069da4, 0x1b2ba1518094da05}, - {0x87aa9aff79042286, 0x90fb44d2f05d0843}, - {0xa99541bf57452b28, 0x353a1607ac744a54}, - {0xd3fa922f2d1675f2, 0x42889b8997915ce9}, - {0x847c9b5d7c2e09b7, 0x69956135febada12}, - {0xa59bc234db398c25, 0x43fab9837e699096}, - {0xcf02b2c21207ef2e, 0x94f967e45e03f4bc}, - {0x8161afb94b44f57d, 0x1d1be0eebac278f6}, - {0xa1ba1ba79e1632dc, 0x6462d92a69731733}, - {0xca28a291859bbf93, 0x7d7b8f7503cfdcff}, - {0xfcb2cb35e702af78, 0x5cda735244c3d43f}, - {0x9defbf01b061adab, 0x3a0888136afa64a8}, - {0xc56baec21c7a1916, 0x088aaa1845b8fdd1}, - {0xf6c69a72a3989f5b, 0x8aad549e57273d46}, - {0x9a3c2087a63f6399, 0x36ac54e2f678864c}, - {0xc0cb28a98fcf3c7f, 0x84576a1bb416a7de}, - {0xf0fdf2d3f3c30b9f, 0x656d44a2a11c51d6}, - {0x969eb7c47859e743, 0x9f644ae5a4b1b326}, - {0xbc4665b596706114, 0x873d5d9f0dde1fef}, - {0xeb57ff22fc0c7959, 0xa90cb506d155a7eb}, - {0x9316ff75dd87cbd8, 0x09a7f12442d588f3}, - {0xb7dcbf5354e9bece, 0x0c11ed6d538aeb30}, - {0xe5d3ef282a242e81, 0x8f1668c8a86da5fb}, - {0x8fa475791a569d10, 0xf96e017d694487bd}, - {0xb38d92d760ec4455, 0x37c981dcc395a9ad}, - {0xe070f78d3927556a, 0x85bbe253f47b1418}, - {0x8c469ab843b89562, 0x93956d7478ccec8f}, - {0xaf58416654a6babb, 0x387ac8d1970027b3}, - {0xdb2e51bfe9d0696a, 0x06997b05fcc0319f}, - {0x88fcf317f22241e2, 0x441fece3bdf81f04}, - {0xab3c2fddeeaad25a, 0xd527e81cad7626c4}, - {0xd60b3bd56a5586f1, 0x8a71e223d8d3b075}, - {0x85c7056562757456, 0xf6872d5667844e4a}, - {0xa738c6bebb12d16c, 0xb428f8ac016561dc}, - {0xd106f86e69d785c7, 0xe13336d701beba53}, - {0x82a45b450226b39c, 0xecc0024661173474}, - {0xa34d721642b06084, 0x27f002d7f95d0191}, - {0xcc20ce9bd35c78a5, 0x31ec038df7b441f5}, - {0xff290242c83396ce, 0x7e67047175a15272}, - {0x9f79a169bd203e41, 0x0f0062c6e984d387}, - {0xc75809c42c684dd1, 0x52c07b78a3e60869}, - {0xf92e0c3537826145, 0xa7709a56ccdf8a83}, - {0x9bbcc7a142b17ccb, 0x88a66076400bb692}, - {0xc2abf989935ddbfe, 0x6acff893d00ea436}, - {0xf356f7ebf83552fe, 0x0583f6b8c4124d44}, - {0x98165af37b2153de, 0xc3727a337a8b704b}, - {0xbe1bf1b059e9a8d6, 0x744f18c0592e4c5d}, - {0xeda2ee1c7064130c, 0x1162def06f79df74}, - {0x9485d4d1c63e8be7, 0x8addcb5645ac2ba9}, - {0xb9a74a0637ce2ee1, 0x6d953e2bd7173693}, - {0xe8111c87c5c1ba99, 0xc8fa8db6ccdd0438}, - {0x910ab1d4db9914a0, 0x1d9c9892400a22a3}, - {0xb54d5e4a127f59c8, 0x2503beb6d00cab4c}, - {0xe2a0b5dc971f303a, 0x2e44ae64840fd61e}, - {0x8da471a9de737e24, 0x5ceaecfed289e5d3}, - {0xb10d8e1456105dad, 0x7425a83e872c5f48}, - {0xdd50f1996b947518, 0xd12f124e28f7771a}, - {0x8a5296ffe33cc92f, 0x82bd6b70d99aaa70}, - {0xace73cbfdc0bfb7b, 0x636cc64d1001550c}, - {0xd8210befd30efa5a, 0x3c47f7e05401aa4f}, - {0x8714a775e3e95c78, 0x65acfaec34810a72}, - {0xa8d9d1535ce3b396, 0x7f1839a741a14d0e}, - {0xd31045a8341ca07c, 0x1ede48111209a051}, - {0x83ea2b892091e44d, 0x934aed0aab460433}, - {0xa4e4b66b68b65d60, 0xf81da84d56178540}, - {0xce1de40642e3f4b9, 0x36251260ab9d668f}, - {0x80d2ae83e9ce78f3, 0xc1d72b7c6b42601a}, - {0xa1075a24e4421730, 0xb24cf65b8612f820}, - {0xc94930ae1d529cfc, 0xdee033f26797b628}, - {0xfb9b7cd9a4a7443c, 0x169840ef017da3b2}, - {0x9d412e0806e88aa5, 0x8e1f289560ee864f}, - {0xc491798a08a2ad4e, 0xf1a6f2bab92a27e3}, - {0xf5b5d7ec8acb58a2, 0xae10af696774b1dc}, - {0x9991a6f3d6bf1765, 0xacca6da1e0a8ef2a}, - {0xbff610b0cc6edd3f, 0x17fd090a58d32af4}, - {0xeff394dcff8a948e, 0xddfc4b4cef07f5b1}, - {0x95f83d0a1fb69cd9, 0x4abdaf101564f98f}, - {0xbb764c4ca7a4440f, 0x9d6d1ad41abe37f2}, - {0xea53df5fd18d5513, 0x84c86189216dc5ee}, - {0x92746b9be2f8552c, 0x32fd3cf5b4e49bb5}, - {0xb7118682dbb66a77, 0x3fbc8c33221dc2a2}, - {0xe4d5e82392a40515, 0x0fabaf3feaa5334b}, - {0x8f05b1163ba6832d, 0x29cb4d87f2a7400f}, - {0xb2c71d5bca9023f8, 0x743e20e9ef511013}, - {0xdf78e4b2bd342cf6, 0x914da9246b255417}, - {0x8bab8eefb6409c1a, 0x1ad089b6c2f7548f}, - {0xae9672aba3d0c320, 0xa184ac2473b529b2}, - {0xda3c0f568cc4f3e8, 0xc9e5d72d90a2741f}, - {0x8865899617fb1871, 0x7e2fa67c7a658893}, - {0xaa7eebfb9df9de8d, 0xddbb901b98feeab8}, - {0xd51ea6fa85785631, 0x552a74227f3ea566}, - {0x8533285c936b35de, 0xd53a88958f872760}, - {0xa67ff273b8460356, 0x8a892abaf368f138}, - {0xd01fef10a657842c, 0x2d2b7569b0432d86}, - {0x8213f56a67f6b29b, 0x9c3b29620e29fc74}, - {0xa298f2c501f45f42, 0x8349f3ba91b47b90}, - {0xcb3f2f7642717713, 0x241c70a936219a74}, - {0xfe0efb53d30dd4d7, 0xed238cd383aa0111}, - {0x9ec95d1463e8a506, 0xf4363804324a40ab}, - {0xc67bb4597ce2ce48, 0xb143c6053edcd0d6}, - {0xf81aa16fdc1b81da, 0xdd94b7868e94050b}, - {0x9b10a4e5e9913128, 0xca7cf2b4191c8327}, - {0xc1d4ce1f63f57d72, 0xfd1c2f611f63a3f1}, - {0xf24a01a73cf2dccf, 0xbc633b39673c8ced}, - {0x976e41088617ca01, 0xd5be0503e085d814}, - {0xbd49d14aa79dbc82, 0x4b2d8644d8a74e19}, - {0xec9c459d51852ba2, 0xddf8e7d60ed1219f}, - {0x93e1ab8252f33b45, 0xcabb90e5c942b504}, - {0xb8da1662e7b00a17, 0x3d6a751f3b936244}, - {0xe7109bfba19c0c9d, 0x0cc512670a783ad5}, - {0x906a617d450187e2, 0x27fb2b80668b24c6}, - {0xb484f9dc9641e9da, 0xb1f9f660802dedf7}, - {0xe1a63853bbd26451, 0x5e7873f8a0396974}, - {0x8d07e33455637eb2, 0xdb0b487b6423e1e9}, - {0xb049dc016abc5e5f, 0x91ce1a9a3d2cda63}, - {0xdc5c5301c56b75f7, 0x7641a140cc7810fc}, - {0x89b9b3e11b6329ba, 0xa9e904c87fcb0a9e}, - {0xac2820d9623bf429, 0x546345fa9fbdcd45}, - {0xd732290fbacaf133, 0xa97c177947ad4096}, - {0x867f59a9d4bed6c0, 0x49ed8eabcccc485e}, - {0xa81f301449ee8c70, 0x5c68f256bfff5a75}, - {0xd226fc195c6a2f8c, 0x73832eec6fff3112}, - {0x83585d8fd9c25db7, 0xc831fd53c5ff7eac}, - {0xa42e74f3d032f525, 0xba3e7ca8b77f5e56}, - {0xcd3a1230c43fb26f, 0x28ce1bd2e55f35ec}, - {0x80444b5e7aa7cf85, 0x7980d163cf5b81b4}, - {0xa0555e361951c366, 0xd7e105bcc3326220}, - {0xc86ab5c39fa63440, 0x8dd9472bf3fefaa8}, - {0xfa856334878fc150, 0xb14f98f6f0feb952}, - {0x9c935e00d4b9d8d2, 0x6ed1bf9a569f33d4}, - {0xc3b8358109e84f07, 0x0a862f80ec4700c9}, - {0xf4a642e14c6262c8, 0xcd27bb612758c0fb}, - {0x98e7e9cccfbd7dbd, 0x8038d51cb897789d}, - {0xbf21e44003acdd2c, 0xe0470a63e6bd56c4}, - {0xeeea5d5004981478, 0x1858ccfce06cac75}, - {0x95527a5202df0ccb, 0x0f37801e0c43ebc9}, - {0xbaa718e68396cffd, 0xd30560258f54e6bb}, - {0xe950df20247c83fd, 0x47c6b82ef32a206a}, - {0x91d28b7416cdd27e, 0x4cdc331d57fa5442}, - {0xb6472e511c81471d, 0xe0133fe4adf8e953}, - {0xe3d8f9e563a198e5, 0x58180fddd97723a7}, - {0x8e679c2f5e44ff8f, 0x570f09eaa7ea7649}, - {0xb201833b35d63f73, 0x2cd2cc6551e513db}, - {0xde81e40a034bcf4f, 0xf8077f7ea65e58d2}, - {0x8b112e86420f6191, 0xfb04afaf27faf783}, - {0xadd57a27d29339f6, 0x79c5db9af1f9b564}, - {0xd94ad8b1c7380874, 0x18375281ae7822bd}, - {0x87cec76f1c830548, 0x8f2293910d0b15b6}, - {0xa9c2794ae3a3c69a, 0xb2eb3875504ddb23}, - {0xd433179d9c8cb841, 0x5fa60692a46151ec}, - {0x849feec281d7f328, 0xdbc7c41ba6bcd334}, - {0xa5c7ea73224deff3, 0x12b9b522906c0801}, - {0xcf39e50feae16bef, 0xd768226b34870a01}, - {0x81842f29f2cce375, 0xe6a1158300d46641}, - {0xa1e53af46f801c53, 0x60495ae3c1097fd1}, - {0xca5e89b18b602368, 0x385bb19cb14bdfc5}, - {0xfcf62c1dee382c42, 0x46729e03dd9ed7b6}, - {0x9e19db92b4e31ba9, 0x6c07a2c26a8346d2}, - {0xc5a05277621be293, 0xc7098b7305241886}, - {0xf70867153aa2db38, 0xb8cbee4fc66d1ea8}, - {0x9a65406d44a5c903, 0x737f74f1dc043329}, - {0xc0fe908895cf3b44, 0x505f522e53053ff3}, - {0xf13e34aabb430a15, 0x647726b9e7c68ff0}, - {0x96c6e0eab509e64d, 0x5eca783430dc19f6}, - {0xbc789925624c5fe0, 0xb67d16413d132073}, - {0xeb96bf6ebadf77d8, 0xe41c5bd18c57e890}, - {0x933e37a534cbaae7, 0x8e91b962f7b6f15a}, - {0xb80dc58e81fe95a1, 0x723627bbb5a4adb1}, - {0xe61136f2227e3b09, 0xcec3b1aaa30dd91d}, - {0x8fcac257558ee4e6, 0x213a4f0aa5e8a7b2}, - {0xb3bd72ed2af29e1f, 0xa988e2cd4f62d19e}, - {0xe0accfa875af45a7, 0x93eb1b80a33b8606}, - {0x8c6c01c9498d8b88, 0xbc72f130660533c4}, - {0xaf87023b9bf0ee6a, 0xeb8fad7c7f8680b5}, - {0xdb68c2ca82ed2a05, 0xa67398db9f6820e2}, -#else - {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b}, - {0xce5d73ff402d98e3, 0xfb0a3d212dc81290}, - {0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481f}, - {0x86a8d39ef77164bc, 0xae5dff9c02033198}, - {0xd98ddaee19068c76, 0x3badd624dd9b0958}, - {0xafbd2350644eeacf, 0xe5d1929ef90898fb}, - {0x8df5efabc5979c8f, 0xca8d3ffa1ef463c2}, - {0xe55990879ddcaabd, 0xcc420a6a101d0516}, - {0xb94470938fa89bce, 0xf808e40e8d5b3e6a}, - {0x95a8637627989aad, 0xdde7001379a44aa9}, - {0xf1c90080baf72cb1, 0x5324c68b12dd6339}, - {0xc350000000000000, 0x0000000000000000}, - {0x9dc5ada82b70b59d, 0xf020000000000000}, - {0xfee50b7025c36a08, 0x02f236d04753d5b5}, - {0xcde6fd5e09abcf26, 0xed4c0226b55e6f87}, - {0xa6539930bf6bff45, 0x84db8346b786151d}, - {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3}, - {0xd910f7ff28069da4, 0x1b2ba1518094da05}, - {0xaf58416654a6babb, 0x387ac8d1970027b3}, - {0x8da471a9de737e24, 0x5ceaecfed289e5d3}, - {0xe4d5e82392a40515, 0x0fabaf3feaa5334b}, - {0xb8da1662e7b00a17, 0x3d6a751f3b936244}, - {0x95527a5202df0ccb, 0x0f37801e0c43ebc9}, - {0xf13e34aabb430a15, 0x647726b9e7c68ff0} -#endif - }; - -#if FMT_USE_FULL_CACHE_DRAGONBOX - return pow10_significands[k - float_info::min_k]; -#else - static constexpr const uint64_t powers_of_5_64[] = {0x0000000000000001, - 0x0000000000000005, 0x0000000000000019, 0x000000000000007d, - 0x0000000000000271, 0x0000000000000c35, 0x0000000000003d09, - 0x000000000001312d, 0x000000000005f5e1, 0x00000000001dcd65, - 0x00000000009502f9, 0x0000000002e90edd, 0x000000000e8d4a51, - 0x0000000048c27395, 0x000000016bcc41e9, 0x000000071afd498d, - 0x0000002386f26fc1, 0x000000b1a2bc2ec5, 0x000003782dace9d9, - 0x00001158e460913d, 0x000056bc75e2d631, 0x0001b1ae4d6e2ef5, - 0x000878678326eac9, 0x002a5a058fc295ed, 0x00d3c21bcecceda1, - 0x0422ca8b0a00a425, 0x14adf4b7320334b9}; - - static const int compression_ratio = 27; - - // Compute base index. - int cache_index = (k - float_info::min_k) / compression_ratio; - int kb = cache_index * compression_ratio + float_info::min_k; - int offset = k - kb; - - // Get base cache. - uint128_fallback base_cache = pow10_significands[cache_index]; - if (offset == 0) return base_cache; - - // Compute the required amount of bit-shift. - int alpha - = floor_log2_pow10(kb + offset) - floor_log2_pow10(kb) - offset; - FMT_ASSERT(alpha > 0 && alpha < 64, "shifting error detected"); - - // Try to recover the real cache. - uint64_t pow5 = powers_of_5_64[offset]; - uint128_fallback recovered_cache = umul128(base_cache.high(), pow5); - uint128_fallback middle_low = umul128(base_cache.low(), pow5); - - recovered_cache += middle_low.high(); - - uint64_t high_to_middle = recovered_cache.high() << (64 - alpha); - uint64_t middle_to_low = recovered_cache.low() << (64 - alpha); - - recovered_cache = uint128_fallback { - (recovered_cache.low() >> alpha) | high_to_middle, - ((middle_low.low() >> alpha) | middle_to_low)}; - FMT_ASSERT(recovered_cache.low() + 1 != 0, ""); - return {recovered_cache.high(), recovered_cache.low() + 1}; -#endif - } - - struct compute_mul_result { - carrier_uint result; - bool is_integer; - }; - struct compute_mul_parity_result { - bool parity; - bool is_integer; - }; - - static auto compute_mul(carrier_uint u, - const cache_entry_type &cache) noexcept -> compute_mul_result { - auto r = umul192_upper128(u, cache); - return {r.high(), r.low() == 0}; - } - - static auto compute_delta(cache_entry_type const &cache, int beta) noexcept - -> uint32_t { - return static_cast(cache.high() >> (64 - 1 - beta)); - } - - static auto compute_mul_parity(carrier_uint two_f, - const cache_entry_type &cache, int beta) noexcept - -> compute_mul_parity_result { - FMT_ASSERT(beta >= 1, ""); - FMT_ASSERT(beta < 64, ""); - - auto r = umul192_lower128(two_f, cache); - return {((r.high() >> (64 - beta)) & 1) != 0, - ((r.high() << beta) | (r.low() >> (64 - beta))) == 0}; - } - - static auto compute_left_endpoint_for_shorter_interval_case( - const cache_entry_type &cache, int beta) noexcept -> carrier_uint { - return (cache.high() - - (cache.high() >> (num_significand_bits() + 2))) - >> (64 - num_significand_bits() - 1 - beta); - } - - static auto compute_right_endpoint_for_shorter_interval_case( - const cache_entry_type &cache, int beta) noexcept -> carrier_uint { - return (cache.high() - + (cache.high() >> (num_significand_bits() + 1))) - >> (64 - num_significand_bits() - 1 - beta); - } - - static auto compute_round_up_for_shorter_interval_case( - const cache_entry_type &cache, int beta) noexcept -> carrier_uint { - return ((cache.high() - >> (64 - num_significand_bits() - 2 - beta)) - + 1) - / 2; - } -}; - -FMT_FUNC auto get_cached_power(int k) noexcept -> uint128_fallback { - return cache_accessor::get_cached_power(k); -} - -// Various integer checks -template -auto is_left_endpoint_integer_shorter_interval(int exponent) noexcept -> bool { - const int case_shorter_interval_left_endpoint_lower_threshold = 2; - const int case_shorter_interval_left_endpoint_upper_threshold = 3; - return exponent >= case_shorter_interval_left_endpoint_lower_threshold - && exponent <= case_shorter_interval_left_endpoint_upper_threshold; -} - -// Remove trailing zeros from n and return the number of zeros removed (float) -FMT_INLINE int remove_trailing_zeros(uint32_t &n, int s = 0) noexcept { - FMT_ASSERT(n != 0, ""); - // Modular inverse of 5 (mod 2^32): (mod_inv_5 * 5) mod 2^32 = 1. - constexpr uint32_t mod_inv_5 = 0xcccccccd; - constexpr uint32_t mod_inv_25 = 0xc28f5c29; // = mod_inv_5 * mod_inv_5 - - while (true) { - auto q = rotr(n * mod_inv_25, 2); - if (q > max_value() / 100) break; - n = q; - s += 2; - } - auto q = rotr(n * mod_inv_5, 1); - if (q <= max_value() / 10) { - n = q; - s |= 1; - } - return s; -} - -// Removes trailing zeros and returns the number of zeros removed (double) -FMT_INLINE int remove_trailing_zeros(uint64_t &n) noexcept { - FMT_ASSERT(n != 0, ""); - - // This magic number is ceil(2^90 / 10^8). - constexpr uint64_t magic_number = 12379400392853802749ull; - auto nm = umul128(n, magic_number); - - // Is n is divisible by 10^8? - if ((nm.high() & ((1ull << (90 - 64)) - 1)) == 0 - && nm.low() < magic_number) { - // If yes, work with the quotient... - auto n32 = static_cast(nm.high() >> (90 - 64)); - // ... and use the 32 bit variant of the function - int s = remove_trailing_zeros(n32, 8); - n = n32; - return s; - } - - // If n is not divisible by 10^8, work with n itself. - constexpr uint64_t mod_inv_5 = 0xcccccccccccccccd; - constexpr uint64_t mod_inv_25 = 0x8f5c28f5c28f5c29; // mod_inv_5 * mod_inv_5 - - int s = 0; - while (true) { - auto q = rotr(n * mod_inv_25, 2); - if (q > max_value() / 100) break; - n = q; - s += 2; - } - auto q = rotr(n * mod_inv_5, 1); - if (q <= max_value() / 10) { - n = q; - s |= 1; - } - - return s; -} - -// The main algorithm for shorter interval case -template -FMT_INLINE decimal_fp shorter_interval_case(int exponent) noexcept { - decimal_fp ret_value; - // Compute k and beta - const int minus_k = floor_log10_pow2_minus_log10_4_over_3(exponent); - const int beta = exponent + floor_log2_pow10(-minus_k); - - // Compute xi and zi - using cache_entry_type = typename cache_accessor::cache_entry_type; - const cache_entry_type cache - = cache_accessor::get_cached_power(-minus_k); - - auto xi = cache_accessor< - T>::compute_left_endpoint_for_shorter_interval_case(cache, beta); - auto zi = cache_accessor< - T>::compute_right_endpoint_for_shorter_interval_case(cache, beta); - - // If the left endpoint is not an integer, increase it - if (!is_left_endpoint_integer_shorter_interval(exponent)) ++xi; - - // Try bigger divisor - ret_value.significand = zi / 10; - - // If succeed, remove trailing zeros if necessary and return - if (ret_value.significand * 10 >= xi) { - ret_value.exponent = minus_k + 1; - ret_value.exponent += remove_trailing_zeros(ret_value.significand); - return ret_value; - } - - // Otherwise, compute the round-up of y - ret_value.significand - = cache_accessor::compute_round_up_for_shorter_interval_case( - cache, beta); - ret_value.exponent = minus_k; - - // When tie occurs, choose one of them according to the rule - if (exponent >= float_info::shorter_interval_tie_lower_threshold - && exponent - <= float_info::shorter_interval_tie_upper_threshold) { - ret_value.significand = ret_value.significand % 2 == 0 - ? ret_value.significand - : ret_value.significand - 1; - } else if (ret_value.significand < xi) { - ++ret_value.significand; - } - return ret_value; -} - -template -auto to_decimal(T x) noexcept -> decimal_fp { - // Step 1: integer promotion & Schubfach multiplier calculation. - - using carrier_uint = typename float_info::carrier_uint; - using cache_entry_type = typename cache_accessor::cache_entry_type; - auto br = bit_cast(x); - - // Extract significand bits and exponent bits. - const carrier_uint significand_mask - = (static_cast(1) << num_significand_bits()) - 1; - carrier_uint significand = (br & significand_mask); - int exponent = static_cast( - (br & exponent_mask()) >> num_significand_bits()); - - if (exponent != 0) { // Check if normal. - exponent -= exponent_bias() + num_significand_bits(); - - // Shorter interval case; proceed like Schubfach. - // In fact, when exponent == 1 and significand == 0, the interval is - // regular. However, it can be shown that the end-results are anyway same. - if (significand == 0) return shorter_interval_case(exponent); - - significand - |= (static_cast(1) << num_significand_bits()); - } else { - // Subnormal case; the interval is always regular. - if (significand == 0) return {0, 0}; - exponent = std::numeric_limits::min_exponent - - num_significand_bits() - 1; - } - - const bool include_left_endpoint = (significand % 2 == 0); - const bool include_right_endpoint = include_left_endpoint; - - // Compute k and beta. - const int minus_k = floor_log10_pow2(exponent) - float_info::kappa; - const cache_entry_type cache - = cache_accessor::get_cached_power(-minus_k); - const int beta = exponent + floor_log2_pow10(-minus_k); - - // Compute zi and deltai. - // 10^kappa <= deltai < 10^(kappa + 1) - const uint32_t deltai = cache_accessor::compute_delta(cache, beta); - const carrier_uint two_fc = significand << 1; - - // For the case of binary32, the result of integer check is not correct for - // 29711844 * 2^-82 - // = 6.1442653300000000008655037797566933477355632930994033813476... * 10^-18 - // and 29711844 * 2^-81 - // = 1.2288530660000000001731007559513386695471126586198806762695... * 10^-17, - // and they are the unique counterexamples. However, since 29711844 is even, - // this does not cause any problem for the endpoints calculations; it can only - // cause a problem when we need to perform integer check for the center. - // Fortunately, with these inputs, that branch is never executed, so we are - // fine. - const typename cache_accessor::compute_mul_result z_mul - = cache_accessor::compute_mul((two_fc | 1) << beta, cache); - - // Step 2: Try larger divisor; remove trailing zeros if necessary. - - // Using an upper bound on zi, we might be able to optimize the division - // better than the compiler; we are computing zi / big_divisor here. - decimal_fp ret_value; - ret_value.significand = divide_by_10_to_kappa_plus_1(z_mul.result); - uint32_t r = static_cast( - z_mul.result - float_info::big_divisor * ret_value.significand); - - if (r < deltai) { - // Exclude the right endpoint if necessary. - if (r == 0 && (z_mul.is_integer & !include_right_endpoint)) { - --ret_value.significand; - r = float_info::big_divisor; - goto small_divisor_case_label; - } - } else if (r > deltai) { - goto small_divisor_case_label; - } else { - // r == deltai; compare fractional parts. - const typename cache_accessor::compute_mul_parity_result x_mul - = cache_accessor::compute_mul_parity( - two_fc - 1, cache, beta); - - if (!(x_mul.parity | (x_mul.is_integer & include_left_endpoint))) - goto small_divisor_case_label; - } - ret_value.exponent = minus_k + float_info::kappa + 1; - - // We may need to remove trailing zeros. - ret_value.exponent += remove_trailing_zeros(ret_value.significand); - return ret_value; - - // Step 3: Find the significand with the smaller divisor. - -small_divisor_case_label: - ret_value.significand *= 10; - ret_value.exponent = minus_k + float_info::kappa; - - uint32_t dist = r - (deltai / 2) + (float_info::small_divisor / 2); - const bool approx_y_parity - = ((dist ^ (float_info::small_divisor / 2)) & 1) != 0; - - // Is dist divisible by 10^kappa? - const bool divisible_by_small_divisor - = check_divisibility_and_divide_by_pow10::kappa>( - dist); - - // Add dist / 10^kappa to the significand. - ret_value.significand += dist; - - if (!divisible_by_small_divisor) return ret_value; - - // Check z^(f) >= epsilon^(f). - // We have either yi == zi - epsiloni or yi == (zi - epsiloni) - 1, - // where yi == zi - epsiloni if and only if z^(f) >= epsilon^(f). - // Since there are only 2 possibilities, we only need to care about the - // parity. Also, zi and r should have the same parity since the divisor - // is an even number. - const auto y_mul - = cache_accessor::compute_mul_parity(two_fc, cache, beta); - - // If z^(f) >= epsilon^(f), we might have a tie when z^(f) == epsilon^(f), - // or equivalently, when y is an integer. - if (y_mul.parity != approx_y_parity) - --ret_value.significand; - else if (y_mul.is_integer & (ret_value.significand % 2 != 0)) - --ret_value.significand; - return ret_value; -} -} // namespace dragonbox -} // namespace detail - -template <> -struct formatter { - FMT_CONSTEXPR auto parse(format_parse_context &ctx) - -> format_parse_context::iterator { - return ctx.begin(); - } - - auto format(const detail::bigint &n, format_context &ctx) const - -> format_context::iterator { - auto out = ctx.out(); - bool first = true; - for (auto i = n.bigits_.size(); i > 0; --i) { - auto value = n.bigits_[i - 1u]; - if (first) { - out = fmt::format_to(out, FMT_STRING("{:x}"), value); - first = false; - continue; - } - out = fmt::format_to(out, FMT_STRING("{:08x}"), value); - } - if (n.exp_ > 0) - out = fmt::format_to(out, FMT_STRING("p{}"), - n.exp_ * detail::bigint::bigit_bits); - return out; - } -}; - -FMT_FUNC detail::utf8_to_utf16::utf8_to_utf16(string_view s) { - for_each_codepoint(s, [this](uint32_t cp, string_view) { - if (cp == invalid_code_point) - FMT_THROW(std::runtime_error("invalid utf8")); - if (cp <= 0xFFFF) { - buffer_.push_back(static_cast(cp)); - } else { - cp -= 0x10000; - buffer_.push_back(static_cast(0xD800 + (cp >> 10))); - buffer_.push_back(static_cast(0xDC00 + (cp & 0x3FF))); - } - return true; - }); - buffer_.push_back(0); -} - -FMT_FUNC void format_system_error(detail::buffer &out, int error_code, - const char *message) noexcept { - FMT_TRY { - auto ec = std::error_code(error_code, std::generic_category()); - write(std::back_inserter(out), std::system_error(ec, message).what()); - return; - } - FMT_CATCH(...) {} - format_error_code(out, error_code, message); -} - -FMT_FUNC void report_system_error( - int error_code, const char *message) noexcept { - report_error(format_system_error, error_code, message); -} - -FMT_FUNC auto vformat(string_view fmt, format_args args) -> std::string { - // Don't optimize the "{}" case to keep the binary size small and because it - // can be better optimized in fmt::format anyway. - auto buffer = memory_buffer(); - detail::vformat_to(buffer, fmt, args); - return to_string(buffer); -} - -namespace detail { -#if !defined(_WIN32) || defined(FMT_WINDOWS_NO_WCHAR) -FMT_FUNC auto write_console(int, string_view) -> bool { - return false; -} -FMT_FUNC auto write_console(std::FILE *, string_view) -> bool { - return false; -} -#else -using dword = conditional_t; -extern "C" __declspec(dllimport) int __stdcall WriteConsoleW( // - void *, const void *, dword, dword *, void *); - -FMT_FUNC bool write_console(int fd, string_view text) { - auto u16 = utf8_to_utf16(text); - return WriteConsoleW(reinterpret_cast(_get_osfhandle(fd)), - u16.c_str(), static_cast(u16.size()), nullptr, - nullptr) - != 0; -} - -FMT_FUNC auto write_console(std::FILE *f, string_view text) -> bool { - return write_console(_fileno(f), text); -} -#endif - -#ifdef _WIN32 -// Print assuming legacy (non-Unicode) encoding. -FMT_FUNC void vprint_mojibake(std::FILE *f, string_view fmt, format_args args) { - auto buffer = memory_buffer(); - detail::vformat_to(buffer, fmt, args); - fwrite_fully(buffer.data(), buffer.size(), f); -} -#endif - -FMT_FUNC void print(std::FILE *f, string_view text) { -#ifdef _WIN32 - int fd = _fileno(f); - if (_isatty(fd)) { - std::fflush(f); - if (write_console(fd, text)) return; - } -#endif - fwrite_fully(text.data(), text.size(), f); -} -} // namespace detail - -FMT_FUNC void vprint(std::FILE *f, string_view fmt, format_args args) { - auto buffer = memory_buffer(); - detail::vformat_to(buffer, fmt, args); - detail::print(f, {buffer.data(), buffer.size()}); -} - -FMT_FUNC void vprint(string_view fmt, format_args args) { - vprint(stdout, fmt, args); -} - -namespace detail { - -struct singleton { - unsigned char upper; - unsigned char lower_count; -}; - -inline auto is_printable(uint16_t x, const singleton *singletons, - size_t singletons_size, const unsigned char *singleton_lowers, - const unsigned char *normal, size_t normal_size) -> bool { - auto upper = x >> 8; - auto lower_start = 0; - for (size_t i = 0; i < singletons_size; ++i) { - auto s = singletons[i]; - auto lower_end = lower_start + s.lower_count; - if (upper < s.upper) break; - if (upper == s.upper) { - for (auto j = lower_start; j < lower_end; ++j) { - if (singleton_lowers[j] == (x & 0xff)) return false; - } - } - lower_start = lower_end; - } - - auto xsigned = static_cast(x); - auto current = true; - for (size_t i = 0; i < normal_size; ++i) { - auto v = static_cast(normal[i]); - auto len = (v & 0x80) != 0 ? (v & 0x7f) << 8 | normal[++i] : v; - xsigned -= len; - if (xsigned < 0) break; - current = !current; - } - return current; -} - -// This code is generated by support/printable.py. -FMT_FUNC auto is_printable(uint32_t cp) -> bool { - static constexpr singleton singletons0[] = { - {0x00, 1}, - {0x03, 5}, - {0x05, 6}, - {0x06, 3}, - {0x07, 6}, - {0x08, 8}, - {0x09, 17}, - {0x0a, 28}, - {0x0b, 25}, - {0x0c, 20}, - {0x0d, 16}, - {0x0e, 13}, - {0x0f, 4}, - {0x10, 3}, - {0x12, 18}, - {0x13, 9}, - {0x16, 1}, - {0x17, 5}, - {0x18, 2}, - {0x19, 3}, - {0x1a, 7}, - {0x1c, 2}, - {0x1d, 1}, - {0x1f, 22}, - {0x20, 3}, - {0x2b, 3}, - {0x2c, 2}, - {0x2d, 11}, - {0x2e, 1}, - {0x30, 3}, - {0x31, 2}, - {0x32, 1}, - {0xa7, 2}, - {0xa9, 2}, - {0xaa, 4}, - {0xab, 8}, - {0xfa, 2}, - {0xfb, 5}, - {0xfd, 4}, - {0xfe, 3}, - {0xff, 9}, - }; - static constexpr unsigned char singletons0_lower[] = { - 0xad, - 0x78, - 0x79, - 0x8b, - 0x8d, - 0xa2, - 0x30, - 0x57, - 0x58, - 0x8b, - 0x8c, - 0x90, - 0x1c, - 0x1d, - 0xdd, - 0x0e, - 0x0f, - 0x4b, - 0x4c, - 0xfb, - 0xfc, - 0x2e, - 0x2f, - 0x3f, - 0x5c, - 0x5d, - 0x5f, - 0xb5, - 0xe2, - 0x84, - 0x8d, - 0x8e, - 0x91, - 0x92, - 0xa9, - 0xb1, - 0xba, - 0xbb, - 0xc5, - 0xc6, - 0xc9, - 0xca, - 0xde, - 0xe4, - 0xe5, - 0xff, - 0x00, - 0x04, - 0x11, - 0x12, - 0x29, - 0x31, - 0x34, - 0x37, - 0x3a, - 0x3b, - 0x3d, - 0x49, - 0x4a, - 0x5d, - 0x84, - 0x8e, - 0x92, - 0xa9, - 0xb1, - 0xb4, - 0xba, - 0xbb, - 0xc6, - 0xca, - 0xce, - 0xcf, - 0xe4, - 0xe5, - 0x00, - 0x04, - 0x0d, - 0x0e, - 0x11, - 0x12, - 0x29, - 0x31, - 0x34, - 0x3a, - 0x3b, - 0x45, - 0x46, - 0x49, - 0x4a, - 0x5e, - 0x64, - 0x65, - 0x84, - 0x91, - 0x9b, - 0x9d, - 0xc9, - 0xce, - 0xcf, - 0x0d, - 0x11, - 0x29, - 0x45, - 0x49, - 0x57, - 0x64, - 0x65, - 0x8d, - 0x91, - 0xa9, - 0xb4, - 0xba, - 0xbb, - 0xc5, - 0xc9, - 0xdf, - 0xe4, - 0xe5, - 0xf0, - 0x0d, - 0x11, - 0x45, - 0x49, - 0x64, - 0x65, - 0x80, - 0x84, - 0xb2, - 0xbc, - 0xbe, - 0xbf, - 0xd5, - 0xd7, - 0xf0, - 0xf1, - 0x83, - 0x85, - 0x8b, - 0xa4, - 0xa6, - 0xbe, - 0xbf, - 0xc5, - 0xc7, - 0xce, - 0xcf, - 0xda, - 0xdb, - 0x48, - 0x98, - 0xbd, - 0xcd, - 0xc6, - 0xce, - 0xcf, - 0x49, - 0x4e, - 0x4f, - 0x57, - 0x59, - 0x5e, - 0x5f, - 0x89, - 0x8e, - 0x8f, - 0xb1, - 0xb6, - 0xb7, - 0xbf, - 0xc1, - 0xc6, - 0xc7, - 0xd7, - 0x11, - 0x16, - 0x17, - 0x5b, - 0x5c, - 0xf6, - 0xf7, - 0xfe, - 0xff, - 0x80, - 0x0d, - 0x6d, - 0x71, - 0xde, - 0xdf, - 0x0e, - 0x0f, - 0x1f, - 0x6e, - 0x6f, - 0x1c, - 0x1d, - 0x5f, - 0x7d, - 0x7e, - 0xae, - 0xaf, - 0xbb, - 0xbc, - 0xfa, - 0x16, - 0x17, - 0x1e, - 0x1f, - 0x46, - 0x47, - 0x4e, - 0x4f, - 0x58, - 0x5a, - 0x5c, - 0x5e, - 0x7e, - 0x7f, - 0xb5, - 0xc5, - 0xd4, - 0xd5, - 0xdc, - 0xf0, - 0xf1, - 0xf5, - 0x72, - 0x73, - 0x8f, - 0x74, - 0x75, - 0x96, - 0x2f, - 0x5f, - 0x26, - 0x2e, - 0x2f, - 0xa7, - 0xaf, - 0xb7, - 0xbf, - 0xc7, - 0xcf, - 0xd7, - 0xdf, - 0x9a, - 0x40, - 0x97, - 0x98, - 0x30, - 0x8f, - 0x1f, - 0xc0, - 0xc1, - 0xce, - 0xff, - 0x4e, - 0x4f, - 0x5a, - 0x5b, - 0x07, - 0x08, - 0x0f, - 0x10, - 0x27, - 0x2f, - 0xee, - 0xef, - 0x6e, - 0x6f, - 0x37, - 0x3d, - 0x3f, - 0x42, - 0x45, - 0x90, - 0x91, - 0xfe, - 0xff, - 0x53, - 0x67, - 0x75, - 0xc8, - 0xc9, - 0xd0, - 0xd1, - 0xd8, - 0xd9, - 0xe7, - 0xfe, - 0xff, - }; - static constexpr singleton singletons1[] = { - {0x00, 6}, - {0x01, 1}, - {0x03, 1}, - {0x04, 2}, - {0x08, 8}, - {0x09, 2}, - {0x0a, 5}, - {0x0b, 2}, - {0x0e, 4}, - {0x10, 1}, - {0x11, 2}, - {0x12, 5}, - {0x13, 17}, - {0x14, 1}, - {0x15, 2}, - {0x17, 2}, - {0x19, 13}, - {0x1c, 5}, - {0x1d, 8}, - {0x24, 1}, - {0x6a, 3}, - {0x6b, 2}, - {0xbc, 2}, - {0xd1, 2}, - {0xd4, 12}, - {0xd5, 9}, - {0xd6, 2}, - {0xd7, 2}, - {0xda, 1}, - {0xe0, 5}, - {0xe1, 2}, - {0xe8, 2}, - {0xee, 32}, - {0xf0, 4}, - {0xf8, 2}, - {0xf9, 2}, - {0xfa, 2}, - {0xfb, 1}, - }; - static constexpr unsigned char singletons1_lower[] = { - 0x0c, - 0x27, - 0x3b, - 0x3e, - 0x4e, - 0x4f, - 0x8f, - 0x9e, - 0x9e, - 0x9f, - 0x06, - 0x07, - 0x09, - 0x36, - 0x3d, - 0x3e, - 0x56, - 0xf3, - 0xd0, - 0xd1, - 0x04, - 0x14, - 0x18, - 0x36, - 0x37, - 0x56, - 0x57, - 0x7f, - 0xaa, - 0xae, - 0xaf, - 0xbd, - 0x35, - 0xe0, - 0x12, - 0x87, - 0x89, - 0x8e, - 0x9e, - 0x04, - 0x0d, - 0x0e, - 0x11, - 0x12, - 0x29, - 0x31, - 0x34, - 0x3a, - 0x45, - 0x46, - 0x49, - 0x4a, - 0x4e, - 0x4f, - 0x64, - 0x65, - 0x5c, - 0xb6, - 0xb7, - 0x1b, - 0x1c, - 0x07, - 0x08, - 0x0a, - 0x0b, - 0x14, - 0x17, - 0x36, - 0x39, - 0x3a, - 0xa8, - 0xa9, - 0xd8, - 0xd9, - 0x09, - 0x37, - 0x90, - 0x91, - 0xa8, - 0x07, - 0x0a, - 0x3b, - 0x3e, - 0x66, - 0x69, - 0x8f, - 0x92, - 0x6f, - 0x5f, - 0xee, - 0xef, - 0x5a, - 0x62, - 0x9a, - 0x9b, - 0x27, - 0x28, - 0x55, - 0x9d, - 0xa0, - 0xa1, - 0xa3, - 0xa4, - 0xa7, - 0xa8, - 0xad, - 0xba, - 0xbc, - 0xc4, - 0x06, - 0x0b, - 0x0c, - 0x15, - 0x1d, - 0x3a, - 0x3f, - 0x45, - 0x51, - 0xa6, - 0xa7, - 0xcc, - 0xcd, - 0xa0, - 0x07, - 0x19, - 0x1a, - 0x22, - 0x25, - 0x3e, - 0x3f, - 0xc5, - 0xc6, - 0x04, - 0x20, - 0x23, - 0x25, - 0x26, - 0x28, - 0x33, - 0x38, - 0x3a, - 0x48, - 0x4a, - 0x4c, - 0x50, - 0x53, - 0x55, - 0x56, - 0x58, - 0x5a, - 0x5c, - 0x5e, - 0x60, - 0x63, - 0x65, - 0x66, - 0x6b, - 0x73, - 0x78, - 0x7d, - 0x7f, - 0x8a, - 0xa4, - 0xaa, - 0xaf, - 0xb0, - 0xc0, - 0xd0, - 0xae, - 0xaf, - 0x79, - 0xcc, - 0x6e, - 0x6f, - 0x93, - }; - static constexpr unsigned char normal0[] = { - 0x00, - 0x20, - 0x5f, - 0x22, - 0x82, - 0xdf, - 0x04, - 0x82, - 0x44, - 0x08, - 0x1b, - 0x04, - 0x06, - 0x11, - 0x81, - 0xac, - 0x0e, - 0x80, - 0xab, - 0x35, - 0x28, - 0x0b, - 0x80, - 0xe0, - 0x03, - 0x19, - 0x08, - 0x01, - 0x04, - 0x2f, - 0x04, - 0x34, - 0x04, - 0x07, - 0x03, - 0x01, - 0x07, - 0x06, - 0x07, - 0x11, - 0x0a, - 0x50, - 0x0f, - 0x12, - 0x07, - 0x55, - 0x07, - 0x03, - 0x04, - 0x1c, - 0x0a, - 0x09, - 0x03, - 0x08, - 0x03, - 0x07, - 0x03, - 0x02, - 0x03, - 0x03, - 0x03, - 0x0c, - 0x04, - 0x05, - 0x03, - 0x0b, - 0x06, - 0x01, - 0x0e, - 0x15, - 0x05, - 0x3a, - 0x03, - 0x11, - 0x07, - 0x06, - 0x05, - 0x10, - 0x07, - 0x57, - 0x07, - 0x02, - 0x07, - 0x15, - 0x0d, - 0x50, - 0x04, - 0x43, - 0x03, - 0x2d, - 0x03, - 0x01, - 0x04, - 0x11, - 0x06, - 0x0f, - 0x0c, - 0x3a, - 0x04, - 0x1d, - 0x25, - 0x5f, - 0x20, - 0x6d, - 0x04, - 0x6a, - 0x25, - 0x80, - 0xc8, - 0x05, - 0x82, - 0xb0, - 0x03, - 0x1a, - 0x06, - 0x82, - 0xfd, - 0x03, - 0x59, - 0x07, - 0x15, - 0x0b, - 0x17, - 0x09, - 0x14, - 0x0c, - 0x14, - 0x0c, - 0x6a, - 0x06, - 0x0a, - 0x06, - 0x1a, - 0x06, - 0x59, - 0x07, - 0x2b, - 0x05, - 0x46, - 0x0a, - 0x2c, - 0x04, - 0x0c, - 0x04, - 0x01, - 0x03, - 0x31, - 0x0b, - 0x2c, - 0x04, - 0x1a, - 0x06, - 0x0b, - 0x03, - 0x80, - 0xac, - 0x06, - 0x0a, - 0x06, - 0x21, - 0x3f, - 0x4c, - 0x04, - 0x2d, - 0x03, - 0x74, - 0x08, - 0x3c, - 0x03, - 0x0f, - 0x03, - 0x3c, - 0x07, - 0x38, - 0x08, - 0x2b, - 0x05, - 0x82, - 0xff, - 0x11, - 0x18, - 0x08, - 0x2f, - 0x11, - 0x2d, - 0x03, - 0x20, - 0x10, - 0x21, - 0x0f, - 0x80, - 0x8c, - 0x04, - 0x82, - 0x97, - 0x19, - 0x0b, - 0x15, - 0x88, - 0x94, - 0x05, - 0x2f, - 0x05, - 0x3b, - 0x07, - 0x02, - 0x0e, - 0x18, - 0x09, - 0x80, - 0xb3, - 0x2d, - 0x74, - 0x0c, - 0x80, - 0xd6, - 0x1a, - 0x0c, - 0x05, - 0x80, - 0xff, - 0x05, - 0x80, - 0xdf, - 0x0c, - 0xee, - 0x0d, - 0x03, - 0x84, - 0x8d, - 0x03, - 0x37, - 0x09, - 0x81, - 0x5c, - 0x14, - 0x80, - 0xb8, - 0x08, - 0x80, - 0xcb, - 0x2a, - 0x38, - 0x03, - 0x0a, - 0x06, - 0x38, - 0x08, - 0x46, - 0x08, - 0x0c, - 0x06, - 0x74, - 0x0b, - 0x1e, - 0x03, - 0x5a, - 0x04, - 0x59, - 0x09, - 0x80, - 0x83, - 0x18, - 0x1c, - 0x0a, - 0x16, - 0x09, - 0x4c, - 0x04, - 0x80, - 0x8a, - 0x06, - 0xab, - 0xa4, - 0x0c, - 0x17, - 0x04, - 0x31, - 0xa1, - 0x04, - 0x81, - 0xda, - 0x26, - 0x07, - 0x0c, - 0x05, - 0x05, - 0x80, - 0xa5, - 0x11, - 0x81, - 0x6d, - 0x10, - 0x78, - 0x28, - 0x2a, - 0x06, - 0x4c, - 0x04, - 0x80, - 0x8d, - 0x04, - 0x80, - 0xbe, - 0x03, - 0x1b, - 0x03, - 0x0f, - 0x0d, - }; - static constexpr unsigned char normal1[] = { - 0x5e, - 0x22, - 0x7b, - 0x05, - 0x03, - 0x04, - 0x2d, - 0x03, - 0x66, - 0x03, - 0x01, - 0x2f, - 0x2e, - 0x80, - 0x82, - 0x1d, - 0x03, - 0x31, - 0x0f, - 0x1c, - 0x04, - 0x24, - 0x09, - 0x1e, - 0x05, - 0x2b, - 0x05, - 0x44, - 0x04, - 0x0e, - 0x2a, - 0x80, - 0xaa, - 0x06, - 0x24, - 0x04, - 0x24, - 0x04, - 0x28, - 0x08, - 0x34, - 0x0b, - 0x01, - 0x80, - 0x90, - 0x81, - 0x37, - 0x09, - 0x16, - 0x0a, - 0x08, - 0x80, - 0x98, - 0x39, - 0x03, - 0x63, - 0x08, - 0x09, - 0x30, - 0x16, - 0x05, - 0x21, - 0x03, - 0x1b, - 0x05, - 0x01, - 0x40, - 0x38, - 0x04, - 0x4b, - 0x05, - 0x2f, - 0x04, - 0x0a, - 0x07, - 0x09, - 0x07, - 0x40, - 0x20, - 0x27, - 0x04, - 0x0c, - 0x09, - 0x36, - 0x03, - 0x3a, - 0x05, - 0x1a, - 0x07, - 0x04, - 0x0c, - 0x07, - 0x50, - 0x49, - 0x37, - 0x33, - 0x0d, - 0x33, - 0x07, - 0x2e, - 0x08, - 0x0a, - 0x81, - 0x26, - 0x52, - 0x4e, - 0x28, - 0x08, - 0x2a, - 0x56, - 0x1c, - 0x14, - 0x17, - 0x09, - 0x4e, - 0x04, - 0x1e, - 0x0f, - 0x43, - 0x0e, - 0x19, - 0x07, - 0x0a, - 0x06, - 0x48, - 0x08, - 0x27, - 0x09, - 0x75, - 0x0b, - 0x3f, - 0x41, - 0x2a, - 0x06, - 0x3b, - 0x05, - 0x0a, - 0x06, - 0x51, - 0x06, - 0x01, - 0x05, - 0x10, - 0x03, - 0x05, - 0x80, - 0x8b, - 0x62, - 0x1e, - 0x48, - 0x08, - 0x0a, - 0x80, - 0xa6, - 0x5e, - 0x22, - 0x45, - 0x0b, - 0x0a, - 0x06, - 0x0d, - 0x13, - 0x39, - 0x07, - 0x0a, - 0x36, - 0x2c, - 0x04, - 0x10, - 0x80, - 0xc0, - 0x3c, - 0x64, - 0x53, - 0x0c, - 0x48, - 0x09, - 0x0a, - 0x46, - 0x45, - 0x1b, - 0x48, - 0x08, - 0x53, - 0x1d, - 0x39, - 0x81, - 0x07, - 0x46, - 0x0a, - 0x1d, - 0x03, - 0x47, - 0x49, - 0x37, - 0x03, - 0x0e, - 0x08, - 0x0a, - 0x06, - 0x39, - 0x07, - 0x0a, - 0x81, - 0x36, - 0x19, - 0x80, - 0xb7, - 0x01, - 0x0f, - 0x32, - 0x0d, - 0x83, - 0x9b, - 0x66, - 0x75, - 0x0b, - 0x80, - 0xc4, - 0x8a, - 0xbc, - 0x84, - 0x2f, - 0x8f, - 0xd1, - 0x82, - 0x47, - 0xa1, - 0xb9, - 0x82, - 0x39, - 0x07, - 0x2a, - 0x04, - 0x02, - 0x60, - 0x26, - 0x0a, - 0x46, - 0x0a, - 0x28, - 0x05, - 0x13, - 0x82, - 0xb0, - 0x5b, - 0x65, - 0x4b, - 0x04, - 0x39, - 0x07, - 0x11, - 0x40, - 0x05, - 0x0b, - 0x02, - 0x0e, - 0x97, - 0xf8, - 0x08, - 0x84, - 0xd6, - 0x2a, - 0x09, - 0xa2, - 0xf7, - 0x81, - 0x1f, - 0x31, - 0x03, - 0x11, - 0x04, - 0x08, - 0x81, - 0x8c, - 0x89, - 0x04, - 0x6b, - 0x05, - 0x0d, - 0x03, - 0x09, - 0x07, - 0x10, - 0x93, - 0x60, - 0x80, - 0xf6, - 0x0a, - 0x73, - 0x08, - 0x6e, - 0x17, - 0x46, - 0x80, - 0x9a, - 0x14, - 0x0c, - 0x57, - 0x09, - 0x19, - 0x80, - 0x87, - 0x81, - 0x47, - 0x03, - 0x85, - 0x42, - 0x0f, - 0x15, - 0x85, - 0x50, - 0x2b, - 0x80, - 0xd5, - 0x2d, - 0x03, - 0x1a, - 0x04, - 0x02, - 0x81, - 0x70, - 0x3a, - 0x05, - 0x01, - 0x85, - 0x00, - 0x80, - 0xd7, - 0x29, - 0x4c, - 0x04, - 0x0a, - 0x04, - 0x02, - 0x83, - 0x11, - 0x44, - 0x4c, - 0x3d, - 0x80, - 0xc2, - 0x3c, - 0x06, - 0x01, - 0x04, - 0x55, - 0x05, - 0x1b, - 0x34, - 0x02, - 0x81, - 0x0e, - 0x2c, - 0x04, - 0x64, - 0x0c, - 0x56, - 0x0a, - 0x80, - 0xae, - 0x38, - 0x1d, - 0x0d, - 0x2c, - 0x04, - 0x09, - 0x07, - 0x02, - 0x0e, - 0x06, - 0x80, - 0x9a, - 0x83, - 0xd8, - 0x08, - 0x0d, - 0x03, - 0x0d, - 0x03, - 0x74, - 0x0c, - 0x59, - 0x07, - 0x0c, - 0x14, - 0x0c, - 0x04, - 0x38, - 0x08, - 0x0a, - 0x06, - 0x28, - 0x08, - 0x22, - 0x4e, - 0x81, - 0x54, - 0x0c, - 0x15, - 0x03, - 0x03, - 0x05, - 0x07, - 0x09, - 0x19, - 0x07, - 0x07, - 0x09, - 0x03, - 0x0d, - 0x07, - 0x29, - 0x80, - 0xcb, - 0x25, - 0x0a, - 0x84, - 0x06, - }; - auto lower = static_cast(cp); - if (cp < 0x10000) { - return is_printable(lower, singletons0, - sizeof(singletons0) / sizeof(*singletons0), singletons0_lower, - normal0, sizeof(normal0)); - } - if (cp < 0x20000) { - return is_printable(lower, singletons1, - sizeof(singletons1) / sizeof(*singletons1), singletons1_lower, - normal1, sizeof(normal1)); - } - if (0x2a6de <= cp && cp < 0x2a700) return false; - if (0x2b735 <= cp && cp < 0x2b740) return false; - if (0x2b81e <= cp && cp < 0x2b820) return false; - if (0x2cea2 <= cp && cp < 0x2ceb0) return false; - if (0x2ebe1 <= cp && cp < 0x2f800) return false; - if (0x2fa1e <= cp && cp < 0x30000) return false; - if (0x3134b <= cp && cp < 0xe0100) return false; - if (0xe01f0 <= cp && cp < 0x110000) return false; - return cp < 0x110000; -} - -} // namespace detail - -FMT_END_NAMESPACE - -#endif // FMT_FORMAT_INL_H_ diff --git a/src/common/spdlog/fmt/bundled/format.h b/src/common/spdlog/fmt/bundled/format.h deleted file mode 100755 index c8e36554fc2..00000000000 --- a/src/common/spdlog/fmt/bundled/format.h +++ /dev/null @@ -1,4664 +0,0 @@ -/******************************************************************************* -* Copyright 2024 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - Formatting library for C++ - - Copyright (c) 2012 - present, Victor Zverovich - - Permission is hereby granted, free of charge, to any person obtaining - a copy of this software and associated documentation files (the - "Software"), to deal in the Software without restriction, including - without limitation the rights to use, copy, modify, merge, publish, - distribute, sublicense, and/or sell copies of the Software, and to - permit persons to whom the Software is furnished to do so, subject to - the following conditions: - - The above copyright notice and this permission notice shall be - included in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE - LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION - OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION - WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - --- Optional exception to the license --- - - As an exception, if, as a result of your compiling your source code, portions - of this Software are embedded into a machine-executable object form of such - source code, you may redistribute such embedded portions in such object form - without including the above copyright and permission notices. - */ - -#ifndef FMT_FORMAT_H_ -#define FMT_FORMAT_H_ - -#include // std::signbit -#include // uint32_t -#include // std::memcpy -#include // std::numeric_limits -#include // std::uninitialized_copy -#include // std::runtime_error -#include // std::initializer_list -#include // std::system_error - -#ifdef __cpp_lib_bit_cast -#include // std::bit_cast -#endif - -#include "common/spdlog/fmt/bundled/core.h" - -#if defined __cpp_inline_variables && __cpp_inline_variables >= 201606L -#define FMT_INLINE_VARIABLE inline -#else -#define FMT_INLINE_VARIABLE -#endif - -#if FMT_HAS_CPP17_ATTRIBUTE(fallthrough) -#define FMT_FALLTHROUGH [[fallthrough]] -#elif defined(__clang__) -#define FMT_FALLTHROUGH [[clang::fallthrough]] -#elif FMT_GCC_VERSION >= 700 \ - && (!defined(__EDG_VERSION__) || __EDG_VERSION__ >= 520) -#define FMT_FALLTHROUGH [[gnu::fallthrough]] -#else -#define FMT_FALLTHROUGH -#endif - -#ifndef FMT_DEPRECATED -#if FMT_HAS_CPP14_ATTRIBUTE(deprecated) || FMT_MSC_VERSION >= 1900 -#define FMT_DEPRECATED [[deprecated]] -#else -#if (defined(__GNUC__) && !defined(__LCC__)) || defined(__clang__) -#define FMT_DEPRECATED __attribute__((deprecated)) -#elif FMT_MSC_VERSION -#define FMT_DEPRECATED __declspec(deprecated) -#else -#define FMT_DEPRECATED /* deprecated */ -#endif -#endif -#endif - -#ifndef FMT_NO_UNIQUE_ADDRESS -#if FMT_CPLUSPLUS >= 202002L -#if FMT_HAS_CPP_ATTRIBUTE(no_unique_address) -#define FMT_NO_UNIQUE_ADDRESS [[no_unique_address]] -// VS2019 v16.10 and later except clang-cl (https://reviews.llvm.org/D110485) -#elif (FMT_MSC_VERSION >= 1929) && !FMT_CLANG_VERSION -#define FMT_NO_UNIQUE_ADDRESS [[msvc::no_unique_address]] -#endif -#endif -#endif -#ifndef FMT_NO_UNIQUE_ADDRESS -#define FMT_NO_UNIQUE_ADDRESS -#endif - -// Visibility when compiled as a shared library/object. -#if defined(FMT_LIB_EXPORT) || defined(FMT_SHARED) -#define FMT_SO_VISIBILITY(value) FMT_VISIBILITY(value) -#else -#define FMT_SO_VISIBILITY(value) -#endif - -#ifdef __has_builtin -#define FMT_HAS_BUILTIN(x) __has_builtin(x) -#else -#define FMT_HAS_BUILTIN(x) 0 -#endif - -#if FMT_GCC_VERSION || FMT_CLANG_VERSION -#define FMT_NOINLINE __attribute__((noinline)) -#else -#define FMT_NOINLINE -#endif - -#ifndef FMT_THROW -#if FMT_EXCEPTIONS -#if FMT_MSC_VERSION || defined(__NVCC__) -FMT_BEGIN_NAMESPACE -namespace detail { -template -inline void do_throw(const Exception &x) { - // Silence unreachable code warnings in MSVC and NVCC because these - // are nearly impossible to fix in a generic code. - volatile bool b = true; - if (b) throw x; -} -} // namespace detail -FMT_END_NAMESPACE -#define FMT_THROW(x) detail::do_throw(x) -#else -#define FMT_THROW(x) throw x -#endif -#else -#define FMT_THROW(x) ::fmt::detail::assert_fail(__FILE__, __LINE__, (x).what()) -#endif -#endif - -#if FMT_EXCEPTIONS -#define FMT_TRY try -#define FMT_CATCH(x) catch (x) -#else -#define FMT_TRY if (true) -#define FMT_CATCH(x) if (false) -#endif - -#ifndef FMT_MAYBE_UNUSED -#if FMT_HAS_CPP17_ATTRIBUTE(maybe_unused) -#define FMT_MAYBE_UNUSED [[maybe_unused]] -#else -#define FMT_MAYBE_UNUSED -#endif -#endif - -#ifndef FMT_USE_USER_DEFINED_LITERALS -// EDG based compilers (Intel, NVIDIA, Elbrus, etc), GCC and MSVC support UDLs. -// -// GCC before 4.9 requires a space in `operator"" _a` which is invalid in later -// compiler versions. -#if (FMT_HAS_FEATURE(cxx_user_literals) || FMT_GCC_VERSION >= 409 \ - || FMT_MSC_VERSION >= 1900) \ - && (!defined(__EDG_VERSION__) \ - || __EDG_VERSION__ >= /* UDL feature */ 480) -#define FMT_USE_USER_DEFINED_LITERALS 1 -#else -#define FMT_USE_USER_DEFINED_LITERALS 0 -#endif -#endif - -// Defining FMT_REDUCE_INT_INSTANTIATIONS to 1, will reduce the number of -// integer formatter template instantiations to just one by only using the -// largest integer type. This results in a reduction in binary size but will -// cause a decrease in integer formatting performance. -#if !defined(FMT_REDUCE_INT_INSTANTIATIONS) -#define FMT_REDUCE_INT_INSTANTIATIONS 0 -#endif - -// __builtin_clz is broken in clang with Microsoft CodeGen: -// https://github.com/fmtlib/fmt/issues/519. -#if !FMT_MSC_VERSION -#if FMT_HAS_BUILTIN(__builtin_clz) || FMT_GCC_VERSION || FMT_ICC_VERSION -#define FMT_BUILTIN_CLZ(n) __builtin_clz(n) -#endif -#if FMT_HAS_BUILTIN(__builtin_clzll) || FMT_GCC_VERSION || FMT_ICC_VERSION -#define FMT_BUILTIN_CLZLL(n) __builtin_clzll(n) -#endif -#endif - -// __builtin_ctz is broken in Intel Compiler Classic on Windows: -// https://github.com/fmtlib/fmt/issues/2510. -#ifndef __ICL -#if FMT_HAS_BUILTIN(__builtin_ctz) || FMT_GCC_VERSION || FMT_ICC_VERSION \ - || defined(__NVCOMPILER) -#define FMT_BUILTIN_CTZ(n) __builtin_ctz(n) -#endif -#if FMT_HAS_BUILTIN(__builtin_ctzll) || FMT_GCC_VERSION || FMT_ICC_VERSION \ - || defined(__NVCOMPILER) -#define FMT_BUILTIN_CTZLL(n) __builtin_ctzll(n) -#endif -#endif - -#if FMT_MSC_VERSION -#include // _BitScanReverse[64], _BitScanForward[64], _umul128 -#endif - -// Some compilers masquerade as both MSVC and GCC-likes or otherwise support -// __builtin_clz and __builtin_clzll, so only define FMT_BUILTIN_CLZ using the -// MSVC intrinsics if the clz and clzll builtins are not available. -#if FMT_MSC_VERSION && !defined(FMT_BUILTIN_CLZLL) \ - && !defined(FMT_BUILTIN_CTZLL) -FMT_BEGIN_NAMESPACE -namespace detail { -// Avoid Clang with Microsoft CodeGen's -Wunknown-pragmas warning. -#if !defined(__clang__) -#pragma intrinsic(_BitScanForward) -#pragma intrinsic(_BitScanReverse) -#if defined(_WIN64) -#pragma intrinsic(_BitScanForward64) -#pragma intrinsic(_BitScanReverse64) -#endif -#endif - -inline auto clz(uint32_t x) -> int { - unsigned long r = 0; - _BitScanReverse(&r, x); - FMT_ASSERT(x != 0, ""); - // Static analysis complains about using uninitialized data - // "r", but the only way that can happen is if "x" is 0, - // which the callers guarantee to not happen. - FMT_MSC_WARNING(suppress : 6102) - return 31 ^ static_cast(r); -} -#define FMT_BUILTIN_CLZ(n) detail::clz(n) - -inline auto clzll(uint64_t x) -> int { - unsigned long r = 0; -#ifdef _WIN64 - _BitScanReverse64(&r, x); -#else - // Scan the high 32 bits. - if (_BitScanReverse(&r, static_cast(x >> 32))) - return 63 ^ static_cast(r + 32); - // Scan the low 32 bits. - _BitScanReverse(&r, static_cast(x)); -#endif - FMT_ASSERT(x != 0, ""); - FMT_MSC_WARNING( - suppress : 6102) // Suppress a bogus static analysis warning. - return 63 ^ static_cast(r); -} -#define FMT_BUILTIN_CLZLL(n) detail::clzll(n) - -inline auto ctz(uint32_t x) -> int { - unsigned long r = 0; - _BitScanForward(&r, x); - FMT_ASSERT(x != 0, ""); - FMT_MSC_WARNING( - suppress : 6102) // Suppress a bogus static analysis warning. - return static_cast(r); -} -#define FMT_BUILTIN_CTZ(n) detail::ctz(n) - -inline auto ctzll(uint64_t x) -> int { - unsigned long r = 0; - FMT_ASSERT(x != 0, ""); - FMT_MSC_WARNING( - suppress : 6102) // Suppress a bogus static analysis warning. -#ifdef _WIN64 - _BitScanForward64(&r, x); -#else - // Scan the low 32 bits. - if (_BitScanForward(&r, static_cast(x))) - return static_cast(r); - // Scan the high 32 bits. - _BitScanForward(&r, static_cast(x >> 32)); - r += 32; -#endif - return static_cast(r); -} -#define FMT_BUILTIN_CTZLL(n) detail::ctzll(n) -} // namespace detail -FMT_END_NAMESPACE -#endif - -FMT_BEGIN_NAMESPACE -namespace detail { - -FMT_CONSTEXPR inline void abort_fuzzing_if(bool condition) { - ignore_unused(condition); -#ifdef FMT_FUZZ - if (condition) throw std::runtime_error("fuzzing limit reached"); -#endif -} - -template -struct string_literal { - static constexpr CharT value[sizeof...(C)] = {C...}; - constexpr operator basic_string_view() const { - return {value, sizeof...(C)}; - } -}; - -#if FMT_CPLUSPLUS < 201703L -template -constexpr CharT string_literal::value[sizeof...(C)]; -#endif - -// Implementation of std::bit_cast for pre-C++20. -template -FMT_CONSTEXPR20 auto bit_cast(const From &from) -> To { -#ifdef __cpp_lib_bit_cast - if (is_constant_evaluated()) return std::bit_cast(from); -#endif - auto to = To(); - // The cast suppresses a bogus -Wclass-memaccess on GCC. - std::memcpy(static_cast(&to), &from, sizeof(to)); - return to; -} - -inline auto is_big_endian() -> bool { -#ifdef _WIN32 - return false; -#elif defined(__BIG_ENDIAN__) - return true; -#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) - return __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__; -#else - struct bytes { - char data[sizeof(int)]; - }; - return bit_cast(1).data[0] == 0; -#endif -} - -class uint128_fallback { -private: - uint64_t lo_, hi_; - -public: - constexpr uint128_fallback(uint64_t hi, uint64_t lo) : lo_(lo), hi_(hi) {} - constexpr uint128_fallback(uint64_t value = 0) : lo_(value), hi_(0) {} - - constexpr auto high() const noexcept -> uint64_t { return hi_; } - constexpr auto low() const noexcept -> uint64_t { return lo_; } - - template ::value)> - constexpr explicit operator T() const { - return static_cast(lo_); - } - - friend constexpr auto operator==( - const uint128_fallback &lhs, const uint128_fallback &rhs) -> bool { - return lhs.hi_ == rhs.hi_ && lhs.lo_ == rhs.lo_; - } - friend constexpr auto operator!=( - const uint128_fallback &lhs, const uint128_fallback &rhs) -> bool { - return !(lhs == rhs); - } - friend constexpr auto operator>( - const uint128_fallback &lhs, const uint128_fallback &rhs) -> bool { - return lhs.hi_ != rhs.hi_ ? lhs.hi_ > rhs.hi_ : lhs.lo_ > rhs.lo_; - } - friend constexpr auto operator|(const uint128_fallback &lhs, - const uint128_fallback &rhs) -> uint128_fallback { - return {lhs.hi_ | rhs.hi_, lhs.lo_ | rhs.lo_}; - } - friend constexpr auto operator&(const uint128_fallback &lhs, - const uint128_fallback &rhs) -> uint128_fallback { - return {lhs.hi_ & rhs.hi_, lhs.lo_ & rhs.lo_}; - } - friend constexpr auto operator~(const uint128_fallback &n) - -> uint128_fallback { - return {~n.hi_, ~n.lo_}; - } - friend auto operator+(const uint128_fallback &lhs, - const uint128_fallback &rhs) -> uint128_fallback { - auto result = uint128_fallback(lhs); - result += rhs; - return result; - } - friend auto operator*(const uint128_fallback &lhs, uint32_t rhs) - -> uint128_fallback { - FMT_ASSERT(lhs.hi_ == 0, ""); - uint64_t hi = (lhs.lo_ >> 32) * rhs; - uint64_t lo = (lhs.lo_ & ~uint32_t()) * rhs; - uint64_t new_lo = (hi << 32) + lo; - return {(hi >> 32) + (new_lo < lo ? 1 : 0), new_lo}; - } - friend auto operator-(const uint128_fallback &lhs, uint64_t rhs) - -> uint128_fallback { - return {lhs.hi_ - (lhs.lo_ < rhs ? 1 : 0), lhs.lo_ - rhs}; - } - FMT_CONSTEXPR auto operator>>(int shift) const -> uint128_fallback { - if (shift == 64) return {0, hi_}; - if (shift > 64) return uint128_fallback(0, hi_) >> (shift - 64); - return {hi_ >> shift, (hi_ << (64 - shift)) | (lo_ >> shift)}; - } - FMT_CONSTEXPR auto operator<<(int shift) const -> uint128_fallback { - if (shift == 64) return {lo_, 0}; - if (shift > 64) return uint128_fallback(lo_, 0) << (shift - 64); - return {hi_ << shift | (lo_ >> (64 - shift)), (lo_ << shift)}; - } - FMT_CONSTEXPR auto operator>>=(int shift) -> uint128_fallback & { - return *this = *this >> shift; - } - FMT_CONSTEXPR void operator+=(uint128_fallback n) { - uint64_t new_lo = lo_ + n.lo_; - uint64_t new_hi = hi_ + n.hi_ + (new_lo < lo_ ? 1 : 0); - FMT_ASSERT(new_hi >= hi_, ""); - lo_ = new_lo; - hi_ = new_hi; - } - FMT_CONSTEXPR void operator&=(uint128_fallback n) { - lo_ &= n.lo_; - hi_ &= n.hi_; - } - - FMT_CONSTEXPR20 auto operator+=(uint64_t n) noexcept -> uint128_fallback & { - if (is_constant_evaluated()) { - lo_ += n; - hi_ += (lo_ < n ? 1 : 0); - return *this; - } -#if FMT_HAS_BUILTIN(__builtin_addcll) && !defined(__ibmxl__) - unsigned long long carry; - lo_ = __builtin_addcll(lo_, n, 0, &carry); - hi_ += carry; -#elif FMT_HAS_BUILTIN(__builtin_ia32_addcarryx_u64) && !defined(__ibmxl__) - unsigned long long result; - auto carry = __builtin_ia32_addcarryx_u64(0, lo_, n, &result); - lo_ = result; - hi_ += carry; -#elif defined(_MSC_VER) && defined(_M_X64) - auto carry = _addcarry_u64(0, lo_, n, &lo_); - _addcarry_u64(carry, hi_, 0, &hi_); -#else - lo_ += n; - hi_ += (lo_ < n ? 1 : 0); -#endif - return *this; - } -}; - -using uint128_t = conditional_t; - -#ifdef UINTPTR_MAX -using uintptr_t = ::uintptr_t; -#else -using uintptr_t = uint128_t; -#endif - -// Returns the largest possible value for type T. Same as -// std::numeric_limits::max() but shorter and not affected by the max macro. -template -constexpr auto max_value() -> T { - return (std::numeric_limits::max)(); -} -template -constexpr auto num_bits() -> int { - return std::numeric_limits::digits; -} -// std::numeric_limits::digits may return 0 for 128-bit ints. -template <> -constexpr auto num_bits() -> int { - return 128; -} -template <> -constexpr auto num_bits() -> int { - return 128; -} - -// A heterogeneous bit_cast used for converting 96-bit long double to uint128_t -// and 128-bit pointers to uint128_fallback. -template sizeof(From))> -inline auto bit_cast(const From &from) -> To { - constexpr auto size = static_cast(sizeof(From) / sizeof(unsigned)); - struct data_t { - unsigned value[static_cast(size)]; - } data = bit_cast(from); - auto result = To(); - if (const_check(is_big_endian())) { - for (int i = 0; i < size; ++i) - result = (result << num_bits()) | data.value[i]; - } else { - for (int i = size - 1; i >= 0; --i) - result = (result << num_bits()) | data.value[i]; - } - return result; -} - -template -FMT_CONSTEXPR20 inline auto countl_zero_fallback(UInt n) -> int { - int lz = 0; - constexpr UInt msb_mask = static_cast(1) << (num_bits() - 1); - for (; (n & msb_mask) == 0; n <<= 1) - lz++; - return lz; -} - -FMT_CONSTEXPR20 inline auto countl_zero(uint32_t n) -> int { -#ifdef FMT_BUILTIN_CLZ - if (!is_constant_evaluated()) return FMT_BUILTIN_CLZ(n); -#endif - return countl_zero_fallback(n); -} - -FMT_CONSTEXPR20 inline auto countl_zero(uint64_t n) -> int { -#ifdef FMT_BUILTIN_CLZLL - if (!is_constant_evaluated()) return FMT_BUILTIN_CLZLL(n); -#endif - return countl_zero_fallback(n); -} - -FMT_INLINE void assume(bool condition) { - (void)condition; -#if FMT_HAS_BUILTIN(__builtin_assume) && !FMT_ICC_VERSION - __builtin_assume(condition); -#elif FMT_GCC_VERSION - if (!condition) __builtin_unreachable(); -#endif -} - -// An approximation of iterator_t for pre-C++20 systems. -template -using iterator_t = decltype(std::begin(std::declval())); -template -using sentinel_t = decltype(std::end(std::declval())); - -// A workaround for std::string not having mutable data() until C++17. -template -inline auto get_data(std::basic_string &s) -> Char * { - return &s[0]; -} -template -inline auto get_data(Container &c) -> typename Container::value_type * { - return c.data(); -} - -// Attempts to reserve space for n extra characters in the output range. -// Returns a pointer to the reserved range or a reference to it. -template ::value)> -#if FMT_CLANG_VERSION >= 307 && !FMT_ICC_VERSION -__attribute__((no_sanitize("undefined"))) -#endif -inline auto -reserve(std::back_insert_iterator it, size_t n) -> - typename Container::value_type * { - Container &c = get_container(it); - size_t size = c.size(); - c.resize(size + n); - return get_data(c) + size; -} - -template -inline auto reserve(buffer_appender it, size_t n) -> buffer_appender { - buffer &buf = get_container(it); - buf.try_reserve(buf.size() + n); - return it; -} - -template -constexpr auto reserve(Iterator &it, size_t) -> Iterator & { - return it; -} - -template -using reserve_iterator - = remove_reference_t(), 0))>; - -template -constexpr auto to_pointer(OutputIt, size_t) -> T * { - return nullptr; -} -template -auto to_pointer(buffer_appender it, size_t n) -> T * { - buffer &buf = get_container(it); - auto size = buf.size(); - if (buf.capacity() < size + n) return nullptr; - buf.try_resize(size + n); - return buf.data() + size; -} - -template ::value)> -inline auto base_iterator(std::back_insert_iterator it, - typename Container::value_type *) - -> std::back_insert_iterator { - return it; -} - -template -constexpr auto base_iterator(Iterator, Iterator it) -> Iterator { - return it; -} - -// is spectacularly slow to compile in C++20 so use a simple fill_n -// instead (#1998). -template -FMT_CONSTEXPR auto fill_n(OutputIt out, Size count, const T &value) - -> OutputIt { - for (Size i = 0; i < count; ++i) - *out++ = value; - return out; -} -template -FMT_CONSTEXPR20 auto fill_n(T *out, Size count, char value) -> T * { - if (is_constant_evaluated()) { - return fill_n(out, count, value); - } - std::memset(out, value, to_unsigned(count)); - return out + count; -} - -#ifdef __cpp_char8_t -using char8_type = char8_t; -#else -enum char8_type : unsigned char {}; -#endif - -template -FMT_CONSTEXPR FMT_NOINLINE auto copy_str_noinline( - InputIt begin, InputIt end, OutputIt out) -> OutputIt { - return copy_str(begin, end, out); -} - -// A public domain branchless UTF-8 decoder by Christopher Wellons: -// https://github.com/skeeto/branchless-utf8 -/* Decode the next character, c, from s, reporting errors in e. - * - * Since this is a branchless decoder, four bytes will be read from the - * buffer regardless of the actual length of the next character. This - * means the buffer _must_ have at least three bytes of zero padding - * following the end of the data stream. - * - * Errors are reported in e, which will be non-zero if the parsed - * character was somehow invalid: invalid byte sequence, non-canonical - * encoding, or a surrogate half. - * - * The function returns a pointer to the next character. When an error - * occurs, this pointer will be a guess that depends on the particular - * error, but it will always advance at least one byte. - */ -FMT_CONSTEXPR inline auto utf8_decode(const char *s, uint32_t *c, int *e) - -> const char * { - constexpr const int masks[] = {0x00, 0x7f, 0x1f, 0x0f, 0x07}; - constexpr const uint32_t mins[] = {4194304, 0, 128, 2048, 65536}; - constexpr const int shiftc[] = {0, 18, 12, 6, 0}; - constexpr const int shifte[] = {0, 6, 4, 2, 0}; - - int len = "\1\1\1\1\1\1\1\1\1\1\1\1\1\1\1\1\0\0\0\0\0\0\0\0\2\2\2\2\3\3\4" - [static_cast(*s) >> 3]; - // Compute the pointer to the next character early so that the next - // iteration can start working on the next character. Neither Clang - // nor GCC figure out this reordering on their own. - const char *next = s + len + !len; - - using uchar = unsigned char; - - // Assume a four-byte character and load four bytes. Unused bits are - // shifted out. - *c = uint32_t(uchar(s[0]) & masks[len]) << 18; - *c |= uint32_t(uchar(s[1]) & 0x3f) << 12; - *c |= uint32_t(uchar(s[2]) & 0x3f) << 6; - *c |= uint32_t(uchar(s[3]) & 0x3f) << 0; - *c >>= shiftc[len]; - - // Accumulate the various error conditions. - *e = (*c < mins[len]) << 6; // non-canonical encoding - *e |= ((*c >> 11) == 0x1b) << 7; // surrogate half? - *e |= (*c > 0x10FFFF) << 8; // out of range? - *e |= (uchar(s[1]) & 0xc0) >> 2; - *e |= (uchar(s[2]) & 0xc0) >> 4; - *e |= uchar(s[3]) >> 6; - *e ^= 0x2a; // top two bits of each tail byte correct? - *e >>= shifte[len]; - - return next; -} - -constexpr FMT_INLINE_VARIABLE uint32_t invalid_code_point = ~uint32_t(); - -// Invokes f(cp, sv) for every code point cp in s with sv being the string view -// corresponding to the code point. cp is invalid_code_point on error. -template -FMT_CONSTEXPR void for_each_codepoint(string_view s, F f) { - auto decode = [f](const char *buf_ptr, const char *ptr) { - auto cp = uint32_t(); - auto error = 0; - auto end = utf8_decode(buf_ptr, &cp, &error); - bool result = f(error ? invalid_code_point : cp, - string_view(ptr, error ? 1 : to_unsigned(end - buf_ptr))); - return result ? (error ? buf_ptr + 1 : end) : nullptr; - }; - auto p = s.data(); - const size_t block_size = 4; // utf8_decode always reads blocks of 4 chars. - if (s.size() >= block_size) { - for (auto end = p + s.size() - block_size + 1; p < end;) { - p = decode(p, p); - if (!p) return; - } - } - if (auto num_chars_left = s.data() + s.size() - p) { - char buf[2 * block_size - 1] = {}; - copy_str(p, p + num_chars_left, buf); - const char *buf_ptr = buf; - do { - auto end = decode(buf_ptr, p); - if (!end) return; - p += end - buf_ptr; - buf_ptr = end; - } while (buf_ptr - buf < num_chars_left); - } -} - -template -inline auto compute_width(basic_string_view s) -> size_t { - return s.size(); -} - -// Computes approximate display width of a UTF-8 string. -FMT_CONSTEXPR inline auto compute_width(string_view s) -> size_t { - size_t num_code_points = 0; - // It is not a lambda for compatibility with C++14. - struct count_code_points { - size_t *count; - FMT_CONSTEXPR auto operator()(uint32_t cp, string_view) const -> bool { - *count += detail::to_unsigned(1 - + (cp >= 0x1100 - && (cp <= 0x115f || // Hangul Jamo init. consonants - cp == 0x2329 - || // LEFT-POINTING ANGLE BRACKET - cp == 0x232a - || // RIGHT-POINTING ANGLE BRACKET - // CJK ... Yi except IDEOGRAPHIC HALF FILL SPACE: - (cp >= 0x2e80 && cp <= 0xa4cf - && cp != 0x303f) - || (cp >= 0xac00 && cp <= 0xd7a3) - || // Hangul Syllables - (cp >= 0xf900 && cp <= 0xfaff) - || // CJK Compatibility Ideographs - (cp >= 0xfe10 && cp <= 0xfe19) - || // Vertical Forms - (cp >= 0xfe30 && cp <= 0xfe6f) - || // CJK Compatibility Forms - (cp >= 0xff00 && cp <= 0xff60) - || // Fullwidth Forms - (cp >= 0xffe0 && cp <= 0xffe6) - || // Fullwidth Forms - (cp >= 0x20000 && cp <= 0x2fffd) || // CJK - (cp >= 0x30000 && cp <= 0x3fffd) || - // Miscellaneous Symbols and Pictographs + Emoticons: - (cp >= 0x1f300 && cp <= 0x1f64f) || - // Supplemental Symbols and Pictographs: - (cp >= 0x1f900 && cp <= 0x1f9ff)))); - return true; - } - }; - // We could avoid branches by using utf8_decode directly. - for_each_codepoint(s, count_code_points {&num_code_points}); - return num_code_points; -} - -inline auto compute_width(basic_string_view s) -> size_t { - return compute_width( - string_view(reinterpret_cast(s.data()), s.size())); -} - -template -inline auto code_point_index(basic_string_view s, size_t n) -> size_t { - size_t size = s.size(); - return n < size ? n : size; -} - -// Calculates the index of the nth code point in a UTF-8 string. -inline auto code_point_index(string_view s, size_t n) -> size_t { - size_t result = s.size(); - const char *begin = s.begin(); - for_each_codepoint(s, [begin, &n, &result](uint32_t, string_view sv) { - if (n != 0) { - --n; - return true; - } - result = to_unsigned(sv.begin() - begin); - return false; - }); - return result; -} - -inline auto code_point_index(basic_string_view s, size_t n) - -> size_t { - return code_point_index( - string_view(reinterpret_cast(s.data()), s.size()), n); -} - -template -struct is_integral : std::is_integral {}; -template <> -struct is_integral : std::true_type {}; -template <> -struct is_integral : std::true_type {}; - -template -using is_signed = std::integral_constant::is_signed - || std::is_same::value>; - -template -using is_integer = bool_constant::value - && !std::is_same::value && !std::is_same::value - && !std::is_same::value>; - -#ifndef FMT_USE_FLOAT -#define FMT_USE_FLOAT 1 -#endif -#ifndef FMT_USE_DOUBLE -#define FMT_USE_DOUBLE 1 -#endif -#ifndef FMT_USE_LONG_DOUBLE -#define FMT_USE_LONG_DOUBLE 1 -#endif - -#ifndef FMT_USE_FLOAT128 -#ifdef __clang__ -// Clang emulates GCC, so it has to appear early. -#if FMT_HAS_INCLUDE() -#define FMT_USE_FLOAT128 1 -#endif -#elif defined(__GNUC__) -// GNU C++: -#if defined(_GLIBCXX_USE_FLOAT128) && !defined(__STRICT_ANSI__) -#define FMT_USE_FLOAT128 1 -#endif -#endif -#ifndef FMT_USE_FLOAT128 -#define FMT_USE_FLOAT128 0 -#endif -#endif - -#if FMT_USE_FLOAT128 -using float128 = __float128; -#else -using float128 = void; -#endif -template -using is_float128 = std::is_same; - -template -using is_floating_point = bool_constant::value - || is_float128::value>; - -template ::value> -struct is_fast_float : bool_constant::is_iec559 - && sizeof(T) <= sizeof(double)> {}; -template -struct is_fast_float : std::false_type {}; - -template -using is_double_double = bool_constant::digits == 106>; - -#ifndef FMT_USE_FULL_CACHE_DRAGONBOX -#define FMT_USE_FULL_CACHE_DRAGONBOX 0 -#endif - -template -template -void buffer::append(const U *begin, const U *end) { - while (begin != end) { - auto count = to_unsigned(end - begin); - try_reserve(size_ + count); - auto free_cap = capacity_ - size_; - if (free_cap < count) count = free_cap; - std::uninitialized_copy_n(begin, count, ptr_ + size_); - size_ += count; - begin += count; - } -} - -template -struct is_locale : std::false_type {}; -template -struct is_locale> : std::true_type {}; -} // namespace detail - -FMT_BEGIN_EXPORT - -// The number of characters to store in the basic_memory_buffer object itself -// to avoid dynamic memory allocation. -enum { inline_buffer_size = 500 }; - -/** - \rst - A dynamically growing memory buffer for trivially copyable/constructible types - with the first ``SIZE`` elements stored in the object itself. - - You can use the ``memory_buffer`` type alias for ``char`` instead. - - **Example**:: - - auto out = fmt::memory_buffer(); - fmt::format_to(std::back_inserter(out), "The answer is {}.", 42); - - This will append the following output to the ``out`` object: - - .. code-block:: none - - The answer is 42. - - The output can be converted to an ``std::string`` with ``to_string(out)``. - \endrst - */ -template > -class basic_memory_buffer final : public detail::buffer { -private: - T store_[SIZE]; - - // Don't inherit from Allocator to avoid generating type_info for it. - FMT_NO_UNIQUE_ADDRESS Allocator alloc_; - - // Deallocate memory allocated by the buffer. - FMT_CONSTEXPR20 void deallocate() { - T *data = this->data(); - if (data != store_) alloc_.deallocate(data, this->capacity()); - } - -protected: - FMT_CONSTEXPR20 void grow(size_t size) override { - detail::abort_fuzzing_if(size > 5000); - const size_t max_size - = std::allocator_traits::max_size(alloc_); - size_t old_capacity = this->capacity(); - size_t new_capacity = old_capacity + old_capacity / 2; - if (size > new_capacity) - new_capacity = size; - else if (new_capacity > max_size) - new_capacity = size > max_size ? size : max_size; - T *old_data = this->data(); - T *new_data = std::allocator_traits::allocate( - alloc_, new_capacity); - // Suppress a bogus -Wstringop-overflow in gcc 13.1 (#3481). - detail::assume(this->size() <= new_capacity); - // The following code doesn't throw, so the raw pointer above doesn't leak. - std::uninitialized_copy_n(old_data, this->size(), new_data); - this->set(new_data, new_capacity); - // deallocate must not throw according to the standard, but even if it does, - // the buffer already uses the new storage and will deallocate it in - // destructor. - if (old_data != store_) alloc_.deallocate(old_data, old_capacity); - } - -public: - using value_type = T; - using const_reference = const T &; - - FMT_CONSTEXPR20 explicit basic_memory_buffer( - const Allocator &alloc = Allocator()) - : alloc_(alloc) { - this->set(store_, SIZE); - if (detail::is_constant_evaluated()) detail::fill_n(store_, SIZE, T()); - } - FMT_CONSTEXPR20 ~basic_memory_buffer() { deallocate(); } - -private: - // Move data from other to this buffer. - FMT_CONSTEXPR20 void move(basic_memory_buffer &other) { - alloc_ = std::move(other.alloc_); - T *data = other.data(); - size_t size = other.size(), capacity = other.capacity(); - if (data == other.store_) { - this->set(store_, capacity); - detail::copy_str(other.store_, other.store_ + size, store_); - } else { - this->set(data, capacity); - // Set pointer to the inline array so that delete is not called - // when deallocating. - other.set(other.store_, 0); - other.clear(); - } - this->resize(size); - } - -public: - /** - \rst - Constructs a :class:`fmt::basic_memory_buffer` object moving the content - of the other object to it. - \endrst - */ - FMT_CONSTEXPR20 basic_memory_buffer(basic_memory_buffer &&other) noexcept { - move(other); - } - - /** - \rst - Moves the content of the other ``basic_memory_buffer`` object to this one. - \endrst - */ - auto operator=(basic_memory_buffer &&other) noexcept - -> basic_memory_buffer & { - FMT_ASSERT(this != &other, ""); - deallocate(); - move(other); - return *this; - } - - // Returns a copy of the allocator associated with this buffer. - auto get_allocator() const -> Allocator { return alloc_; } - - /** - Resizes the buffer to contain *count* elements. If T is a POD type new - elements may not be initialized. - */ - FMT_CONSTEXPR20 void resize(size_t count) { this->try_resize(count); } - - /** Increases the buffer capacity to *new_capacity*. */ - void reserve(size_t new_capacity) { this->try_reserve(new_capacity); } - - using detail::buffer::append; - template - void append(const ContiguousRange &range) { - append(range.data(), range.data() + range.size()); - } -}; - -using memory_buffer = basic_memory_buffer; - -template -struct is_contiguous> : std::true_type { -}; - -FMT_END_EXPORT -namespace detail { -FMT_API auto write_console(int fd, string_view text) -> bool; -FMT_API auto write_console(std::FILE *f, string_view text) -> bool; -FMT_API void print(std::FILE *, string_view); -} // namespace detail - -FMT_BEGIN_EXPORT - -// Suppress a misleading warning in older versions of clang. -#if FMT_CLANG_VERSION -#pragma clang diagnostic ignored "-Wweak-vtables" -#endif - -/** An error reported from a formatting function. */ -class FMT_SO_VISIBILITY("default") format_error : public std::runtime_error { -public: - using std::runtime_error::runtime_error; -}; - -namespace detail_exported { -#if FMT_USE_NONTYPE_TEMPLATE_ARGS -template -struct fixed_string { - constexpr fixed_string(const Char (&str)[N]) { - detail::copy_str( - static_cast(str), str + N, data); - } - Char data[N] = {}; -}; -#endif - -// Converts a compile-time string to basic_string_view. -template -constexpr auto compile_string_to_view(const Char (&s)[N]) - -> basic_string_view { - // Remove trailing NUL character if needed. Won't be present if this is used - // with a raw character array (i.e. not defined as a string). - return {s, - N - (std::char_traits::to_int_type(s[N - 1]) == 0 ? 1 : 0)}; -} -template -constexpr auto compile_string_to_view(detail::std_string_view s) - -> basic_string_view { - return {s.data(), s.size()}; -} -} // namespace detail_exported - -class loc_value { -private: - basic_format_arg value_; - -public: - template ::value)> - loc_value(T value) : value_(detail::make_arg(value)) {} - - template ::value)> - loc_value(T) {} - - template - auto visit(Visitor &&vis) -> decltype(vis(0)) { - return visit_format_arg(vis, value_); - } -}; - -// A locale facet that formats values in UTF-8. -// It is parameterized on the locale to avoid the heavy include. -template -class format_facet : public Locale::facet { -private: - std::string separator_; - std::string grouping_; - std::string decimal_point_; - -protected: - virtual auto do_put(appender out, loc_value val, - const format_specs<> &specs) const -> bool; - -public: - static FMT_API typename Locale::id id; - - explicit format_facet(Locale &loc); - explicit format_facet(string_view sep = "", - std::initializer_list g = {3}, - std::string decimal_point = ".") - : separator_(sep.data(), sep.size()) - , grouping_(g.begin(), g.end()) - , decimal_point_(decimal_point) {} - - auto put(appender out, loc_value val, const format_specs<> &specs) const - -> bool { - return do_put(out, val, specs); - } -}; - -namespace detail { - -// Returns true if value is negative, false otherwise. -// Same as `value < 0` but doesn't produce warnings if T is an unsigned type. -template ::value)> -constexpr auto is_negative(T value) -> bool { - return value < 0; -} -template ::value)> -constexpr auto is_negative(T) -> bool { - return false; -} - -template -FMT_CONSTEXPR auto is_supported_floating_point(T) -> bool { - if (std::is_same()) return FMT_USE_FLOAT; - if (std::is_same()) return FMT_USE_DOUBLE; - if (std::is_same()) return FMT_USE_LONG_DOUBLE; - return true; -} - -// Smallest of uint32_t, uint64_t, uint128_t that is large enough to -// represent all values of an integral type T. -template -using uint32_or_64_or_128_t = conditional_t() <= 32 - && !FMT_REDUCE_INT_INSTANTIATIONS, - uint32_t, conditional_t() <= 64, uint64_t, uint128_t>>; -template -using uint64_or_128_t = conditional_t() <= 64, uint64_t, uint128_t>; - -#define FMT_POWERS_OF_10(factor) \ - factor * 10, (factor)*100, (factor)*1000, (factor)*10000, (factor)*100000, \ - (factor)*1000000, (factor)*10000000, (factor)*100000000, \ - (factor)*1000000000 - -// Converts value in the range [0, 100) to a string. -constexpr auto digits2(size_t value) -> const char * { - // GCC generates slightly better code when value is pointer-size. - return &"0001020304050607080910111213141516171819" - "2021222324252627282930313233343536373839" - "4041424344454647484950515253545556575859" - "6061626364656667686970717273747576777879" - "8081828384858687888990919293949596979899"[value * 2]; -} - -// Sign is a template parameter to workaround a bug in gcc 4.8. -template -constexpr auto sign(Sign s) -> Char { -#if !FMT_GCC_VERSION || FMT_GCC_VERSION >= 604 - static_assert(std::is_same::value, ""); -#endif - return static_cast("\0-+ "[s]); -} - -template -FMT_CONSTEXPR auto count_digits_fallback(T n) -> int { - int count = 1; - for (;;) { - // Integer division is slow so do it for a group of four digits instead - // of for every digit. The idea comes from the talk by Alexandrescu - // "Three Optimization Tips for C++". See speed-test for a comparison. - if (n < 10) return count; - if (n < 100) return count + 1; - if (n < 1000) return count + 2; - if (n < 10000) return count + 3; - n /= 10000u; - count += 4; - } -} -#if FMT_USE_INT128 -FMT_CONSTEXPR inline auto count_digits(uint128_opt n) -> int { - return count_digits_fallback(n); -} -#endif - -#ifdef FMT_BUILTIN_CLZLL -// It is a separate function rather than a part of count_digits to workaround -// the lack of static constexpr in constexpr functions. -inline auto do_count_digits(uint64_t n) -> int { - // This has comparable performance to the version by Kendall Willets - // (https://github.com/fmtlib/format-benchmark/blob/master/digits10) - // but uses smaller tables. - // Maps bsr(n) to ceil(log10(pow(2, bsr(n) + 1) - 1)). - static constexpr uint8_t bsr2log10[] = {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, - 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 10, - 11, 11, 11, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, - 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 19, 20}; - auto t = bsr2log10[FMT_BUILTIN_CLZLL(n | 1) ^ 63]; - static constexpr const uint64_t zero_or_powers_of_10[] - = {0, 0, FMT_POWERS_OF_10(1U), FMT_POWERS_OF_10(1000000000ULL), - 10000000000000000000ULL}; - return t - (n < zero_or_powers_of_10[t]); -} -#endif - -// Returns the number of decimal digits in n. Leading zeros are not counted -// except for n == 0 in which case count_digits returns 1. -FMT_CONSTEXPR20 inline auto count_digits(uint64_t n) -> int { -#ifdef FMT_BUILTIN_CLZLL - if (!is_constant_evaluated()) { return do_count_digits(n); } -#endif - return count_digits_fallback(n); -} - -// Counts the number of digits in n. BITS = log2(radix). -template -FMT_CONSTEXPR auto count_digits(UInt n) -> int { -#ifdef FMT_BUILTIN_CLZ - if (!is_constant_evaluated() && num_bits() == 32) - return (FMT_BUILTIN_CLZ(static_cast(n) | 1) ^ 31) / BITS + 1; -#endif - // Lambda avoids unreachable code warnings from NVHPC. - return [](UInt m) { - int num_digits = 0; - do { - ++num_digits; - } while ((m >>= BITS) != 0); - return num_digits; - }(n); -} - -#ifdef FMT_BUILTIN_CLZ -// It is a separate function rather than a part of count_digits to workaround -// the lack of static constexpr in constexpr functions. -FMT_INLINE auto do_count_digits(uint32_t n) -> int { -// An optimization by Kendall Willets from https://bit.ly/3uOIQrB. -// This increments the upper 32 bits (log10(T) - 1) when >= T is added. -#define FMT_INC(T) (((sizeof(#T) - 1ull) << 32) - T) - static constexpr uint64_t table[] = { - FMT_INC(0), FMT_INC(0), FMT_INC(0), // 8 - FMT_INC(10), FMT_INC(10), FMT_INC(10), // 64 - FMT_INC(100), FMT_INC(100), FMT_INC(100), // 512 - FMT_INC(1000), FMT_INC(1000), FMT_INC(1000), // 4096 - FMT_INC(10000), FMT_INC(10000), FMT_INC(10000), // 32k - FMT_INC(100000), FMT_INC(100000), FMT_INC(100000), // 256k - FMT_INC(1000000), FMT_INC(1000000), FMT_INC(1000000), // 2048k - FMT_INC(10000000), FMT_INC(10000000), FMT_INC(10000000), // 16M - FMT_INC(100000000), FMT_INC(100000000), FMT_INC(100000000), // 128M - FMT_INC(1000000000), FMT_INC(1000000000), - FMT_INC(1000000000), // 1024M - FMT_INC(1000000000), FMT_INC(1000000000) // 4B - }; - auto inc = table[FMT_BUILTIN_CLZ(n | 1) ^ 31]; - return static_cast((n + inc) >> 32); -} -#endif - -// Optional version of count_digits for better performance on 32-bit platforms. -FMT_CONSTEXPR20 inline auto count_digits(uint32_t n) -> int { -#ifdef FMT_BUILTIN_CLZ - if (!is_constant_evaluated()) { return do_count_digits(n); } -#endif - return count_digits_fallback(n); -} - -template -constexpr auto digits10() noexcept -> int { - return std::numeric_limits::digits10; -} -template <> -constexpr auto digits10() noexcept -> int { - return 38; -} -template <> -constexpr auto digits10() noexcept -> int { - return 38; -} - -template -struct thousands_sep_result { - std::string grouping; - Char thousands_sep; -}; - -template -FMT_API auto thousands_sep_impl(locale_ref loc) -> thousands_sep_result; -template -inline auto thousands_sep(locale_ref loc) -> thousands_sep_result { - auto result = thousands_sep_impl(loc); - return {result.grouping, Char(result.thousands_sep)}; -} -template <> -inline auto thousands_sep(locale_ref loc) -> thousands_sep_result { - return thousands_sep_impl(loc); -} - -template -FMT_API auto decimal_point_impl(locale_ref loc) -> Char; -template -inline auto decimal_point(locale_ref loc) -> Char { - return Char(decimal_point_impl(loc)); -} -template <> -inline auto decimal_point(locale_ref loc) -> wchar_t { - return decimal_point_impl(loc); -} - -// Compares two characters for equality. -template -auto equal2(const Char *lhs, const char *rhs) -> bool { - return lhs[0] == Char(rhs[0]) && lhs[1] == Char(rhs[1]); -} -inline auto equal2(const char *lhs, const char *rhs) -> bool { - return memcmp(lhs, rhs, 2) == 0; -} - -// Copies two characters from src to dst. -template -FMT_CONSTEXPR20 FMT_INLINE void copy2(Char *dst, const char *src) { - if (!is_constant_evaluated() && sizeof(Char) == sizeof(char)) { - memcpy(dst, src, 2); - return; - } - *dst++ = static_cast(*src++); - *dst = static_cast(*src); -} - -template -struct format_decimal_result { - Iterator begin; - Iterator end; -}; - -// Formats a decimal unsigned integer value writing into out pointing to a -// buffer of specified size. The caller must ensure that the buffer is large -// enough. -template -FMT_CONSTEXPR20 auto format_decimal(Char *out, UInt value, int size) - -> format_decimal_result { - FMT_ASSERT(size >= count_digits(value), "invalid digit count"); - out += size; - Char *end = out; - while (value >= 100) { - // Integer division is slow so do it for a group of two digits instead - // of for every digit. The idea comes from the talk by Alexandrescu - // "Three Optimization Tips for C++". See speed-test for a comparison. - out -= 2; - copy2(out, digits2(static_cast(value % 100))); - value /= 100; - } - if (value < 10) { - *--out = static_cast('0' + value); - return {out, end}; - } - out -= 2; - copy2(out, digits2(static_cast(value))); - return {out, end}; -} - -template >::value)> -FMT_CONSTEXPR inline auto format_decimal(Iterator out, UInt value, int size) - -> format_decimal_result { - // Buffer is large enough to hold all digits (digits10 + 1). - Char buffer[digits10() + 1] = {}; - auto end = format_decimal(buffer, value, size).end; - return {out, detail::copy_str_noinline(buffer, end, out)}; -} - -template -FMT_CONSTEXPR auto format_uint(Char *buffer, UInt value, int num_digits, - bool upper = false) -> Char * { - buffer += num_digits; - Char *end = buffer; - do { - const char *digits = upper ? "0123456789ABCDEF" : "0123456789abcdef"; - unsigned digit = static_cast(value & ((1 << BASE_BITS) - 1)); - *--buffer = static_cast( - BASE_BITS < 4 ? static_cast('0' + digit) : digits[digit]); - } while ((value >>= BASE_BITS) != 0); - return end; -} - -template -FMT_CONSTEXPR inline auto format_uint( - It out, UInt value, int num_digits, bool upper = false) -> It { - if (auto ptr = to_pointer(out, to_unsigned(num_digits))) { - format_uint(ptr, value, num_digits, upper); - return out; - } - // Buffer should be large enough to hold all digits (digits / BASE_BITS + 1). - char buffer[num_bits() / BASE_BITS + 1] = {}; - format_uint(buffer, value, num_digits, upper); - return detail::copy_str_noinline(buffer, buffer + num_digits, out); -} - -// A converter from UTF-8 to UTF-16. -class utf8_to_utf16 { -private: - basic_memory_buffer buffer_; - -public: - FMT_API explicit utf8_to_utf16(string_view s); - operator basic_string_view() const { - return {&buffer_[0], size()}; - } - auto size() const -> size_t { return buffer_.size() - 1; } - auto c_str() const -> const wchar_t * { return &buffer_[0]; } - auto str() const -> std::wstring { return {&buffer_[0], size()}; } -}; - -enum class to_utf8_error_policy { abort, replace }; - -// A converter from UTF-16/UTF-32 (host endian) to UTF-8. -template -class to_utf8 { -private: - Buffer buffer_; - -public: - to_utf8() {} - explicit to_utf8(basic_string_view s, - to_utf8_error_policy policy = to_utf8_error_policy::abort) { - static_assert(sizeof(WChar) == 2 || sizeof(WChar) == 4, - "Expect utf16 or utf32"); - if (!convert(s, policy)) - FMT_THROW(std::runtime_error( - sizeof(WChar) == 2 ? "invalid utf16" : "invalid utf32")); - } - operator string_view() const { return string_view(&buffer_[0], size()); } - auto size() const -> size_t { return buffer_.size() - 1; } - auto c_str() const -> const char * { return &buffer_[0]; } - auto str() const -> std::string { return std::string(&buffer_[0], size()); } - - // Performs conversion returning a bool instead of throwing exception on - // conversion error. This method may still throw in case of memory allocation - // error. - auto convert(basic_string_view s, - to_utf8_error_policy policy = to_utf8_error_policy::abort) -> bool { - if (!convert(buffer_, s, policy)) return false; - buffer_.push_back(0); - return true; - } - static auto convert(Buffer &buf, basic_string_view s, - to_utf8_error_policy policy = to_utf8_error_policy::abort) -> bool { - for (auto p = s.begin(); p != s.end(); ++p) { - uint32_t c = static_cast(*p); - if (sizeof(WChar) == 2 && c >= 0xd800 && c <= 0xdfff) { - // Handle a surrogate pair. - ++p; - if (p == s.end() || (c & 0xfc00) != 0xd800 - || (*p & 0xfc00) != 0xdc00) { - if (policy == to_utf8_error_policy::abort) return false; - buf.append(string_view("\xEF\xBF\xBD")); - --p; - } else { - c = (c << 10) + static_cast(*p) - 0x35fdc00; - } - } else if (c < 0x80) { - buf.push_back(static_cast(c)); - } else if (c < 0x800) { - buf.push_back(static_cast(0xc0 | (c >> 6))); - buf.push_back(static_cast(0x80 | (c & 0x3f))); - } else if ((c >= 0x800 && c <= 0xd7ff) - || (c >= 0xe000 && c <= 0xffff)) { - buf.push_back(static_cast(0xe0 | (c >> 12))); - buf.push_back(static_cast(0x80 | ((c & 0xfff) >> 6))); - buf.push_back(static_cast(0x80 | (c & 0x3f))); - } else if (c >= 0x10000 && c <= 0x10ffff) { - buf.push_back(static_cast(0xf0 | (c >> 18))); - buf.push_back(static_cast(0x80 | ((c & 0x3ffff) >> 12))); - buf.push_back(static_cast(0x80 | ((c & 0xfff) >> 6))); - buf.push_back(static_cast(0x80 | (c & 0x3f))); - } else { - return false; - } - } - return true; - } -}; - -// Computes 128-bit result of multiplication of two 64-bit unsigned integers. -inline auto umul128(uint64_t x, uint64_t y) noexcept -> uint128_fallback { -#if FMT_USE_INT128 - auto p = static_cast(x) * static_cast(y); - return {static_cast(p >> 64), static_cast(p)}; -#elif defined(_MSC_VER) && defined(_M_X64) - auto hi = uint64_t(); - auto lo = _umul128(x, y, &hi); - return {hi, lo}; -#else - const uint64_t mask = static_cast(max_value()); - - uint64_t a = x >> 32; - uint64_t b = x & mask; - uint64_t c = y >> 32; - uint64_t d = y & mask; - - uint64_t ac = a * c; - uint64_t bc = b * c; - uint64_t ad = a * d; - uint64_t bd = b * d; - - uint64_t intermediate = (bd >> 32) + (ad & mask) + (bc & mask); - - return {ac + (intermediate >> 32) + (ad >> 32) + (bc >> 32), - (intermediate << 32) + (bd & mask)}; -#endif -} - -namespace dragonbox { -// Computes floor(log10(pow(2, e))) for e in [-2620, 2620] using the method from -// https://fmt.dev/papers/Dragonbox.pdf#page=28, section 6.1. -inline auto floor_log10_pow2(int e) noexcept -> int { - FMT_ASSERT(e <= 2620 && e >= -2620, "too large exponent"); - static_assert((-1 >> 1) == -1, "right shift is not arithmetic"); - return (e * 315653) >> 20; -} - -inline auto floor_log2_pow10(int e) noexcept -> int { - FMT_ASSERT(e <= 1233 && e >= -1233, "too large exponent"); - return (e * 1741647) >> 19; -} - -// Computes upper 64 bits of multiplication of two 64-bit unsigned integers. -inline auto umul128_upper64(uint64_t x, uint64_t y) noexcept -> uint64_t { -#if FMT_USE_INT128 - auto p = static_cast(x) * static_cast(y); - return static_cast(p >> 64); -#elif defined(_MSC_VER) && defined(_M_X64) - return __umulh(x, y); -#else - return umul128(x, y).high(); -#endif -} - -// Computes upper 128 bits of multiplication of a 64-bit unsigned integer and a -// 128-bit unsigned integer. -inline auto umul192_upper128(uint64_t x, uint128_fallback y) noexcept - -> uint128_fallback { - uint128_fallback r = umul128(x, y.high()); - r += umul128_upper64(x, y.low()); - return r; -} - -FMT_API auto get_cached_power(int k) noexcept -> uint128_fallback; - -// Type-specific information that Dragonbox uses. -template -struct float_info; - -template <> -struct float_info { - using carrier_uint = uint32_t; - static const int exponent_bits = 8; - static const int kappa = 1; - static const int big_divisor = 100; - static const int small_divisor = 10; - static const int min_k = -31; - static const int max_k = 46; - static const int shorter_interval_tie_lower_threshold = -35; - static const int shorter_interval_tie_upper_threshold = -35; -}; - -template <> -struct float_info { - using carrier_uint = uint64_t; - static const int exponent_bits = 11; - static const int kappa = 2; - static const int big_divisor = 1000; - static const int small_divisor = 100; - static const int min_k = -292; - static const int max_k = 341; - static const int shorter_interval_tie_lower_threshold = -77; - static const int shorter_interval_tie_upper_threshold = -77; -}; - -// An 80- or 128-bit floating point number. -template -struct float_info::digits == 64 - || std::numeric_limits::digits == 113 - || is_float128::value>> { - using carrier_uint = detail::uint128_t; - static const int exponent_bits = 15; -}; - -// A double-double floating point number. -template -struct float_info::value>> { - using carrier_uint = detail::uint128_t; -}; - -template -struct decimal_fp { - using significand_type = typename float_info::carrier_uint; - significand_type significand; - int exponent; -}; - -template -FMT_API auto to_decimal(T x) noexcept -> decimal_fp; -} // namespace dragonbox - -// Returns true iff Float has the implicit bit which is not stored. -template -constexpr auto has_implicit_bit() -> bool { - // An 80-bit FP number has a 64-bit significand an no implicit bit. - return std::numeric_limits::digits != 64; -} - -// Returns the number of significand bits stored in Float. The implicit bit is -// not counted since it is not stored. -template -constexpr auto num_significand_bits() -> int { - // std::numeric_limits may not support __float128. - return is_float128() ? 112 - : (std::numeric_limits::digits - - (has_implicit_bit() ? 1 : 0)); -} - -template -constexpr auto exponent_mask() -> - typename dragonbox::float_info::carrier_uint { - using float_uint = typename dragonbox::float_info::carrier_uint; - return ((float_uint(1) << dragonbox::float_info::exponent_bits) - 1) - << num_significand_bits(); -} -template -constexpr auto exponent_bias() -> int { - // std::numeric_limits may not support __float128. - return is_float128() ? 16383 - : std::numeric_limits::max_exponent - 1; -} - -// Writes the exponent exp in the form "[+-]d{2,3}" to buffer. -template -FMT_CONSTEXPR auto write_exponent(int exp, It it) -> It { - FMT_ASSERT(-10000 < exp && exp < 10000, "exponent out of range"); - if (exp < 0) { - *it++ = static_cast('-'); - exp = -exp; - } else { - *it++ = static_cast('+'); - } - if (exp >= 100) { - const char *top = digits2(to_unsigned(exp / 100)); - if (exp >= 1000) *it++ = static_cast(top[0]); - *it++ = static_cast(top[1]); - exp %= 100; - } - const char *d = digits2(to_unsigned(exp)); - *it++ = static_cast(d[0]); - *it++ = static_cast(d[1]); - return it; -} - -// A floating-point number f * pow(2, e) where F is an unsigned type. -template -struct basic_fp { - F f; - int e; - - static constexpr const int num_significand_bits - = static_cast(sizeof(F) * num_bits()); - - constexpr basic_fp() : f(0), e(0) {} - constexpr basic_fp(uint64_t f_val, int e_val) : f(f_val), e(e_val) {} - - // Constructs fp from an IEEE754 floating-point number. - template - FMT_CONSTEXPR basic_fp(Float n) { - assign(n); - } - - // Assigns n to this and return true iff predecessor is closer than successor. - template ::value)> - FMT_CONSTEXPR auto assign(Float n) -> bool { - static_assert( - std::numeric_limits::digits <= 113, "unsupported FP"); - // Assume Float is in the format [sign][exponent][significand]. - using carrier_uint = - typename dragonbox::float_info::carrier_uint; - const auto num_float_significand_bits - = detail::num_significand_bits(); - const auto implicit_bit = carrier_uint(1) << num_float_significand_bits; - const auto significand_mask = implicit_bit - 1; - auto u = bit_cast(n); - f = static_cast(u & significand_mask); - auto biased_e = static_cast( - (u & exponent_mask()) >> num_float_significand_bits); - // The predecessor is closer if n is a normalized power of 2 (f == 0) - // other than the smallest normalized number (biased_e > 1). - auto is_predecessor_closer = f == 0 && biased_e > 1; - if (biased_e == 0) - biased_e = 1; // Subnormals use biased exponent 1 (min exponent). - else if (has_implicit_bit()) - f += static_cast(implicit_bit); - e = biased_e - exponent_bias() - num_float_significand_bits; - if (!has_implicit_bit()) ++e; - return is_predecessor_closer; - } - - template ::value)> - FMT_CONSTEXPR auto assign(Float n) -> bool { - static_assert(std::numeric_limits::is_iec559, "unsupported FP"); - return assign(static_cast(n)); - } -}; - -using fp = basic_fp; - -// Normalizes the value converted from double and multiplied by (1 << SHIFT). -template -FMT_CONSTEXPR auto normalize(basic_fp value) -> basic_fp { - // Handle subnormals. - const auto implicit_bit = F(1) << num_significand_bits(); - const auto shifted_implicit_bit = implicit_bit << SHIFT; - while ((value.f & shifted_implicit_bit) == 0) { - value.f <<= 1; - --value.e; - } - // Subtract 1 to account for hidden bit. - const auto offset = basic_fp::num_significand_bits - - num_significand_bits() - SHIFT - 1; - value.f <<= offset; - value.e -= offset; - return value; -} - -// Computes lhs * rhs / pow(2, 64) rounded to nearest with half-up tie breaking. -FMT_CONSTEXPR inline auto multiply(uint64_t lhs, uint64_t rhs) -> uint64_t { -#if FMT_USE_INT128 - auto product = static_cast<__uint128_t>(lhs) * rhs; - auto f = static_cast(product >> 64); - return (static_cast(product) & (1ULL << 63)) != 0 ? f + 1 : f; -#else - // Multiply 32-bit parts of significands. - uint64_t mask = (1ULL << 32) - 1; - uint64_t a = lhs >> 32, b = lhs & mask; - uint64_t c = rhs >> 32, d = rhs & mask; - uint64_t ac = a * c, bc = b * c, ad = a * d, bd = b * d; - // Compute mid 64-bit of result and round. - uint64_t mid = (bd >> 32) + (ad & mask) + (bc & mask) + (1U << 31); - return ac + (ad >> 32) + (bc >> 32) + (mid >> 32); -#endif -} - -FMT_CONSTEXPR inline auto operator*(fp x, fp y) -> fp { - return {multiply(x.f, y.f), x.e + y.e + 64}; -} - -template () == num_bits()> -using convert_float_result - = conditional_t::value || doublish, double, T>; - -template -constexpr auto convert_float(T value) -> convert_float_result { - return static_cast>(value); -} - -template -FMT_NOINLINE FMT_CONSTEXPR auto fill( - OutputIt it, size_t n, const fill_t &fill) -> OutputIt { - auto fill_size = fill.size(); - if (fill_size == 1) return detail::fill_n(it, n, fill[0]); - auto data = fill.data(); - for (size_t i = 0; i < n; ++i) - it = copy_str(data, data + fill_size, it); - return it; -} - -// Writes the output of f, padded according to format specifications in specs. -// size: output size in code units. -// width: output display width in (terminal) column positions. -template -FMT_CONSTEXPR auto write_padded(OutputIt out, const format_specs &specs, - size_t size, size_t width, F &&f) -> OutputIt { - static_assert(align == align::left || align == align::right, ""); - unsigned spec_width = to_unsigned(specs.width); - size_t padding = spec_width > width ? spec_width - width : 0; - // Shifts are encoded as string literals because static constexpr is not - // supported in constexpr functions. - auto *shifts - = align == align::left ? "\x1f\x1f\x00\x01" : "\x00\x1f\x00\x01"; - size_t left_padding = padding >> shifts[specs.align]; - size_t right_padding = padding - left_padding; - auto it = reserve(out, size + padding * specs.fill.size()); - if (left_padding != 0) it = fill(it, left_padding, specs.fill); - it = f(it); - if (right_padding != 0) it = fill(it, right_padding, specs.fill); - return base_iterator(out, it); -} - -template -constexpr auto write_padded(OutputIt out, const format_specs &specs, - size_t size, F &&f) -> OutputIt { - return write_padded(out, specs, size, size, f); -} - -template -FMT_CONSTEXPR auto write_bytes(OutputIt out, string_view bytes, - const format_specs &specs) -> OutputIt { - return write_padded( - out, specs, bytes.size(), [bytes](reserve_iterator it) { - const char *data = bytes.data(); - return copy_str(data, data + bytes.size(), it); - }); -} - -template -auto write_ptr(OutputIt out, UIntPtr value, const format_specs *specs) - -> OutputIt { - int num_digits = count_digits<4>(value); - auto size = to_unsigned(num_digits) + size_t(2); - auto write = [=](reserve_iterator it) { - *it++ = static_cast('0'); - *it++ = static_cast('x'); - return format_uint<4, Char>(it, value, num_digits); - }; - return specs ? write_padded(out, *specs, size, write) - : base_iterator(out, write(reserve(out, size))); -} - -// Returns true iff the code point cp is printable. -FMT_API auto is_printable(uint32_t cp) -> bool; - -inline auto needs_escape(uint32_t cp) -> bool { - return cp < 0x20 || cp == 0x7f || cp == '"' || cp == '\\' - || !is_printable(cp); -} - -template -struct find_escape_result { - const Char *begin; - const Char *end; - uint32_t cp; -}; - -template -using make_unsigned_char = typename conditional_t::value, - std::make_unsigned, type_identity>::type; - -template -auto find_escape(const Char *begin, const Char *end) - -> find_escape_result { - for (; begin != end; ++begin) { - uint32_t cp = static_cast>(*begin); - if (const_check(sizeof(Char) == 1) && cp >= 0x80) continue; - if (needs_escape(cp)) return {begin, begin + 1, cp}; - } - return {begin, nullptr, 0}; -} - -inline auto find_escape(const char *begin, const char *end) - -> find_escape_result { - if (!is_utf8()) return find_escape(begin, end); - auto result = find_escape_result {end, nullptr, 0}; - for_each_codepoint(string_view(begin, to_unsigned(end - begin)), - [&](uint32_t cp, string_view sv) { - if (needs_escape(cp)) { - result = {sv.begin(), sv.end(), cp}; - return false; - } - return true; - }); - return result; -} - -#define FMT_STRING_IMPL(s, base, explicit) \ - [] { \ - /* Use the hidden visibility as a workaround for a GCC bug (#1973). */ \ - /* Use a macro-like name to avoid shadowing warnings. */ \ - struct FMT_VISIBILITY("hidden") FMT_COMPILE_STRING : base { \ - using char_type FMT_MAYBE_UNUSED \ - = fmt::remove_cvref_t; \ - FMT_MAYBE_UNUSED FMT_CONSTEXPR explicit \ - operator fmt::basic_string_view() const { \ - return fmt::detail_exported::compile_string_to_view< \ - char_type>(s); \ - } \ - }; \ - return FMT_COMPILE_STRING(); \ - }() - -/** - \rst - Constructs a compile-time format string from a string literal *s*. - - **Example**:: - - // A compile-time error because 'd' is an invalid specifier for strings. - std::string s = fmt::format(FMT_STRING("{:d}"), "foo"); - \endrst - */ -#define FMT_STRING(s) FMT_STRING_IMPL(s, fmt::detail::compile_string, ) - -template -auto write_codepoint(OutputIt out, char prefix, uint32_t cp) -> OutputIt { - *out++ = static_cast('\\'); - *out++ = static_cast(prefix); - Char buf[width]; - fill_n(buf, width, static_cast('0')); - format_uint<4>(buf, cp, width); - return copy_str(buf, buf + width, out); -} - -template -auto write_escaped_cp(OutputIt out, const find_escape_result &escape) - -> OutputIt { - auto c = static_cast(escape.cp); - switch (escape.cp) { - case '\n': - *out++ = static_cast('\\'); - c = static_cast('n'); - break; - case '\r': - *out++ = static_cast('\\'); - c = static_cast('r'); - break; - case '\t': - *out++ = static_cast('\\'); - c = static_cast('t'); - break; - case '"': FMT_FALLTHROUGH; - case '\'': FMT_FALLTHROUGH; - case '\\': *out++ = static_cast('\\'); break; - default: - if (escape.cp < 0x100) { - return write_codepoint<2, Char>(out, 'x', escape.cp); - } - if (escape.cp < 0x10000) { - return write_codepoint<4, Char>(out, 'u', escape.cp); - } - if (escape.cp < 0x110000) { - return write_codepoint<8, Char>(out, 'U', escape.cp); - } - for (Char escape_char : basic_string_view(escape.begin, - to_unsigned(escape.end - escape.begin))) { - out = write_codepoint<2, Char>( - out, 'x', static_cast(escape_char) & 0xFF); - } - return out; - } - *out++ = c; - return out; -} - -template -auto write_escaped_string(OutputIt out, basic_string_view str) - -> OutputIt { - *out++ = static_cast('"'); - auto begin = str.begin(), end = str.end(); - do { - auto escape = find_escape(begin, end); - out = copy_str(begin, escape.begin, out); - begin = escape.end; - if (!begin) break; - out = write_escaped_cp(out, escape); - } while (begin != end); - *out++ = static_cast('"'); - return out; -} - -template -auto write_escaped_char(OutputIt out, Char v) -> OutputIt { - Char v_array[1] = {v}; - *out++ = static_cast('\''); - if ((needs_escape(static_cast(v)) && v != static_cast('"')) - || v == static_cast('\'')) { - out = write_escaped_cp(out, - find_escape_result { - v_array, v_array + 1, static_cast(v)}); - } else { - *out++ = v; - } - *out++ = static_cast('\''); - return out; -} - -template -FMT_CONSTEXPR auto write_char( - OutputIt out, Char value, const format_specs &specs) -> OutputIt { - bool is_debug = specs.type == presentation_type::debug; - return write_padded(out, specs, 1, [=](reserve_iterator it) { - if (is_debug) return write_escaped_char(it, value); - *it++ = value; - return it; - }); -} -template -FMT_CONSTEXPR auto write(OutputIt out, Char value, - const format_specs &specs, locale_ref loc = {}) -> OutputIt { - // char is formatted as unsigned char for consistency across platforms. - using unsigned_type = conditional_t::value, - unsigned char, unsigned>; - return check_char_specs(specs) - ? write_char(out, value, specs) - : write(out, static_cast(value), specs, loc); -} - -// Data for write_int that doesn't depend on output iterator type. It is used to -// avoid template code bloat. -template -struct write_int_data { - size_t size; - size_t padding; - - FMT_CONSTEXPR write_int_data( - int num_digits, unsigned prefix, const format_specs &specs) - : size((prefix >> 24) + to_unsigned(num_digits)), padding(0) { - if (specs.align == align::numeric) { - auto width = to_unsigned(specs.width); - if (width > size) { - padding = width - size; - size = width; - } - } else if (specs.precision > num_digits) { - size = (prefix >> 24) + to_unsigned(specs.precision); - padding = to_unsigned(specs.precision - num_digits); - } - } -}; - -// Writes an integer in the format -// -// where are written by write_digits(it). -// prefix contains chars in three lower bytes and the size in the fourth byte. -template -FMT_CONSTEXPR FMT_INLINE auto write_int(OutputIt out, int num_digits, - unsigned prefix, const format_specs &specs, W write_digits) - -> OutputIt { - // Slightly faster check for specs.width == 0 && specs.precision == -1. - if ((specs.width | (specs.precision + 1)) == 0) { - auto it = reserve(out, to_unsigned(num_digits) + (prefix >> 24)); - if (prefix != 0) { - for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) - *it++ = static_cast(p & 0xff); - } - return base_iterator(out, write_digits(it)); - } - auto data = write_int_data(num_digits, prefix, specs); - return write_padded( - out, specs, data.size, [=](reserve_iterator it) { - for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) - *it++ = static_cast(p & 0xff); - it = detail::fill_n(it, data.padding, static_cast('0')); - return write_digits(it); - }); -} - -template -class digit_grouping { -private: - std::string grouping_; - std::basic_string thousands_sep_; - - struct next_state { - std::string::const_iterator group; - int pos; - }; - auto initial_state() const -> next_state { return {grouping_.begin(), 0}; } - - // Returns the next digit group separator position. - auto next(next_state &state) const -> int { - if (thousands_sep_.empty()) return max_value(); - if (state.group == grouping_.end()) - return state.pos += grouping_.back(); - if (*state.group <= 0 || *state.group == max_value()) - return max_value(); - state.pos += *state.group++; - return state.pos; - } - -public: - explicit digit_grouping(locale_ref loc, bool localized = true) { - if (!localized) return; - auto sep = thousands_sep(loc); - grouping_ = sep.grouping; - if (sep.thousands_sep) thousands_sep_.assign(1, sep.thousands_sep); - } - digit_grouping(std::string grouping, std::basic_string sep) - : grouping_(std::move(grouping)), thousands_sep_(std::move(sep)) {} - - auto has_separator() const -> bool { return !thousands_sep_.empty(); } - - auto count_separators(int num_digits) const -> int { - int count = 0; - auto state = initial_state(); - while (num_digits > next(state)) - ++count; - return count; - } - - // Applies grouping to digits and write the output to out. - template - auto apply(Out out, basic_string_view digits) const -> Out { - auto num_digits = static_cast(digits.size()); - auto separators = basic_memory_buffer(); - separators.push_back(0); - auto state = initial_state(); - while (int i = next(state)) { - if (i >= num_digits) break; - separators.push_back(i); - } - for (int i = 0, sep_index = static_cast(separators.size() - 1); - i < num_digits; ++i) { - if (num_digits - i == separators[sep_index]) { - out = copy_str(thousands_sep_.data(), - thousands_sep_.data() + thousands_sep_.size(), out); - --sep_index; - } - *out++ = static_cast(digits[to_unsigned(i)]); - } - return out; - } -}; - -FMT_CONSTEXPR inline void prefix_append(unsigned &prefix, unsigned value) { - prefix |= prefix != 0 ? value << 8 : value; - prefix += (1u + (value > 0xff ? 1 : 0)) << 24; -} - -// Writes a decimal integer with digit grouping. -template -auto write_int(OutputIt out, UInt value, unsigned prefix, - const format_specs &specs, const digit_grouping &grouping) - -> OutputIt { - static_assert(std::is_same, UInt>::value, ""); - int num_digits = 0; - auto buffer = memory_buffer(); - switch (specs.type) { - case presentation_type::none: - case presentation_type::dec: { - num_digits = count_digits(value); - format_decimal(appender(buffer), value, num_digits); - break; - } - case presentation_type::hex_lower: - case presentation_type::hex_upper: { - bool upper = specs.type == presentation_type::hex_upper; - if (specs.alt) - prefix_append(prefix, unsigned(upper ? 'X' : 'x') << 8 | '0'); - num_digits = count_digits<4>(value); - format_uint<4, char>(appender(buffer), value, num_digits, upper); - break; - } - case presentation_type::bin_lower: - case presentation_type::bin_upper: { - bool upper = specs.type == presentation_type::bin_upper; - if (specs.alt) - prefix_append(prefix, unsigned(upper ? 'B' : 'b') << 8 | '0'); - num_digits = count_digits<1>(value); - format_uint<1, char>(appender(buffer), value, num_digits); - break; - } - case presentation_type::oct: { - num_digits = count_digits<3>(value); - // Octal prefix '0' is counted as a digit, so only add it if precision - // is not greater than the number of digits. - if (specs.alt && specs.precision <= num_digits && value != 0) - prefix_append(prefix, '0'); - format_uint<3, char>(appender(buffer), value, num_digits); - break; - } - case presentation_type::chr: - return write_char(out, static_cast(value), specs); - default: throw_format_error("invalid format specifier"); - } - - unsigned size = (prefix != 0 ? prefix >> 24 : 0) + to_unsigned(num_digits) - + to_unsigned(grouping.count_separators(num_digits)); - return write_padded( - out, specs, size, size, [&](reserve_iterator it) { - for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) - *it++ = static_cast(p & 0xff); - return grouping.apply( - it, string_view(buffer.data(), buffer.size())); - }); -} - -// Writes a localized value. -FMT_API auto write_loc(appender out, loc_value value, - const format_specs<> &specs, locale_ref loc) -> bool; -template -inline auto write_loc( - OutputIt, loc_value, const format_specs &, locale_ref) -> bool { - return false; -} - -template -struct write_int_arg { - UInt abs_value; - unsigned prefix; -}; - -template -FMT_CONSTEXPR auto make_write_int_arg(T value, sign_t sign) - -> write_int_arg> { - auto prefix = 0u; - auto abs_value = static_cast>(value); - if (is_negative(value)) { - prefix = 0x01000000 | '-'; - abs_value = 0 - abs_value; - } else { - constexpr const unsigned prefixes[4] - = {0, 0, 0x1000000u | '+', 0x1000000u | ' '}; - prefix = prefixes[sign]; - } - return {abs_value, prefix}; -} - -template -struct loc_writer { - buffer_appender out; - const format_specs &specs; - std::basic_string sep; - std::string grouping; - std::basic_string decimal_point; - - template ::value)> - auto operator()(T value) -> bool { - auto arg = make_write_int_arg(value, specs.sign); - write_int(out, static_cast>(arg.abs_value), - arg.prefix, specs, digit_grouping(grouping, sep)); - return true; - } - - template ::value)> - auto operator()(T) -> bool { - return false; - } -}; - -template -FMT_CONSTEXPR FMT_INLINE auto write_int(OutputIt out, write_int_arg arg, - const format_specs &specs, locale_ref) -> OutputIt { - static_assert(std::is_same>::value, ""); - auto abs_value = arg.abs_value; - auto prefix = arg.prefix; - switch (specs.type) { - case presentation_type::none: - case presentation_type::dec: { - auto num_digits = count_digits(abs_value); - return write_int(out, num_digits, prefix, specs, - [=](reserve_iterator it) { - return format_decimal(it, abs_value, num_digits) - .end; - }); - } - case presentation_type::hex_lower: - case presentation_type::hex_upper: { - bool upper = specs.type == presentation_type::hex_upper; - if (specs.alt) - prefix_append(prefix, unsigned(upper ? 'X' : 'x') << 8 | '0'); - int num_digits = count_digits<4>(abs_value); - return write_int(out, num_digits, prefix, specs, - [=](reserve_iterator it) { - return format_uint<4, Char>( - it, abs_value, num_digits, upper); - }); - } - case presentation_type::bin_lower: - case presentation_type::bin_upper: { - bool upper = specs.type == presentation_type::bin_upper; - if (specs.alt) - prefix_append(prefix, unsigned(upper ? 'B' : 'b') << 8 | '0'); - int num_digits = count_digits<1>(abs_value); - return write_int(out, num_digits, prefix, specs, - [=](reserve_iterator it) { - return format_uint<1, Char>(it, abs_value, num_digits); - }); - } - case presentation_type::oct: { - int num_digits = count_digits<3>(abs_value); - // Octal prefix '0' is counted as a digit, so only add it if precision - // is not greater than the number of digits. - if (specs.alt && specs.precision <= num_digits && abs_value != 0) - prefix_append(prefix, '0'); - return write_int(out, num_digits, prefix, specs, - [=](reserve_iterator it) { - return format_uint<3, Char>(it, abs_value, num_digits); - }); - } - case presentation_type::chr: - return write_char(out, static_cast(abs_value), specs); - default: throw_format_error("invalid format specifier"); - } - return out; -} -template -FMT_CONSTEXPR FMT_NOINLINE auto write_int_noinline(OutputIt out, - write_int_arg arg, const format_specs &specs, locale_ref loc) - -> OutputIt { - return write_int(out, arg, specs, loc); -} -template ::value && !std::is_same::value - && std::is_same>::value)> -FMT_CONSTEXPR FMT_INLINE auto write(OutputIt out, T value, - const format_specs &specs, locale_ref loc) -> OutputIt { - if (specs.localized && write_loc(out, value, specs, loc)) return out; - return write_int_noinline( - out, make_write_int_arg(value, specs.sign), specs, loc); -} -// An inlined version of write used in format string compilation. -template ::value && !std::is_same::value - && !std::is_same>::value)> -FMT_CONSTEXPR FMT_INLINE auto write(OutputIt out, T value, - const format_specs &specs, locale_ref loc) -> OutputIt { - if (specs.localized && write_loc(out, value, specs, loc)) return out; - return write_int(out, make_write_int_arg(value, specs.sign), specs, loc); -} - -// An output iterator that counts the number of objects written to it and -// discards them. -class counting_iterator { -private: - size_t count_; - -public: - using iterator_category = std::output_iterator_tag; - using difference_type = std::ptrdiff_t; - using pointer = void; - using reference = void; - FMT_UNCHECKED_ITERATOR(counting_iterator); - - struct value_type { - template - FMT_CONSTEXPR void operator=(const T &) {} - }; - - FMT_CONSTEXPR counting_iterator() : count_(0) {} - - FMT_CONSTEXPR auto count() const -> size_t { return count_; } - - FMT_CONSTEXPR auto operator++() -> counting_iterator & { - ++count_; - return *this; - } - FMT_CONSTEXPR auto operator++(int) -> counting_iterator { - auto it = *this; - ++*this; - return it; - } - - FMT_CONSTEXPR friend auto operator+(counting_iterator it, difference_type n) - -> counting_iterator { - it.count_ += static_cast(n); - return it; - } - - FMT_CONSTEXPR auto operator*() const -> value_type { return {}; } -}; - -template -FMT_CONSTEXPR auto write(OutputIt out, basic_string_view s, - const format_specs &specs) -> OutputIt { - auto data = s.data(); - auto size = s.size(); - if (specs.precision >= 0 && to_unsigned(specs.precision) < size) - size = code_point_index(s, to_unsigned(specs.precision)); - bool is_debug = specs.type == presentation_type::debug; - size_t width = 0; - if (specs.width != 0) { - if (is_debug) - width = write_escaped_string(counting_iterator {}, s).count(); - else - width = compute_width(basic_string_view(data, size)); - } - return write_padded( - out, specs, size, width, [=](reserve_iterator it) { - if (is_debug) return write_escaped_string(it, s); - return copy_str(data, data + size, it); - }); -} -template -FMT_CONSTEXPR auto write(OutputIt out, - basic_string_view> s, - const format_specs &specs, locale_ref) -> OutputIt { - return write(out, s, specs); -} -template -FMT_CONSTEXPR auto write(OutputIt out, const Char *s, - const format_specs &specs, locale_ref) -> OutputIt { - if (specs.type == presentation_type::pointer) - return write_ptr(out, bit_cast(s), &specs); - if (!s) throw_format_error("string pointer is null"); - return write(out, basic_string_view(s), specs, {}); -} - -template ::value && !std::is_same::value - && !std::is_same::value)> -FMT_CONSTEXPR auto write(OutputIt out, T value) -> OutputIt { - auto abs_value = static_cast>(value); - bool negative = is_negative(value); - // Don't do -abs_value since it trips unsigned-integer-overflow sanitizer. - if (negative) abs_value = ~abs_value + 1; - int num_digits = count_digits(abs_value); - auto size = (negative ? 1 : 0) + static_cast(num_digits); - auto it = reserve(out, size); - if (auto ptr = to_pointer(it, size)) { - if (negative) *ptr++ = static_cast('-'); - format_decimal(ptr, abs_value, num_digits); - return out; - } - if (negative) *it++ = static_cast('-'); - it = format_decimal(it, abs_value, num_digits).end; - return base_iterator(out, it); -} - -// DEPRECATED! -template -FMT_CONSTEXPR auto parse_align(const Char *begin, const Char *end, - format_specs &specs) -> const Char * { - FMT_ASSERT(begin != end, ""); - auto align = align::none; - auto p = begin + code_point_length(begin); - if (end - p <= 0) p = begin; - for (;;) { - switch (to_ascii(*p)) { - case '<': align = align::left; break; - case '>': align = align::right; break; - case '^': align = align::center; break; - } - if (align != align::none) { - if (p != begin) { - auto c = *begin; - if (c == '}') return begin; - if (c == '{') { - throw_format_error("invalid fill character '{'"); - return begin; - } - specs.fill = {begin, to_unsigned(p - begin)}; - begin = p + 1; - } else { - ++begin; - } - break; - } else if (p == begin) { - break; - } - p = begin; - } - specs.align = align; - return begin; -} - -// A floating-point presentation format. -enum class float_format : unsigned char { - general, // General: exponent notation or fixed point based on magnitude. - exp, // Exponent notation with the default precision of 6, e.g. 1.2e-3. - fixed, // Fixed point with the default precision of 6, e.g. 0.0012. - hex -}; - -struct float_specs { - int precision; - float_format format : 8; - sign_t sign : 8; - bool upper : 1; - bool locale : 1; - bool binary32 : 1; - bool showpoint : 1; -}; - -template -FMT_CONSTEXPR auto parse_float_type_spec(const format_specs &specs) - -> float_specs { - auto result = float_specs(); - result.showpoint = specs.alt; - result.locale = specs.localized; - switch (specs.type) { - case presentation_type::none: - result.format = float_format::general; - break; - case presentation_type::general_upper: - result.upper = true; - FMT_FALLTHROUGH; - case presentation_type::general_lower: - result.format = float_format::general; - break; - case presentation_type::exp_upper: result.upper = true; FMT_FALLTHROUGH; - case presentation_type::exp_lower: - result.format = float_format::exp; - result.showpoint |= specs.precision != 0; - break; - case presentation_type::fixed_upper: - result.upper = true; - FMT_FALLTHROUGH; - case presentation_type::fixed_lower: - result.format = float_format::fixed; - result.showpoint |= specs.precision != 0; - break; - case presentation_type::hexfloat_upper: - result.upper = true; - FMT_FALLTHROUGH; - case presentation_type::hexfloat_lower: - result.format = float_format::hex; - break; - default: throw_format_error("invalid format specifier"); break; - } - return result; -} - -template -FMT_CONSTEXPR20 auto write_nonfinite(OutputIt out, bool isnan, - format_specs specs, const float_specs &fspecs) -> OutputIt { - auto str = isnan ? (fspecs.upper ? "NAN" : "nan") - : (fspecs.upper ? "INF" : "inf"); - constexpr size_t str_size = 3; - auto sign = fspecs.sign; - auto size = str_size + (sign ? 1 : 0); - // Replace '0'-padding with space for non-finite values. - const bool is_zero_fill = specs.fill.size() == 1 - && *specs.fill.data() == static_cast('0'); - if (is_zero_fill) specs.fill[0] = static_cast(' '); - return write_padded(out, specs, size, [=](reserve_iterator it) { - if (sign) *it++ = detail::sign(sign); - return copy_str(str, str + str_size, it); - }); -} - -// A decimal floating-point number significand * pow(10, exp). -struct big_decimal_fp { - const char *significand; - int significand_size; - int exponent; -}; - -constexpr auto get_significand_size(const big_decimal_fp &f) -> int { - return f.significand_size; -} -template -inline auto get_significand_size(const dragonbox::decimal_fp &f) -> int { - return count_digits(f.significand); -} - -template -constexpr auto write_significand(OutputIt out, const char *significand, - int significand_size) -> OutputIt { - return copy_str(significand, significand + significand_size, out); -} -template -inline auto write_significand( - OutputIt out, UInt significand, int significand_size) -> OutputIt { - return format_decimal(out, significand, significand_size).end; -} -template -FMT_CONSTEXPR20 auto write_significand(OutputIt out, T significand, - int significand_size, int exponent, const Grouping &grouping) - -> OutputIt { - if (!grouping.has_separator()) { - out = write_significand(out, significand, significand_size); - return detail::fill_n(out, exponent, static_cast('0')); - } - auto buffer = memory_buffer(); - write_significand(appender(buffer), significand, significand_size); - detail::fill_n(appender(buffer), exponent, '0'); - return grouping.apply(out, string_view(buffer.data(), buffer.size())); -} - -template ::value)> -inline auto write_significand(Char *out, UInt significand, int significand_size, - int integral_size, Char decimal_point) -> Char * { - if (!decimal_point) - return format_decimal(out, significand, significand_size).end; - out += significand_size + 1; - Char *end = out; - int floating_size = significand_size - integral_size; - for (int i = floating_size / 2; i > 0; --i) { - out -= 2; - copy2(out, digits2(static_cast(significand % 100))); - significand /= 100; - } - if (floating_size % 2 != 0) { - *--out = static_cast('0' + significand % 10); - significand /= 10; - } - *--out = decimal_point; - format_decimal(out - integral_size, significand, integral_size); - return end; -} - -template >::value)> -inline auto write_significand(OutputIt out, UInt significand, - int significand_size, int integral_size, Char decimal_point) - -> OutputIt { - // Buffer is large enough to hold digits (digits10 + 1) and a decimal point. - Char buffer[digits10() + 2]; - auto end = write_significand(buffer, significand, significand_size, - integral_size, decimal_point); - return detail::copy_str_noinline(buffer, end, out); -} - -template -FMT_CONSTEXPR auto write_significand(OutputIt out, const char *significand, - int significand_size, int integral_size, Char decimal_point) - -> OutputIt { - out = detail::copy_str_noinline( - significand, significand + integral_size, out); - if (!decimal_point) return out; - *out++ = decimal_point; - return detail::copy_str_noinline( - significand + integral_size, significand + significand_size, out); -} - -template -FMT_CONSTEXPR20 auto write_significand(OutputIt out, T significand, - int significand_size, int integral_size, Char decimal_point, - const Grouping &grouping) -> OutputIt { - if (!grouping.has_separator()) { - return write_significand(out, significand, significand_size, - integral_size, decimal_point); - } - auto buffer = basic_memory_buffer(); - write_significand(buffer_appender(buffer), significand, - significand_size, integral_size, decimal_point); - grouping.apply(out, - basic_string_view(buffer.data(), to_unsigned(integral_size))); - return detail::copy_str_noinline( - buffer.data() + integral_size, buffer.end(), out); -} - -template > -FMT_CONSTEXPR20 auto do_write_float(OutputIt out, const DecimalFP &f, - const format_specs &specs, float_specs fspecs, locale_ref loc) - -> OutputIt { - auto significand = f.significand; - int significand_size = get_significand_size(f); - const Char zero = static_cast('0'); - auto sign = fspecs.sign; - size_t size = to_unsigned(significand_size) + (sign ? 1 : 0); - using iterator = reserve_iterator; - - Char decimal_point = fspecs.locale ? detail::decimal_point(loc) - : static_cast('.'); - - int output_exp = f.exponent + significand_size - 1; - auto use_exp_format = [=]() { - if (fspecs.format == float_format::exp) return true; - if (fspecs.format != float_format::general) return false; - // Use the fixed notation if the exponent is in [exp_lower, exp_upper), - // e.g. 0.0001 instead of 1e-04. Otherwise use the exponent notation. - const int exp_lower = -4, exp_upper = 16; - return output_exp < exp_lower - || output_exp - >= (fspecs.precision > 0 ? fspecs.precision : exp_upper); - }; - if (use_exp_format()) { - int num_zeros = 0; - if (fspecs.showpoint) { - num_zeros = fspecs.precision - significand_size; - if (num_zeros < 0) num_zeros = 0; - size += to_unsigned(num_zeros); - } else if (significand_size == 1) { - decimal_point = Char(); - } - auto abs_output_exp = output_exp >= 0 ? output_exp : -output_exp; - int exp_digits = 2; - if (abs_output_exp >= 100) exp_digits = abs_output_exp >= 1000 ? 4 : 3; - - size += to_unsigned((decimal_point ? 1 : 0) + 2 + exp_digits); - char exp_char = fspecs.upper ? 'E' : 'e'; - auto write = [=](iterator it) { - if (sign) *it++ = detail::sign(sign); - // Insert a decimal point after the first digit and add an exponent. - it = write_significand( - it, significand, significand_size, 1, decimal_point); - if (num_zeros > 0) it = detail::fill_n(it, num_zeros, zero); - *it++ = static_cast(exp_char); - return write_exponent(output_exp, it); - }; - return specs.width > 0 - ? write_padded(out, specs, size, write) - : base_iterator(out, write(reserve(out, size))); - } - - int exp = f.exponent + significand_size; - if (f.exponent >= 0) { - // 1234e5 -> 123400000[.0+] - size += to_unsigned(f.exponent); - int num_zeros = fspecs.precision - exp; - abort_fuzzing_if(num_zeros > 5000); - if (fspecs.showpoint) { - ++size; - if (num_zeros <= 0 && fspecs.format != float_format::fixed) - num_zeros = 0; - if (num_zeros > 0) size += to_unsigned(num_zeros); - } - auto grouping = Grouping(loc, fspecs.locale); - size += to_unsigned(grouping.count_separators(exp)); - return write_padded(out, specs, size, [&](iterator it) { - if (sign) *it++ = detail::sign(sign); - it = write_significand( - it, significand, significand_size, f.exponent, grouping); - if (!fspecs.showpoint) return it; - *it++ = decimal_point; - return num_zeros > 0 ? detail::fill_n(it, num_zeros, zero) : it; - }); - } else if (exp > 0) { - // 1234e-2 -> 12.34[0+] - int num_zeros - = fspecs.showpoint ? fspecs.precision - significand_size : 0; - size += 1 + to_unsigned(num_zeros > 0 ? num_zeros : 0); - auto grouping = Grouping(loc, fspecs.locale); - size += to_unsigned(grouping.count_separators(exp)); - return write_padded(out, specs, size, [&](iterator it) { - if (sign) *it++ = detail::sign(sign); - it = write_significand(it, significand, significand_size, exp, - decimal_point, grouping); - return num_zeros > 0 ? detail::fill_n(it, num_zeros, zero) : it; - }); - } - // 1234e-6 -> 0.001234 - int num_zeros = -exp; - if (significand_size == 0 && fspecs.precision >= 0 - && fspecs.precision < num_zeros) { - num_zeros = fspecs.precision; - } - bool pointy = num_zeros != 0 || significand_size != 0 || fspecs.showpoint; - size += 1 + (pointy ? 1 : 0) + to_unsigned(num_zeros); - return write_padded(out, specs, size, [&](iterator it) { - if (sign) *it++ = detail::sign(sign); - *it++ = zero; - if (!pointy) return it; - *it++ = decimal_point; - it = detail::fill_n(it, num_zeros, zero); - return write_significand(it, significand, significand_size); - }); -} - -template -class fallback_digit_grouping { -public: - constexpr fallback_digit_grouping(locale_ref, bool) {} - - constexpr auto has_separator() const -> bool { return false; } - - constexpr auto count_separators(int) const -> int { return 0; } - - template - constexpr auto apply(Out out, basic_string_view) const -> Out { - return out; - } -}; - -template -FMT_CONSTEXPR20 auto write_float(OutputIt out, const DecimalFP &f, - const format_specs &specs, float_specs fspecs, locale_ref loc) - -> OutputIt { - if (is_constant_evaluated()) { - return do_write_float>(out, f, specs, fspecs, loc); - } else { - return do_write_float(out, f, specs, fspecs, loc); - } -} - -template -constexpr auto isnan(T value) -> bool { - return !(value >= value); // std::isnan doesn't support __float128. -} - -template -struct has_isfinite : std::false_type {}; - -template -struct has_isfinite> - : std::true_type {}; - -template ::value &&has_isfinite::value)> -FMT_CONSTEXPR20 auto isfinite(T value) -> bool { - constexpr T inf = T(std::numeric_limits::infinity()); - if (is_constant_evaluated()) - return !detail::isnan(value) && value < inf && value > -inf; - return std::isfinite(value); -} -template ::value)> -FMT_CONSTEXPR auto isfinite(T value) -> bool { - T inf = T(std::numeric_limits::infinity()); - // std::isfinite doesn't support __float128. - return !detail::isnan(value) && value < inf && value > -inf; -} - -template ::value)> -FMT_INLINE FMT_CONSTEXPR bool signbit(T value) { - if (is_constant_evaluated()) { -#ifdef __cpp_if_constexpr - if constexpr (std::numeric_limits::is_iec559) { - auto bits = detail::bit_cast(static_cast(value)); - return (bits >> (num_bits() - 1)) != 0; - } -#endif - } - return std::signbit(static_cast(value)); -} - -inline FMT_CONSTEXPR20 void adjust_precision(int &precision, int exp10) { - // Adjust fixed precision by exponent because it is relative to decimal - // point. - if (exp10 > 0 && precision > max_value() - exp10) - FMT_THROW(format_error("number is too big")); - precision += exp10; -} - -class bigint { -private: - // A bigint is stored as an array of bigits (big digits), with bigit at index - // 0 being the least significant one. - using bigit = uint32_t; - using double_bigit = uint64_t; - enum { bigits_capacity = 32 }; - basic_memory_buffer bigits_; - int exp_; - - FMT_CONSTEXPR20 auto operator[](int index) const -> bigit { - return bigits_[to_unsigned(index)]; - } - FMT_CONSTEXPR20 auto operator[](int index) -> bigit & { - return bigits_[to_unsigned(index)]; - } - - static constexpr const int bigit_bits = num_bits(); - - friend struct formatter; - - FMT_CONSTEXPR20 void subtract_bigits( - int index, bigit other, bigit &borrow) { - auto result - = static_cast((*this)[index]) - other - borrow; - (*this)[index] = static_cast(result); - borrow = static_cast(result >> (bigit_bits * 2 - 1)); - } - - FMT_CONSTEXPR20 void remove_leading_zeros() { - int num_bigits = static_cast(bigits_.size()) - 1; - while (num_bigits > 0 && (*this)[num_bigits] == 0) - --num_bigits; - bigits_.resize(to_unsigned(num_bigits + 1)); - } - - // Computes *this -= other assuming aligned bigints and *this >= other. - FMT_CONSTEXPR20 void subtract_aligned(const bigint &other) { - FMT_ASSERT(other.exp_ >= exp_, "unaligned bigints"); - FMT_ASSERT(fmt_compare(*this, other) >= 0, ""); - bigit borrow = 0; - int i = other.exp_ - exp_; - for (size_t j = 0, n = other.bigits_.size(); j != n; ++i, ++j) - subtract_bigits(i, other.bigits_[j], borrow); - while (borrow > 0) - subtract_bigits(i, 0, borrow); - remove_leading_zeros(); - } - - FMT_CONSTEXPR20 void multiply(uint32_t value) { - const double_bigit wide_value = value; - bigit carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - double_bigit result = bigits_[i] * wide_value + carry; - bigits_[i] = static_cast(result); - carry = static_cast(result >> bigit_bits); - } - if (carry != 0) bigits_.push_back(carry); - } - - template ::value - || std::is_same::value)> - FMT_CONSTEXPR20 void multiply(UInt value) { - using half_uint = conditional_t::value, - uint64_t, uint32_t>; - const int shift = num_bits() - bigit_bits; - const UInt lower = static_cast(value); - const UInt upper = value >> num_bits(); - UInt carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - UInt result = lower * bigits_[i] + static_cast(carry); - carry = (upper * bigits_[i] << shift) + (result >> bigit_bits) - + (carry >> bigit_bits); - bigits_[i] = static_cast(result); - } - while (carry != 0) { - bigits_.push_back(static_cast(carry)); - carry >>= bigit_bits; - } - } - - template ::value - || std::is_same::value)> - FMT_CONSTEXPR20 void assign(UInt n) { - size_t num_bigits = 0; - do { - bigits_[num_bigits++] = static_cast(n); - n >>= bigit_bits; - } while (n != 0); - bigits_.resize(num_bigits); - exp_ = 0; - } - -public: - FMT_CONSTEXPR20 bigint() : exp_(0) {} - explicit bigint(uint64_t n) { assign(n); } - - bigint(const bigint &) = delete; - void operator=(const bigint &) = delete; - - FMT_CONSTEXPR20 void assign(const bigint &other) { - auto size = other.bigits_.size(); - bigits_.resize(size); - auto data = other.bigits_.data(); - copy_str(data, data + size, bigits_.data()); - exp_ = other.exp_; - } - - template - FMT_CONSTEXPR20 void operator=(Int n) { - FMT_ASSERT(n > 0, ""); - assign(uint64_or_128_t(n)); - } - - FMT_CONSTEXPR20 auto num_bigits() const -> int { - return static_cast(bigits_.size()) + exp_; - } - - FMT_NOINLINE FMT_CONSTEXPR20 auto operator<<=(int shift) -> bigint & { - FMT_ASSERT(shift >= 0, ""); - exp_ += shift / bigit_bits; - shift %= bigit_bits; - if (shift == 0) return *this; - bigit carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - bigit c = bigits_[i] >> (bigit_bits - shift); - bigits_[i] = (bigits_[i] << shift) + carry; - carry = c; - } - if (carry != 0) bigits_.push_back(carry); - return *this; - } - - template - FMT_CONSTEXPR20 auto operator*=(Int value) -> bigint & { - FMT_ASSERT(value > 0, ""); - multiply(uint32_or_64_or_128_t(value)); - return *this; - } - - // updated from compare to fmt_compare to avoid build conflicts with oneDNN - // primitive descriptors. - friend FMT_CONSTEXPR20 auto fmt_compare( - const bigint &lhs, const bigint &rhs) -> int { - int num_lhs_bigits = lhs.num_bigits(), - num_rhs_bigits = rhs.num_bigits(); - if (num_lhs_bigits != num_rhs_bigits) - return num_lhs_bigits > num_rhs_bigits ? 1 : -1; - int i = static_cast(lhs.bigits_.size()) - 1; - int j = static_cast(rhs.bigits_.size()) - 1; - int end = i - j; - if (end < 0) end = 0; - for (; i >= end; --i, --j) { - bigit lhs_bigit = lhs[i], rhs_bigit = rhs[j]; - if (lhs_bigit != rhs_bigit) return lhs_bigit > rhs_bigit ? 1 : -1; - } - if (i != j) return i > j ? 1 : -1; - return 0; - } - - // Returns fmt_compare(lhs1 + lhs2, rhs). - friend FMT_CONSTEXPR20 auto add_compare( - const bigint &lhs1, const bigint &lhs2, const bigint &rhs) -> int { - auto minimum = [](int a, int b) { return a < b ? a : b; }; - auto maximum = [](int a, int b) { return a > b ? a : b; }; - int max_lhs_bigits = maximum(lhs1.num_bigits(), lhs2.num_bigits()); - int num_rhs_bigits = rhs.num_bigits(); - if (max_lhs_bigits + 1 < num_rhs_bigits) return -1; - if (max_lhs_bigits > num_rhs_bigits) return 1; - auto get_bigit = [](const bigint &n, int i) -> bigit { - return i >= n.exp_ && i < n.num_bigits() ? n[i - n.exp_] : 0; - }; - double_bigit borrow = 0; - int min_exp = minimum(minimum(lhs1.exp_, lhs2.exp_), rhs.exp_); - for (int i = num_rhs_bigits - 1; i >= min_exp; --i) { - double_bigit sum = static_cast(get_bigit(lhs1, i)) - + get_bigit(lhs2, i); - bigit rhs_bigit = get_bigit(rhs, i); - if (sum > rhs_bigit + borrow) return 1; - borrow = rhs_bigit + borrow - sum; - if (borrow > 1) return -1; - borrow <<= bigit_bits; - } - return borrow != 0 ? -1 : 0; - } - - // Assigns pow(10, exp) to this bigint. - FMT_CONSTEXPR20 void assign_pow10(int exp) { - FMT_ASSERT(exp >= 0, ""); - if (exp == 0) return *this = 1; - // Find the top bit. - int bitmask = 1; - while (exp >= bitmask) - bitmask <<= 1; - bitmask >>= 1; - // pow(10, exp) = pow(5, exp) * pow(2, exp). First compute pow(5, exp) by - // repeated squaring and multiplication. - *this = 5; - bitmask >>= 1; - while (bitmask != 0) { - square(); - if ((exp & bitmask) != 0) *this *= 5; - bitmask >>= 1; - } - *this <<= exp; // Multiply by pow(2, exp) by shifting. - } - - FMT_CONSTEXPR20 void square() { - int num_bigits = static_cast(bigits_.size()); - int num_result_bigits = 2 * num_bigits; - basic_memory_buffer n(std::move(bigits_)); - bigits_.resize(to_unsigned(num_result_bigits)); - auto sum = uint128_t(); - for (int bigit_index = 0; bigit_index < num_bigits; ++bigit_index) { - // Compute bigit at position bigit_index of the result by adding - // cross-product terms n[i] * n[j] such that i + j == bigit_index. - for (int i = 0, j = bigit_index; j >= 0; ++i, --j) { - // Most terms are multiplied twice which can be optimized in the future. - sum += static_cast(n[i]) * n[j]; - } - (*this)[bigit_index] = static_cast(sum); - sum >>= num_bits(); // Compute the carry. - } - // Do the same for the top half. - for (int bigit_index = num_bigits; bigit_index < num_result_bigits; - ++bigit_index) { - for (int j = num_bigits - 1, i = bigit_index - j; i < num_bigits;) - sum += static_cast(n[i++]) * n[j--]; - (*this)[bigit_index] = static_cast(sum); - sum >>= num_bits(); - } - remove_leading_zeros(); - exp_ *= 2; - } - - // If this bigint has a bigger exponent than other, adds trailing zero to make - // exponents equal. This simplifies some operations such as subtraction. - FMT_CONSTEXPR20 void align(const bigint &other) { - int exp_difference = exp_ - other.exp_; - if (exp_difference <= 0) return; - int num_bigits = static_cast(bigits_.size()); - bigits_.resize(to_unsigned(num_bigits + exp_difference)); - for (int i = num_bigits - 1, j = i + exp_difference; i >= 0; --i, --j) - bigits_[j] = bigits_[i]; - std::uninitialized_fill_n(bigits_.data(), exp_difference, 0u); - exp_ -= exp_difference; - } - - // Divides this bignum by divisor, assigning the remainder to this and - // returning the quotient. - FMT_CONSTEXPR20 auto divmod_assign(const bigint &divisor) -> int { - FMT_ASSERT(this != &divisor, ""); - if (fmt_compare(*this, divisor) < 0) return 0; - FMT_ASSERT(divisor.bigits_[divisor.bigits_.size() - 1u] != 0, ""); - align(divisor); - int quotient = 0; - do { - subtract_aligned(divisor); - ++quotient; - } while (fmt_compare(*this, divisor) >= 0); - return quotient; - } -}; - -// format_dragon flags. -enum dragon { - predecessor_closer = 1, - fixup = 2, // Run fixup to correct exp10 which can be off by one. - fixed = 4, -}; - -// Formats a floating-point number using a variation of the Fixed-Precision -// Positive Floating-Point Printout ((FPP)^2) algorithm by Steele & White: -// https://fmt.dev/papers/p372-steele.pdf. -FMT_CONSTEXPR20 inline void format_dragon(basic_fp value, - unsigned flags, int num_digits, buffer &buf, int &exp10) { - bigint numerator; // 2 * R in (FPP)^2. - bigint denominator; // 2 * S in (FPP)^2. - // lower and upper are differences between value and corresponding boundaries. - bigint lower; // (M^- in (FPP)^2). - bigint upper_store; // upper's value if different from lower. - bigint *upper = nullptr; // (M^+ in (FPP)^2). - // Shift numerator and denominator by an extra bit or two (if lower boundary - // is closer) to make lower and upper integers. This eliminates multiplication - // by 2 during later computations. - bool is_predecessor_closer = (flags & dragon::predecessor_closer) != 0; - int shift = is_predecessor_closer ? 2 : 1; - if (value.e >= 0) { - numerator = value.f; - numerator <<= value.e + shift; - lower = 1; - lower <<= value.e; - if (is_predecessor_closer) { - upper_store = 1; - upper_store <<= value.e + 1; - upper = &upper_store; - } - denominator.assign_pow10(exp10); - denominator <<= shift; - } else if (exp10 < 0) { - numerator.assign_pow10(-exp10); - lower.assign(numerator); - if (is_predecessor_closer) { - upper_store.assign(numerator); - upper_store <<= 1; - upper = &upper_store; - } - numerator *= value.f; - numerator <<= shift; - denominator = 1; - denominator <<= shift - value.e; - } else { - numerator = value.f; - numerator <<= shift; - denominator.assign_pow10(exp10); - denominator <<= shift - value.e; - lower = 1; - if (is_predecessor_closer) { - upper_store = 1ULL << 1; - upper = &upper_store; - } - } - int even = static_cast((value.f & 1) == 0); - if (!upper) upper = &lower; - bool shortest = num_digits < 0; - if ((flags & dragon::fixup) != 0) { - if (add_compare(numerator, *upper, denominator) + even <= 0) { - --exp10; - numerator *= 10; - if (num_digits < 0) { - lower *= 10; - if (upper != &lower) *upper *= 10; - } - } - if ((flags & dragon::fixed) != 0) - adjust_precision(num_digits, exp10 + 1); - } - // Invariant: value == (numerator / denominator) * pow(10, exp10). - if (shortest) { - // Generate the shortest representation. - num_digits = 0; - char *data = buf.data(); - for (;;) { - int digit = numerator.divmod_assign(denominator); - bool low = fmt_compare(numerator, lower) - even - < 0; // numerator <[=] lower. - // numerator + upper >[=] pow10: - bool high = add_compare(numerator, *upper, denominator) + even > 0; - data[num_digits++] = static_cast('0' + digit); - if (low || high) { - if (!low) { - ++data[num_digits - 1]; - } else if (high) { - int result = add_compare(numerator, numerator, denominator); - // Round half to even. - if (result > 0 || (result == 0 && (digit % 2) != 0)) - ++data[num_digits - 1]; - } - buf.try_resize(to_unsigned(num_digits)); - exp10 -= num_digits - 1; - return; - } - numerator *= 10; - lower *= 10; - if (upper != &lower) *upper *= 10; - } - } - // Generate the given number of digits. - exp10 -= num_digits - 1; - if (num_digits <= 0) { - denominator *= 10; - auto digit = add_compare(numerator, numerator, denominator) > 0 ? '1' - : '0'; - buf.push_back(digit); - return; - } - buf.try_resize(to_unsigned(num_digits)); - for (int i = 0; i < num_digits - 1; ++i) { - int digit = numerator.divmod_assign(denominator); - buf[i] = static_cast('0' + digit); - numerator *= 10; - } - int digit = numerator.divmod_assign(denominator); - auto result = add_compare(numerator, numerator, denominator); - if (result > 0 || (result == 0 && (digit % 2) != 0)) { - if (digit == 9) { - const auto overflow = '0' + 10; - buf[num_digits - 1] = overflow; - // Propagate the carry. - for (int i = num_digits - 1; i > 0 && buf[i] == overflow; --i) { - buf[i] = '0'; - ++buf[i - 1]; - } - if (buf[0] == overflow) { - buf[0] = '1'; - if ((flags & dragon::fixed) != 0) - buf.push_back('0'); - else - ++exp10; - } - return; - } - ++digit; - } - buf[num_digits - 1] = static_cast('0' + digit); -} - -// Formats a floating-point number using the hexfloat format. -template ::value)> -FMT_CONSTEXPR20 void format_hexfloat( - Float value, int precision, float_specs specs, buffer &buf) { - // float is passed as double to reduce the number of instantiations and to - // simplify implementation. - static_assert(!std::is_same::value, ""); - - using info = dragonbox::float_info; - - // Assume Float is in the format [sign][exponent][significand]. - using carrier_uint = typename info::carrier_uint; - - constexpr auto num_float_significand_bits - = detail::num_significand_bits(); - - basic_fp f(value); - f.e += num_float_significand_bits; - if (!has_implicit_bit()) --f.e; - - constexpr auto num_fraction_bits - = num_float_significand_bits + (has_implicit_bit() ? 1 : 0); - constexpr auto num_xdigits = (num_fraction_bits + 3) / 4; - - constexpr auto leading_shift = ((num_xdigits - 1) * 4); - const auto leading_mask = carrier_uint(0xF) << leading_shift; - const auto leading_xdigit - = static_cast((f.f & leading_mask) >> leading_shift); - if (leading_xdigit > 1) f.e -= (32 - countl_zero(leading_xdigit) - 1); - - int print_xdigits = num_xdigits - 1; - if (precision >= 0 && print_xdigits > precision) { - const int shift = ((print_xdigits - precision - 1) * 4); - const auto mask = carrier_uint(0xF) << shift; - const auto v = static_cast((f.f & mask) >> shift); - - if (v >= 8) { - const auto inc = carrier_uint(1) << (shift + 4); - f.f += inc; - f.f &= ~(inc - 1); - } - - // Check long double overflow - if (!has_implicit_bit()) { - const auto implicit_bit = carrier_uint(1) - << num_float_significand_bits; - if ((f.f & implicit_bit) == implicit_bit) { - f.f >>= 4; - f.e += 4; - } - } - - print_xdigits = precision; - } - - char xdigits[num_bits() / 4]; - detail::fill_n(xdigits, sizeof(xdigits), '0'); - format_uint<4>(xdigits, f.f, num_xdigits, specs.upper); - - // Remove zero tail - while (print_xdigits > 0 && xdigits[print_xdigits] == '0') - --print_xdigits; - - buf.push_back('0'); - buf.push_back(specs.upper ? 'X' : 'x'); - buf.push_back(xdigits[0]); - if (specs.showpoint || print_xdigits > 0 || print_xdigits < precision) - buf.push_back('.'); - buf.append(xdigits + 1, xdigits + 1 + print_xdigits); - for (; print_xdigits < precision; ++print_xdigits) - buf.push_back('0'); - - buf.push_back(specs.upper ? 'P' : 'p'); - - uint32_t abs_e; - if (f.e < 0) { - buf.push_back('-'); - abs_e = static_cast(-f.e); - } else { - buf.push_back('+'); - abs_e = static_cast(f.e); - } - format_decimal(appender(buf), abs_e, detail::count_digits(abs_e)); -} - -template ::value)> -FMT_CONSTEXPR20 void format_hexfloat( - Float value, int precision, float_specs specs, buffer &buf) { - format_hexfloat(static_cast(value), precision, specs, buf); -} - -constexpr auto fractional_part_rounding_thresholds(int index) -> uint32_t { - // For checking rounding thresholds. - // The kth entry is chosen to be the smallest integer such that the - // upper 32-bits of 10^(k+1) times it is strictly bigger than 5 * 10^k. - // It is equal to ceil(2^31 + 2^32/10^(k + 1)). - // These are stored in a string literal because we cannot have static arrays - // in constexpr functions and non-static ones are poorly optimized. - return U"\x9999999a\x828f5c29\x80418938\x80068db9\x8000a7c6\x800010c7" - U"\x800001ae\x8000002b"[index]; -} - -template -FMT_CONSTEXPR20 auto format_float(Float value, int precision, float_specs specs, - buffer &buf) -> int { - // float is passed as double to reduce the number of instantiations. - static_assert(!std::is_same::value, ""); - FMT_ASSERT(value >= 0, "value is negative"); - auto converted_value = convert_float(value); - - const bool fixed = specs.format == float_format::fixed; - if (value <= 0) { // <= instead of == to silence a warning. - if (precision <= 0 || !fixed) { - buf.push_back('0'); - return 0; - } - buf.try_resize(to_unsigned(precision)); - fill_n(buf.data(), precision, '0'); - return -precision; - } - - int exp = 0; - bool use_dragon = true; - unsigned dragon_flags = 0; - if (!is_fast_float() || is_constant_evaluated()) { - const auto inv_log2_10 = 0.3010299956639812; // 1 / log2(10) - using info = dragonbox::float_info; - const auto f = basic_fp(converted_value); - // Compute exp, an approximate power of 10, such that - // 10^(exp - 1) <= value < 10^exp or 10^exp <= value < 10^(exp + 1). - // This is based on log10(value) == log2(value) / log2(10) and approximation - // of log2(value) by e + num_fraction_bits idea from double-conversion. - auto e = (f.e + count_digits<1>(f.f) - 1) * inv_log2_10 - 1e-10; - exp = static_cast(e); - if (e > exp) ++exp; // Compute ceil. - dragon_flags = dragon::fixup; - } else if (precision < 0) { - // Use Dragonbox for the shortest format. - if (specs.binary32) { - auto dec = dragonbox::to_decimal(static_cast(value)); - write(buffer_appender(buf), dec.significand); - return dec.exponent; - } - auto dec = dragonbox::to_decimal(static_cast(value)); - write(buffer_appender(buf), dec.significand); - return dec.exponent; - } else { - // Extract significand bits and exponent bits. - using info = dragonbox::float_info; - auto br = bit_cast(static_cast(value)); - - const uint64_t significand_mask - = (static_cast(1) << num_significand_bits()) - - 1; - uint64_t significand = (br & significand_mask); - int exponent = static_cast((br & exponent_mask()) - >> num_significand_bits()); - - if (exponent != 0) { // Check if normal. - exponent - -= exponent_bias() + num_significand_bits(); - significand |= (static_cast(1) - << num_significand_bits()); - significand <<= 1; - } else { - // Normalize subnormal inputs. - FMT_ASSERT(significand != 0, "zeros should not appear here"); - int shift = countl_zero(significand); - FMT_ASSERT(shift >= num_bits() - - num_significand_bits(), - ""); - shift -= (num_bits() - num_significand_bits() - - 2); - exponent = (std::numeric_limits::min_exponent - - num_significand_bits()) - - shift; - significand <<= shift; - } - - // Compute the first several nonzero decimal significand digits. - // We call the number we get the first segment. - const int k = info::kappa - dragonbox::floor_log10_pow2(exponent); - exp = -k; - const int beta = exponent + dragonbox::floor_log2_pow10(k); - uint64_t first_segment; - bool has_more_segments; - int digits_in_the_first_segment; - { - const auto r = dragonbox::umul192_upper128( - significand << beta, dragonbox::get_cached_power(k)); - first_segment = r.high(); - has_more_segments = r.low() != 0; - - // The first segment can have 18 ~ 19 digits. - if (first_segment >= 1000000000000000000ULL) { - digits_in_the_first_segment = 19; - } else { - // When it is of 18-digits, we align it to 19-digits by adding a bogus - // zero at the end. - digits_in_the_first_segment = 18; - first_segment *= 10; - } - } - - // Compute the actual number of decimal digits to print. - if (fixed) - adjust_precision(precision, exp + digits_in_the_first_segment); - - // Use Dragon4 only when there might be not enough digits in the first - // segment. - if (digits_in_the_first_segment > precision) { - use_dragon = false; - - if (precision <= 0) { - exp += digits_in_the_first_segment; - - if (precision < 0) { - // Nothing to do, since all we have are just leading zeros. - buf.try_resize(0); - } else { - // We may need to round-up. - buf.try_resize(1); - if ((first_segment - | static_cast(has_more_segments)) - > 5000000000000000000ULL) { - buf[0] = '1'; - } else { - buf[0] = '0'; - } - } - } // precision <= 0 - else { - exp += digits_in_the_first_segment - precision; - - // When precision > 0, we divide the first segment into three - // subsegments, each with 9, 9, and 0 ~ 1 digits so that each fits - // in 32-bits which usually allows faster calculation than in - // 64-bits. Since some compiler (e.g. MSVC) doesn't know how to optimize - // division-by-constant for large 64-bit divisors, we do it here - // manually. The magic number 7922816251426433760 below is equal to - // ceil(2^(64+32) / 10^10). - const uint32_t first_subsegment = static_cast( - dragonbox::umul128_upper64( - first_segment, 7922816251426433760ULL) - >> 32); - const uint64_t second_third_subsegments - = first_segment - first_subsegment * 10000000000ULL; - - uint64_t prod; - uint32_t digits; - bool should_round_up; - int number_of_digits_to_print = precision > 9 ? 9 : precision; - - // Print a 9-digits subsegment, either the first or the second. - auto print_subsegment = [&](uint32_t subsegment, char *buffer) { - int number_of_digits_printed = 0; - - // If we want to print an odd number of digits from the subsegment, - if ((number_of_digits_to_print & 1) != 0) { - // Convert to 64-bit fixed-point fractional form with 1-digit - // integer part. The magic number 720575941 is a good enough - // approximation of 2^(32 + 24) / 10^8; see - // https://jk-jeon.github.io/posts/2022/12/fixed-precision-formatting/#fixed-length-case - // for details. - prod = ((subsegment * static_cast(720575941)) - >> 24) - + 1; - digits = static_cast(prod >> 32); - *buffer = static_cast('0' + digits); - number_of_digits_printed++; - } - // If we want to print an even number of digits from the - // first_subsegment, - else { - // Convert to 64-bit fixed-point fractional form with 2-digits - // integer part. The magic number 450359963 is a good enough - // approximation of 2^(32 + 20) / 10^7; see - // https://jk-jeon.github.io/posts/2022/12/fixed-precision-formatting/#fixed-length-case - // for details. - prod = ((subsegment * static_cast(450359963)) - >> 20) - + 1; - digits = static_cast(prod >> 32); - copy2(buffer, digits2(digits)); - number_of_digits_printed += 2; - } - - // Print all digit pairs. - while (number_of_digits_printed - < number_of_digits_to_print) { - prod = static_cast(prod) - * static_cast(100); - digits = static_cast(prod >> 32); - copy2(buffer + number_of_digits_printed, - digits2(digits)); - number_of_digits_printed += 2; - } - }; - - // Print first subsegment. - print_subsegment(first_subsegment, buf.data()); - - // Perform rounding if the first subsegment is the last subsegment to - // print. - if (precision <= 9) { - // Rounding inside the subsegment. - // We round-up if: - // - either the fractional part is strictly larger than 1/2, or - // - the fractional part is exactly 1/2 and the last digit is odd. - // We rely on the following observations: - // - If fractional_part >= threshold, then the fractional part is - // strictly larger than 1/2. - // - If the MSB of fractional_part is set, then the fractional part - // must be at least 1/2. - // - When the MSB of fractional_part is set, either - // second_third_subsegments being nonzero or has_more_segments - // being true means there are further digits not printed, so the - // fractional part is strictly larger than 1/2. - if (precision < 9) { - uint32_t fractional_part = static_cast(prod); - should_round_up = fractional_part - >= fractional_part_rounding_thresholds( - 8 - number_of_digits_to_print) - || ((fractional_part >> 31) - & ((digits & 1) - | (second_third_subsegments - != 0) - | has_more_segments)) - != 0; - } - // Rounding at the subsegment boundary. - // In this case, the fractional part is at least 1/2 if and only if - // second_third_subsegments >= 5000000000ULL, and is strictly larger - // than 1/2 if we further have either second_third_subsegments > - // 5000000000ULL or has_more_segments == true. - else { - should_round_up - = second_third_subsegments > 5000000000ULL - || (second_third_subsegments == 5000000000ULL - && ((digits & 1) != 0 - || has_more_segments)); - } - } - // Otherwise, print the second subsegment. - else { - // Compilers are not aware of how to leverage the maximum value of - // second_third_subsegments to find out a better magic number which - // allows us to eliminate an additional shift. 1844674407370955162 = - // ceil(2^64/10) < ceil(2^64*(10^9/(10^10 - 1))). - const uint32_t second_subsegment = static_cast( - dragonbox::umul128_upper64(second_third_subsegments, - 1844674407370955162ULL)); - const uint32_t third_subsegment - = static_cast(second_third_subsegments) - - second_subsegment * 10; - - number_of_digits_to_print = precision - 9; - print_subsegment(second_subsegment, buf.data() + 9); - - // Rounding inside the subsegment. - if (precision < 18) { - // The condition third_subsegment != 0 implies that the segment was - // of 19 digits, so in this case the third segment should be - // consisting of a genuine digit from the input. - uint32_t fractional_part = static_cast(prod); - should_round_up = fractional_part - >= fractional_part_rounding_thresholds( - 8 - number_of_digits_to_print) - || ((fractional_part >> 31) - & ((digits & 1) - | (third_subsegment != 0) - | has_more_segments)) - != 0; - } - // Rounding at the subsegment boundary. - else { - // In this case, the segment must be of 19 digits, thus - // the third subsegment should be consisting of a genuine digit from - // the input. - should_round_up = third_subsegment > 5 - || (third_subsegment == 5 - && ((digits & 1) != 0 - || has_more_segments)); - } - } - - // Round-up if necessary. - if (should_round_up) { - ++buf[precision - 1]; - for (int i = precision - 1; i > 0 && buf[i] > '9'; --i) { - buf[i] = '0'; - ++buf[i - 1]; - } - if (buf[0] > '9') { - buf[0] = '1'; - if (fixed) - buf[precision++] = '0'; - else - ++exp; - } - } - buf.try_resize(to_unsigned(precision)); - } - } // if (digits_in_the_first_segment > precision) - else { - // Adjust the exponent for its use in Dragon4. - exp += digits_in_the_first_segment - 1; - } - } - if (use_dragon) { - auto f = basic_fp(); - bool is_predecessor_closer = specs.binary32 - ? f.assign(static_cast(value)) - : f.assign(converted_value); - if (is_predecessor_closer) dragon_flags |= dragon::predecessor_closer; - if (fixed) dragon_flags |= dragon::fixed; - // Limit precision to the maximum possible number of significant digits in - // an IEEE754 double because we don't need to generate zeros. - const int max_double_digits = 767; - if (precision > max_double_digits) precision = max_double_digits; - format_dragon(f, dragon_flags, precision, buf, exp); - } - if (!fixed && !specs.showpoint) { - // Remove trailing zeros. - auto num_digits = buf.size(); - while (num_digits > 0 && buf[num_digits - 1] == '0') { - --num_digits; - ++exp; - } - buf.try_resize(num_digits); - } - return exp; -} -template -FMT_CONSTEXPR20 auto write_float(OutputIt out, T value, - format_specs specs, locale_ref loc) -> OutputIt { - float_specs fspecs = parse_float_type_spec(specs); - fspecs.sign = specs.sign; - if (detail::signbit(value)) { // value < 0 is false for NaN so use signbit. - fspecs.sign = sign::minus; - value = -value; - } else if (fspecs.sign == sign::minus) { - fspecs.sign = sign::none; - } - - if (!detail::isfinite(value)) - return write_nonfinite(out, detail::isnan(value), specs, fspecs); - - if (specs.align == align::numeric && fspecs.sign) { - auto it = reserve(out, 1); - *it++ = detail::sign(fspecs.sign); - out = base_iterator(out, it); - fspecs.sign = sign::none; - if (specs.width != 0) --specs.width; - } - - memory_buffer buffer; - if (fspecs.format == float_format::hex) { - if (fspecs.sign) buffer.push_back(detail::sign(fspecs.sign)); - format_hexfloat(convert_float(value), specs.precision, fspecs, buffer); - return write_bytes( - out, {buffer.data(), buffer.size()}, specs); - } - int precision - = specs.precision >= 0 || specs.type == presentation_type::none - ? specs.precision - : 6; - if (fspecs.format == float_format::exp) { - if (precision == max_value()) - throw_format_error("number is too big"); - else - ++precision; - } else if (fspecs.format != float_format::fixed && precision == 0) { - precision = 1; - } - if (const_check(std::is_same())) fspecs.binary32 = true; - int exp = format_float(convert_float(value), precision, fspecs, buffer); - fspecs.precision = precision; - auto f = big_decimal_fp { - buffer.data(), static_cast(buffer.size()), exp}; - return write_float(out, f, specs, fspecs, loc); -} - -template ::value)> -FMT_CONSTEXPR20 auto write(OutputIt out, T value, format_specs specs, - locale_ref loc = {}) -> OutputIt { - if (const_check(!is_supported_floating_point(value))) return out; - return specs.localized && write_loc(out, value, specs, loc) - ? out - : write_float(out, value, specs, loc); -} - -template ::value)> -FMT_CONSTEXPR20 auto write(OutputIt out, T value) -> OutputIt { - if (is_constant_evaluated()) return write(out, value, format_specs()); - if (const_check(!is_supported_floating_point(value))) return out; - - auto fspecs = float_specs(); - if (detail::signbit(value)) { - fspecs.sign = sign::minus; - value = -value; - } - - constexpr auto specs = format_specs(); - using floaty - = conditional_t::value, double, T>; - using floaty_uint = typename dragonbox::float_info::carrier_uint; - floaty_uint mask = exponent_mask(); - if ((bit_cast(value) & mask) == mask) - return write_nonfinite(out, std::isnan(value), specs, fspecs); - - auto dec = dragonbox::to_decimal(static_cast(value)); - return write_float(out, dec, specs, fspecs, {}); -} - -template ::value && !is_fast_float::value)> -inline auto write(OutputIt out, T value) -> OutputIt { - return write(out, value, format_specs()); -} - -template -auto write(OutputIt out, monostate, format_specs = {}, locale_ref = {}) - -> OutputIt { - FMT_ASSERT(false, ""); - return out; -} - -template -FMT_CONSTEXPR auto write(OutputIt out, basic_string_view value) - -> OutputIt { - auto it = reserve(out, value.size()); - it = copy_str_noinline(value.begin(), value.end(), it); - return base_iterator(out, it); -} - -template ::value)> -constexpr auto write(OutputIt out, const T &value) -> OutputIt { - return write(out, to_string_view(value)); -} - -// FMT_ENABLE_IF() condition separated to workaround an MSVC bug. -template ::value && !std::is_same::value - && mapped_type_constant>::value - != type::custom_type, - FMT_ENABLE_IF(check)> -FMT_CONSTEXPR auto write(OutputIt out, T value) -> OutputIt { - return write(out, static_cast>(value)); -} - -template ::value)> -FMT_CONSTEXPR auto write(OutputIt out, T value, - const format_specs &specs = {}, locale_ref = {}) -> OutputIt { - return specs.type != presentation_type::none - && specs.type != presentation_type::string - ? write(out, value ? 1 : 0, specs, {}) - : write_bytes(out, value ? "true" : "false", specs); -} - -template -FMT_CONSTEXPR auto write(OutputIt out, Char value) -> OutputIt { - auto it = reserve(out, 1); - *it++ = value; - return base_iterator(out, it); -} - -template -FMT_CONSTEXPR_CHAR_TRAITS auto write(OutputIt out, const Char *value) - -> OutputIt { - if (value) return write(out, basic_string_view(value)); - throw_format_error("string pointer is null"); - return out; -} - -template ::value)> -auto write(OutputIt out, const T *value, const format_specs &specs = {}, - locale_ref = {}) -> OutputIt { - return write_ptr(out, bit_cast(value), &specs); -} - -// A write overload that handles implicit conversions. -template > -FMT_CONSTEXPR auto write( - OutputIt out, const T &value) -> enable_if_t::value - && !is_string::value && !is_floating_point::value - && !std::is_same::value - && !std::is_same().map(value))>>::value, - OutputIt> { - return write(out, arg_mapper().map(value)); -} - -template > -FMT_CONSTEXPR auto write(OutputIt out, const T &value) - -> enable_if_t::value - == type::custom_type, - OutputIt> { - auto formatter = typename Context::template formatter_type(); - auto parse_ctx = typename Context::parse_context_type({}); - formatter.parse(parse_ctx); - auto ctx = Context(out, {}, {}); - return formatter.format(value, ctx); -} - -// An argument visitor that formats the argument and writes it via the output -// iterator. It's a class and not a generic lambda for compatibility with C++11. -template -struct default_arg_formatter { - using iterator = buffer_appender; - using context = buffer_context; - - iterator out; - basic_format_args args; - locale_ref loc; - - template - auto operator()(T value) -> iterator { - return write(out, value); - } - auto operator()(typename basic_format_arg::handle h) -> iterator { - basic_format_parse_context parse_ctx({}); - context format_ctx(out, args, loc); - h.format(parse_ctx, format_ctx); - return format_ctx.out(); - } -}; - -template -struct arg_formatter { - using iterator = buffer_appender; - using context = buffer_context; - - iterator out; - const format_specs &specs; - locale_ref locale; - - template - FMT_CONSTEXPR FMT_INLINE auto operator()(T value) -> iterator { - return detail::write(out, value, specs, locale); - } - auto operator()(typename basic_format_arg::handle) -> iterator { - // User-defined types are handled separately because they require access - // to the parse context. - return out; - } -}; - -struct width_checker { - template ::value)> - FMT_CONSTEXPR auto operator()(T value) -> unsigned long long { - if (is_negative(value)) throw_format_error("negative width"); - return static_cast(value); - } - - template ::value)> - FMT_CONSTEXPR auto operator()(T) -> unsigned long long { - throw_format_error("width is not integer"); - return 0; - } -}; - -struct precision_checker { - template ::value)> - FMT_CONSTEXPR auto operator()(T value) -> unsigned long long { - if (is_negative(value)) throw_format_error("negative precision"); - return static_cast(value); - } - - template ::value)> - FMT_CONSTEXPR auto operator()(T) -> unsigned long long { - throw_format_error("precision is not integer"); - return 0; - } -}; - -template -FMT_CONSTEXPR auto get_dynamic_spec(FormatArg arg) -> int { - unsigned long long value = visit_format_arg(Handler(), arg); - if (value > to_unsigned(max_value())) - throw_format_error("number is too big"); - return static_cast(value); -} - -template -FMT_CONSTEXPR auto get_arg(Context &ctx, ID id) -> decltype(ctx.arg(id)) { - auto arg = ctx.arg(id); - if (!arg) ctx.on_error("argument not found"); - return arg; -} - -template -FMT_CONSTEXPR void handle_dynamic_spec( - int &value, arg_ref ref, Context &ctx) { - switch (ref.kind) { - case arg_id_kind::none: break; - case arg_id_kind::index: - value = detail::get_dynamic_spec( - get_arg(ctx, ref.val.index)); - break; - case arg_id_kind::name: - value = detail::get_dynamic_spec( - get_arg(ctx, ref.val.name)); - break; - } -} - -#if FMT_USE_USER_DEFINED_LITERALS -#if FMT_USE_NONTYPE_TEMPLATE_ARGS -template Str> -struct statically_named_arg : view { - static constexpr auto name = Str.data; - - const T &value; - statically_named_arg(const T &v) : value(v) {} -}; - -template Str> -struct is_named_arg> : std::true_type {}; - -template Str> -struct is_statically_named_arg> - : std::true_type {}; - -template Str> -struct udl_arg { - template - auto operator=(T &&value) const { - return statically_named_arg(std::forward(value)); - } -}; -#else -template -struct udl_arg { - const Char *str; - - template - auto operator=(T &&value) const -> named_arg { - return {str, std::forward(value)}; - } -}; -#endif -#endif // FMT_USE_USER_DEFINED_LITERALS - -template -auto vformat(const Locale &loc, basic_string_view fmt, - basic_format_args>> args) - -> std::basic_string { - auto buf = basic_memory_buffer(); - detail::vformat_to(buf, fmt, args, detail::locale_ref(loc)); - return {buf.data(), buf.size()}; -} - -using format_func = void (*)(detail::buffer &, int, const char *); - -FMT_API void format_error_code( - buffer &out, int error_code, string_view message) noexcept; - -FMT_API void report_error( - format_func func, int error_code, const char *message) noexcept; -} // namespace detail - -FMT_API auto vsystem_error(int error_code, string_view format_str, - format_args args) -> std::system_error; - -/** - \rst - Constructs :class:`std::system_error` with a message formatted with - ``fmt::format(fmt, args...)``. - *error_code* is a system error code as given by ``errno``. - - **Example**:: - - // This throws std::system_error with the description - // cannot open file 'madeup': No such file or directory - // or similar (system message may vary). - const char* filename = "madeup"; - std::FILE* file = std::fopen(filename, "r"); - if (!file) - throw fmt::system_error(errno, "cannot open file '{}'", filename); - \endrst - */ -template -auto system_error(int error_code, format_string fmt, T &&...args) - -> std::system_error { - return vsystem_error(error_code, fmt, fmt::make_format_args(args...)); -} - -/** - \rst - Formats an error message for an error returned by an operating system or a - language runtime, for example a file opening error, and writes it to *out*. - The format is the same as the one used by ``std::system_error(ec, message)`` - where ``ec`` is ``std::error_code(error_code, std::generic_category()})``. - It is implementation-defined but normally looks like: - - .. parsed-literal:: - **: ** - - where ** is the passed message and ** is the system - message corresponding to the error code. - *error_code* is a system error code as given by ``errno``. - \endrst - */ -FMT_API void format_system_error(detail::buffer &out, int error_code, - const char *message) noexcept; - -// Reports a system error without throwing an exception. -// Can be used to report errors from destructors. -FMT_API void report_system_error(int error_code, const char *message) noexcept; - -/** Fast integer formatter. */ -class format_int { -private: - // Buffer should be large enough to hold all digits (digits10 + 1), - // a sign and a null character. - enum { - buffer_size = std::numeric_limits::digits10 + 3 - }; - mutable char buffer_[buffer_size]; - char *str_; - - template - auto format_unsigned(UInt value) -> char * { - auto n = static_cast>(value); - return detail::format_decimal(buffer_, n, buffer_size - 1).begin; - } - - template - auto format_signed(Int value) -> char * { - auto abs_value = static_cast>(value); - bool negative = value < 0; - if (negative) abs_value = 0 - abs_value; - auto begin = format_unsigned(abs_value); - if (negative) *--begin = '-'; - return begin; - } - -public: - explicit format_int(int value) : str_(format_signed(value)) {} - explicit format_int(long value) : str_(format_signed(value)) {} - explicit format_int(long long value) : str_(format_signed(value)) {} - explicit format_int(unsigned value) : str_(format_unsigned(value)) {} - explicit format_int(unsigned long value) : str_(format_unsigned(value)) {} - explicit format_int(unsigned long long value) - : str_(format_unsigned(value)) {} - - /** Returns the number of characters written to the output buffer. */ - auto size() const -> size_t { - return detail::to_unsigned(buffer_ - str_ + buffer_size - 1); - } - - /** - Returns a pointer to the output buffer content. No terminating null - character is appended. - */ - auto data() const -> const char * { return str_; } - - /** - Returns a pointer to the output buffer content with terminating null - character appended. - */ - auto c_str() const -> const char * { - buffer_[buffer_size - 1] = '\0'; - return str_; - } - - /** - \rst - Returns the content of the output buffer as an ``std::string``. - \endrst - */ - auto str() const -> std::string { return std::string(str_, size()); } -}; - -template -struct formatter::value>> - : formatter, Char> { - template - auto format(const T &value, FormatContext &ctx) const - -> decltype(ctx.out()) { - using base = formatter, Char>; - return base::format(format_as(value), ctx); - } -}; - -#define FMT_FORMAT_AS(Type, Base) \ - template \ - struct formatter : formatter {} - -FMT_FORMAT_AS(signed char, int); -FMT_FORMAT_AS(unsigned char, unsigned); -FMT_FORMAT_AS(short, int); -FMT_FORMAT_AS(unsigned short, unsigned); -FMT_FORMAT_AS(long, detail::long_type); -FMT_FORMAT_AS(unsigned long, detail::ulong_type); -FMT_FORMAT_AS(Char *, const Char *); -FMT_FORMAT_AS(std::basic_string, basic_string_view); -FMT_FORMAT_AS(std::nullptr_t, const void *); -FMT_FORMAT_AS(detail::std_string_view, basic_string_view); -FMT_FORMAT_AS(void *, const void *); - -template -struct formatter : formatter, Char> {}; - -/** - \rst - Converts ``p`` to ``const void*`` for pointer formatting. - - **Example**:: - - auto s = fmt::format("{}", fmt::ptr(p)); - \endrst - */ -template -auto ptr(T p) -> const void * { - static_assert(std::is_pointer::value, ""); - return detail::bit_cast(p); -} -template -auto ptr(const std::unique_ptr &p) -> const void * { - return p.get(); -} -template -auto ptr(const std::shared_ptr &p) -> const void * { - return p.get(); -} - -/** - \rst - Converts ``e`` to the underlying type. - - **Example**:: - - enum class color { red, green, blue }; - auto s = fmt::format("{}", fmt::underlying(color::red)); - \endrst - */ -template -constexpr auto underlying(Enum e) noexcept -> underlying_t { - return static_cast>(e); -} - -namespace enums { -template ::value)> -constexpr auto format_as(Enum e) noexcept -> underlying_t { - return static_cast>(e); -} -} // namespace enums - -class bytes { -private: - string_view data_; - friend struct formatter; - -public: - explicit bytes(string_view data) : data_(data) {} -}; - -template <> -struct formatter { -private: - detail::dynamic_format_specs<> specs_; - -public: - template - FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { - return parse_format_specs( - ctx.begin(), ctx.end(), specs_, ctx, detail::type::string_type); - } - - template - auto format(bytes b, FormatContext &ctx) -> decltype(ctx.out()) { - detail::handle_dynamic_spec( - specs_.width, specs_.width_ref, ctx); - detail::handle_dynamic_spec( - specs_.precision, specs_.precision_ref, ctx); - return detail::write_bytes(ctx.out(), b.data_, specs_); - } -}; - -// group_digits_view is not derived from view because it copies the argument. -template -struct group_digits_view { - T value; -}; - -/** - \rst - Returns a view that formats an integer value using ',' as a locale-independent - thousands separator. - - **Example**:: - - fmt::print("{}", fmt::group_digits(12345)); - // Output: "12,345" - \endrst - */ -template -auto group_digits(T value) -> group_digits_view { - return {value}; -} - -template -struct formatter> : formatter { -private: - detail::dynamic_format_specs<> specs_; - -public: - template - FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { - return parse_format_specs( - ctx.begin(), ctx.end(), specs_, ctx, detail::type::int_type); - } - - template - auto format(group_digits_view t, FormatContext &ctx) - -> decltype(ctx.out()) { - detail::handle_dynamic_spec( - specs_.width, specs_.width_ref, ctx); - detail::handle_dynamic_spec( - specs_.precision, specs_.precision_ref, ctx); - return detail::write_int(ctx.out(), - static_cast>(t.value), 0, specs_, - detail::digit_grouping("\3", ",")); - } -}; - -template -struct nested_view { - const formatter *fmt; - const T *value; -}; - -template -struct formatter> { - FMT_CONSTEXPR auto parse(format_parse_context &ctx) -> const char * { - return ctx.begin(); - } - auto format(nested_view view, format_context &ctx) const - -> decltype(ctx.out()) { - return view.fmt->format(*view.value, ctx); - } -}; - -template -struct nested_formatter { -private: - int width_; - detail::fill_t fill_; - align_t align_ : 4; - formatter formatter_; - -public: - constexpr nested_formatter() : width_(0), align_(align_t::none) {} - - FMT_CONSTEXPR auto parse(format_parse_context &ctx) -> const char * { - auto specs = detail::dynamic_format_specs(); - auto it = parse_format_specs( - ctx.begin(), ctx.end(), specs, ctx, detail::type::none_type); - width_ = specs.width; - fill_ = specs.fill; - align_ = specs.align; - ctx.advance_to(it); - return formatter_.parse(ctx); - } - - template - auto write_padded(format_context &ctx, F write) const - -> decltype(ctx.out()) { - if (width_ == 0) return write(ctx.out()); - auto buf = memory_buffer(); - write(std::back_inserter(buf)); - auto specs = format_specs<>(); - specs.width = width_; - specs.fill = fill_; - specs.align = align_; - return detail::write( - ctx.out(), string_view(buf.data(), buf.size()), specs); - } - - auto nested(const T &value) const -> nested_view { - return nested_view {&formatter_, &value}; - } -}; - -// DEPRECATED! join_view will be moved to ranges.h. -template -struct join_view : detail::view { - It begin; - Sentinel end; - basic_string_view sep; - - join_view(It b, Sentinel e, basic_string_view s) - : begin(b), end(e), sep(s) {} -}; - -template -struct formatter, Char> { -private: - using value_type = -#ifdef __cpp_lib_ranges - std::iter_value_t; -#else - typename std::iterator_traits::value_type; -#endif - formatter, Char> value_formatter_; - -public: - template - FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const Char * { - return value_formatter_.parse(ctx); - } - - template - auto format(const join_view &value, - FormatContext &ctx) const -> decltype(ctx.out()) { - auto it = value.begin; - auto out = ctx.out(); - if (it != value.end) { - out = value_formatter_.format(*it, ctx); - ++it; - while (it != value.end) { - out = detail::copy_str( - value.sep.begin(), value.sep.end(), out); - ctx.advance_to(out); - out = value_formatter_.format(*it, ctx); - ++it; - } - } - return out; - } -}; - -/** - Returns a view that formats the iterator range `[begin, end)` with elements - separated by `sep`. - */ -template -auto join(It begin, Sentinel end, string_view sep) -> join_view { - return {begin, end, sep}; -} - -/** - \rst - Returns a view that formats `range` with elements separated by `sep`. - - **Example**:: - - std::vector v = {1, 2, 3}; - fmt::print("{}", fmt::join(v, ", ")); - // Output: "1, 2, 3" - - ``fmt::join`` applies passed format specifiers to the range elements:: - - fmt::print("{:02}", fmt::join(v, ", ")); - // Output: "01, 02, 03" - \endrst - */ -template -auto join(Range &&range, string_view sep) - -> join_view, detail::sentinel_t> { - return join(std::begin(range), std::end(range), sep); -} - -/** - \rst - Converts *value* to ``std::string`` using the default format for type *T*. - - **Example**:: - - #include - - std::string answer = fmt::to_string(42); - \endrst - */ -template ::value - && !detail::has_format_as::value)> -inline auto to_string(const T &value) -> std::string { - auto buffer = memory_buffer(); - detail::write(appender(buffer), value); - return {buffer.data(), buffer.size()}; -} - -template ::value)> -FMT_NODISCARD inline auto to_string(T value) -> std::string { - // The buffer should be large enough to store the number including the sign - // or "false" for bool. - constexpr int max_size = detail::digits10() + 2; - char buffer[max_size > 5 ? static_cast(max_size) : 5]; - char *begin = buffer; - return std::string(begin, detail::write(begin, value)); -} - -template -FMT_NODISCARD auto to_string(const basic_memory_buffer &buf) - -> std::basic_string { - auto size = buf.size(); - detail::assume(size < std::basic_string().max_size()); - return std::basic_string(buf.data(), size); -} - -template ::value && detail::has_format_as::value)> -inline auto to_string(const T &value) -> std::string { - return to_string(format_as(value)); -} - -FMT_END_EXPORT - -namespace detail { - -template -void vformat_to(buffer &buf, basic_string_view fmt, - typename vformat_args::type args, locale_ref loc) { - auto out = buffer_appender(buf); - if (fmt.size() == 2 && equal2(fmt.data(), "{}")) { - auto arg = args.get(0); - if (!arg) throw_format_error("argument not found"); - visit_format_arg(default_arg_formatter {out, args, loc}, arg); - return; - } - - struct format_handler : error_handler { - basic_format_parse_context parse_context; - buffer_context context; - - format_handler(buffer_appender p_out, basic_string_view str, - basic_format_args> p_args, - locale_ref p_loc) - : parse_context(str), context(p_out, p_args, p_loc) {} - - void on_text(const Char *begin, const Char *end) { - auto text - = basic_string_view(begin, to_unsigned(end - begin)); - context.advance_to(write(context.out(), text)); - } - - FMT_CONSTEXPR auto on_arg_id() -> int { - return parse_context.next_arg_id(); - } - FMT_CONSTEXPR auto on_arg_id(int id) -> int { - return parse_context.check_arg_id(id), id; - } - FMT_CONSTEXPR auto on_arg_id(basic_string_view id) -> int { - int arg_id = context.arg_id(id); - if (arg_id < 0) throw_format_error("argument not found"); - return arg_id; - } - - FMT_INLINE void on_replacement_field(int id, const Char *) { - auto arg = get_arg(context, id); - context.advance_to( - visit_format_arg(default_arg_formatter {context.out(), - context.args(), context.locale()}, - arg)); - } - - auto on_format_specs(int id, const Char *begin, const Char *end) - -> const Char * { - auto arg = get_arg(context, id); - // Not using a visitor for custom types gives better codegen. - if (arg.format_custom(begin, parse_context, context)) - return parse_context.begin(); - auto specs = detail::dynamic_format_specs(); - begin = parse_format_specs( - begin, end, specs, parse_context, arg.type()); - detail::handle_dynamic_spec( - specs.width, specs.width_ref, context); - detail::handle_dynamic_spec( - specs.precision, specs.precision_ref, context); - if (begin == end || *begin != '}') - throw_format_error("missing '}' in format string"); - auto f = arg_formatter { - context.out(), specs, context.locale()}; - context.advance_to(visit_format_arg(f, arg)); - return begin; - } - }; - detail::parse_format_string( - fmt, format_handler(out, fmt, args, loc)); -} - -FMT_BEGIN_EXPORT - -#ifndef FMT_HEADER_ONLY -extern template FMT_API void vformat_to( - buffer &, string_view, typename vformat_args<>::type, locale_ref); -extern template FMT_API auto thousands_sep_impl(locale_ref) - -> thousands_sep_result; -extern template FMT_API auto thousands_sep_impl(locale_ref) - -> thousands_sep_result; -extern template FMT_API auto decimal_point_impl(locale_ref) -> char; -extern template FMT_API auto decimal_point_impl(locale_ref) -> wchar_t; -#endif // FMT_HEADER_ONLY - -} // namespace detail - -#if FMT_USE_USER_DEFINED_LITERALS -inline namespace literals { -/** - \rst - User-defined literal equivalent of :func:`fmt::arg`. - - **Example**:: - - using namespace fmt::literals; - fmt::print("Elapsed time: {s:.2f} seconds", "s"_a=1.23); - \endrst - */ -#if FMT_USE_NONTYPE_TEMPLATE_ARGS -template -constexpr auto operator""_a() { - using char_t = remove_cvref_t; - return detail::udl_arg(); -} -#else -constexpr auto operator""_a(const char *s, size_t) -> detail::udl_arg { - return {s}; -} -#endif -} // namespace literals -#endif // FMT_USE_USER_DEFINED_LITERALS - -template ::value)> -inline auto vformat(const Locale &loc, string_view fmt, format_args args) - -> std::string { - return detail::vformat(loc, fmt, args); -} - -template ::value)> -inline auto format(const Locale &loc, format_string fmt, T &&...args) - -> std::string { - return fmt::vformat(loc, string_view(fmt), fmt::make_format_args(args...)); -} - -template ::value - &&detail::is_locale::value)> -auto vformat_to(OutputIt out, const Locale &loc, string_view fmt, - format_args args) -> OutputIt { - using detail::get_buffer; - auto &&buf = get_buffer(out); - detail::vformat_to(buf, fmt, args, detail::locale_ref(loc)); - return detail::get_iterator(buf, out); -} - -template ::value - &&detail::is_locale::value)> -FMT_INLINE auto format_to(OutputIt out, const Locale &loc, - format_string fmt, T &&...args) -> OutputIt { - return vformat_to(out, loc, fmt, fmt::make_format_args(args...)); -} - -template ::value)> -FMT_NODISCARD FMT_INLINE auto formatted_size( - const Locale &loc, format_string fmt, T &&...args) -> size_t { - auto buf = detail::counting_buffer<>(); - detail::vformat_to( - buf, fmt, fmt::make_format_args(args...), detail::locale_ref(loc)); - return buf.count(); -} - -FMT_END_EXPORT - -template -template -FMT_CONSTEXPR FMT_INLINE auto formatter::value - != detail::type::custom_type>>::format(const T &val, - FormatContext &ctx) const -> decltype(ctx.out()) { - if (specs_.width_ref.kind == detail::arg_id_kind::none - && specs_.precision_ref.kind == detail::arg_id_kind::none) { - return detail::write(ctx.out(), val, specs_, ctx.locale()); - } - auto specs = specs_; - detail::handle_dynamic_spec( - specs.width, specs.width_ref, ctx); - detail::handle_dynamic_spec( - specs.precision, specs.precision_ref, ctx); - return detail::write(ctx.out(), val, specs, ctx.locale()); -} - -FMT_END_NAMESPACE - -#ifdef FMT_HEADER_ONLY -#define FMT_FUNC inline -#include "format-inl.h" -#else -#define FMT_FUNC -#endif - -#endif // FMT_FORMAT_H_ diff --git a/src/common/spdlog/fmt/fmt.h b/src/common/spdlog/fmt/fmt.h deleted file mode 100755 index 426251ea4e1..00000000000 --- a/src/common/spdlog/fmt/fmt.h +++ /dev/null @@ -1,31 +0,0 @@ -// -// Copyright(c) 2016-2018 Gabi Melman. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) -// - -#pragma once - -// -// Include a bundled header-only copy of fmtlib or an external one. -// By default spdlog include its own copy. -// -#include - -#if defined( \ - SPDLOG_USE_STD_FORMAT) // SPDLOG_USE_STD_FORMAT is defined - use std::format -#include -#elif !defined(SPDLOG_FMT_EXTERNAL) -#if !defined(SPDLOG_COMPILED_LIB) && !defined(FMT_HEADER_ONLY) -#define FMT_HEADER_ONLY -#endif -#ifndef FMT_USE_WINDOWS_H -#define FMT_USE_WINDOWS_H 0 -#endif - -#include -#include - -#else // SPDLOG_FMT_EXTERNAL is defined - use external fmtlib -#include -#include -#endif diff --git a/src/common/spdlog/logger-inl.h b/src/common/spdlog/logger-inl.h deleted file mode 100755 index 08e52ad0e27..00000000000 --- a/src/common/spdlog/logger-inl.h +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include -#include - -#include - -namespace spdlog { - -// public methods -SPDLOG_INLINE logger::logger(const logger &other) - : name_(other.name_) - , sinks_(other.sinks_) - , level_(other.level_.load(std::memory_order_relaxed)) - , flush_level_(other.flush_level_.load(std::memory_order_relaxed)) - , custom_err_handler_(other.custom_err_handler_) - , tracer_(other.tracer_) {} - -SPDLOG_INLINE logger::logger(logger &&other) SPDLOG_NOEXCEPT - : name_(std::move(other.name_)), - sinks_(std::move(other.sinks_)), - level_(other.level_.load(std::memory_order_relaxed)), - flush_level_(other.flush_level_.load(std::memory_order_relaxed)), - custom_err_handler_(std::move(other.custom_err_handler_)), - tracer_(std::move(other.tracer_)) - -{} - -SPDLOG_INLINE logger &logger::operator=(logger other) SPDLOG_NOEXCEPT { - this->swap(other); - return *this; -} - -SPDLOG_INLINE void logger::swap(spdlog::logger &other) SPDLOG_NOEXCEPT { - name_.swap(other.name_); - sinks_.swap(other.sinks_); - - // swap level_ - auto other_level = other.level_.load(); - auto my_level = level_.exchange(other_level); - other.level_.store(my_level); - - // swap flush level_ - other_level = other.flush_level_.load(); - my_level = flush_level_.exchange(other_level); - other.flush_level_.store(my_level); - - custom_err_handler_.swap(other.custom_err_handler_); - std::swap(tracer_, other.tracer_); -} - -SPDLOG_INLINE void swap(logger &a, logger &b) { - a.swap(b); -} - -SPDLOG_INLINE void logger::set_level(level::level_enum log_level) { - level_.store(log_level); -} - -SPDLOG_INLINE level::level_enum logger::level() const { - return static_cast( - level_.load(std::memory_order_relaxed)); -} - -SPDLOG_INLINE const std::string &logger::name() const { - return name_; -} - -// set formatting for the sinks in this logger. -// each sink will get a separate instance of the formatter object. -SPDLOG_INLINE void logger::set_formatter(std::unique_ptr f) { - for (auto it = sinks_.begin(); it != sinks_.end(); ++it) { - if (std::next(it) == sinks_.end()) { - // last element - we can be move it. - (*it)->set_formatter(std::move(f)); - break; // to prevent clang-tidy warning - } else { - (*it)->set_formatter(f->clone()); - } - } -} - -SPDLOG_INLINE void logger::set_pattern( - std::string pattern, pattern_time_type time_type) { - auto new_formatter = details::make_unique( - std::move(pattern), time_type); - set_formatter(std::move(new_formatter)); -} - -// create new backtrace sink and move to it all our child sinks -SPDLOG_INLINE void logger::enable_backtrace(size_t n_messages) { - tracer_.enable(n_messages); -} - -// restore orig sinks and level and delete the backtrace sink -SPDLOG_INLINE void logger::disable_backtrace() { - tracer_.disable(); -} - -SPDLOG_INLINE void logger::dump_backtrace() { - dump_backtrace_(); -} - -// flush functions -SPDLOG_INLINE void logger::flush() { - flush_(); -} - -SPDLOG_INLINE void logger::flush_on(level::level_enum log_level) { - flush_level_.store(log_level); -} - -SPDLOG_INLINE level::level_enum logger::flush_level() const { - return static_cast( - flush_level_.load(std::memory_order_relaxed)); -} - -// sinks -SPDLOG_INLINE const std::vector &logger::sinks() const { - return sinks_; -} - -SPDLOG_INLINE std::vector &logger::sinks() { - return sinks_; -} - -// error handler -SPDLOG_INLINE void logger::set_error_handler(err_handler handler) { - custom_err_handler_ = std::move(handler); -} - -// create new logger with same sinks and configuration. -SPDLOG_INLINE std::shared_ptr logger::clone(std::string logger_name) { - auto cloned = std::make_shared(*this); - cloned->name_ = std::move(logger_name); - return cloned; -} - -// protected methods -SPDLOG_INLINE void logger::log_it_(const spdlog::details::log_msg &log_msg, - bool log_enabled, bool traceback_enabled) { - if (log_enabled) { sink_it_(log_msg); } - if (traceback_enabled) { tracer_.push_back(log_msg); } -} - -SPDLOG_INLINE void logger::sink_it_(const details::log_msg &msg) { - for (auto &sink : sinks_) { - if (sink->should_log(msg.level)) { - SPDLOG_TRY { sink->log(msg); } - SPDLOG_LOGGER_CATCH(msg.source) - } - } - - if (should_flush_(msg)) { flush_(); } -} - -SPDLOG_INLINE void logger::flush_() { - for (auto &sink : sinks_) { - SPDLOG_TRY { sink->flush(); } - SPDLOG_LOGGER_CATCH(source_loc()) - } -} - -SPDLOG_INLINE void logger::dump_backtrace_() { - using details::log_msg; - if (tracer_.enabled() && !tracer_.empty()) { - sink_it_(log_msg {name(), level::info, - "****************** Backtrace Start ******************"}); - tracer_.foreach_pop( - [this](const log_msg &msg) { this->sink_it_(msg); }); - sink_it_(log_msg {name(), level::info, - "****************** Backtrace End ********************"}); - } -} - -SPDLOG_INLINE bool logger::should_flush_(const details::log_msg &msg) { - auto flush_level = flush_level_.load(std::memory_order_relaxed); - return (msg.level >= flush_level) && (msg.level != level::off); -} - -SPDLOG_INLINE void logger::err_handler_(const std::string &msg) { - if (custom_err_handler_) { - custom_err_handler_(msg); - } else { - using std::chrono::system_clock; - static std::mutex mutex; - static std::chrono::system_clock::time_point last_report_time; - static size_t err_counter = 0; - std::lock_guard lk {mutex}; - auto now = system_clock::now(); - err_counter++; - if (now - last_report_time < std::chrono::seconds(1)) { return; } - last_report_time = now; - auto tm_time = details::os::localtime(system_clock::to_time_t(now)); - char date_buf[64]; - std::strftime( - date_buf, sizeof(date_buf), "%Y-%m-%d %H:%M:%S", &tm_time); -#if defined(USING_R) && defined(R_R_H) // if in R environment - REprintf("[*** LOG ERROR #%04zu ***] [%s] [%s] %s\n", err_counter, - date_buf, name().c_str(), msg.c_str()); -#else - std::fprintf(stderr, "[*** LOG ERROR #%04zu ***] [%s] [%s] %s\n", - err_counter, date_buf, name().c_str(), msg.c_str()); -#endif - } -} -} // namespace spdlog diff --git a/src/common/spdlog/logger.h b/src/common/spdlog/logger.h deleted file mode 100755 index 4de596385fe..00000000000 --- a/src/common/spdlog/logger.h +++ /dev/null @@ -1,386 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -// Thread safe logger (except for set_error_handler()) -// Has name, log level, vector of std::shared sink pointers and formatter -// Upon each log write the logger: -// 1. Checks if its log level is enough to log the message and if yes: -// 2. Call the underlying sinks to do the job. -// 3. Each sink use its own private copy of a formatter to format the message -// and send to its destination. -// -// The use of private formatter per sink provides the opportunity to cache some -// formatted data, and support for different format per sink. - -#include -#include -#include - -#ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT -#ifndef _WIN32 -#error SPDLOG_WCHAR_TO_UTF8_SUPPORT only supported on windows -#endif -#include -#endif - -#include - -#ifndef SPDLOG_NO_EXCEPTIONS -#define SPDLOG_LOGGER_CATCH(location) \ - catch (const std::exception &ex) { \ - if (location.filename) { \ - err_handler_(fmt_lib::format(SPDLOG_FMT_STRING("{} [{}({})]"), \ - ex.what(), location.filename, location.line)); \ - } else { \ - err_handler_(ex.what()); \ - } \ - } \ - catch (...) { \ - err_handler_("Rethrowing unknown exception in logger"); \ - throw; \ - } -#else -#define SPDLOG_LOGGER_CATCH(location) -#endif - -namespace spdlog { - -class SPDLOG_API logger { -public: - // Empty logger - explicit logger(std::string name) : name_(std::move(name)), sinks_() {} - - // Logger with range on sinks - template - logger(std::string name, It begin, It end) - : name_(std::move(name)), sinks_(begin, end) {} - - // Logger with single sink - logger(std::string name, sink_ptr single_sink) - : logger(std::move(name), {std::move(single_sink)}) {} - - // Logger with sinks init list - logger(std::string name, sinks_init_list sinks) - : logger(std::move(name), sinks.begin(), sinks.end()) {} - - virtual ~logger() = default; - - logger(const logger &other); - logger(logger &&other) SPDLOG_NOEXCEPT; - logger &operator=(logger other) SPDLOG_NOEXCEPT; - void swap(spdlog::logger &other) SPDLOG_NOEXCEPT; - - template - void log(source_loc loc, level::level_enum lvl, - format_string_t fmt, Args &&...args) { - log_(loc, lvl, details::to_string_view(fmt), - std::forward(args)...); - } - - template - void log(level::level_enum lvl, format_string_t fmt, - Args &&...args) { - log(source_loc {}, lvl, fmt, std::forward(args)...); - } - - template - void log(level::level_enum lvl, const T &msg) { - log(source_loc {}, lvl, msg); - } - - // T cannot be statically converted to format string (including string_view/wstring_view) - template ::value, - int>::type - = 0> - void log(source_loc loc, level::level_enum lvl, const T &msg) { - log(loc, lvl, "{}", msg); - } - - void log(log_clock::time_point log_time, source_loc loc, - level::level_enum lvl, string_view_t msg) { - bool log_enabled = should_log(lvl); - bool traceback_enabled = tracer_.enabled(); - if (!log_enabled && !traceback_enabled) { return; } - - details::log_msg log_msg(log_time, loc, name_, lvl, msg); - log_it_(log_msg, log_enabled, traceback_enabled); - } - - void log(source_loc loc, level::level_enum lvl, string_view_t msg) { - bool log_enabled = should_log(lvl); - bool traceback_enabled = tracer_.enabled(); - if (!log_enabled && !traceback_enabled) { return; } - - details::log_msg log_msg(loc, name_, lvl, msg); - log_it_(log_msg, log_enabled, traceback_enabled); - } - - void log(level::level_enum lvl, string_view_t msg) { - log(source_loc {}, lvl, msg); - } - - template - void trace(format_string_t fmt, Args &&...args) { - log(level::trace, fmt, std::forward(args)...); - } - - template - void debug(format_string_t fmt, Args &&...args) { - log(level::debug, fmt, std::forward(args)...); - } - - template - void info(format_string_t fmt, Args &&...args) { - log(level::info, fmt, std::forward(args)...); - } - - template - void warn(format_string_t fmt, Args &&...args) { - log(level::warn, fmt, std::forward(args)...); - } - - template - void error(format_string_t fmt, Args &&...args) { - log(level::err, fmt, std::forward(args)...); - } - - template - void critical(format_string_t fmt, Args &&...args) { - log(level::critical, fmt, std::forward(args)...); - } - -#ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT - template - void log(source_loc loc, level::level_enum lvl, - wformat_string_t fmt, Args &&...args) { - log_(loc, lvl, details::to_string_view(fmt), - std::forward(args)...); - } - - template - void log(level::level_enum lvl, wformat_string_t fmt, - Args &&...args) { - log(source_loc {}, lvl, fmt, std::forward(args)...); - } - - void log(log_clock::time_point log_time, source_loc loc, - level::level_enum lvl, wstring_view_t msg) { - bool log_enabled = should_log(lvl); - bool traceback_enabled = tracer_.enabled(); - if (!log_enabled && !traceback_enabled) { return; } - - memory_buf_t buf; - details::os::wstr_to_utf8buf( - wstring_view_t(msg.data(), msg.size()), buf); - details::log_msg log_msg(log_time, loc, name_, lvl, - string_view_t(buf.data(), buf.size())); - log_it_(log_msg, log_enabled, traceback_enabled); - } - - void log(source_loc loc, level::level_enum lvl, wstring_view_t msg) { - bool log_enabled = should_log(lvl); - bool traceback_enabled = tracer_.enabled(); - if (!log_enabled && !traceback_enabled) { return; } - - memory_buf_t buf; - details::os::wstr_to_utf8buf( - wstring_view_t(msg.data(), msg.size()), buf); - details::log_msg log_msg( - loc, name_, lvl, string_view_t(buf.data(), buf.size())); - log_it_(log_msg, log_enabled, traceback_enabled); - } - - void log(level::level_enum lvl, wstring_view_t msg) { - log(source_loc {}, lvl, msg); - } - - template - void trace(wformat_string_t fmt, Args &&...args) { - log(level::trace, fmt, std::forward(args)...); - } - - template - void debug(wformat_string_t fmt, Args &&...args) { - log(level::debug, fmt, std::forward(args)...); - } - - template - void info(wformat_string_t fmt, Args &&...args) { - log(level::info, fmt, std::forward(args)...); - } - - template - void warn(wformat_string_t fmt, Args &&...args) { - log(level::warn, fmt, std::forward(args)...); - } - - template - void error(wformat_string_t fmt, Args &&...args) { - log(level::err, fmt, std::forward(args)...); - } - - template - void critical(wformat_string_t fmt, Args &&...args) { - log(level::critical, fmt, std::forward(args)...); - } -#endif - - template - void trace(const T &msg) { - log(level::trace, msg); - } - - template - void debug(const T &msg) { - log(level::debug, msg); - } - - template - void info(const T &msg) { - log(level::info, msg); - } - - template - void warn(const T &msg) { - log(level::warn, msg); - } - - template - void error(const T &msg) { - log(level::err, msg); - } - - template - void critical(const T &msg) { - log(level::critical, msg); - } - - // return true logging is enabled for the given level. - bool should_log(level::level_enum msg_level) const { - return msg_level >= level_.load(std::memory_order_relaxed); - } - - // return true if backtrace logging is enabled. - bool should_backtrace() const { return tracer_.enabled(); } - - void set_level(level::level_enum log_level); - - level::level_enum level() const; - - const std::string &name() const; - - // set formatting for the sinks in this logger. - // each sink will get a separate instance of the formatter object. - void set_formatter(std::unique_ptr f); - - // set formatting for the sinks in this logger. - // equivalent to - // set_formatter(make_unique(pattern, time_type)) - // Note: each sink will get a new instance of a formatter object, replacing the old one. - void set_pattern(std::string pattern, - pattern_time_type time_type = pattern_time_type::local); - - // backtrace support. - // efficiently store all debug/trace messages in a circular buffer until needed for debugging. - void enable_backtrace(size_t n_messages); - void disable_backtrace(); - void dump_backtrace(); - - // flush functions - void flush(); - void flush_on(level::level_enum log_level); - level::level_enum flush_level() const; - - // sinks - const std::vector &sinks() const; - - std::vector &sinks(); - - // error handler - void set_error_handler(err_handler); - - // create new logger with same sinks and configuration. - virtual std::shared_ptr clone(std::string logger_name); - -protected: - std::string name_; - std::vector sinks_; - spdlog::level_t level_ {level::info}; - spdlog::level_t flush_level_ {level::off}; - err_handler custom_err_handler_ {nullptr}; - details::backtracer tracer_; - - // common implementation for after templated public api has been resolved - template - void log_(source_loc loc, level::level_enum lvl, string_view_t fmt, - Args &&...args) { - bool log_enabled = should_log(lvl); - bool traceback_enabled = tracer_.enabled(); - if (!log_enabled && !traceback_enabled) { return; } - SPDLOG_TRY { - memory_buf_t buf; -#ifdef SPDLOG_USE_STD_FORMAT - fmt_lib::vformat_to(std::back_inserter(buf), fmt, - fmt_lib::make_format_args(args...)); -#else - fmt::vformat_to( - fmt::appender(buf), fmt, fmt::make_format_args(args...)); -#endif - - details::log_msg log_msg( - loc, name_, lvl, string_view_t(buf.data(), buf.size())); - log_it_(log_msg, log_enabled, traceback_enabled); - } - SPDLOG_LOGGER_CATCH(loc) - } - -#ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT - template - void log_(source_loc loc, level::level_enum lvl, wstring_view_t fmt, - Args &&...args) { - bool log_enabled = should_log(lvl); - bool traceback_enabled = tracer_.enabled(); - if (!log_enabled && !traceback_enabled) { return; } - SPDLOG_TRY { - // format to wmemory_buffer and convert to utf8 - wmemory_buf_t wbuf; - fmt_lib::vformat_to(std::back_inserter(wbuf), fmt, - fmt_lib::make_format_args( - args...)); - - memory_buf_t buf; - details::os::wstr_to_utf8buf( - wstring_view_t(wbuf.data(), wbuf.size()), buf); - details::log_msg log_msg( - loc, name_, lvl, string_view_t(buf.data(), buf.size())); - log_it_(log_msg, log_enabled, traceback_enabled); - } - SPDLOG_LOGGER_CATCH(loc) - } -#endif // SPDLOG_WCHAR_TO_UTF8_SUPPORT - - // log the given message (if the given log level is high enough), - // and save backtrace (if backtrace is enabled). - void log_it_(const details::log_msg &log_msg, bool log_enabled, - bool traceback_enabled); - virtual void sink_it_(const details::log_msg &msg); - virtual void flush_(); - void dump_backtrace_(); - bool should_flush_(const details::log_msg &msg); - - // handle errors during logging. - // default handler prints the error to stderr at max rate of 1 message/sec. - void err_handler_(const std::string &msg); -}; - -void swap(logger &a, logger &b); - -} // namespace spdlog - -#ifdef SPDLOG_HEADER_ONLY -#include "logger-inl.h" -#endif diff --git a/src/common/spdlog/pattern_formatter-inl.h b/src/common/spdlog/pattern_formatter-inl.h deleted file mode 100755 index 5f8b3d4a02e..00000000000 --- a/src/common/spdlog/pattern_formatter-inl.h +++ /dev/null @@ -1,1424 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace spdlog { -namespace details { - -/////////////////////////////////////////////////////////////////////// -// name & level pattern appender -/////////////////////////////////////////////////////////////////////// - -class scoped_padder { -public: - scoped_padder(size_t wrapped_size, const padding_info &padinfo, - memory_buf_t &dest) - : padinfo_(padinfo), dest_(dest) { - remaining_pad_ = static_cast(padinfo.width_) - - static_cast(wrapped_size); - if (remaining_pad_ <= 0) { return; } - - if (padinfo_.side_ == padding_info::pad_side::left) { - pad_it(remaining_pad_); - remaining_pad_ = 0; - } else if (padinfo_.side_ == padding_info::pad_side::center) { - auto half_pad = remaining_pad_ / 2; - auto reminder = remaining_pad_ & 1; - pad_it(half_pad); - remaining_pad_ = half_pad + reminder; // for the right side - } - } - - template - static unsigned int count_digits(T n) { - return fmt_helper::count_digits(n); - } - - ~scoped_padder() { - if (remaining_pad_ >= 0) { - pad_it(remaining_pad_); - } else if (padinfo_.truncate_) { - long new_size = static_cast(dest_.size()) + remaining_pad_; - dest_.resize(static_cast(new_size)); - } - } - -private: - void pad_it(long count) { - fmt_helper::append_string_view( - string_view_t(spaces_.data(), static_cast(count)), - dest_); - } - - const padding_info &padinfo_; - memory_buf_t &dest_; - long remaining_pad_; - string_view_t spaces_ { - " ", - 64}; -}; - -struct null_scoped_padder { - null_scoped_padder(size_t /*wrapped_size*/, - const padding_info & /*padinfo*/, memory_buf_t & /*dest*/) {} - - template - static unsigned int count_digits(T /* number */) { - return 0; - } -}; - -template -class name_formatter final : public flag_formatter { -public: - explicit name_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - ScopedPadder p(msg.logger_name.size(), padinfo_, dest); - fmt_helper::append_string_view(msg.logger_name, dest); - } -}; - -// log level appender -template -class level_formatter final : public flag_formatter { -public: - explicit level_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - const string_view_t &level_name = level::to_string_view(msg.level); - ScopedPadder p(level_name.size(), padinfo_, dest); - fmt_helper::append_string_view(level_name, dest); - } -}; - -// short log level appender -template -class short_level_formatter final : public flag_formatter { -public: - explicit short_level_formatter(padding_info padinfo) - : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - string_view_t level_name {level::to_short_c_str(msg.level)}; - ScopedPadder p(level_name.size(), padinfo_, dest); - fmt_helper::append_string_view(level_name, dest); - } -}; - -/////////////////////////////////////////////////////////////////////// -// Date time pattern appenders -/////////////////////////////////////////////////////////////////////// - -static const char *ampm(const tm &t) { - return t.tm_hour >= 12 ? "PM" : "AM"; -} - -static int to12h(const tm &t) { - return t.tm_hour > 12 ? t.tm_hour - 12 : t.tm_hour; -} - -// Abbreviated weekday name -static std::array days { - {"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"}}; - -template -class a_formatter final : public flag_formatter { -public: - explicit a_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - string_view_t field_value {days[static_cast(tm_time.tm_wday)]}; - ScopedPadder p(field_value.size(), padinfo_, dest); - fmt_helper::append_string_view(field_value, dest); - } -}; - -// Full weekday name -static std::array full_days {{"Sunday", "Monday", "Tuesday", - "Wednesday", "Thursday", "Friday", "Saturday"}}; - -template -class A_formatter : public flag_formatter { -public: - explicit A_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - string_view_t field_value { - full_days[static_cast(tm_time.tm_wday)]}; - ScopedPadder p(field_value.size(), padinfo_, dest); - fmt_helper::append_string_view(field_value, dest); - } -}; - -// Abbreviated month -static const std::array months {{"Jan", "Feb", "Mar", "Apr", - "May", "Jun", "Jul", "Aug", "Sept", "Oct", "Nov", "Dec"}}; - -template -class b_formatter final : public flag_formatter { -public: - explicit b_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - string_view_t field_value {months[static_cast(tm_time.tm_mon)]}; - ScopedPadder p(field_value.size(), padinfo_, dest); - fmt_helper::append_string_view(field_value, dest); - } -}; - -// Full month name -static const std::array full_months { - {"January", "February", "March", "April", "May", "June", "July", - "August", "September", "October", "November", "December"}}; - -template -class B_formatter final : public flag_formatter { -public: - explicit B_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - string_view_t field_value { - full_months[static_cast(tm_time.tm_mon)]}; - ScopedPadder p(field_value.size(), padinfo_, dest); - fmt_helper::append_string_view(field_value, dest); - } -}; - -// Date and time representation (Thu Aug 23 15:35:46 2014) -template -class c_formatter final : public flag_formatter { -public: - explicit c_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 24; - ScopedPadder p(field_size, padinfo_, dest); - - fmt_helper::append_string_view( - days[static_cast(tm_time.tm_wday)], dest); - dest.push_back(' '); - fmt_helper::append_string_view( - months[static_cast(tm_time.tm_mon)], dest); - dest.push_back(' '); - fmt_helper::append_int(tm_time.tm_mday, dest); - dest.push_back(' '); - // time - - fmt_helper::pad2(tm_time.tm_hour, dest); - dest.push_back(':'); - fmt_helper::pad2(tm_time.tm_min, dest); - dest.push_back(':'); - fmt_helper::pad2(tm_time.tm_sec, dest); - dest.push_back(' '); - fmt_helper::append_int(tm_time.tm_year + 1900, dest); - } -}; - -// year - 2 digit -template -class C_formatter final : public flag_formatter { -public: - explicit C_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 2; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad2(tm_time.tm_year % 100, dest); - } -}; - -// Short MM/DD/YY date, equivalent to %m/%d/%y 08/23/01 -template -class D_formatter final : public flag_formatter { -public: - explicit D_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 10; - ScopedPadder p(field_size, padinfo_, dest); - - fmt_helper::pad2(tm_time.tm_mon + 1, dest); - dest.push_back('/'); - fmt_helper::pad2(tm_time.tm_mday, dest); - dest.push_back('/'); - fmt_helper::pad2(tm_time.tm_year % 100, dest); - } -}; - -// year - 4 digit -template -class Y_formatter final : public flag_formatter { -public: - explicit Y_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 4; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::append_int(tm_time.tm_year + 1900, dest); - } -}; - -// month 1-12 -template -class m_formatter final : public flag_formatter { -public: - explicit m_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 2; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad2(tm_time.tm_mon + 1, dest); - } -}; - -// day of month 1-31 -template -class d_formatter final : public flag_formatter { -public: - explicit d_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 2; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad2(tm_time.tm_mday, dest); - } -}; - -// hours in 24 format 0-23 -template -class H_formatter final : public flag_formatter { -public: - explicit H_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 2; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad2(tm_time.tm_hour, dest); - } -}; - -// hours in 12 format 1-12 -template -class I_formatter final : public flag_formatter { -public: - explicit I_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 2; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad2(to12h(tm_time), dest); - } -}; - -// minutes 0-59 -template -class M_formatter final : public flag_formatter { -public: - explicit M_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 2; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad2(tm_time.tm_min, dest); - } -}; - -// seconds 0-59 -template -class S_formatter final : public flag_formatter { -public: - explicit S_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 2; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad2(tm_time.tm_sec, dest); - } -}; - -// milliseconds -template -class e_formatter final : public flag_formatter { -public: - explicit e_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - auto millis = fmt_helper::time_fraction( - msg.time); - const size_t field_size = 3; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad3(static_cast(millis.count()), dest); - } -}; - -// microseconds -template -class f_formatter final : public flag_formatter { -public: - explicit f_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - auto micros = fmt_helper::time_fraction( - msg.time); - - const size_t field_size = 6; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad6(static_cast(micros.count()), dest); - } -}; - -// nanoseconds -template -class F_formatter final : public flag_formatter { -public: - explicit F_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - auto ns = fmt_helper::time_fraction(msg.time); - const size_t field_size = 9; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::pad9(static_cast(ns.count()), dest); - } -}; - -// seconds since epoch -template -class E_formatter final : public flag_formatter { -public: - explicit E_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - const size_t field_size = 10; - ScopedPadder p(field_size, padinfo_, dest); - auto duration = msg.time.time_since_epoch(); - auto seconds - = std::chrono::duration_cast(duration) - .count(); - fmt_helper::append_int(seconds, dest); - } -}; - -// AM/PM -template -class p_formatter final : public flag_formatter { -public: - explicit p_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 2; - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::append_string_view(ampm(tm_time), dest); - } -}; - -// 12 hour clock 02:55:02 pm -template -class r_formatter final : public flag_formatter { -public: - explicit r_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 11; - ScopedPadder p(field_size, padinfo_, dest); - - fmt_helper::pad2(to12h(tm_time), dest); - dest.push_back(':'); - fmt_helper::pad2(tm_time.tm_min, dest); - dest.push_back(':'); - fmt_helper::pad2(tm_time.tm_sec, dest); - dest.push_back(' '); - fmt_helper::append_string_view(ampm(tm_time), dest); - } -}; - -// 24-hour HH:MM time, equivalent to %H:%M -template -class R_formatter final : public flag_formatter { -public: - explicit R_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 5; - ScopedPadder p(field_size, padinfo_, dest); - - fmt_helper::pad2(tm_time.tm_hour, dest); - dest.push_back(':'); - fmt_helper::pad2(tm_time.tm_min, dest); - } -}; - -// ISO 8601 time format (HH:MM:SS), equivalent to %H:%M:%S -template -class T_formatter final : public flag_formatter { -public: - explicit T_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 8; - ScopedPadder p(field_size, padinfo_, dest); - - fmt_helper::pad2(tm_time.tm_hour, dest); - dest.push_back(':'); - fmt_helper::pad2(tm_time.tm_min, dest); - dest.push_back(':'); - fmt_helper::pad2(tm_time.tm_sec, dest); - } -}; - -// ISO 8601 offset from UTC in timezone (+-HH:MM) -template -class z_formatter final : public flag_formatter { -public: - explicit z_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - z_formatter() = default; - z_formatter(const z_formatter &) = delete; - z_formatter &operator=(const z_formatter &) = delete; - - void format(const details::log_msg &msg, const std::tm &tm_time, - memory_buf_t &dest) override { - const size_t field_size = 6; - ScopedPadder p(field_size, padinfo_, dest); - - auto total_minutes = get_cached_offset(msg, tm_time); - bool is_negative = total_minutes < 0; - if (is_negative) { - total_minutes = -total_minutes; - dest.push_back('-'); - } else { - dest.push_back('+'); - } - - fmt_helper::pad2(total_minutes / 60, dest); // hours - dest.push_back(':'); - fmt_helper::pad2(total_minutes % 60, dest); // minutes - } - -private: - log_clock::time_point last_update_ {std::chrono::seconds(0)}; - int offset_minutes_ {0}; - - int get_cached_offset(const log_msg &msg, const std::tm &tm_time) { - // refresh every 10 seconds - if (msg.time - last_update_ >= std::chrono::seconds(10)) { - offset_minutes_ = os::utc_minutes_offset(tm_time); - last_update_ = msg.time; - } - return offset_minutes_; - } -}; - -// Thread id -template -class t_formatter final : public flag_formatter { -public: - explicit t_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - const auto field_size = ScopedPadder::count_digits(msg.thread_id); - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::append_int(msg.thread_id, dest); - } -}; - -// Current pid -template -class pid_formatter final : public flag_formatter { -public: - explicit pid_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &, - memory_buf_t &dest) override { - const auto pid = static_cast(details::os::pid()); - auto field_size = ScopedPadder::count_digits(pid); - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::append_int(pid, dest); - } -}; - -template -class v_formatter final : public flag_formatter { -public: - explicit v_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - ScopedPadder p(msg.payload.size(), padinfo_, dest); - fmt_helper::append_string_view(msg.payload, dest); - } -}; - -class ch_formatter final : public flag_formatter { -public: - explicit ch_formatter(char ch) : ch_(ch) {} - - void format(const details::log_msg &, const std::tm &, - memory_buf_t &dest) override { - dest.push_back(ch_); - } - -private: - char ch_; -}; - -// aggregate user chars to display as is -class aggregate_formatter final : public flag_formatter { -public: - aggregate_formatter() = default; - - void add_ch(char ch) { str_ += ch; } - void format(const details::log_msg &, const std::tm &, - memory_buf_t &dest) override { - fmt_helper::append_string_view(str_, dest); - } - -private: - std::string str_; -}; - -// mark the color range. expect it to be in the form of "%^colored text%$" -class color_start_formatter final : public flag_formatter { -public: - explicit color_start_formatter(padding_info padinfo) - : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - msg.color_range_start = dest.size(); - } -}; - -class color_stop_formatter final : public flag_formatter { -public: - explicit color_stop_formatter(padding_info padinfo) - : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - msg.color_range_end = dest.size(); - } -}; - -// print source location -template -class source_location_formatter final : public flag_formatter { -public: - explicit source_location_formatter(padding_info padinfo) - : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - if (msg.source.empty()) { - ScopedPadder p(0, padinfo_, dest); - return; - } - - size_t text_size; - if (padinfo_.enabled()) { - // calc text size for padding based on "filename:line" - text_size = std::char_traits::length(msg.source.filename) - + ScopedPadder::count_digits(msg.source.line) + 1; - } else { - text_size = 0; - } - - ScopedPadder p(text_size, padinfo_, dest); - fmt_helper::append_string_view(msg.source.filename, dest); - dest.push_back(':'); - fmt_helper::append_int(msg.source.line, dest); - } -}; - -// print source filename -template -class source_filename_formatter final : public flag_formatter { -public: - explicit source_filename_formatter(padding_info padinfo) - : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - if (msg.source.empty()) { - ScopedPadder p(0, padinfo_, dest); - return; - } - size_t text_size = padinfo_.enabled() - ? std::char_traits::length(msg.source.filename) - : 0; - ScopedPadder p(text_size, padinfo_, dest); - fmt_helper::append_string_view(msg.source.filename, dest); - } -}; - -template -class short_filename_formatter final : public flag_formatter { -public: - explicit short_filename_formatter(padding_info padinfo) - : flag_formatter(padinfo) {} - -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4127) // consider using 'if constexpr' instead -#endif // _MSC_VER - static const char *basename(const char *filename) { - // if the size is 2 (1 character + null terminator) we can use the more efficient strrchr - // the branch will be elided by optimizations - if (sizeof(os::folder_seps) == 2) { - const char *rv = std::strrchr(filename, os::folder_seps[0]); - return rv != nullptr ? rv + 1 : filename; - } else { - const std::reverse_iterator begin( - filename + std::strlen(filename)); - const std::reverse_iterator end(filename); - - const auto it = std::find_first_of(begin, end, - std::begin(os::folder_seps), std::end(os::folder_seps) - 1); - return it != end ? it.base() : filename; - } - } -#ifdef _MSC_VER -#pragma warning(pop) -#endif // _MSC_VER - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - if (msg.source.empty()) { - ScopedPadder p(0, padinfo_, dest); - return; - } - auto filename = basename(msg.source.filename); - size_t text_size = padinfo_.enabled() - ? std::char_traits::length(filename) - : 0; - ScopedPadder p(text_size, padinfo_, dest); - fmt_helper::append_string_view(filename, dest); - } -}; - -template -class source_linenum_formatter final : public flag_formatter { -public: - explicit source_linenum_formatter(padding_info padinfo) - : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - if (msg.source.empty()) { - ScopedPadder p(0, padinfo_, dest); - return; - } - - auto field_size = ScopedPadder::count_digits(msg.source.line); - ScopedPadder p(field_size, padinfo_, dest); - fmt_helper::append_int(msg.source.line, dest); - } -}; - -// print source funcname -template -class source_funcname_formatter final : public flag_formatter { -public: - explicit source_funcname_formatter(padding_info padinfo) - : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - if (msg.source.empty()) { - ScopedPadder p(0, padinfo_, dest); - return; - } - size_t text_size = padinfo_.enabled() - ? std::char_traits::length(msg.source.funcname) - : 0; - ScopedPadder p(text_size, padinfo_, dest); - fmt_helper::append_string_view(msg.source.funcname, dest); - } -}; - -// print elapsed time since last message -template -class elapsed_formatter final : public flag_formatter { -public: - using DurationUnits = Units; - - explicit elapsed_formatter(padding_info padinfo) - : flag_formatter(padinfo), last_message_time_(log_clock::now()) {} - - void format(const details::log_msg &msg, const std::tm &, - memory_buf_t &dest) override { - auto delta = (std::max)( - msg.time - last_message_time_, log_clock::duration::zero()); - auto delta_units = std::chrono::duration_cast(delta); - last_message_time_ = msg.time; - auto delta_count = static_cast(delta_units.count()); - auto n_digits - = static_cast(ScopedPadder::count_digits(delta_count)); - ScopedPadder p(n_digits, padinfo_, dest); - fmt_helper::append_int(delta_count, dest); - } - -private: - log_clock::time_point last_message_time_; -}; - -// Class for formatting Mapped Diagnostic Context (MDC) in log messages. -// Example: [logger-name] [info] [mdc_key_1:mdc_value_1 mdc_key_2:mdc_value_2] some message -template -class mdc_formatter : public flag_formatter { -public: - explicit mdc_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &, const std::tm &, - memory_buf_t &dest) override { - auto &mdc_map = mdc::get_context(); - if (mdc_map.empty()) { - ScopedPadder p(0, padinfo_, dest); - return; - } else { - format_mdc(mdc_map, dest); - } - } - - void format_mdc(const mdc::mdc_map_t &mdc_map, memory_buf_t &dest) { - auto last_element = --mdc_map.end(); - for (auto it = mdc_map.begin(); it != mdc_map.end(); ++it) { - auto &pair = *it; - const auto &key = pair.first; - const auto &value = pair.second; - size_t content_size = key.size() + value.size() + 1; // 1 for ':' - - if (it != last_element) { - content_size++; // 1 for ' ' - } - - ScopedPadder p(content_size, padinfo_, dest); - fmt_helper::append_string_view(key, dest); - fmt_helper::append_string_view(":", dest); - fmt_helper::append_string_view(value, dest); - if (it != last_element) { - fmt_helper::append_string_view(" ", dest); - } - } - } -}; - -// Full info formatter -// pattern: [%Y-%m-%d %H:%M:%S.%e] [%n] [%l] [%s:%#] %v -class full_formatter final : public flag_formatter { -public: - explicit full_formatter(padding_info padinfo) : flag_formatter(padinfo) {} - - void format(const details::log_msg &msg, const std::tm &tm_time, - memory_buf_t &dest) override { - using std::chrono::duration_cast; - using std::chrono::milliseconds; - using std::chrono::seconds; - - // cache the date/time part for the next second. - auto duration = msg.time.time_since_epoch(); - auto secs = duration_cast(duration); - - if (cache_timestamp_ != secs || cached_datetime_.size() == 0) { - cached_datetime_.clear(); - cached_datetime_.push_back('['); - fmt_helper::append_int(tm_time.tm_year + 1900, cached_datetime_); - cached_datetime_.push_back('-'); - - fmt_helper::pad2(tm_time.tm_mon + 1, cached_datetime_); - cached_datetime_.push_back('-'); - - fmt_helper::pad2(tm_time.tm_mday, cached_datetime_); - cached_datetime_.push_back(' '); - - fmt_helper::pad2(tm_time.tm_hour, cached_datetime_); - cached_datetime_.push_back(':'); - - fmt_helper::pad2(tm_time.tm_min, cached_datetime_); - cached_datetime_.push_back(':'); - - fmt_helper::pad2(tm_time.tm_sec, cached_datetime_); - cached_datetime_.push_back('.'); - - cache_timestamp_ = secs; - } - dest.append(cached_datetime_.begin(), cached_datetime_.end()); - - auto millis = fmt_helper::time_fraction(msg.time); - fmt_helper::pad3(static_cast(millis.count()), dest); - dest.push_back(']'); - dest.push_back(' '); - - // append logger name if exists - if (msg.logger_name.size() > 0) { - dest.push_back('['); - fmt_helper::append_string_view(msg.logger_name, dest); - dest.push_back(']'); - dest.push_back(' '); - } - - dest.push_back('['); - // wrap the level name with color - msg.color_range_start = dest.size(); - // fmt_helper::append_string_view(level::to_c_str(msg.level), dest); - fmt_helper::append_string_view(level::to_string_view(msg.level), dest); - msg.color_range_end = dest.size(); - dest.push_back(']'); - dest.push_back(' '); - - // add source location if present - if (!msg.source.empty()) { - dest.push_back('['); - const char *filename = details::short_filename_formatter< - details::null_scoped_padder>::basename(msg.source.filename); - fmt_helper::append_string_view(filename, dest); - dest.push_back(':'); - fmt_helper::append_int(msg.source.line, dest); - dest.push_back(']'); - dest.push_back(' '); - } - - // add mdc if present - auto &mdc_map = mdc::get_context(); - if (!mdc_map.empty()) { - dest.push_back('['); - mdc_formatter_.format_mdc(mdc_map, dest); - dest.push_back(']'); - dest.push_back(' '); - } - // fmt_helper::append_string_view(msg.msg(), dest); - fmt_helper::append_string_view(msg.payload, dest); - } - -private: - std::chrono::seconds cache_timestamp_ {0}; - memory_buf_t cached_datetime_; - mdc_formatter mdc_formatter_ {padding_info {}}; -}; - -} // namespace details - -SPDLOG_INLINE pattern_formatter::pattern_formatter(std::string pattern, - pattern_time_type time_type, std::string eol, - custom_flags custom_user_flags) - : pattern_(std::move(pattern)) - , eol_(std::move(eol)) - , pattern_time_type_(time_type) - , need_localtime_(false) - , last_log_secs_(0) - , custom_handlers_(std::move(custom_user_flags)) { - std::memset(&cached_tm_, 0, sizeof(cached_tm_)); - compile_pattern_(pattern_); -} - -// use by default full formatter for if pattern is not given -SPDLOG_INLINE pattern_formatter::pattern_formatter( - pattern_time_type time_type, std::string eol) - : pattern_("%+") - , eol_(std::move(eol)) - , pattern_time_type_(time_type) - , need_localtime_(true) - , last_log_secs_(0) { - std::memset(&cached_tm_, 0, sizeof(cached_tm_)); - formatters_.push_back(details::make_unique( - details::padding_info {})); -} - -SPDLOG_INLINE std::unique_ptr pattern_formatter::clone() const { - custom_flags cloned_custom_formatters; - for (auto &it : custom_handlers_) { - cloned_custom_formatters[it.first] = it.second->clone(); - } - auto cloned = details::make_unique(pattern_, - pattern_time_type_, eol_, std::move(cloned_custom_formatters)); - cloned->need_localtime(need_localtime_); -#if defined(__GNUC__) && __GNUC__ < 5 - return std::move(cloned); -#else - return cloned; -#endif -} - -SPDLOG_INLINE void pattern_formatter::format( - const details::log_msg &msg, memory_buf_t &dest) { - if (need_localtime_) { - const auto secs = std::chrono::duration_cast( - msg.time.time_since_epoch()); - if (secs != last_log_secs_) { - cached_tm_ = get_time_(msg); - last_log_secs_ = secs; - } - } - - for (auto &f : formatters_) { - f->format(msg, cached_tm_, dest); - } - // write eol - details::fmt_helper::append_string_view(eol_, dest); -} - -SPDLOG_INLINE void pattern_formatter::set_pattern(std::string pattern) { - pattern_ = std::move(pattern); - need_localtime_ = false; - compile_pattern_(pattern_); -} - -SPDLOG_INLINE void pattern_formatter::need_localtime(bool need) { - need_localtime_ = need; -} - -SPDLOG_INLINE std::tm pattern_formatter::get_time_( - const details::log_msg &msg) { - if (pattern_time_type_ == pattern_time_type::local) { - return details::os::localtime(log_clock::to_time_t(msg.time)); - } - return details::os::gmtime(log_clock::to_time_t(msg.time)); -} - -template -SPDLOG_INLINE void pattern_formatter::handle_flag_( - char flag, details::padding_info padding) { - // process custom flags - auto it = custom_handlers_.find(flag); - if (it != custom_handlers_.end()) { - auto custom_handler = it->second->clone(); - custom_handler->set_padding_info(padding); - formatters_.push_back(std::move(custom_handler)); - return; - } - - // process built-in flags - switch (flag) { - case ('+'): // default formatter - formatters_.push_back( - details::make_unique(padding)); - need_localtime_ = true; - break; - - case 'n': // logger name - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case 'l': // level - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case 'L': // short level - formatters_.push_back(details::make_unique< - details::short_level_formatter>(padding)); - break; - - case ('t'): // thread id - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case ('v'): // the message text - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case ('a'): // weekday - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('A'): // short weekday - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('b'): - case ('h'): // month - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('B'): // short month - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('c'): // datetime - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('C'): // year 2 digits - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('Y'): // year 4 digits - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('D'): - case ('x'): // datetime MM/DD/YY - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('m'): // month 1-12 - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('d'): // day of month 1-31 - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('H'): // hours 24 - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('I'): // hours 12 - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('M'): // minutes - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('S'): // seconds - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('e'): // milliseconds - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case ('f'): // microseconds - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case ('F'): // nanoseconds - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case ('E'): // seconds since epoch - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case ('p'): // am/pm - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('r'): // 12 hour clock 02:55:02 pm - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('R'): // 24-hour HH:MM time - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('T'): - case ('X'): // ISO 8601 time format (HH:MM:SS) - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('z'): // timezone - formatters_.push_back( - details::make_unique>( - padding)); - need_localtime_ = true; - break; - - case ('P'): // pid - formatters_.push_back( - details::make_unique>( - padding)); - break; - - case ('^'): // color range start - formatters_.push_back( - details::make_unique( - padding)); - break; - - case ('$'): // color range end - formatters_.push_back( - details::make_unique( - padding)); - break; - - case ('@'): // source location (filename:filenumber) - formatters_.push_back(details::make_unique< - details::source_location_formatter>(padding)); - break; - - case ('s'): // short source filename - without directory name - formatters_.push_back(details::make_unique< - details::short_filename_formatter>(padding)); - break; - - case ('g'): // full source filename - formatters_.push_back(details::make_unique< - details::source_filename_formatter>(padding)); - break; - - case ('#'): // source line number - formatters_.push_back(details::make_unique< - details::source_linenum_formatter>(padding)); - break; - - case ('!'): // source funcname - formatters_.push_back(details::make_unique< - details::source_funcname_formatter>(padding)); - break; - - case ('%'): // % char - formatters_.push_back( - details::make_unique('%')); - break; - - case ('u'): // elapsed time since last log message in nanos - formatters_.push_back( - details::make_unique>(padding)); - break; - - case ('i'): // elapsed time since last log message in micros - formatters_.push_back( - details::make_unique>(padding)); - break; - - case ('o'): // elapsed time since last log message in millis - formatters_.push_back( - details::make_unique>(padding)); - break; - - case ('O'): // elapsed time since last log message in seconds - formatters_.push_back(details::make_unique< - details::elapsed_formatter>( - padding)); - break; - - case ('&'): - formatters_.push_back( - details::make_unique>( - padding)); - break; - - default: // Unknown flag appears as is - auto unknown_flag - = details::make_unique(); - - if (!padding.truncate_) { - unknown_flag->add_ch('%'); - unknown_flag->add_ch(flag); - formatters_.push_back((std::move(unknown_flag))); - } - // fix issue #1617 (prev char was '!' and should have been treated as funcname flag - // instead of truncating flag) spdlog::set_pattern("[%10!] %v") => "[ main] some - // message" spdlog::set_pattern("[%3!!] %v") => "[mai] some message" - else { - padding.truncate_ = false; - formatters_.push_back(details::make_unique< - details::source_funcname_formatter>(padding)); - unknown_flag->add_ch(flag); - formatters_.push_back((std::move(unknown_flag))); - } - - break; - } -} - -// Extract given pad spec (e.g. %8X, %=8X, %-8!X, %8!X, %=8!X, %-8!X, %+8!X) -// Advance the given it pass the end of the padding spec found (if any) -// Return padding. -SPDLOG_INLINE details::padding_info pattern_formatter::handle_padspec_( - std::string::const_iterator &it, std::string::const_iterator end) { - using details::padding_info; - using details::scoped_padder; - const size_t max_width = 64; - if (it == end) { return padding_info {}; } - - padding_info::pad_side side; - switch (*it) { - case '-': - side = padding_info::pad_side::right; - ++it; - break; - case '=': - side = padding_info::pad_side::center; - ++it; - break; - default: side = details::padding_info::pad_side::left; break; - } - - if (it == end || !std::isdigit(static_cast(*it))) { - return padding_info {}; // no padding if no digit found here - } - - auto width = static_cast(*it) - '0'; - for (++it; it != end && std::isdigit(static_cast(*it)); - ++it) { - auto digit = static_cast(*it) - '0'; - width = width * 10 + digit; - } - - // search for the optional truncate marker '!' - bool truncate; - if (it != end && *it == '!') { - truncate = true; - ++it; - } else { - truncate = false; - } - return details::padding_info { - std::min(width, max_width), side, truncate}; -} - -SPDLOG_INLINE void pattern_formatter::compile_pattern_( - const std::string &pattern) { - auto end = pattern.end(); - std::unique_ptr user_chars; - formatters_.clear(); - for (auto it = pattern.begin(); it != end; ++it) { - if (*it == '%') { - if (user_chars) // append user chars found so far - { - formatters_.push_back(std::move(user_chars)); - } - - auto padding = handle_padspec_(++it, end); - - if (it != end) { - if (padding.enabled()) { - handle_flag_(*it, padding); - } else { - handle_flag_(*it, padding); - } - } else { - break; - } - } else // chars not following the % sign should be displayed as is - { - if (!user_chars) { - user_chars - = details::make_unique(); - } - user_chars->add_ch(*it); - } - } - if (user_chars) // append raw chars found so far - { - formatters_.push_back(std::move(user_chars)); - } -} -} // namespace spdlog diff --git a/src/common/spdlog/pattern_formatter.h b/src/common/spdlog/pattern_formatter.h deleted file mode 100755 index 3f19b6e8973..00000000000 --- a/src/common/spdlog/pattern_formatter.h +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include - -namespace spdlog { -namespace details { - -// padding information. -struct padding_info { - enum class pad_side { left, right, center }; - - padding_info() = default; - padding_info(size_t width, padding_info::pad_side side, bool truncate) - : width_(width), side_(side), truncate_(truncate), enabled_(true) {} - - bool enabled() const { return enabled_; } - size_t width_ = 0; - pad_side side_ = pad_side::left; - bool truncate_ = false; - bool enabled_ = false; -}; - -class SPDLOG_API flag_formatter { -public: - explicit flag_formatter(padding_info padinfo) : padinfo_(padinfo) {} - flag_formatter() = default; - virtual ~flag_formatter() = default; - virtual void format(const details::log_msg &msg, const std::tm &tm_time, - memory_buf_t &dest) - = 0; - -protected: - padding_info padinfo_; -}; - -} // namespace details - -class SPDLOG_API custom_flag_formatter : public details::flag_formatter { -public: - virtual std::unique_ptr clone() const = 0; - - void set_padding_info(const details::padding_info &padding) { - flag_formatter::padinfo_ = padding; - } -}; - -class SPDLOG_API pattern_formatter final : public formatter { -public: - using custom_flags - = std::unordered_map>; - - explicit pattern_formatter(std::string pattern, - pattern_time_type time_type = pattern_time_type::local, - std::string eol = spdlog::details::os::default_eol, - custom_flags custom_user_flags = custom_flags()); - - // use default pattern is not given - explicit pattern_formatter( - pattern_time_type time_type = pattern_time_type::local, - std::string eol = spdlog::details::os::default_eol); - - pattern_formatter(const pattern_formatter &other) = delete; - pattern_formatter &operator=(const pattern_formatter &other) = delete; - - std::unique_ptr clone() const override; - void format(const details::log_msg &msg, memory_buf_t &dest) override; - - template - pattern_formatter &add_flag(char flag, Args &&...args) { - custom_handlers_[flag] - = details::make_unique(std::forward(args)...); - return *this; - } - void set_pattern(std::string pattern); - void need_localtime(bool need = true); - -private: - std::string pattern_; - std::string eol_; - pattern_time_type pattern_time_type_; - bool need_localtime_; - std::tm cached_tm_; - std::chrono::seconds last_log_secs_; - std::vector> formatters_; - custom_flags custom_handlers_; - - std::tm get_time_(const details::log_msg &msg); - template - void handle_flag_(char flag, details::padding_info padding); - - // Extract given pad spec (e.g. %8X) - // Advance the given it pass the end of the padding spec found (if any) - // Return padding. - static details::padding_info handle_padspec_( - std::string::const_iterator &it, std::string::const_iterator end); - - void compile_pattern_(const std::string &pattern); -}; -} // namespace spdlog - -#ifdef SPDLOG_HEADER_ONLY -#include "pattern_formatter-inl.h" -#endif diff --git a/src/common/spdlog/sinks/ansicolor_sink-inl.h b/src/common/spdlog/sinks/ansicolor_sink-inl.h deleted file mode 100755 index bfd5aa00784..00000000000 --- a/src/common/spdlog/sinks/ansicolor_sink-inl.h +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include - -namespace spdlog { -namespace sinks { - -template -SPDLOG_INLINE ansicolor_sink::ansicolor_sink( - FILE *target_file, color_mode mode) - : target_file_(target_file) - , mutex_(ConsoleMutex::mutex()) - , formatter_(details::make_unique()) - -{ - set_color_mode(mode); - colors_.at(level::trace) = to_string_(white); - colors_.at(level::debug) = to_string_(cyan); - colors_.at(level::info) = to_string_(green); - colors_.at(level::warn) = to_string_(yellow_bold); - colors_.at(level::err) = to_string_(red_bold); - colors_.at(level::critical) = to_string_(bold_on_red); - colors_.at(level::off) = to_string_(reset); -} - -template -SPDLOG_INLINE void ansicolor_sink::set_color( - level::level_enum color_level, string_view_t color) { - std::lock_guard lock(mutex_); - colors_.at(static_cast(color_level)) = to_string_(color); -} - -template -SPDLOG_INLINE void ansicolor_sink::log( - const details::log_msg &msg) { - // Wrap the originally formatted message in color codes. - // If color is not supported in the terminal, log as is instead. - std::lock_guard lock(mutex_); - msg.color_range_start = 0; - msg.color_range_end = 0; - memory_buf_t formatted; - formatter_->format(msg, formatted); - if (should_do_colors_ && msg.color_range_end > msg.color_range_start) { - // before color range - print_range_(formatted, 0, msg.color_range_start); - // in color range - print_ccode_(colors_.at(static_cast(msg.level))); - print_range_(formatted, msg.color_range_start, msg.color_range_end); - print_ccode_(reset); - // after color range - print_range_(formatted, msg.color_range_end, formatted.size()); - } else // no color - { - print_range_(formatted, 0, formatted.size()); - } - fflush(target_file_); -} - -template -SPDLOG_INLINE void ansicolor_sink::flush() { - std::lock_guard lock(mutex_); - fflush(target_file_); -} - -template -SPDLOG_INLINE void ansicolor_sink::set_pattern( - const std::string &pattern) { - std::lock_guard lock(mutex_); - formatter_ = std::unique_ptr( - new pattern_formatter(pattern)); -} - -template -SPDLOG_INLINE void ansicolor_sink::set_formatter( - std::unique_ptr sink_formatter) { - std::lock_guard lock(mutex_); - formatter_ = std::move(sink_formatter); -} - -template -SPDLOG_INLINE bool ansicolor_sink::should_color() { - return should_do_colors_; -} - -template -SPDLOG_INLINE void ansicolor_sink::set_color_mode( - color_mode mode) { - switch (mode) { - case color_mode::always: should_do_colors_ = true; return; - case color_mode::automatic: - should_do_colors_ = details::os::in_terminal(target_file_) - && details::os::is_color_terminal(); - return; - case color_mode::never: should_do_colors_ = false; return; - default: should_do_colors_ = false; - } -} - -template -SPDLOG_INLINE void ansicolor_sink::print_ccode_( - const string_view_t &color_code) { - fwrite(color_code.data(), sizeof(char), color_code.size(), target_file_); -} - -template -SPDLOG_INLINE void ansicolor_sink::print_range_( - const memory_buf_t &formatted, size_t start, size_t end) { - fwrite(formatted.data() + start, sizeof(char), end - start, target_file_); -} - -template -SPDLOG_INLINE std::string ansicolor_sink::to_string_( - const string_view_t &sv) { - return std::string(sv.data(), sv.size()); -} - -// ansicolor_stdout_sink -template -SPDLOG_INLINE ansicolor_stdout_sink::ansicolor_stdout_sink( - color_mode mode) - : ansicolor_sink(stdout, mode) {} - -// ansicolor_stderr_sink -template -SPDLOG_INLINE ansicolor_stderr_sink::ansicolor_stderr_sink( - color_mode mode) - : ansicolor_sink(stderr, mode) {} - -} // namespace sinks -} // namespace spdlog diff --git a/src/common/spdlog/sinks/base_sink-inl.h b/src/common/spdlog/sinks/base_sink-inl.h deleted file mode 100755 index 8c94fa5d587..00000000000 --- a/src/common/spdlog/sinks/base_sink-inl.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include - -#include -#include - -template -SPDLOG_INLINE spdlog::sinks::base_sink::base_sink() - : formatter_ {details::make_unique()} {} - -template -SPDLOG_INLINE spdlog::sinks::base_sink::base_sink( - std::unique_ptr formatter) - : formatter_ {std::move(formatter)} {} - -template -void SPDLOG_INLINE spdlog::sinks::base_sink::log( - const details::log_msg &msg) { - std::lock_guard lock(mutex_); - sink_it_(msg); -} - -template -void SPDLOG_INLINE spdlog::sinks::base_sink::flush() { - std::lock_guard lock(mutex_); - flush_(); -} - -template -void SPDLOG_INLINE spdlog::sinks::base_sink::set_pattern( - const std::string &pattern) { - std::lock_guard lock(mutex_); - set_pattern_(pattern); -} - -template -void SPDLOG_INLINE spdlog::sinks::base_sink::set_formatter( - std::unique_ptr sink_formatter) { - std::lock_guard lock(mutex_); - set_formatter_(std::move(sink_formatter)); -} - -template -void SPDLOG_INLINE spdlog::sinks::base_sink::set_pattern_( - const std::string &pattern) { - set_formatter_(details::make_unique(pattern)); -} - -template -void SPDLOG_INLINE spdlog::sinks::base_sink::set_formatter_( - std::unique_ptr sink_formatter) { - formatter_ = std::move(sink_formatter); -} diff --git a/src/common/spdlog/sinks/base_sink.h b/src/common/spdlog/sinks/base_sink.h deleted file mode 100755 index 2a37d9d1555..00000000000 --- a/src/common/spdlog/sinks/base_sink.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once -// -// base sink templated over a mutex (either dummy or real) -// concrete implementation should override the sink_it_() and flush_() methods. -// locking is taken care of in this class - no locking needed by the -// implementers.. -// - -#include -#include -#include - -namespace spdlog { -namespace sinks { -template -class SPDLOG_API base_sink : public sink { -public: - base_sink(); - explicit base_sink(std::unique_ptr formatter); - ~base_sink() override = default; - - base_sink(const base_sink &) = delete; - base_sink(base_sink &&) = delete; - - base_sink &operator=(const base_sink &) = delete; - base_sink &operator=(base_sink &&) = delete; - - void log(const details::log_msg &msg) final; - void flush() final; - void set_pattern(const std::string &pattern) final; - void set_formatter(std::unique_ptr sink_formatter) final; - -protected: - // sink formatter - std::unique_ptr formatter_; - Mutex mutex_; - - virtual void sink_it_(const details::log_msg &msg) = 0; - virtual void flush_() = 0; - virtual void set_pattern_(const std::string &pattern); - virtual void set_formatter_( - std::unique_ptr sink_formatter); -}; -} // namespace sinks -} // namespace spdlog - -#ifdef SPDLOG_HEADER_ONLY -#include "base_sink-inl.h" -#endif diff --git a/src/common/spdlog/sinks/basic_file_sink-inl.h b/src/common/spdlog/sinks/basic_file_sink-inl.h deleted file mode 100755 index ad2ee24eb42..00000000000 --- a/src/common/spdlog/sinks/basic_file_sink-inl.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include - -namespace spdlog { -namespace sinks { - -template -SPDLOG_INLINE basic_file_sink::basic_file_sink( - const filename_t &filename, bool truncate, - const file_event_handlers &event_handlers) - : file_helper_ {event_handlers} { - file_helper_.open(filename, truncate); -} - -template -SPDLOG_INLINE const filename_t &basic_file_sink::filename() const { - return file_helper_.filename(); -} - -template -SPDLOG_INLINE void basic_file_sink::sink_it_( - const details::log_msg &msg) { - memory_buf_t formatted; - base_sink::formatter_->format(msg, formatted); - file_helper_.write(formatted); -} - -template -SPDLOG_INLINE void basic_file_sink::flush_() { - file_helper_.flush(); -} - -} // namespace sinks -} // namespace spdlog diff --git a/src/common/spdlog/sinks/basic_file_sink.h b/src/common/spdlog/sinks/basic_file_sink.h deleted file mode 100755 index 3d7742356e3..00000000000 --- a/src/common/spdlog/sinks/basic_file_sink.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include -#include -#include - -#include -#include - -namespace spdlog { -namespace sinks { -/* - * Trivial file sink with single file as target - */ -template -class basic_file_sink final : public base_sink { -public: - explicit basic_file_sink(const filename_t &filename, bool truncate = false, - const file_event_handlers &event_handlers = {}); - const filename_t &filename() const; - -protected: - void sink_it_(const details::log_msg &msg) override; - void flush_() override; - -private: - details::file_helper file_helper_; -}; - -using basic_file_sink_mt = basic_file_sink; -using basic_file_sink_st = basic_file_sink; - -} // namespace sinks - -// -// factory functions -// -template -inline std::shared_ptr basic_logger_mt(const std::string &logger_name, - const filename_t &filename, bool truncate = false, - const file_event_handlers &event_handlers = {}) { - return Factory::template create( - logger_name, filename, truncate, event_handlers); -} - -template -inline std::shared_ptr basic_logger_st(const std::string &logger_name, - const filename_t &filename, bool truncate = false, - const file_event_handlers &event_handlers = {}) { - return Factory::template create( - logger_name, filename, truncate, event_handlers); -} - -} // namespace spdlog - -#ifdef SPDLOG_HEADER_ONLY -#include "basic_file_sink-inl.h" -#endif diff --git a/src/common/spdlog/sinks/null_sink.h b/src/common/spdlog/sinks/null_sink.h deleted file mode 100755 index 628bdee33ac..00000000000 --- a/src/common/spdlog/sinks/null_sink.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include -#include - -#include - -namespace spdlog { -namespace sinks { - -template -class null_sink : public base_sink { -protected: - void sink_it_(const details::log_msg &) override {} - void flush_() override {} -}; - -using null_sink_mt = null_sink; -using null_sink_st = null_sink; - -} // namespace sinks - -template -inline std::shared_ptr null_logger_mt(const std::string &logger_name) { - auto null_logger - = Factory::template create(logger_name); - null_logger->set_level(level::off); - return null_logger; -} - -template -inline std::shared_ptr null_logger_st(const std::string &logger_name) { - auto null_logger - = Factory::template create(logger_name); - null_logger->set_level(level::off); - return null_logger; -} - -} // namespace spdlog diff --git a/src/common/spdlog/sinks/ostream_sink.h b/src/common/spdlog/sinks/ostream_sink.h deleted file mode 100755 index 383ae4cb3aa..00000000000 --- a/src/common/spdlog/sinks/ostream_sink.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include - -#include -#include - -namespace spdlog { -namespace sinks { -template -class ostream_sink final : public base_sink { -public: - explicit ostream_sink(std::ostream &os, bool force_flush = false) - : ostream_(os), force_flush_(force_flush) {} - ostream_sink(const ostream_sink &) = delete; - ostream_sink &operator=(const ostream_sink &) = delete; - -protected: - void sink_it_(const details::log_msg &msg) override { - memory_buf_t formatted; - base_sink::formatter_->format(msg, formatted); - ostream_.write(formatted.data(), - static_cast(formatted.size())); - if (force_flush_) { ostream_.flush(); } - } - - void flush_() override { ostream_.flush(); } - - std::ostream &ostream_; - bool force_flush_; -}; - -using ostream_sink_mt = ostream_sink; -using ostream_sink_st = ostream_sink; - -} // namespace sinks -} // namespace spdlog diff --git a/src/common/spdlog/sinks/rotating_file_sink-inl.h b/src/common/spdlog/sinks/rotating_file_sink-inl.h deleted file mode 100755 index 8b491c9828b..00000000000 --- a/src/common/spdlog/sinks/rotating_file_sink-inl.h +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace spdlog { -namespace sinks { - -template -SPDLOG_INLINE rotating_file_sink::rotating_file_sink( - filename_t base_filename, std::size_t max_size, std::size_t max_files, - bool rotate_on_open, const file_event_handlers &event_handlers) - : base_filename_(std::move(base_filename)) - , max_size_(max_size) - , max_files_(max_files) - , file_helper_ {event_handlers} { - if (max_size == 0) { - throw_spdlog_ex( - "rotating sink constructor: max_size arg cannot be zero"); - } - - if (max_files > 200000) { - throw_spdlog_ex( - "rotating sink constructor: max_files arg cannot exceed " - "200000"); - } - file_helper_.open(calc_filename(base_filename_, 0)); - current_size_ = file_helper_.size(); // expensive. called only once - if (rotate_on_open && current_size_ > 0) { - rotate_(); - current_size_ = 0; - } -} - -// calc filename according to index and file extension if exists. -// e.g. calc_filename("logs/mylog.txt, 3) => "logs/mylog.3.txt". -template -SPDLOG_INLINE filename_t rotating_file_sink::calc_filename( - const filename_t &filename, std::size_t index) { - if (index == 0u) { return filename; } - - filename_t basename, ext; - std::tie(basename, ext) - = details::file_helper::split_by_extension(filename); - return fmt_lib::format(SPDLOG_FILENAME_T("{}.{}{}"), basename, index, ext); -} - -template -SPDLOG_INLINE filename_t rotating_file_sink::filename() { - std::lock_guard lock(base_sink::mutex_); - return file_helper_.filename(); -} - -template -SPDLOG_INLINE void rotating_file_sink::sink_it_( - const details::log_msg &msg) { - memory_buf_t formatted; - base_sink::formatter_->format(msg, formatted); - auto new_size = current_size_ + formatted.size(); - - // rotate if the new estimated file size exceeds max size. - // rotate only if the real size > 0 to better deal with full disk (see issue #2261). - // we only check the real size when new_size > max_size_ because it is relatively expensive. - if (new_size > max_size_) { - file_helper_.flush(); - if (file_helper_.size() > 0) { - rotate_(); - new_size = formatted.size(); - } - } - file_helper_.write(formatted); - current_size_ = new_size; -} - -template -SPDLOG_INLINE void rotating_file_sink::flush_() { - file_helper_.flush(); -} - -// Rotate files: -// log.txt -> log.1.txt -// log.1.txt -> log.2.txt -// log.2.txt -> log.3.txt -// log.3.txt -> delete -template -SPDLOG_INLINE void rotating_file_sink::rotate_() { - using details::os::filename_to_str; - using details::os::path_exists; - - file_helper_.close(); - for (auto i = max_files_; i > 0; --i) { - filename_t src = calc_filename(base_filename_, i - 1); - if (!path_exists(src)) { continue; } - filename_t target = calc_filename(base_filename_, i); - - if (!rename_file_(src, target)) { - // if failed try again after a small delay. - // this is a workaround to a windows issue, where very high rotation - // rates can cause the rename to fail with permission denied (because of antivirus?). - details::os::sleep_for_millis(100); - if (!rename_file_(src, target)) { - file_helper_.reopen( - true); // truncate the log file anyway to prevent it to grow beyond its limit! - current_size_ = 0; - throw_spdlog_ex("rotating_file_sink: failed renaming " - + filename_to_str(src) + " to " - + filename_to_str(target), - errno); - } - } - } - file_helper_.reopen(true); -} - -// delete the target if exists, and rename the src file to target -// return true on success, false otherwise. -template -SPDLOG_INLINE bool rotating_file_sink::rename_file_( - const filename_t &src_filename, const filename_t &target_filename) { - // try to delete the target file in case it already exists. - (void)details::os::remove(target_filename); - return details::os::rename(src_filename, target_filename) == 0; -} - -} // namespace sinks -} // namespace spdlog diff --git a/src/common/spdlog/sinks/rotating_file_sink.h b/src/common/spdlog/sinks/rotating_file_sink.h deleted file mode 100755 index 937c165e8ae..00000000000 --- a/src/common/spdlog/sinks/rotating_file_sink.h +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include -#include -#include - -#include -#include -#include - -namespace spdlog { -namespace sinks { - -// -// Rotating file sink based on size -// -template -class rotating_file_sink final : public base_sink { -public: - rotating_file_sink(filename_t base_filename, std::size_t max_size, - std::size_t max_files, bool rotate_on_open = false, - const file_event_handlers &event_handlers = {}); - static filename_t calc_filename( - const filename_t &filename, std::size_t index); - filename_t filename(); - -protected: - void sink_it_(const details::log_msg &msg) override; - void flush_() override; - -private: - // Rotate files: - // log.txt -> log.1.txt - // log.1.txt -> log.2.txt - // log.2.txt -> log.3.txt - // log.3.txt -> delete - void rotate_(); - - // delete the target if exists, and rename the src file to target - // return true on success, false otherwise. - bool rename_file_( - const filename_t &src_filename, const filename_t &target_filename); - - filename_t base_filename_; - std::size_t max_size_; - std::size_t max_files_; - std::size_t current_size_; - details::file_helper file_helper_; -}; - -using rotating_file_sink_mt = rotating_file_sink; -using rotating_file_sink_st = rotating_file_sink; - -} // namespace sinks - -// -// factory functions -// - -template -inline std::shared_ptr rotating_logger_mt( - const std::string &logger_name, const filename_t &filename, - size_t max_file_size, size_t max_files, bool rotate_on_open = false, - const file_event_handlers &event_handlers = {}) { - return Factory::template create(logger_name, - filename, max_file_size, max_files, rotate_on_open, event_handlers); -} - -template -inline std::shared_ptr rotating_logger_st( - const std::string &logger_name, const filename_t &filename, - size_t max_file_size, size_t max_files, bool rotate_on_open = false, - const file_event_handlers &event_handlers = {}) { - return Factory::template create(logger_name, - filename, max_file_size, max_files, rotate_on_open, event_handlers); -} -} // namespace spdlog - -#ifdef SPDLOG_HEADER_ONLY -#include "rotating_file_sink-inl.h" -#endif diff --git a/src/common/spdlog/sinks/sink-inl.h b/src/common/spdlog/sinks/sink-inl.h deleted file mode 100755 index a1ef129fba9..00000000000 --- a/src/common/spdlog/sinks/sink-inl.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include - -SPDLOG_INLINE bool spdlog::sinks::sink::should_log( - spdlog::level::level_enum msg_level) const { - return msg_level >= level_.load(std::memory_order_relaxed); -} - -SPDLOG_INLINE void spdlog::sinks::sink::set_level(level::level_enum log_level) { - level_.store(log_level, std::memory_order_relaxed); -} - -SPDLOG_INLINE spdlog::level::level_enum spdlog::sinks::sink::level() const { - return static_cast( - level_.load(std::memory_order_relaxed)); -} diff --git a/src/common/spdlog/sinks/sink.h b/src/common/spdlog/sinks/sink.h deleted file mode 100755 index 18e0d7cffba..00000000000 --- a/src/common/spdlog/sinks/sink.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#include -#include - -namespace spdlog { - -namespace sinks { -class SPDLOG_API sink { -public: - virtual ~sink() = default; - virtual void log(const details::log_msg &msg) = 0; - virtual void flush() = 0; - virtual void set_pattern(const std::string &pattern) = 0; - virtual void set_formatter( - std::unique_ptr sink_formatter) - = 0; - - void set_level(level::level_enum log_level); - level::level_enum level() const; - bool should_log(level::level_enum msg_level) const; - -protected: - // sink log level - default is all - level_t level_ {level::trace}; -}; - -} // namespace sinks -} // namespace spdlog - -#ifdef SPDLOG_HEADER_ONLY -#include "sink-inl.h" -#endif diff --git a/src/common/spdlog/spdlog-inl.h b/src/common/spdlog/spdlog-inl.h deleted file mode 100755 index b0641663f5c..00000000000 --- a/src/common/spdlog/spdlog-inl.h +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#ifndef SPDLOG_HEADER_ONLY -#include -#endif - -#include -#include - -namespace spdlog { - -SPDLOG_INLINE void initialize_logger(std::shared_ptr logger) { - details::registry::instance().initialize_logger(std::move(logger)); -} - -SPDLOG_INLINE std::shared_ptr get(const std::string &name) { - return details::registry::instance().get(name); -} - -SPDLOG_INLINE void set_formatter(std::unique_ptr formatter) { - details::registry::instance().set_formatter(std::move(formatter)); -} - -SPDLOG_INLINE void set_pattern( - std::string pattern, pattern_time_type time_type) { - set_formatter(std::unique_ptr( - new pattern_formatter(std::move(pattern), time_type))); -} - -SPDLOG_INLINE void enable_backtrace(size_t n_messages) { - details::registry::instance().enable_backtrace(n_messages); -} - -SPDLOG_INLINE void disable_backtrace() { - details::registry::instance().disable_backtrace(); -} - -SPDLOG_INLINE void dump_backtrace() { - default_logger_raw()->dump_backtrace(); -} - -SPDLOG_INLINE level::level_enum get_level() { - return default_logger_raw()->level(); -} - -SPDLOG_INLINE bool should_log(level::level_enum log_level) { - return default_logger_raw()->should_log(log_level); -} - -SPDLOG_INLINE void set_level(level::level_enum log_level) { - details::registry::instance().set_level(log_level); -} - -SPDLOG_INLINE void flush_on(level::level_enum log_level) { - details::registry::instance().flush_on(log_level); -} - -SPDLOG_INLINE void set_error_handler(void (*handler)(const std::string &msg)) { - details::registry::instance().set_error_handler(handler); -} - -SPDLOG_INLINE void register_logger(std::shared_ptr logger) { - details::registry::instance().register_logger(std::move(logger)); -} - -SPDLOG_INLINE void apply_all( - const std::function)> &fun) { - details::registry::instance().apply_all(fun); -} - -SPDLOG_INLINE void drop(const std::string &name) { - details::registry::instance().drop(name); -} - -SPDLOG_INLINE void drop_all() { - details::registry::instance().drop_all(); -} - -SPDLOG_INLINE void shutdown() { - details::registry::instance().shutdown(); -} - -SPDLOG_INLINE void set_automatic_registration(bool automatic_registration) { - details::registry::instance().set_automatic_registration( - automatic_registration); -} - -SPDLOG_INLINE std::shared_ptr default_logger() { - return details::registry::instance().default_logger(); -} - -SPDLOG_INLINE spdlog::logger *default_logger_raw() { - return details::registry::instance().get_default_raw(); -} - -SPDLOG_INLINE void set_default_logger( - std::shared_ptr default_logger) { - details::registry::instance().set_default_logger(std::move(default_logger)); -} - -SPDLOG_INLINE void apply_logger_env_levels(std::shared_ptr logger) { - details::registry::instance().apply_logger_env_levels(std::move(logger)); -} - -} // namespace spdlog diff --git a/src/common/spdlog/spdlog.h b/src/common/spdlog/spdlog.h deleted file mode 100755 index ef7ac2d53ee..00000000000 --- a/src/common/spdlog/spdlog.h +++ /dev/null @@ -1,362 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -// spdlog main header file. -// see example.cpp for usage example - -#ifndef SPDLOG_H -#define SPDLOG_H - -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace spdlog { - -using default_factory = synchronous_factory; - -// Create and register a logger with a templated sink type -// The logger's level, formatter and flush level will be set according the -// global settings. -// -// Example: -// spdlog::create("logger_name", "dailylog_filename", 11, 59); -template -inline std::shared_ptr create( - std::string logger_name, SinkArgs &&...sink_args) { - return default_factory::create( - std::move(logger_name), std::forward(sink_args)...); -} - -// Initialize and register a logger, -// formatter and flush level will be set according the global settings. -// -// Useful for initializing manually created loggers with the global settings. -// -// Example: -// auto mylogger = std::make_shared("mylogger", ...); -// spdlog::initialize_logger(mylogger); -SPDLOG_API void initialize_logger(std::shared_ptr logger); - -// Return an existing logger or nullptr if a logger with such name doesn't -// exist. -// example: spdlog::get("my_logger")->info("hello {}", "world"); -SPDLOG_API std::shared_ptr get(const std::string &name); - -// Set global formatter. Each sink in each logger will get a clone of this object -SPDLOG_API void set_formatter(std::unique_ptr formatter); - -// Set global format string. -// example: spdlog::set_pattern("%Y-%m-%d %H:%M:%S.%e %l : %v"); -SPDLOG_API void set_pattern(std::string pattern, - pattern_time_type time_type = pattern_time_type::local); - -// enable global backtrace support -SPDLOG_API void enable_backtrace(size_t n_messages); - -// disable global backtrace support -SPDLOG_API void disable_backtrace(); - -// call dump backtrace on default logger -SPDLOG_API void dump_backtrace(); - -// Get global logging level -SPDLOG_API level::level_enum get_level(); - -// Set global logging level -SPDLOG_API void set_level(level::level_enum log_level); - -// Determine whether the default logger should log messages with a certain level -SPDLOG_API bool should_log(level::level_enum lvl); - -// Set global flush level -SPDLOG_API void flush_on(level::level_enum log_level); - -// Start/Restart a periodic flusher thread -// Warning: Use only if all your loggers are thread safe! -template -inline void flush_every(std::chrono::duration interval) { - details::registry::instance().flush_every(interval); -} - -// Set global error handler -SPDLOG_API void set_error_handler(void (*handler)(const std::string &msg)); - -// Register the given logger with the given name -SPDLOG_API void register_logger(std::shared_ptr logger); - -// Apply a user defined function on all registered loggers -// Example: -// spdlog::apply_all([&](std::shared_ptr l) {l->flush();}); -SPDLOG_API void apply_all( - const std::function)> &fun); - -// Drop the reference to the given logger -SPDLOG_API void drop(const std::string &name); - -// Drop all references from the registry -SPDLOG_API void drop_all(); - -// stop any running threads started by spdlog and clean registry loggers -SPDLOG_API void shutdown(); - -// Automatic registration of loggers when using spdlog::create() or spdlog::create_async -SPDLOG_API void set_automatic_registration(bool automatic_registration); - -// API for using default logger (stdout_color_mt), -// e.g: spdlog::info("Message {}", 1); -// -// The default logger object can be accessed using the spdlog::default_logger(): -// For example, to add another sink to it: -// spdlog::default_logger()->sinks().push_back(some_sink); -// -// The default logger can replaced using spdlog::set_default_logger(new_logger). -// For example, to replace it with a file logger. -// -// IMPORTANT: -// The default API is thread safe (for _mt loggers), but: -// set_default_logger() *should not* be used concurrently with the default API. -// e.g do not call set_default_logger() from one thread while calling spdlog::info() from another. - -SPDLOG_API std::shared_ptr default_logger(); - -SPDLOG_API spdlog::logger *default_logger_raw(); - -SPDLOG_API void set_default_logger( - std::shared_ptr default_logger); - -// Initialize logger level based on environment configs. -// -// Useful for applying SPDLOG_LEVEL to manually created loggers. -// -// Example: -// auto mylogger = std::make_shared("mylogger", ...); -// spdlog::apply_logger_env_levels(mylogger); -SPDLOG_API void apply_logger_env_levels(std::shared_ptr logger); - -template -inline void log(source_loc source, level::level_enum lvl, - format_string_t fmt, Args &&...args) { - default_logger_raw()->log(source, lvl, fmt, std::forward(args)...); -} - -template -inline void log( - level::level_enum lvl, format_string_t fmt, Args &&...args) { - default_logger_raw()->log( - source_loc {}, lvl, fmt, std::forward(args)...); -} - -template -inline void trace(format_string_t fmt, Args &&...args) { - default_logger_raw()->trace(fmt, std::forward(args)...); -} - -template -inline void debug(format_string_t fmt, Args &&...args) { - default_logger_raw()->debug(fmt, std::forward(args)...); -} - -template -inline void info(format_string_t fmt, Args &&...args) { - default_logger_raw()->info(fmt, std::forward(args)...); -} - -template -inline void warn(format_string_t fmt, Args &&...args) { - default_logger_raw()->warn(fmt, std::forward(args)...); -} - -template -inline void error(format_string_t fmt, Args &&...args) { - default_logger_raw()->error(fmt, std::forward(args)...); -} - -template -inline void critical(format_string_t fmt, Args &&...args) { - default_logger_raw()->critical(fmt, std::forward(args)...); -} - -template -inline void log(source_loc source, level::level_enum lvl, const T &msg) { - default_logger_raw()->log(source, lvl, msg); -} - -template -inline void log(level::level_enum lvl, const T &msg) { - default_logger_raw()->log(lvl, msg); -} - -#ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT -template -inline void log(source_loc source, level::level_enum lvl, - wformat_string_t fmt, Args &&...args) { - default_logger_raw()->log(source, lvl, fmt, std::forward(args)...); -} - -template -inline void log( - level::level_enum lvl, wformat_string_t fmt, Args &&...args) { - default_logger_raw()->log( - source_loc {}, lvl, fmt, std::forward(args)...); -} - -template -inline void trace(wformat_string_t fmt, Args &&...args) { - default_logger_raw()->trace(fmt, std::forward(args)...); -} - -template -inline void debug(wformat_string_t fmt, Args &&...args) { - default_logger_raw()->debug(fmt, std::forward(args)...); -} - -template -inline void info(wformat_string_t fmt, Args &&...args) { - default_logger_raw()->info(fmt, std::forward(args)...); -} - -template -inline void warn(wformat_string_t fmt, Args &&...args) { - default_logger_raw()->warn(fmt, std::forward(args)...); -} - -template -inline void error(wformat_string_t fmt, Args &&...args) { - default_logger_raw()->error(fmt, std::forward(args)...); -} - -template -inline void critical(wformat_string_t fmt, Args &&...args) { - default_logger_raw()->critical(fmt, std::forward(args)...); -} -#endif - -template -inline void trace(const T &msg) { - default_logger_raw()->trace(msg); -} - -template -inline void debug(const T &msg) { - default_logger_raw()->debug(msg); -} - -template -inline void info(const T &msg) { - default_logger_raw()->info(msg); -} - -template -inline void warn(const T &msg) { - default_logger_raw()->warn(msg); -} - -template -inline void error(const T &msg) { - default_logger_raw()->error(msg); -} - -template -inline void critical(const T &msg) { - default_logger_raw()->critical(msg); -} - -} // namespace spdlog - -// -// enable/disable log calls at compile time according to global level. -// -// define SPDLOG_ACTIVE_LEVEL to one of those (before including spdlog.h): -// SPDLOG_LEVEL_TRACE, -// SPDLOG_LEVEL_DEBUG, -// SPDLOG_LEVEL_INFO, -// SPDLOG_LEVEL_WARN, -// SPDLOG_LEVEL_ERROR, -// SPDLOG_LEVEL_CRITICAL, -// SPDLOG_LEVEL_OFF -// - -#ifndef SPDLOG_NO_SOURCE_LOC -#define SPDLOG_LOGGER_CALL(logger, level, ...) \ - (logger)->log(spdlog::source_loc {__FILE__, __LINE__, SPDLOG_FUNCTION}, \ - level, __VA_ARGS__) -#else -#define SPDLOG_LOGGER_CALL(logger, level, ...) \ - (logger)->log(spdlog::source_loc {}, level, __VA_ARGS__) -#endif - -#if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_TRACE -#define SPDLOG_LOGGER_TRACE(logger, ...) \ - SPDLOG_LOGGER_CALL(logger, spdlog::level::trace, __VA_ARGS__) -#define SPDLOG_TRACE(...) \ - SPDLOG_LOGGER_TRACE(spdlog::default_logger_raw(), __VA_ARGS__) -#else -#define SPDLOG_LOGGER_TRACE(logger, ...) (void)0 -#define SPDLOG_TRACE(...) (void)0 -#endif - -#if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_DEBUG -#define SPDLOG_LOGGER_DEBUG(logger, ...) \ - SPDLOG_LOGGER_CALL(logger, spdlog::level::debug, __VA_ARGS__) -#define SPDLOG_DEBUG(...) \ - SPDLOG_LOGGER_DEBUG(spdlog::default_logger_raw(), __VA_ARGS__) -#else -#define SPDLOG_LOGGER_DEBUG(logger, ...) (void)0 -#define SPDLOG_DEBUG(...) (void)0 -#endif - -#if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_INFO -#define SPDLOG_LOGGER_INFO(logger, ...) \ - SPDLOG_LOGGER_CALL(logger, spdlog::level::info, __VA_ARGS__) -#define SPDLOG_INFO(...) \ - SPDLOG_LOGGER_INFO(spdlog::default_logger_raw(), __VA_ARGS__) -#else -#define SPDLOG_LOGGER_INFO(logger, ...) (void)0 -#define SPDLOG_INFO(...) (void)0 -#endif - -#if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_WARN -#define SPDLOG_LOGGER_WARN(logger, ...) \ - SPDLOG_LOGGER_CALL(logger, spdlog::level::warn, __VA_ARGS__) -#define SPDLOG_WARN(...) \ - SPDLOG_LOGGER_WARN(spdlog::default_logger_raw(), __VA_ARGS__) -#else -#define SPDLOG_LOGGER_WARN(logger, ...) (void)0 -#define SPDLOG_WARN(...) (void)0 -#endif - -#if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_ERROR -#define SPDLOG_LOGGER_ERROR(logger, ...) \ - SPDLOG_LOGGER_CALL(logger, spdlog::level::err, __VA_ARGS__) -#define SPDLOG_ERROR(...) \ - SPDLOG_LOGGER_ERROR(spdlog::default_logger_raw(), __VA_ARGS__) -#else -#define SPDLOG_LOGGER_ERROR(logger, ...) (void)0 -#define SPDLOG_ERROR(...) (void)0 -#endif - -#if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_CRITICAL -#define SPDLOG_LOGGER_CRITICAL(logger, ...) \ - SPDLOG_LOGGER_CALL(logger, spdlog::level::critical, __VA_ARGS__) -#define SPDLOG_CRITICAL(...) \ - SPDLOG_LOGGER_CRITICAL(spdlog::default_logger_raw(), __VA_ARGS__) -#else -#define SPDLOG_LOGGER_CRITICAL(logger, ...) (void)0 -#define SPDLOG_CRITICAL(...) (void)0 -#endif - -#ifdef SPDLOG_HEADER_ONLY -#include "spdlog-inl.h" -#endif - -#endif // SPDLOG_H diff --git a/src/common/spdlog/version.h b/src/common/spdlog/version.h deleted file mode 100755 index d3d49f42e1d..00000000000 --- a/src/common/spdlog/version.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright(c) 2015-present, Gabi Melman & spdlog contributors. -// Distributed under the MIT License (http://opensource.org/licenses/MIT) - -#pragma once - -#define SPDLOG_VER_MAJOR 1 -#define SPDLOG_VER_MINOR 14 -#define SPDLOG_VER_PATCH 1 - -#define SPDLOG_TO_VERSION(major, minor, patch) \ - (major * 10000 + minor * 100 + patch) -#define SPDLOG_VERSION \ - SPDLOG_TO_VERSION(SPDLOG_VER_MAJOR, SPDLOG_VER_MINOR, SPDLOG_VER_PATCH) diff --git a/src/common/stack_checker.hpp b/src/common/stack_checker.hpp index 013cdbcb58e..05cfa44bab1 100644 --- a/src/common/stack_checker.hpp +++ b/src/common/stack_checker.hpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2021-2023 Intel Corporation + * Copyright 2021-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -204,7 +204,7 @@ struct stack_checker_t { size_t soft_stack_limit_in_bytes = get_soft_stack_limit() * get_page_size(); if (stack_consumption > soft_stack_limit_in_bytes) { - VERROR(common, stack_checker, + VWARN(common, stack_checker, "'%s' consumed %lu bytes of " "stack while the limit is %lu bytes", context_.c_str(), stack_consumption, diff --git a/src/common/stream.hpp b/src/common/stream.hpp index a29627cdb1f..e9fa73295e6 100644 --- a/src/common/stream.hpp +++ b/src/common/stream.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ struct dnnl_stream : public dnnl::impl::c_compatible { dnnl_stream(dnnl::impl::engine_t *engine, dnnl::impl::stream_impl_t *impl) : engine_(engine), impl_(impl) {} - virtual ~dnnl_stream() {} + virtual ~dnnl_stream() = default; /** returns stream's engine */ dnnl::impl::engine_t *engine() const { return engine_; } diff --git a/src/common/sum_pd.hpp b/src/common/sum_pd.hpp index 38663af5515..bc50b1d6fa3 100644 --- a/src/common/sum_pd.hpp +++ b/src/common/sum_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,7 @@ namespace dnnl { namespace impl { +// NOLINTBEGIN(google-default-arguments) struct sum_pd_t : public primitive_desc_t { const sum_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { @@ -115,14 +116,14 @@ struct sum_pd_t : public primitive_desc_t { init_desc(); } - sum_pd_t(const sum_pd_t &other) : primitive_desc_t(other) { - n_ = other.n_; - scales_ = other.scales_; - dst_md_ = other.dst_md_; - dst_acc_md_ = other.dst_acc_md_; - src_mds_ = other.src_mds_; - original_dst_md_ = other.original_dst_md_; - + sum_pd_t(const sum_pd_t &other) + : primitive_desc_t(other) + , n_(other.n_) + , scales_(other.scales_) + , dst_md_(other.dst_md_) + , dst_acc_md_(other.dst_acc_md_) + , src_mds_(other.src_mds_) + , original_dst_md_(other.original_dst_md_) { init_desc(); } sum_pd_t &operator=(const sum_pd_t &other) { @@ -195,6 +196,7 @@ struct sum_pd_t : public primitive_desc_t { desc_.src_mds.push_back(&md); } }; +// NOLINTEND(google-default-arguments) #define DECLARE_SUM_PD_t(impl_name, ...) \ static status_t create(sum_pd_t **sum_pd, dnnl::impl::engine_t *engine, \ diff --git a/src/common/tag_traits.hpp b/src/common/tag_traits.hpp index 487f4581c9e..ad34ec963fb 100644 --- a/src/common/tag_traits.hpp +++ b/src/common/tag_traits.hpp @@ -1,5 +1,6 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation +* Copyright 2024 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +26,35 @@ namespace dnnl { namespace impl { +inline format_tag_t get_abx_tag(int ndims) { + switch (ndims) { + case 1: return format_tag::a; + case 2: return format_tag::ab; + case 3: return format_tag::abc; + case 4: return format_tag::abcd; + case 5: return format_tag::abcde; + case 6: return format_tag::abcdef; + case 7: return format_tag::abcdefg; + case 8: return format_tag::abcdefgh; + case 9: return format_tag::abcdefghi; + case 10: return format_tag::abcdefghij; + case 11: return format_tag::abcdefghijk; + case 12: return format_tag::abcdefghijkl; + + default: assert(!"unexpected ndims"); return format_tag::undef; + } +} + +inline format_tag_t get_axb_tag(int ndims) { + switch (ndims) { + case 2: return format_tag::ab; + case 3: return format_tag::acb; + case 4: return format_tag::acdb; + case 5: return format_tag::acdeb; + default: assert(!"unexpected ndims"); return format_tag::undef; + } +} + enum class block_dim_t { _, _A, @@ -89,6 +119,7 @@ enum class inner_blk_t { _8b16c, _8b24c, _8b32a, + _8a32b, _8b8c, _8c2b, _8c4b, @@ -112,6 +143,7 @@ enum class inner_blk_t { _16b4c, _16c2b, _16c4b, + _16e4c, _24a2b, _24a4b, _24b2a, @@ -120,6 +152,7 @@ enum class inner_blk_t { _24b4c, _24c2b, _24c4b, + _16d4c, _32d4c, _32e2c, _32e4c, @@ -209,7 +242,7 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) { utils::one_of(f, ib::_4a4b, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a2b, ib::_8a4b, ib::_8b2a, ib::_8b4a, ib::_8b2c, ib::_8c2b, ib::_8c4b, ib::_8b4c, ib::_8a8b, ib::_8b8a, - ib::_8b16a, ib::_8b24a, ib::_8b32a, ib::_8b8c, ib::_8c8b, + ib::_8b16a, ib::_8b24a, ib::_8b32a, ib::_8a32b, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16b64a, ib::_16b48a, ib::_16b32a, ib::_16b16a, ib::_16b16c, ib::_16c16b, ib::_32a32b, ib::_16a2b, ib::_16a4b, ib::_16b2a, ib::_16b4a, ib::_16b2c, @@ -241,9 +274,10 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) { : (utils::one_of(f, ib::_2a24b, ib::_2b24c, ib::_8a24b, ib::_8b24c)) ? 24 * x0 + x1 : (f == ib::_4a4b || f == ib::_4b4c) ? 4 * x0 + x1 : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 - : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 + : (f == ib::_8a8b || f == ib::_8a32b || f == ib::_8b8c) ? 8 * x0 + x1 : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 : (utils::one_of(f, ib::_16a16b, ib::_16b16c, ib::_8a16b, ib::_8b16c)) ? 16 * x0 + x1 + : (f == ib::_16a16b || f == ib::_16a32b || f == ib::_16b16c) ? 16 * x0 + x1 : (f == ib::_16b64a) ? 64 * x1 + x0 : (f == ib::_16b48a) ? 48 * x1 + x0 : (f == ib::_8b32a || f == ib::_16b32a) ? 32 * x1 + x0 @@ -293,12 +327,12 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) { } template -struct inner_blk_traits { +struct inner_blk_traits_t { using ib = inner_blk_t; }; template -struct tag_traits { +struct tag_traits_t { // block_dim_t block_dims; // inner_blk_t inner_blks; // int ndims; @@ -306,7 +340,7 @@ struct tag_traits { #define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \ template <> \ - struct tag_traits { \ + struct tag_traits_t { \ static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \ static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \ static constexpr int ndims = _ndims; \ @@ -668,6 +702,8 @@ DECL_TRAITS(ABcde16b48a2b, _AB, _16b48a2b, 5); DECL_TRAITS(ABcde16b64a2b, _AB, _16b64a2b, 5); DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4); DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4); +DECL_TRAITS(ABcd8a32b, _AB, _8a32b, 4); +DECL_TRAITS(ABcd16a32b, _AB, _16a32b, 4); DECL_TRAITS(aBcd8b, _B, _8b, 4); DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4); DECL_TRAITS(AcdB8b16a2b, _AB, _8b16a2b, 4); @@ -767,7 +803,9 @@ DECL_TRAITS(AcB16a4b, _AB, _16a4b, 3); DECL_TRAITS(Acb8a, _A, _8a, 3); DECL_TRAITS(AcB8a2b, _AB, _8a2b, 3); DECL_TRAITS(AcB8a4b, _AB, _8a4b, 3); +DECL_TRAITS(aCBd8b8c, _BC, _8b8c, 4); DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aCBde8b8c, _BC, _8b8c, 5); DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5); DECL_TRAITS(Acdb16a, _A, _16a, 4); DECL_TRAITS(AcdB16a2b, _AB, _16a2b, 4); @@ -783,8 +821,11 @@ DECL_TRAITS(AcdeB8a2b, _AB, _8a2b, 5); DECL_TRAITS(AcdeB8a4b, _AB, _8a4b, 5); DECL_TRAITS(Acedb16a, _A, _16a, 5); DECL_TRAITS(Adcb16a, _A, _16a, 4); +DECL_TRAITS(BAc8a8b, _AB, _8a8b, 3); DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(BAcd8a8b, _AB, _8a8b, 4); DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4); +DECL_TRAITS(BAcde8a8b, _AB, _8a8b, 5); DECL_TRAITS(BAcde16a16b, _AB, _16a16b, 5); DECL_TRAITS(ABcd32a32b, _AB, _32a32b, 4); DECL_TRAITS(BAcde16b16a, _AB, _16b16a, 5); @@ -794,7 +835,10 @@ DECL_TRAITS(aBCde4b8c8b4c, _BC, _4b8c8b4c, 5); DECL_TRAITS(aBCde2b8c8b2c, _BC, _2b8c8b2c, 5); DECL_TRAITS(aBdec32b, _B, _32b, 5); DECL_TRAITS(aCBdef16c16b, _BC, _16c16b, 6); +DECL_TRAITS(aCBdef8b8c, _BC, _8b8c, 6); DECL_TRAITS(aCBdef16b16c, _BC, _16b16c, 6); +DECL_TRAITS(Abcdef4a, _A, _4a, 6); +DECL_TRAITS(Abcdef8a, _A, _8a, 6); DECL_TRAITS(Abcdef16a, _A, _16a, 6); DECL_TRAITS(aCBd16c16b, _BC, _16c16b, 4); DECL_TRAITS(aCBde16c16b, _BC, _16c16b, 4); @@ -815,6 +859,7 @@ DECL_TRAITS(aBCde4c8b2c, _BC, _4c8b2c, 5); DECL_TRAITS(aBCdef4c8b2c, _BC, _4c8b2c, 6); DECL_TRAITS(abDc16d, _D, _16d, 4); DECL_TRAITS(abDc32d, _D, _32d, 4); +DECL_TRAITS(abDC16d4c, _CD, _16d4c, 4); DECL_TRAITS(abDC32d4c, _CD, _32d4c, 4); DECL_TRAITS(abCd32c, _C, _32c, 4); DECL_TRAITS(abCde32c, _C, _32c, 5); @@ -824,6 +869,7 @@ DECL_TRAITS(abCde4c, _C, _4c, 5); DECL_TRAITS(abCdef4c, _C, _4c, 6); DECL_TRAITS(abdEc16e, _E, _16e, 5); DECL_TRAITS(abdEc32e, _E, _32e, 5); +DECL_TRAITS(abdEC16e4c, _CE, _16e4c, 5); DECL_TRAITS(abdEC32e2c, _CE, _32e2c, 5); DECL_TRAITS(abdEC32e4c, _CE, _32e4c, 5); DECL_TRAITS(abdEC64e2c, _CE, _64e2c, 5); diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index d5ea2c38dda..ad00fadde16 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,8 +26,11 @@ #include "bit_cast.hpp" #include "c_types_map.hpp" #include "dnnl_traits.hpp" +#include "gemm_types.hpp" #include "memory_desc.hpp" #include "nstl.hpp" +#include "opdesc.hpp" +#include "sdpa_types.hpp" #include "utils.hpp" namespace dnnl { @@ -50,8 +53,16 @@ status_t safe_ptr_assign(std::unique_ptr &lhs, derived_type *rhs) { return status::success; } +template +status_t safe_ptr_assign( + std::unique_ptr &lhs, derived_type *rhs) { + if (rhs == nullptr) return status::out_of_memory; + lhs.reset(rhs); + return status::success; +} + template -struct is_subset { +struct is_subset { // NOLINT(readability-identifier-naming) static constexpr bool value = false; }; template @@ -82,20 +93,24 @@ namespace types { inline size_t data_type_size(data_type_t data_type) { using namespace data_type; switch ((int)data_type) { - case e8m0: return sizeof(prec_traits::type); - case f8_e5m2: return sizeof(prec_traits::type); - case f8_e4m3: return sizeof(prec_traits::type); - case f16: return sizeof(prec_traits::type); - case bf16: return sizeof(prec_traits::type); + case f4_e3m0: return sizeof(prec_traits_t::type); + case f4_e2m1: return sizeof(prec_traits_t::type); + case e8m0: return sizeof(prec_traits_t::type); + case f8_e5m2: return sizeof(prec_traits_t::type); + case f8_e4m3: return sizeof(prec_traits_t::type); + case f16: return sizeof(prec_traits_t::type); + case bf16: return sizeof(prec_traits_t::type); case tf32: // the tf32 type is an f32 - case f32: return sizeof(prec_traits::type); - case f64: return sizeof(prec_traits::type); - case s32: return sizeof(prec_traits::type); - case s8: return sizeof(prec_traits::type); - case u8: return sizeof(prec_traits::type); - case s4: return sizeof(prec_traits::type); - case u4: return sizeof(prec_traits::type); - case boolean: return sizeof(prec_traits::type); + case f32: return sizeof(prec_traits_t::type); + case f64: return sizeof(prec_traits_t::type); + case s32: return sizeof(prec_traits_t::type); + case s8: return sizeof(prec_traits_t::type); + case u8: return sizeof(prec_traits_t::type); + case s4: return sizeof(prec_traits_t::type); + case u4: return sizeof(prec_traits_t::type); + case boolean: return sizeof(prec_traits_t::type); + case bin: return sizeof(prec_traits_t::type); + case nf4: return sizeof(prec_traits_t::type); case data_type::undef: default: assert(!"unknown data_type"); } @@ -105,6 +120,8 @@ inline size_t data_type_size(data_type_t data_type) { inline size_t elements_to_bytes(data_type_t data_type, size_t count) { using namespace data_type; switch ((int)data_type) { + case f4_e2m1: + case f4_e3m0: case s4: case u4: return (count + 1) >> 1; default: return data_type_size(data_type) * count; @@ -114,6 +131,8 @@ inline size_t elements_to_bytes(data_type_t data_type, size_t count) { inline size_t bytes_to_elements(data_type_t data_type, size_t bytes) { using namespace data_type; switch ((int)data_type) { + case f4_e2m1: + case f4_e3m0: case s4: case u4: return bytes * 2; default: return utils::div_up(bytes, data_type_size(data_type)); @@ -125,14 +144,18 @@ inline T min_value(data_type_t data_type) { using namespace data_type; #define CASE(x) \ case x: \ - return static_cast(nstl::numeric_limits::type>::min()) + return static_cast( \ + nstl::numeric_limits::type>::min()) switch (data_type) { + CASE(f4_e3m0); + CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); CASE(f8_e4m3); CASE(f16); CASE(bf16); CASE(f32); + CASE(f64); CASE(s32); CASE(s8); CASE(u8); @@ -150,19 +173,23 @@ inline T max_value(data_type_t data_type) { using namespace data_type; #define CASE(x) \ case x: \ - return static_cast(nstl::numeric_limits::type>::max()) + return static_cast( \ + nstl::numeric_limits::type>::max()) switch (data_type) { + CASE(f4_e3m0); + CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); CASE(f8_e4m3); CASE(f16); - CASE(f32); CASE(bf16); + CASE(f32); CASE(s32); CASE(s8); CASE(u8); CASE(s4); CASE(u4); + case f64: return nstl::numeric_limits::max(); case data_type::undef: default: assert(!"unknown data_type"); } @@ -177,8 +204,10 @@ inline float max_value(data_type_t data_type) { #define CASE(x) \ case x: \ return static_cast( \ - nstl::numeric_limits::type>::max()) + nstl::numeric_limits::type>::max()) switch (data_type) { + CASE(f4_e3m0); + CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); CASE(f8_e4m3); @@ -200,6 +229,7 @@ inline float max_value(data_type_t data_type) { // approach is saturating on some integer values before it should happen // in the reality. case s32: return 2147483520.f; + case f64: return nstl::numeric_limits::max(); case data_type::undef: default: assert(!"unknown data_type"); } @@ -213,8 +243,10 @@ inline T lowest_value(data_type_t data_type) { #define CASE(x) \ case x: \ return static_cast( \ - nstl::numeric_limits::type>::lowest()) + nstl::numeric_limits::type>::lowest()) switch (data_type) { + CASE(f4_e3m0); + CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); CASE(f8_e4m3); @@ -226,6 +258,7 @@ inline T lowest_value(data_type_t data_type) { CASE(u8); CASE(s4); CASE(u4); + case f64: return nstl::numeric_limits::lowest(); case data_type::undef: default: assert(!"unknown data_type"); } @@ -239,14 +272,17 @@ inline T digits(data_type_t data_type) { #define CASE(x) \ case x: \ return static_cast( \ - nstl::numeric_limits::type>::digits) + nstl::numeric_limits::type>::digits) switch (data_type) { + CASE(f4_e3m0); + CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); CASE(f8_e4m3); CASE(f16); CASE(bf16); CASE(f32); + CASE(f64); CASE(s32); CASE(s8); CASE(u8); @@ -271,31 +307,24 @@ inline format_kind_t format_tag_to_kind(format_tag_t tag) { return format_kind::undef; } -// Currently rnn_s8s8_compensation has common bits with rnn_u8s8_compensation -// and scale_adjust constants so we have to perform additional checks to -// separate these two cases -inline bool extra_flag_rnn_s8s8_compensation_is_set(uint64_t flags) { - return ((flags & memory_extra_flags::rnn_s8s8_compensation) - ^ memory_extra_flags::rnn_s8s8_compensation) - == 0; -} - inline bool memory_extra_desc_is_equal( const memory_extra_desc_t &lhs, const memory_extra_desc_t &rhs) { using namespace memory_extra_flags; - return true && lhs.flags == rhs.flags + return lhs.flags == rhs.flags && IMPLICATION(lhs.flags & compensation_conv_s8s8, lhs.compensation_mask == rhs.compensation_mask) - && IMPLICATION((lhs.flags & rnn_u8s8_compensation) - && !extra_flag_rnn_s8s8_compensation_is_set( - lhs.flags), + && IMPLICATION(lhs.flags & rnn_u8s8_compensation, lhs.compensation_mask == rhs.compensation_mask) - && IMPLICATION((lhs.flags & scale_adjust) - && !extra_flag_rnn_s8s8_compensation_is_set( - lhs.flags), + && IMPLICATION(lhs.flags & scale_adjust, lhs.scale_adjust == rhs.scale_adjust) && IMPLICATION(lhs.flags & compensation_conv_asymmetric_src, - lhs.asymm_compensation_mask == rhs.asymm_compensation_mask); + lhs.asymm_compensation_mask == rhs.asymm_compensation_mask) + && IMPLICATION(lhs.flags & compensation_gpu_conv_asymmetric_src, + (lhs.dst_size == rhs.dst_size) + && utils::array_cmp(lhs.idhw, rhs.idhw, 3) + && utils::array_cmp(lhs.odhw, rhs.odhw, 3) + && utils::array_cmp(lhs.pdhw, rhs.pdhw, 3) + && utils::array_cmp(lhs.ddhw, rhs.ddhw, 3)); } inline bool blocking_desc_is_equal(const memory_desc_t &lhs_md, @@ -327,12 +356,12 @@ inline bool blocking_desc_is_equal(const memory_desc_t &lhs_md, bool equal = lhs.inner_nblks == rhs.inner_nblks && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks) && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); - if (ignore_strides) return equal; // Check the strides. - // Note: for dimensions of size `1` the stride doesn't really matter. + // Note: for dimensions of size `1` the stride doesn't really matter + if (ignore_strides) return equal; + for (int d = 0; d < lhs_md.ndims; ++d) { - if (lhs_md.dims[d] == 1 && lhs_md.padded_dims[d] == 1) continue; equal = equal && lhs.strides[d] == rhs.strides[d]; } @@ -346,6 +375,10 @@ inline bool wino_desc_is_equal(const wino_desc_t &lhs, const wino_desc_t &rhs) { && lhs.ic2_block == rhs.ic2_block && lhs.oc2_block == rhs.oc2_block && lhs.r == rhs.r; } +inline bool cublaslt_blocked_desc_is_equal(const cublaslt_blocked_desc_t &lhs, + const cublaslt_blocked_desc_t &rhs) { + return lhs.cublaslt_format == rhs.cublaslt_format && lhs.size == rhs.size; +} inline bool rnn_packed_desc_is_equal( const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) { @@ -364,6 +397,7 @@ inline bool rnn_packed_desc_is_equal( inline bool sparse_desc_is_equal( const sparse_desc_t &lhs, const sparse_desc_t &rhs) { +#if 0 bool ok = lhs.encoding == rhs.encoding && lhs.nnz == rhs.nnz; if (!ok) return false; @@ -371,6 +405,8 @@ inline bool sparse_desc_is_equal( ok = ok && lhs.metadata_types[i] == rhs.metadata_types[i]; return ok; +#endif + return lhs.encoding == rhs.encoding; } inline memory_desc_t zero_md() { @@ -392,6 +428,8 @@ inline data_type_t default_accum_data_type( // true if (one_of(src_dt, s8, u8, u4, s4) && (dst_dt != f32 || strict)) return s32; + if (one_of(f4_e3m0, src_dt, dst_dt)) return f32; + if (one_of(f4_e2m1, src_dt, dst_dt)) return f32; if (one_of(f8_e5m2, src_dt, dst_dt)) return f32; if (one_of(f8_e4m3, src_dt, dst_dt)) return f32; if (one_of(f16, src_dt, dst_dt)) return f32; @@ -415,6 +453,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt, /* prop_kind doesn't matter */ if (everyone_is(f32, src_dt, wei_dt)) return f32; + if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4, f4_e2m1)) return f32; if (everyone_is(f64, src_dt, wei_dt)) return f64; if (one_of(prop_kind, forward_training, forward_inference)) { @@ -433,6 +472,8 @@ inline data_type_t default_accum_data_type(data_type_t src_dt, return f32; } + if (one_of(f4_e3m0, src_dt, wei_dt, dst_dt)) return f32; + if (one_of(f4_e2m1, src_dt, wei_dt, dst_dt)) return f32; if (one_of(f8_e5m2, src_dt, wei_dt, dst_dt)) return f32; if (one_of(f8_e4m3, src_dt, wei_dt, dst_dt)) return f32; if (one_of(bf16, src_dt, wei_dt, dst_dt)) return f32; @@ -594,7 +635,7 @@ inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { #define DEREF_AND_COMPARE_DESC_MEMBERS(m) *lhs.m == *rhs.m #define COMPARE_FLOAT_DESC_MEMBERS(m) utils::equal_with_nan(lhs.m, rhs.m) #define COMPARE_FLOAT_DESC_ARRAY_MEMBERS(m, s) \ - !std::memcmp(lhs.m, rhs.m, sizeof(float) * s) + !std::memcmp(lhs.m, rhs.m, sizeof(float) * (s)) // clang-format off inline bool operator==(const batch_normalization_desc_t &lhs, @@ -619,6 +660,12 @@ inline bool operator==(const binary_desc_t &lhs, const binary_desc_t &rhs) { && COMPARE_DESC_MEMBERS(src_desc[0]) && COMPARE_DESC_MEMBERS(src_desc[1]) && COMPARE_DESC_MEMBERS(dst_desc); + + // For ternary operators like select, the additional input for conditional + // select must also be compared + if(utils::one_of(alg_kind::binary_select, lhs.alg_kind, rhs.alg_kind)) + ret = ret && COMPARE_DESC_MEMBERS(src_desc[2]); + return ret; } @@ -637,11 +684,12 @@ inline bool operator==(const concat_desc_t &lhs, const concat_desc_t &rhs) { return ret; } -inline bool operator==( - const convolution_desc_t &lhs, const convolution_desc_t &rhs) { +// This function can only be used to compare the opdescs in the primitive cache. +// For comparing the opdescs outside the primitive cache please use the regular +// comparison operator (==). +inline bool compare_conv_opdesc(const convolution_desc_t &lhs, const convolution_desc_t &rhs) { bool ret = COMPARE_DESC_MEMBERS(primitive_kind) && COMPARE_DESC_MEMBERS(prop_kind) - && COMPARE_DESC_MEMBERS(alg_kind) && COMPARE_DESC_MEMBERS(src_desc) && COMPARE_DESC_MEMBERS(diff_src_desc) && COMPARE_DESC_MEMBERS(weights_desc) @@ -656,9 +704,31 @@ inline bool operator==( && COMPARE_DESC_ARRAY_MEMBERS(padding[1], DNNL_MAX_NDIMS) && COMPARE_DESC_MEMBERS(accum_data_type) && COMPARE_DESC_MEMBERS(use_inversion); + + // The `alg_kind` can be `auto` only if this function is called for the + // primitive descriptor cache scenario. In this case, we ignore `alg_kind` + // and rely on `pd_iterator_offset` to fetch the first suitable + // implementation. + // + // Background: when a convolution primitive descriptor is created for + // the algorithm `auto` we overwrite `alg_kind` field in `op_desc` when + // store it in the primitive descriptor. Because of that, the `op_desc` + // stored in the primitive descriptor is different from the one user + // passed to oneDNN API. Because of the difference the requested + // primitive descriptor cannot be found in the cache if we compare + // `alg_kind`. + if (!utils::one_of(alg_kind::convolution_auto, lhs.alg_kind, rhs.alg_kind)) + ret = ret && COMPARE_DESC_MEMBERS(alg_kind); + return ret; } +inline bool operator==( + const convolution_desc_t &lhs, const convolution_desc_t &rhs) { + if (!(COMPARE_DESC_MEMBERS(alg_kind))) return false; + return compare_conv_opdesc(lhs, rhs); +} + inline bool operator==(const eltwise_desc_t &lhs, const eltwise_desc_t &rhs) { bool ret = COMPARE_DESC_MEMBERS(primitive_kind) && COMPARE_DESC_MEMBERS(prop_kind) @@ -754,6 +824,8 @@ inline bool operator==(const matmul_desc_t &lhs, const matmul_desc_t &rhs) { && COMPARE_DESC_MEMBERS(weights_desc) && COMPARE_DESC_MEMBERS(bias_desc) && COMPARE_DESC_MEMBERS(dst_desc) + && COMPARE_DESC_MEMBERS(reduce_desc) + && COMPARE_DESC_MEMBERS(reduce_kind) && COMPARE_DESC_MEMBERS(accum_data_type); return ret; } @@ -908,10 +980,16 @@ inline bool operator==(const sdpa_desc_t &lhs, const sdpa_desc_t &rhs) { && COMPARE_DESC_MEMBERS(q_desc) && COMPARE_DESC_MEMBERS(k_desc) && COMPARE_DESC_MEMBERS(v_desc) + && COMPARE_DESC_MEMBERS(kq_scales) + && COMPARE_DESC_MEMBERS(kq_zero_points) + && COMPARE_DESC_MEMBERS(vs_scales) + && COMPARE_DESC_MEMBERS(vs_zero_points) && COMPARE_DESC_MEMBERS(dst_desc) && COMPARE_DESC_MEMBERS(attn_mask_desc) && COMPARE_DESC_MEMBERS(scale_dt) - && COMPARE_DESC_MEMBERS(invert_scale); + && COMPARE_DESC_MEMBERS(invert_scale) + && COMPARE_DESC_MEMBERS(kv_head_number) + && COMPARE_DESC_MEMBERS(mask_type); return ret; } @@ -923,7 +1001,8 @@ inline bool operator==(const sdpa_desc_t &lhs, const sdpa_desc_t &rhs) { #undef COMPARE_FLOAT_DESC_MEMBERS #undef COMPARE_FLOAT_DESC_ARRAY_MEMBERS -inline bool is_dense_format_kind(const std::vector mds) { +inline bool is_dense_format_kind( + const std::vector &mds) { #ifdef DNNL_EXPERIMENTAL_SPARSE for (const auto *md : mds) if (md->format_kind == format_kind::sparse) return false; @@ -1034,22 +1113,21 @@ inline status_t memory_desc_init_by_tag( const bool is_sparse = md.format_kind == format_kind::sparse; auto md_tmp = memory_desc_t(); - CHECK(memory_desc_init_by_tag( - md_tmp, md.ndims, md.dims, md.data_type, tag)); - - if (strides != nullptr && !memory_desc_strides_check(md_tmp, strides)) - return status::invalid_arguments; + status_t status = + memory_desc_init_by_tag(md_tmp, md.ndims, md.dims, md.data_type, tag); if (is_sparse) { - if (md.format_desc.sparse_desc.encoding != sparse_encoding::packed - || md.offset0 != 0) - return status::invalid_arguments; - md = cvt_blocked2sparse_packed(md_tmp, md.format_desc.sparse_desc.nnz); + const auto &bd = md_tmp.format_desc.blocking; + md.format_desc.sparse_desc.encoding = sparse_encoding::packed; + md.format_desc.sparse_desc.packed_desc = bd; } else { md = md_tmp; } - if (strides == nullptr) return status::success; + if (status != status::success || strides == nullptr) return status; + + if (!memory_desc_strides_check(md_tmp, strides)) + return status::invalid_arguments; for (int d = 0; d < md.ndims; ++d) { if (is_sparse) @@ -1057,7 +1135,6 @@ inline status_t memory_desc_init_by_tag( else md.format_desc.blocking.strides[d] = strides[d]; } - return status::success; } @@ -1135,9 +1212,19 @@ inline status_t memory_desc_init_by_md_and_dt(memory_desc_t &md, * Assumes a dense structure such as that returned by memory_desc_init_by_tag(). * Strides must match those returned by memory_desc_init_by_tag(), with one * exception: the strides of unit dimensions are ignored in order to align with - * memory descriptor equality comparisons and hashing. - */ -inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag) { + * memory descriptor equality comparisons and hashing, + * the strides of unit dimensions are ignored. + * When strides are empty the dense structure is assumed (e.g., the one that + * memory_desc_init_by_tag() returns). + * When strides are not empty, standard strides check is overrided, and + * additional rules are applied: + * Strides might contain `0` value, indicating the stride must match the one + * that memory_desc_init_by_tag() returns. + * Strides might contain `-1` values, that would be ignored during the + * comparison. For instance, this can be used if a stride along minibatch + * doesn't matter. */ +inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { if (md.format_kind != format_kind::sparse) { if (md.format_kind != types::format_tag_to_kind(tag)) return false; } @@ -1146,8 +1233,38 @@ inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag) { status_t status = memory_desc_init_by_tag( md_gold, md.ndims, md.dims, md.data_type, tag); if (status != status::success) return false; + // bool equal = types::blocking_desc_is_equal( + // md, md_gold, /* ignore_strides = */ (bool)strides); + // if (!strides || !equal) return equal; + + const bool is_sparse_packed_desc = md.format_kind == format_kind::sparse + && md.format_desc.sparse_desc.encoding == sparse_encoding::packed; - return types::blocking_desc_is_equal(md, md_gold); + if (md.format_kind != format_kind::blocked && !is_sparse_packed_desc) + return false; // unimplemented yet + + const auto &blk = md.format_kind == format_kind::blocked + ? md.format_desc.blocking + : md.format_desc.sparse_desc.packed_desc; + const auto &blk_gold = md_gold.format_desc.blocking; + + using utils::array_cmp; + bool same_blocks = true && blk.inner_nblks == blk_gold.inner_nblks + && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks); + + if (!same_blocks) return false; + + if (strides == nullptr) + return array_cmp(blk.strides, blk_gold.strides, md.ndims); + + for (int d = 0; d < md.ndims; ++d) { + dim_t stride = strides[d]; + if (stride == -1) continue; + if (stride == 0) stride = blk_gold.strides[d]; + if (blk.strides[d] != stride) return false; + } + return true; } /** returns matching tag (or undef if match is not found) @@ -1183,8 +1300,8 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims, if (ndims == 0) return true; bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS - && utils::one_of(data_type, f8_e5m2, f8_e4m3, f16, bf16, f32, f64, - s32, s8, u8, s4, u4); + && utils::one_of(data_type, f4_e3m0, f4_e2m1, e8m0, f8_e5m2, + f8_e4m3, f16, bf16, f32, f64, s32, s8, u8, s4, u4, bin, nf4); if (!ok) return false; bool has_runtime_dims = false; @@ -1206,38 +1323,6 @@ inline bool memory_desc_sanity_check(const memory_desc_t &md) { md.ndims, md.dims, md.data_type, format_kind::undef); } -inline void copy_c_op_desc(op_desc_t *dst, const op_desc_t *src) { -#define CASE_OP_DESC(pkind) \ - case primitive_kind::pkind: dst->pkind = src->pkind; break; - - switch ((int)src->kind) { - CASE_OP_DESC(batch_normalization); - CASE_OP_DESC(binary); - CASE_OP_DESC(convolution); - CASE_OP_DESC(deconvolution); - CASE_OP_DESC(eltwise); - CASE_OP_DESC(gemm); - CASE_OP_DESC(group_normalization); - CASE_OP_DESC(inner_product); - CASE_OP_DESC(layer_normalization); - CASE_OP_DESC(lrn); - CASE_OP_DESC(matmul); - CASE_OP_DESC(pooling); - CASE_OP_DESC(prelu); - CASE_OP_DESC(reduction); - CASE_OP_DESC(resampling); - CASE_OP_DESC(rnn); - CASE_OP_DESC(sdpa); - CASE_OP_DESC(shuffle); - CASE_OP_DESC(softmax); - - // Internal descs - CASE_OP_DESC(zero_pad); - default: assert(!"unknown C primitive kind"); - } -#undef CASE_OP_DESC -} - } // namespace impl } // namespace dnnl diff --git a/src/common/utils.cpp b/src/common/utils.cpp index a11d9104bfd..9ccb8531f96 100644 --- a/src/common/utils.cpp +++ b/src/common/utils.cpp @@ -19,12 +19,12 @@ #include #endif -#if defined __unix__ || defined __APPLE__ || defined __FreeBSD__ \ - || defined __Fuchsia__ +#if defined(__unix__) || defined(__APPLE__) || defined(__FreeBSD__) \ + || defined(__Fuchsia__) #include #endif -#ifdef __unix__ +#if defined(__unix__) || defined(__APPLE__) #include #include #endif @@ -41,6 +41,7 @@ #include "memory_debug.hpp" #include "utils.hpp" +#include "verbose.hpp" #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE #include "cpu/platform.hpp" @@ -124,7 +125,49 @@ std::string getenv_string_user(const char *name) { return value; } +status_t check_for_symlinks(const char *filename, bool *res) { +#ifdef _WIN32 + DWORD attr = GetFileAttributes(filename); + + // checking for ERROR_FILE_NOT_FOUND allows the application to open + // new files without raising an exception + if (attr == INVALID_FILE_ATTRIBUTES) + return (GetLastError() == ERROR_FILE_NOT_FOUND) + ? status::success + : status::invalid_arguments; + *res = (attr & FILE_ATTRIBUTE_REPARSE_POINT); + return status::success; +#else + struct stat finfo; + // checking for ENOENT allows the application to open new files without + // raising an exception + if (lstat(filename, &finfo) != 0) + return (errno == ENOENT) ? status::success : status::invalid_arguments; + *res = (finfo.st_mode & S_IFMT) == S_IFLNK; + return status::success; +#endif +} + FILE *fopen(const char *filename, const char *mode) { + bool is_symlink = false; + status_t fattr_status = check_for_symlinks(filename, &is_symlink); + + // For any return status other than status::success, the file IO operation + // is abandoned implying a major issue in retrieving the file + if (fattr_status != status::success) { + VERROR(common, common, "error reading file attributes for %s", + filename); + return nullptr; + } + + // The symlink flag is updated and checked only after the file attributes are + // successfully read, avoiding the use of an uninitialized variable. + if (is_symlink) { + VERROR(common, common, + "cannot open %s - specified file is a symbolic link", filename); + return nullptr; + } + #ifdef _WIN32 FILE *fp = NULL; return ::fopen_s(&fp, filename, mode) ? NULL : fp; @@ -187,7 +230,7 @@ bool get_jit_dump() { return jit_dump.get(); } -#if defined(DNNL_AARCH64) && (DNNL_AARCH64 == 1) +#if defined(DNNL_AARCH64) && (DNNL_AARCH64 == 1) || defined(DNNL_ARM) && (DNNL_ARM == 1) static setting_t jit_profiling_flags {DNNL_JIT_PROFILE_LINUX_PERFMAP}; #else static setting_t jit_profiling_flags {DNNL_JIT_PROFILE_VTUNE}; diff --git a/src/common/utils.hpp b/src/common/utils.hpp index cb9d681dccc..84957503436 100644 --- a/src/common/utils.hpp +++ b/src/common/utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,21 +52,25 @@ namespace impl { #define DNNL_SHORT_CIRCUIT_SELF_ASSIGN(other) \ do { \ - if (this == &other) return *this; \ + if (this == &(other)) return *this; \ } while (0) #define DNNL_SHORT_CIRCUIT_SELF_COMPARISON(other) \ do { \ - if (this == &other) return true; \ + if (this == &(other)) return true; \ } while (0) #define DNNL_DISALLOW_COPY_AND_ASSIGN(T) \ T(const T &) = delete; \ - T &operator=(const T &) = delete; + void operator=(const T &) = delete; // Sanity check for 64 bits -static_assert(sizeof(void *) == 8, "oneDNN supports 64-bit architectures only"); +// static_assert(sizeof(void *) == 8, "oneDNN supports 64-bit architectures only"); +// Note: if `f` has any explicit templated arguments, e.g., func, then +// compiler returns `error: macro "CHECK" passed 2 arguments, but takes just 1`. +// The solution is to use an alias, e.g. `using func_alias = func;` and +// use `func_alias` in CHECK, then it compiles. #define CHECK(f) \ do { \ dnnl::impl::status_t _status_ = f; \ @@ -88,6 +92,15 @@ static_assert(sizeof(void *) == 8, "oneDNN supports 64-bit architectures only"); #define IMPLICATION(cause, effect) (!(cause) || !!(effect)) +#if defined(_MSC_VER) || defined(__INTEL_COMPILER) \ + || defined(__INTEL_LLVM_COMPILER) +#define FORCE_INLINE __forceinline +#elif defined(__clang__) || defined(__GNUC__) +#define FORCE_INLINE inline __attribute__((always_inline)) +#else +#define FORCE_INLINE inline +#endif + namespace utils { /* a bunch of std:: analogues to be compliant with any msvs version @@ -100,41 +113,50 @@ namespace utils { /* SFINAE helper -- analogue to std::enable_if */ template -struct enable_if {}; +struct enable_if {}; // NOLINT(readability-identifier-naming) + template struct enable_if { - typedef T type; + using type = T; }; +// Replacement implementation of std::enable_if_t from C++14, included here for +// interoperability with C++11 +template +using enable_if_t = typename enable_if::type; + +template +using is_vector = std::is_same>; + /* analogue std::conditional */ template -struct conditional {}; +struct conditional {}; // NOLINT(readability-identifier-naming) template struct conditional { - typedef T type; + using type = T; }; template struct conditional { - typedef F type; + using type = F; }; template -struct conditional3 {}; +struct conditional3 {}; // NOLINT(readability-identifier-naming) template struct conditional3 { - typedef T type; + using type = T; }; template struct conditional3 { - typedef FT type; + using type = FT; }; template struct conditional3 { - typedef FF type; + using type = FF; }; template -struct conditional_v {}; +struct conditional_v {}; // NOLINT(readability-identifier-naming) template struct conditional_v { static constexpr U value = t; @@ -145,16 +167,16 @@ struct conditional_v { }; template -struct remove_reference { - typedef T type; +struct remove_reference { // NOLINT(readability-identifier-naming) + using type = T; }; template struct remove_reference { - typedef T type; + using type = T; }; template struct remove_reference { - typedef T type; + using type = T; }; template @@ -177,6 +199,7 @@ std::unique_ptr make_unique(Args &&...args) { return std::unique_ptr(new T(std::forward(args)...)); } +// NOLINTBEGIN(performance-unnecessary-value-param) template constexpr bool everyone_is(T val, P item) { return val == item; @@ -185,7 +208,9 @@ template constexpr bool everyone_is(T val, P item, Args... item_others) { return val == item && everyone_is(val, item_others...); } +// NOLINTEND(performance-unnecessary-value-param) +// NOLINTBEGIN(performance-unnecessary-value-param) template constexpr bool one_of(T val, P item) { return val == item; @@ -194,6 +219,7 @@ template constexpr bool one_of(T val, P item, Args... item_others) { return val == item || one_of(val, item_others...); } +// NOLINTEND(performance-unnecessary-value-param) template constexpr P map(T pat, P def) { @@ -209,11 +235,30 @@ constexpr bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); } +// For some unknown reason, GCC 11.x and beyond can't compile specific places +// of the library that involve this routine. It's connected to the fact that +// this function is inline and defined in a header. +#if defined(__GNUC__) && __GNUC__ > 8 && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wrestrict" +// /usr/include/bits/string_fortified.h:29:33: warning: ‘void* __builtin_memcpy( +// void*, const void*, long unsigned int)’ accessing 18446744056529682432 or +// more bytes at offsets 320 and 0 overlaps 9223372002495037441 bytes at +// offset -9223372019674906625 [-Wrestrict] +#pragma GCC diagnostic ignored "-Wstringop-overflow" +// warning: ‘void* __builtin_memcpy(void*, const void*, long unsigned int)’ +// specified bound between 18446744056529682432 and 18446744073709551608 +// exceeds maximum object size 9223372036854775807 [-Wstringop-overflow=] +#endif template inline void array_copy(T *dst, const T *src, size_t size) { for (size_t i = 0; i < size; ++i) dst[i] = src[i]; } +#if defined(__GNUC__) && __GNUC__ > 8 && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + template inline bool array_cmp(const T *a1, const T *a2, size_t size) { for (size_t i = 0; i < size; ++i) @@ -226,9 +271,15 @@ inline void array_set(T *arr, const U &val, size_t size) { arr[i] = static_cast(val); } +inline bool array_cmp_weak(const dnnl_dim_t *a1, const dnnl_dim_t *a2, size_t size) { + for (size_t i = 0; i < size; ++i) + if (a1[i] != a2[i] && a1[i] != DNNL_RUNTIME_DIM_VAL && a2[i] != DNNL_RUNTIME_DIM_VAL) return false; + return true; +} + namespace product_impl { template -struct int2type {}; +struct int2type {}; // NOLINT(readability-identifier-naming) template constexpr int product_impl(const T *arr, int2type<0>) { @@ -455,11 +506,10 @@ T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd, const T &val_bwd_d, } template -struct array_offset_calculator { +struct array_offset_calculator { // NOLINT(readability-identifier-naming) template - array_offset_calculator(Telem *base, Targs... Fargs) : _dims {Fargs...} { - _base_ptr = base; - } + array_offset_calculator(Telem *base, Targs... Fargs) + : _base_ptr(base), _dims {Fargs...} {} template array_offset_calculator(std::nullptr_t, Targs... Fargs) = delete; @@ -515,7 +565,14 @@ const char *format_cvt_impl(T &&t) { template std::string format_impl(const char *fmt, Args... args) { + // volatile here is a workaround for GCC 8 format-truncation warning e.g.: + // ‘%d’ directive output truncated writing 1 byte into a region of size 0 + // triggered by overaggressive optmization in '-O3'; fixed in GCC 9+ +#if defined(__GNUC__) && __GNUC__ == 8 && !defined(__clang__) + volatile size_t sz = snprintf(nullptr, 0, fmt, args...); +#else size_t sz = snprintf(nullptr, 0, fmt, args...); +#endif std::string buf(sz + 1, '\0'); snprintf(&buf[0], sz + 1, fmt, args...); buf.resize(sz); @@ -528,15 +585,15 @@ std::string format(const char *fmt, Args &&...args) { } inline bool need_src_or_dst_check( - bool is_fwd, int o, int i, int k, int p, int s, int d) { + bool is_fwd, dim_t o, dim_t i, dim_t k, dim_t p, dim_t s, dim_t d) { if (is_fwd) { - int i_min = -p; - int i_max = (o - 1) * s - p + (k - 1) * (1 + d); + dim_t i_min = -p; + dim_t i_max = (o - 1) * s - p + (k - 1) * (1 + d); return (i_min < 0) || (i_max >= i); } // Backward. - int os_min = p - (k - 1) * (1 + d); - int os_max = (i - 1) + p; + dim_t os_min = p - (k - 1) * (1 + d); + dim_t os_max = (i - 1) + p; return (os_min < 0) || (os_max >= o * s); } @@ -568,15 +625,23 @@ inline int get_dims_mask(const dims_t dims1, const dims_t dims2, int ndims, return mask; }; -inline void copy_dims_with_mask( - dims_t ddims, const dims_t sdims, int ndims, int mask) { +// The function can be used to get dimensions for memory descriptors or +// dimensions for logical offset. First ones are happy to have ones when mask +// is not applied. This allows to initialize them with existing functions using +// tags/strides. Latter ones are not nappy with ones and must have zeros as +// logical offsets starts with 0. `fill_with_one` flag regulates the behavior +// between them. +inline void copy_dims_with_mask(dims_t ddims, const dims_t sdims, int ndims, + int mask, bool fill_with_one = false) { for (int d = 0; d < ndims; ++d) { - ddims[d] = (mask & (1 << d)) ? sdims[d] : 0; + ddims[d] = (mask & (1 << d)) ? sdims[d] + : static_cast(fill_with_one); } } -inline void apply_mask_on_dims(dims_t dims, int ndims, int mask) { - copy_dims_with_mask(dims, dims, ndims, mask); +inline void apply_mask_on_dims( + dims_t dims, int ndims, int mask, bool fill_with_one = false) { + copy_dims_with_mask(dims, dims, ndims, mask, fill_with_one); } inline void dim_iterator(const dims_t dims, dims_t indices, int ndims) { @@ -641,6 +706,9 @@ std::string getenv_string_user(const char *name); bool get_jit_dump(); unsigned get_jit_profiling_flags(); std::string get_jit_profiling_jitdumpdir(); +// Checks if the filepath is a valid path and not a symlink to ensure +// the application only processes secure files. +status_t check_for_symlinks(const char *filename, bool *res); FILE *fopen(const char *filename, const char *mode); int getpagesize(); @@ -680,7 +748,7 @@ struct setting_t { constexpr setting_t(const T init) : value_ {init}, initialized_ {false} {} bool initialized() { return initialized_; } T get() { return value_; } - void set(T new_value) { + void set(const T &new_value) { value_ = new_value; initialized_ = true; } @@ -691,11 +759,17 @@ struct setting_t { // Copyright 2005-2014 Daniel James. // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt) -template +template ::value , int>::type = 0> static size_t hash_combine(size_t seed, const T &v) { return seed ^= std::hash {}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +template ::value , int>::type = 0> +static size_t hash_combine(size_t seed, const T &v) { + using underlying_t = typename std::underlying_type::type; + return hash_combine(seed, static_cast(v)); +} + inline int float2int(float x) { return utils::bit_cast(x); } @@ -749,7 +823,7 @@ struct set_once_before_first_get_setting_t { inline bool is_native_runtime(runtime_kind_t kind) { return utils::one_of(kind, runtime_kind::seq, runtime_kind::omp, - runtime_kind::tbb, runtime_kind::threadpool); + runtime_kind::tbb, runtime_kind::tbb_auto, runtime_kind::threadpool); } // Convenience wrapper to choose at compile-time between std::unique_ptr's @@ -778,6 +852,79 @@ template using maybe_unique_ptr = std::unique_ptr; #endif // DNNL_MAYBE_UNIQUE_PTR_IS_UNIQUE +// Common abstraction to manipulate nibbles in memory as pairs +struct nibble2_t { + + // constructs a nibble pair from a pair of uint8_t values + nibble2_t(uint8_t low_, uint8_t high_) : low(low_), high(high_) {} + + // constructs a nibble pairs from an uin8_t, taking its low and high part + nibble2_t(uint8_t pack_) : low(pack_ & 0xf), high((pack_ >> 4) & 0xf) {} + + // sets low (idx=0) or high (idx=1) nibble. + inline void set(uint8_t val, int idx) { + switch (idx) { + case 0: low = val; return; + case 1: high = val; return; + default: assert(!"Out of range index"); return; + } + } + + // returns low (idx = 0) or high (idx = 1) nibble in a uint8_t + inline uint8_t get(int idx) const { + switch (idx) { + case 0: return low; + case 1: return high; + default: assert(!"out of range index"); return 0; + } + } + + // returns pair of nibbles as uint8t + inline uint8_t get() const { return static_cast(high << 4 | low); } + +private: + uint8_t low : 4; + uint8_t high : 4; +}; +static_assert(sizeof(nibble2_t) == 1, "nibble2_t must be 1 byte"); + +/// Iterates through a binary integer +/// usage: +/// +/// for(int idx : mask_iterator(13)) { // 13 == 1101 +/// printf("%d\t", idx); +/// } +/// output: 0 2 3 +class mask_iterator { // NOLINT(readability-identifier-naming) + int mask_; + int index_; + +public: + using iterator_category = std::input_iterator_tag; + using difference_type = int; + using value_type = int; + using pointer = value_type *; + using reference = value_type &; + mask_iterator() : mask_(0), index_(0) {} + mask_iterator(int mask) : mask_(mask), index_(0) { + if ((mask_ & 0x1) == 0) { ++(*this); } + } + mask_iterator &begin() { return *this; } + mask_iterator end() const { return 0; } + value_type operator*() const { return index_; } + mask_iterator &operator++() { + do { + index_++; + mask_ >>= 1; + } while ((mask_ & 0x1) == 0 && mask_ != 0); + if (mask_ == 0) { index_ = 0; } + return *this; + } + bool operator!=(const mask_iterator &other) const { + return mask_ != other.mask_ || index_ != other.index_; + } +}; + } // namespace impl } // namespace dnnl diff --git a/src/common/verbose.cpp b/src/common/verbose.cpp index 3ea9067130f..1d0d18c8b4a 100644 --- a/src/common/verbose.cpp +++ b/src/common/verbose.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * Copyright 2023 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -53,6 +53,7 @@ #include "reorder_pd.hpp" #include "resampling_pd.hpp" #include "rnn_pd.hpp" +#include "sdpa_pd.hpp" #include "shuffle_pd.hpp" #include "softmax_pd.hpp" #include "sum_pd.hpp" @@ -84,7 +85,15 @@ static constexpr char verbose_version[] = "v1"; static setting_t verbose {0}; -void print_header(const filter_status_t &filter_status) noexcept { +// Component filters help manage verbose output by parsing and printing for +// matching components. The filter status is tracked from verbose initializaton, +// allowing queries for the component type during verbose printing. +filter_status_t &filter_status() { + static filter_status_t filter_status; + return filter_status; +} + +void print_header() noexcept { static std::atomic_flag version_printed = ATOMIC_FLAG_INIT; if (!version_printed.test_and_set()) { verbose_printf("info,oneDNN v%d.%d.%d (commit %s)\n", @@ -119,6 +128,8 @@ void print_header(const filter_status_t &filter_status) noexcept { verbose_printf("info,use batch_normalization stats one pass is %s\n", experimental::use_bnorm_stats_one_pass() ? "enabled" : "disabled"); + verbose_printf("info,GPU convolution v2 is %s\n", + experimental::use_gpu_conv_v2() ? "enabled" : "disabled"); #endif #ifdef DNNL_EXPERIMENTAL_SPARSE @@ -147,16 +158,16 @@ void print_header(const filter_status_t &filter_status) noexcept { "mode,implementation,backend,exec_time\n", get_verbose_timestamp() ? "timestamp," : ""); #endif - if (filter_status.status == filter_status_t::flags::valid) + if (filter_status().status == filter_status_t::flags::valid) verbose_printf( "common,info,filter format is enabled, hit components: " "%s\n", - filter_status.components.c_str()); - else if (filter_status.status == filter_status_t::flags::invalid) + filter_status().components.c_str()); + else if (filter_status().status == filter_status_t::flags::invalid) verbose_printf( "common,error,filter format is ill-formed and is not " "applied, error: %s\n", - filter_status.err_msg.c_str()); + filter_status().err_msg.c_str()); } } @@ -165,11 +176,9 @@ uint32_t get_verbose(verbose_t::flag_kind verbosity_kind, component_t::flag_kind filter_kind) noexcept { #if defined(DISABLE_VERBOSE) return verbose_t::none; -#else +#endif // we print all verbose by default static int flags = component_t::all; - // record filter parsing result to instruct verbose printing - static filter_status_t filter_status; if (!verbose.initialized()) { // Assumes that all threads see the same environment @@ -178,9 +187,8 @@ uint32_t get_verbose(verbose_t::flag_kind verbosity_kind, // Legacy: we accept values 0,1,2 // 0 and none erase previously set flags, including error if (s == "0" || s == "none") k = verbose_t::none; - if (s == "1") k |= verbose_t::exec_profile; - if (s == "2") - k |= verbose_t::exec_profile | verbose_t::create_profile; + if (s == "1") k |= verbose_t::level1; + if (s == "2") k |= verbose_t::level2; if (s == "all" || s == "-1") k |= verbose_t::all; if (s == "error") k |= verbose_t::error; if (s == "check") @@ -192,62 +200,62 @@ uint32_t get_verbose(verbose_t::flag_kind verbosity_kind, if (s == "profile_exec") k |= verbose_t::exec_profile; // Enable profiling to external libraries if (s == "profile_externals") k |= verbose_t::profile_externals; + if (s == "warn") k |= verbose_t::warn; // we extract debug info debuginfo=XX. ignore if debuginfo is invalid. if (s.rfind("debuginfo=", 0) == 0) k |= verbose_t::make_debuginfo( std::strtol(s.c_str() + 10, nullptr, 10)); }; - auto update_filter = [&](const std::string &s, - filter_status_t &filter_status) -> int { + auto update_filter = [&](const std::string &s) -> int { int k = component_t::none; try { std::regex regexp = std::regex(s); -#define REGEX_SEARCH(k, component, regexp, filter_status) \ +#define REGEX_SEARCH(k, component, regexp) \ if (std::regex_search("" #component "", regexp)) { \ (k) |= component_t::component; \ - (filter_status).components += "" #component ","; \ + filter_status().components += "" #component ","; \ } - REGEX_SEARCH(k, primitive, regexp, filter_status); - REGEX_SEARCH(k, reorder, regexp, filter_status); - REGEX_SEARCH(k, shuffle, regexp, filter_status); - REGEX_SEARCH(k, concat, regexp, filter_status); - REGEX_SEARCH(k, sum, regexp, filter_status); - REGEX_SEARCH(k, convolution, regexp, filter_status); - REGEX_SEARCH(k, deconvolution, regexp, filter_status); - REGEX_SEARCH(k, eltwise, regexp, filter_status); - REGEX_SEARCH(k, lrn, regexp, filter_status); - REGEX_SEARCH(k, batch_normalization, regexp, filter_status); - REGEX_SEARCH(k, inner_product, regexp, filter_status); - REGEX_SEARCH(k, rnn, regexp, filter_status); - REGEX_SEARCH(k, binary, regexp, filter_status); - REGEX_SEARCH(k, matmul, regexp, filter_status); - REGEX_SEARCH(k, resampling, regexp, filter_status); - REGEX_SEARCH(k, pooling, regexp, filter_status); - REGEX_SEARCH(k, reduction, regexp, filter_status); - REGEX_SEARCH(k, prelu, regexp, filter_status); - REGEX_SEARCH(k, softmax, regexp, filter_status); - REGEX_SEARCH(k, layer_normalization, regexp, filter_status); - REGEX_SEARCH(k, group_normalization, regexp, filter_status); - REGEX_SEARCH(k, graph, regexp, filter_status); - REGEX_SEARCH(k, gemm_api, regexp, filter_status); - REGEX_SEARCH(k, ukernel, regexp, filter_status); + REGEX_SEARCH(k, primitive, regexp); + REGEX_SEARCH(k, reorder, regexp); + REGEX_SEARCH(k, shuffle, regexp); + REGEX_SEARCH(k, concat, regexp); + REGEX_SEARCH(k, sum, regexp); + REGEX_SEARCH(k, convolution, regexp); + REGEX_SEARCH(k, deconvolution, regexp); + REGEX_SEARCH(k, eltwise, regexp); + REGEX_SEARCH(k, lrn, regexp); + REGEX_SEARCH(k, batch_normalization, regexp); + REGEX_SEARCH(k, inner_product, regexp); + REGEX_SEARCH(k, rnn, regexp); + REGEX_SEARCH(k, binary, regexp); + REGEX_SEARCH(k, matmul, regexp); + REGEX_SEARCH(k, resampling, regexp); + REGEX_SEARCH(k, pooling, regexp); + REGEX_SEARCH(k, reduction, regexp); + REGEX_SEARCH(k, prelu, regexp); + REGEX_SEARCH(k, softmax, regexp); + REGEX_SEARCH(k, layer_normalization, regexp); + REGEX_SEARCH(k, group_normalization, regexp); + REGEX_SEARCH(k, graph, regexp); + REGEX_SEARCH(k, gemm_api, regexp); + REGEX_SEARCH(k, ukernel, regexp); #undef REGEX_SEARCH } catch (const std::exception &e) { - filter_status.status = filter_status_t::flags::invalid; - filter_status.err_msg = e.what(); + filter_status().status = filter_status_t::flags::invalid; + filter_status().err_msg = e.what(); return component_t::all; } // filter enabled and at least one component is hit - if (filter_status.components.length() != 0) { + if (!filter_status().components.empty()) { // pop out the last comma - filter_status.components.pop_back(); - filter_status.status = filter_status_t::flags::valid; + filter_status().components.pop_back(); + filter_status().status = filter_status_t::flags::valid; } else { - filter_status.status = filter_status_t::flags::invalid; - filter_status.err_msg + filter_status().status = filter_status_t::flags::invalid; + filter_status().err_msg = "component with name \'" + s + "\' not found"; } return k; @@ -264,9 +272,7 @@ uint32_t get_verbose(verbose_t::flag_kind verbosity_kind, // update filter flags if (tok.rfind("filter=", 0) == 0) { auto filter_str = tok.substr(7); - if (!filter_str.empty()) { - flags = update_filter(filter_str, filter_status); - } + if (!filter_str.empty()) { flags = update_filter(filter_str); } } if (pos_en == std::string::npos) break; } @@ -284,17 +290,16 @@ uint32_t get_verbose(verbose_t::flag_kind verbosity_kind, int result = verbose.get() & verbosity_kind; if (verbosity_kind == verbose_t::debuginfo) result = verbose_t::get_debuginfo(verbose.get()); - if (result) print_header(filter_status); bool filter_result = flags & filter_kind; return filter_result ? result : 0; -#endif } - +#if !defined(DISABLE_VERBOSE) static setting_t verbose_timestamp {false}; +#endif bool get_verbose_timestamp() { #if defined(DISABLE_VERBOSE) return false; -#else +#endif if (verbose.get() == 0) return false; if (!verbose_timestamp.initialized()) { @@ -304,27 +309,8 @@ bool get_verbose_timestamp() { verbose_timestamp.set(val); } return verbose_timestamp.get(); -#endif } -#if defined(DISABLE_VERBOSE) -void pd_info_t::init( - dnnl::impl::engine_t *, const dnnl::impl::primitive_desc_t *) {} - -std::string rt_mds2str(primitive_kind_t prim_kind, const memory_desc_t *src_md, - const memory_desc_t *wei_md, const memory_desc_t *bia_md, - const memory_desc_t *dst_md) { - return std::string(); -} - -std::string rt_dims2fmt_str(primitive_kind_t prim_kind, - const memory_desc_t *src_md, const memory_desc_t *wei_md, - const memory_desc_t *dst_md) { - return std::string(); -} - -#else - std::ostream &operator<<(std::ostream &ss, engine_kind_t eng_kind) { ss << dnnl_engine_kind2str(eng_kind); return ss; @@ -392,6 +378,14 @@ std::string rnn_flags2str(unsigned flags) { return s; } +std::string cublasltfmt2str(const memory_desc_t *md) { + if (md->format_desc.cublaslt_blocked_desc.cublaslt_format + == cublaslt_memory_format_t::col32_2r_4r4) { + return ":col32_2r_4r4"; + } + return ""; +} + std::ostream &operator<<(std::ostream &ss, const memory_extra_desc_t &extra) { using namespace memory_extra_flags; @@ -400,6 +394,21 @@ std::ostream &operator<<(std::ostream &ss, const memory_extra_desc_t &extra) { ss << ":s8m" << extra.compensation_mask; if (extra.flags & compensation_conv_asymmetric_src) ss << ":zpm" << extra.asymm_compensation_mask; + if (extra.flags & compensation_gpu_conv_asymmetric_src) { + ss << ":zid" << extra.idhw[0]; + ss << ":zih" << extra.idhw[1]; + ss << ":ziw" << extra.idhw[2]; + ss << ":zod" << extra.odhw[0]; + ss << ":zoh" << extra.odhw[1]; + ss << ":zow" << extra.odhw[2]; + ss << ":zpd" << extra.pdhw[0]; + ss << ":zph" << extra.pdhw[1]; + ss << ":zpw" << extra.pdhw[2]; + ss << ":zdd" << extra.ddhw[0]; + ss << ":zdh" << extra.ddhw[1]; + ss << ":zdw" << extra.ddhw[2]; + ss << ":zs" << extra.dst_size; + } if (extra.flags & scale_adjust && extra.scale_adjust != 1.f) ss << ":sa" << extra.scale_adjust; return ss; @@ -408,28 +417,60 @@ std::ostream &operator<<(std::ostream &ss, const memory_extra_desc_t &extra) { std::string md2fmt_tag_str(const memory_desc_t *md) { memory_desc_wrapper mdw(md); - dims_t blocks = {0}; - mdw.compute_blocks(blocks); - - char dim_chars[DNNL_MAX_NDIMS + 1]; - dims_t ou_blocks = {0}; - utils::array_copy(ou_blocks, mdw.padded_dims(), mdw.ndims()); - - for (int d = 0; d < mdw.ndims(); ++d) { - dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d; - ou_blocks[d] /= blocks[d]; - } - // Can't report meaningful tag for runtime dimensions. if (mdw.has_runtime_strides()) return "*"; - dims_t strides; - const auto &blk = mdw.blocking_desc(); - utils::array_copy(strides, blk.strides, mdw.ndims()); + struct sort_key_t { + uint64_t stride_order; + dim_t outer_block; + int idx; + char dim_char; + }; + + dims_t blocks = {0}; + mdw.compute_blocks(blocks); - utils::simultaneous_sort(strides, ou_blocks, dim_chars, mdw.ndims(), - [](dim_t a, dim_t b) { return b - a; }); + std::vector sort_keys(mdw.ndims()); + const auto &pdims = mdw.padded_dims(); + const auto &blk = mdw.blocking_desc(); + for (int i = 0; i < mdw.ndims(); ++i) + // Assume that any dimension with stride 0 is outer relative to other + // dimensions. Use (uint64_t)(stride - 1) to sort a stride of 0 highest. + // Multiple dimensions with stride 0 is ambiguous. + sort_keys[i] = {(uint64_t)(blk.strides[i] - 1), pdims[i] / blocks[i], i, + (char)((blocks[i] == 1 ? 'a' : 'A') + i)}; + + // Old approach: utils::simultaneous_sort(strides, outer_blocks, dim_chars) + // input tag: acdb + // dims: 5x8x0x2 + // strides: 0x1x16x8 + // output tag: cdba + // + // New approach with std::sort and sort keys: + // input tag: acdb + // dims: 5x8x0x2 + // "stride orders": (BIG NUMBER)x0x15x7 + // output tag: acdb + std::sort(sort_keys.begin(), sort_keys.end(), + [](const sort_key_t &left, const sort_key_t &right) { + if (left.stride_order < right.stride_order) return false; + if (left.stride_order == right.stride_order) { + // WLOG, we can assume a dimension of size 1 has the same + // stride as the next outermost dimension. Sort the one with + // the non-unit outer block as the outer dimension. Multiple + // dimensions of size 1 with the same stride is ambiguous. + if (left.outer_block < right.outer_block) return false; + if (left.outer_block == right.outer_block) + // Sort 1x1x... outer blocks to (arbitrarily) list them + // in alphabetical order. + return left.idx < right.idx; + } + return true; + }); + char dim_chars[DNNL_MAX_NDIMS + 1]; + for (int i = 0; i < mdw.ndims(); ++i) + dim_chars[i] = sort_keys[i].dim_char; dim_chars[mdw.ndims()] = '\0'; std::string s(dim_chars); @@ -512,6 +553,7 @@ std::string md2fmt_str( case format_kind::blocked: ss << ":" << md2fmt_tag_str(md) << ":" << md2fmt_strides_str(md); break; + case format_kind::cublaslt_blocked: ss << cublasltfmt2str(md); break; case format_kind::wino: case format_kind::rnn_packed: case format_kind::opaque: ss << "::"; break; @@ -579,24 +621,13 @@ std::string md2desc_str(const memory_desc_t *md) { return s; } -std::ostream &operator<<(std::ostream &ss, const runtime_scales_t &scale) { - ss << scale.mask_; - ss << ":" << scale.data_type_; - if (scale.ndims_) { - ss << ":"; - for (int i = 0; i < scale.ndims_ - 1; ++i) - ss << scale.group_dims_[i] << 'x'; - ss << scale.group_dims_[scale.ndims_ - 1]; - } - return ss; -} - -std::ostream &operator<<(std::ostream &ss, const scales_t &oscale) { - ss << oscale.mask_; - const float val = oscale.scales_[0]; +std::ostream &operator<<( + std::ostream &ss, const rnn_create_time_scales_t &rnn_scales) { + ss << rnn_scales.mask_; + const float val = rnn_scales.scales_[0]; // Can't use scientific flags since it breaks parsing on converter and // benchdnn side. - if (oscale.mask_ == 0 || is_runtime_value(val)) + if (rnn_scales.mask_ == 0 || is_runtime_value(val)) ss << ":" << get_val_str(val); return ss; } @@ -680,7 +711,8 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { const accumulation_mode_t &am = attr->acc_mode_; if (am != accumulation_mode::strict) { - ss << field_delim() << "attr-acc:" << dnnl_accumulation_mode2str(am); + ss << field_delim() + << "attr-acc-mode:" << dnnl_accumulation_mode2str(am); } const auto &rm = attr->rounding_mode_; @@ -701,53 +733,25 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { if (deterministic) { ss << field_delim() << "attr-deterministic:" << deterministic; } + + // Fast exit if rest attributes were not specified. if (attr->has_default_values()) return ss; - const runtime_scales_t &os = attr->output_scales_; - if (!os.has_default_values()) { - ss << field_delim() << "attr-oscale:" << os; + const scales_t &scales = attr->scales_; + if (!scales.has_default_values()) { + ss << field_delim() << "attr-scales:" << scales.get_verbose(); } - const arg_scales_t &as = attr->scales_; - if (!as.has_default_values()) { - std::string delim = empty_delim; - ss << field_delim() << "attr-scales:"; - for (const auto &map_entry : as.scales_) { - const auto &val = map_entry.second; - if (val.has_default_values()) continue; - - int arg = map_entry.first; - ss << delim << arg2str(arg) << ":" << val; - delim = attr_delim; - } + const zero_points_t &zero_points = attr->zero_points_; + if (!zero_points.has_default_values()) { + ss << field_delim() << "attr-zero-points:" << zero_points.get_verbose(); } - - const zero_points_t &zp = attr->zero_points_; - if (!zp.has_default_values()) { - std::string delim = empty_delim; - ss << field_delim() << "attr-zero-points:"; - for (const auto &arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - if (zp.has_default_values(arg)) continue; - - int mask = 0; - zp.get(arg, &mask); - const auto dt = zp.get_data_type(arg); - - ss << delim << arg2str(arg) << ":" << mask << ":" << dt; - - const auto &g_ndim = zp.get_groups_ndims(arg); - if (g_ndim) { - const auto &g_dims = zp.get_groups(arg); - ss << ":"; - for (int i = 0; i < g_ndim - 1; ++i) - ss << g_dims[i] << 'x'; - ss << g_dims[g_ndim - 1]; - } - - delim = attr_delim; - } + const auto &legacy_input_zp = attr->input_zero_points_; + if (!legacy_input_zp.has_default_values()) { + ss << "attr-legacy-input-zero-points:"; + ss << ":" << get_val_str(legacy_input_zp.mask_) << ":" << get_val_str(legacy_input_zp.count_); + ss << " "; } - const post_ops_t &po = attr->post_ops_; if (!po.has_default_values()) { std::string delim = empty_delim; @@ -766,12 +770,14 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { if (s.dt != data_type::undef) ss << ":" << s.dt; } break; case primitive_kind::convolution: { - using namespace data_type; - const auto &c = e.depthwise_conv; - ss << delim << "dw:k" << c.kernel << "s" << c.stride << "p" - << c.padding; - if (c.wei_dt == s8 || c.dst_dt != f32) - ss << ":" << c.dst_dt; + // using namespace data_type; + // const auto &c = e.depthwise_conv; + // ss << delim << "dw:k" << c.kernel << "s" << c.stride << "p" + // << c.padding; + // if (c.wei_dt == s8 || c.dst_dt != f32) + // ss << ":" << c.dst_dt; + const char *alg_str = "depthwise_conv_old"; + ss << delim << alg_str; } break; case primitive_kind::eltwise: { const post_ops_t::entry_t::eltwise_t &ew = e.eltwise; @@ -783,7 +789,7 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { } break; case primitive_kind::binary: { const post_ops_t::entry_t::binary_t &eb = e.binary; - const auto &md = eb.src1_desc; + const auto &md = eb.user_src1_desc; int mask = 0; for (int d = 0; d < md.ndims; ++d) mask += md.dims[d] != 1 ? (1 << d) : 0; @@ -792,7 +798,7 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { switch (mdw.format_kind()) { case format_kind::blocked: if (!mdw.count_non_unit_dims(1)) - ss << ":" << md2fmt_tag_str(&md); + ss << ":" << md2fmt_tag_str(&eb.src1_desc); break; case format_kind::any: ss << ":any"; break; default: assert(!"unsupported format_kind"); @@ -803,6 +809,14 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { ss << delim << "prelu" << ":" << ep.mask; } break; + case primitive_kind::depthwise: { + const post_ops_t::entry_t::depthwise_t &dw = e.depthwise; + ss << delim << dw.alg; + } break; + case primitive_kind::quantization: { + const post_ops_t::entry_t::quantization_t &qt = e.quantization; + ss << delim << qt.alg; + } break; default: assert(!"unsupported post op primitive kind!"); break; } delim = attr_delim; @@ -813,10 +827,18 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { if (!rnn_qp.has_default_values()) { ss << field_delim() << "rnn_data_qparams:" << rnn_qp.scale_ << ":" << rnn_qp.shift_ << ";"; + ss << "rnn_data_qparams:" << rnn_qp.scale_ << ":" << rnn_qp.shift_ + << " "; + } + + const src_dyn_quant_params_t &dyn_qp = attr->src_dyn_quant_params_; + if (!dyn_qp.has_default_values()) { + ss << "src_dyn_quant_group_size:" << dyn_qp.get() << ";"; } if (!attr->dropout_.has_default_values()) { - const memory_desc_wrapper mdw(attr->dropout_.dropout_desc_); + ss << field_delim() << "attr-dropout"; + const memory_desc_wrapper mdw(attr->dropout_.user_dropout_desc_); switch (mdw.format_kind()) { case format_kind::blocked: if (!mdw.count_non_unit_dims(1)) @@ -1525,8 +1547,9 @@ std::string init_info_softmax(const engine_t *e, const pd_t *pd) { << " "; ss << md2fmt_str("dst", dst_md, pd->dst_md(0, true)->format_kind); if (!types::is_zero_md(diff_dst_md)) { - ss << md2fmt_str( - "diff_dst", diff_dst_md, pd->diff_dst_md(0, true)->format_kind); + ss << " " + << md2fmt_str("diff_dst", diff_dst_md, + pd->diff_dst_md(0, true)->format_kind); } ss << "," << pd->attr() << ","; @@ -1556,6 +1579,51 @@ std::string init_info_sum(const engine_t *e, const pd_t *pd) { return ss.str(); } +template +std::string init_info_sdpa(const engine_t *e, const pd_t *pd) { + std::stringstream ss; + ss << e << "," << pd->kind() << "," << pd->name() << ","; + + const sdpa_desc_t *desc = pd->desc(); + + std::string delimiter; + if (!desc->kq_scales.has_default_values()) { + ss << delimiter << "kq_attr-scales:wei:" << desc->kq_scales; + delimiter = "+"; + } + if (!desc->kq_zero_points.has_default_values()) { + ss << delimiter + << "kq_attr-zero-points:" << desc->kq_zero_points.get_verbose(); + delimiter = "+"; + } + + if (!desc->vs_scales.has_default_values()) { + ss << delimiter << "vs_attr-scales:wei:" << desc->vs_scales; + delimiter = "+"; + } + if (!desc->vs_zero_points.has_default_values()) { + ss << delimiter + << "vs_attr-zero-points:" << desc->vs_zero_points.get_verbose(); + } + + ss << ",query:" << pd->qry_md()->data_type << ":" + << md2dim_str(pd->qry_md()); + ss << ",key:" << pd->key_md()->data_type << ":" << md2dim_str(pd->key_md()) + << ":" << md2fmt_tag_str(pd->key_md()); + ss << ",val:" << pd->val_md()->data_type << ":" << md2dim_str(pd->val_md()); + if (pd->with_attn_mask()) { + ss << ",msk:" << pd->attn_mask_md()->data_type << ":" + << md2dim_str(pd->attn_mask_md()); + } else if (pd->with_causal_mask()) { + if (desc->mask_type == attn_mask_type::top_left) + ss << ",msk:causal:top_left"; + else + ss << ",msk:causal:bottom_right"; + } + + return ss.str(); +} + } // namespace std::string rt_mds2str(primitive_kind_t prim_kind, const memory_desc_t *src_md, @@ -1564,6 +1632,10 @@ std::string rt_mds2str(primitive_kind_t prim_kind, const memory_desc_t *src_md, // Note: pass format_kind::undef since runtime dims-ed mds can't have // format_kind::any at any stage. std::string s; +#if defined(DISABLE_VERBOSE) + return s; +#endif + switch ((int)prim_kind) { case primitive_kind::matmul: s = mds2str_matmul(src_md, format_kind::undef, wei_md, @@ -1607,13 +1679,15 @@ std::string prepend_identifier_and_version(const char *fmt_str) { } void verbose_printf_impl(const char *raw_fmt_str, verbose_t::flag_kind kind) { +#if defined(DISABLE_VERBOSE) + return; +#endif + + if (get_verbose(kind)) print_header(); + const auto &fmt_str = prepend_identifier_and_version(raw_fmt_str); #ifdef DNNL_EXPERIMENTAL_LOGGING - // by default, verbose_t::create_check is passed to the logger - // so that it prints at spdlog log_level_t::info when no verbose flag - // is specified. This is useful for printing headers, format fields, etc. - // which do not correspond to a specific verbose kind. const log_manager_t &log_manager = log_manager_t::get_log_manager(); if (log_manager.is_logger_enabled()) @@ -1632,6 +1706,10 @@ std::string rt_dims2fmt_str(primitive_kind_t prim_kind, const memory_desc_t *src_md, const memory_desc_t *wei_md, const memory_desc_t *dst_md) { std::string s; +#if defined(DISABLE_VERBOSE) + return s; +#endif + switch ((int)prim_kind) { case primitive_kind::matmul: s = dims2fmt_str_matmul(src_md, wei_md); @@ -1661,6 +1739,7 @@ std::string rt_dims2fmt_str(primitive_kind_t prim_kind, } void pd_info_t::init(engine_t *engine, const primitive_desc_t *pd) { + // Handles VERBOSE_DISABLE since `is_initialized_` is set to `true`. if (is_initialized_) return; std::call_once(initialization_flag_, [&] { @@ -1692,9 +1771,7 @@ void pd_info_t::init(engine_t *engine, const primitive_desc_t *pd) { CASE(shuffle); CASE(softmax); CASE(sum); - case primitive_kind::sdpa: - str_ = "sdpa, unknown info"; - break; + CASE(sdpa); case primitive_kind::zero_pad: str_ = "zero_pad, unknown info"; break; @@ -1708,7 +1785,6 @@ void pd_info_t::init(engine_t *engine, const primitive_desc_t *pd) { is_initialized_ = true; }); } -#endif } // namespace impl } // namespace dnnl @@ -1719,10 +1795,8 @@ dnnl_status_t dnnl_set_verbose(int level) { if (level < 0 || level > 2) return invalid_arguments; uint32_t verbose_level = verbose_t::none; - if (level == 1) verbose_level = verbose_t::error | verbose_t::exec_profile; - if (level == 2) - verbose_level = verbose_t::error | verbose_t::exec_profile - | verbose_t::create_profile; + if (level == 1) verbose_level = verbose_t::level1; + if (level == 2) verbose_level = verbose_t::level2; // we put the lower byte of level as devinfo to preserve backward // compatibility with historical VERBOSE={1,2} if (level == 1 || level == 2) verbose_level |= (level << 24); diff --git a/src/common/verbose.hpp b/src/common/verbose.hpp index c2839c67a6e..b6315a8f5c1 100644 --- a/src/common/verbose.hpp +++ b/src/common/verbose.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * Copyright 2023 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -57,14 +57,14 @@ inline constexpr size_t get_file_name_offset(T (&str)[1]) { return 0; } template -struct const_expr_value { +struct const_expr_value_t { static constexpr const T value = v; }; } // namespace utility #define UTILITY_CONST_EXPR_VALUE(exp) \ - utility::const_expr_value::value + utility::const_expr_value_t::value #define __FILENAME__ (&__FILE__[utility::get_file_name_offset(__FILE__)]) @@ -73,12 +73,12 @@ struct const_expr_value { // The string can contain format specifiers which are provided in VA_ARGS // Note: using ##__VAR_ARGS__ is necessary to avoid trailing comma in printf call -#define VFORMAT(stamp, apitype, logtype, logsubtype, msg, ...) \ +#define VFORMAT(stamp, flagkind, apitype, logtype, logsubtype, msg, ...) \ do { \ std::string stamp_; \ if (dnnl::impl::get_verbose_timestamp()) \ stamp_ = std::to_string(stamp) + ","; \ - dnnl::impl::verbose_printf( \ + dnnl::impl::verbose_printf(flagkind, \ "%s" CONCAT2(VERBOSE_, apitype) "," CONCAT2( \ VERBOSE_, logtype) "%s," msg "\n", \ stamp_.c_str(), logsubtype, ##__VA_ARGS__); \ @@ -88,7 +88,8 @@ struct const_expr_value { #define VINFO(apitype, logtype, logsubtype, component, msg, ...) \ do { \ if (dnnl::impl::get_verbose(verbose_t::logtype##_##logsubtype)) \ - VFORMAT(get_msec(), apitype, logtype, VERBOSE_##logsubtype, \ + VFORMAT(get_msec(), verbose_t::logtype##_##logsubtype, apitype, \ + logtype, VERBOSE_##logsubtype, \ #component "," msg ",%s:%d", ##__VA_ARGS__, __FILENAME__, \ __LINE__); \ } while (0) @@ -116,8 +117,22 @@ struct const_expr_value { #define VERROR(apitype, component, msg, ...) \ do { \ if (dnnl::impl::get_verbose(verbose_t::error)) { \ - VFORMAT(get_msec(), apitype, error, "", #component "," msg, \ - ##__VA_ARGS__); \ + VFORMAT(get_msec(), verbose_t::error, apitype, error, "", \ + #component "," msg ",%s:%d", ##__VA_ARGS__, __FILENAME__, \ + __LINE__); \ + } \ + } while (0) + +// Special syntactic sugar for warnings, plus flush of the output stream +// The difference between the warn and error verbose modes is that the +// verbose error messages are only reserved for printing when an exception is +// thrown or when a status check fails. +#define VWARN(apitype, component, msg, ...) \ + do { \ + if (dnnl::impl::get_verbose(verbose_t::warn)) { \ + VFORMAT(get_msec(), verbose_t::warn, apitype, warn, "", \ + #component "," msg ",%s:%d", ##__VA_ARGS__, __FILENAME__, \ + __LINE__); \ } \ } while (0) @@ -127,17 +142,21 @@ struct const_expr_value { do { \ if (dnnl::impl::get_verbose_dev_mode(verbose_t::debuginfo) \ >= (level)) { \ - VFORMAT(get_msec(), apitype, debuginfo, "", #component "," msg, \ - ##__VA_ARGS__); \ + VFORMAT(get_msec(), verbose_t::debuginfo, apitype, debuginfo, "", \ + #component "," msg ",%s:%d", ##__VA_ARGS__, __FILENAME__, \ + __LINE__); \ } \ } while (0) // Special syntactic sugar for logging performance // NOTE: the VPROF macro does not check for verbose flags, it is the -// responsibility of the caller do check those (it should happen +// responsibility of the caller to check those (it should happen // anyway to condition collecting stamp/duration) #define VPROF(stamp, apitype, logtype, logsubtype, info, duration) \ - { VFORMAT(stamp, apitype, logtype, logsubtype, "%s,%g", info, duration); } + { \ + VFORMAT(stamp, dnnl::impl::verbose_t::exec_profile, apitype, logtype, \ + logsubtype, "%s,%g", info, duration); \ + } struct verbose_t { enum flag_kind : uint32_t { @@ -152,9 +171,13 @@ struct verbose_t { exec_check = 1 << 6, exec_profile = 1 << 7, profile_externals = 1 << 8, + warn = 1 << 9, // the upper 8 bits are reserved for devinfo levels debuginfo = 1 << 24, // + level1 = error | exec_profile | warn, + level2 = error | exec_profile | warn | create_profile, + all = (uint32_t)-1, }; @@ -234,6 +257,8 @@ get_verbose_to_log_level_map() { verbose_to_log_map { {verbose_t::all, log_manager_t::trace}, {verbose_t::debuginfo, log_manager_t::debug}, + {verbose_t::level1, log_manager_t::info}, + {verbose_t::level2, log_manager_t::info}, {verbose_t::create_dispatch, log_manager_t::info}, {verbose_t::create_check, log_manager_t::info}, {verbose_t::create_profile, log_manager_t::info}, @@ -241,6 +266,7 @@ get_verbose_to_log_level_map() { {verbose_t::exec_profile, log_manager_t::info}, {verbose_t::exec_check, log_manager_t::error}, {verbose_t::error, log_manager_t::critical}, + {verbose_t::warn, log_manager_t::warn}, {verbose_t::none, log_manager_t::off}, }; return verbose_to_log_map; @@ -279,6 +305,10 @@ inline std::string format_verbose_string( // processes fixed strings for logging and printing inline void verbose_printf(const char *fmt_str) { + // by default, verbose_t::create_check is passed to the logger + // so that it prints at spdlog log_level_t::info when no verbose flag + // is specified. This is useful for printing headers, format fields, etc. + // which do not correspond to a specific verbose kind. verbose_printf_impl(fmt_str, verbose_t::create_check); } @@ -293,6 +323,10 @@ inline void verbose_printf(verbose_t::flag_kind kind, const char *fmt_str) { template inline void verbose_printf(const char *fmt_str, str_args... args) { std::string msg = format_verbose_string(fmt_str, args...); + // by default, verbose_t::create_check is passed to the logger + // so that it prints at spdlog log_level_t::info when no verbose flag + // is specified. This is useful for printing headers, format fields, etc. + // which do not correspond to a specific verbose kind. verbose_printf_impl(msg.c_str(), verbose_t::create_check); } @@ -348,6 +382,7 @@ std::string md2fmt_str( const char *name, const memory_desc_t *md, format_kind_t user_format); std::string md2dim_str( const memory_desc_t *md, dims_type_t dims_type = dims_type_t::dims); +std::string arg2str(int arg); // Returns a verbose string of dimensions or descriptor from src, wei, and/or // dst memory descs. Can be called externally to provide info about actual // values of runtime dimensions. diff --git a/src/common/verbose_msg.hpp b/src/common/verbose_msg.hpp index 67e52c1f3c8..cf92ffbbfa4 100644 --- a/src/common/verbose_msg.hpp +++ b/src/common/verbose_msg.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023-2025 Intel Corporation * Copyright 2023 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,6 +26,7 @@ // log type strings #define VERBOSE_error "error" +#define VERBOSE_warn "warning" #define VERBOSE_create "create" #define VERBOSE_create_nested "create_nested" #define VERBOSE_exec "exec" @@ -42,9 +43,9 @@ // verbose messages #define VERBOSE_PROFILING_UNSUPPORTED "profiling capabilities are not supported" #define VERBOSE_INVALID_DEVICE_ENV "invalid %s device in environment: index %zu" -#define VERBOSE_INVALID_ENGINE_KIND "no %s device is available" +#define VERBOSE_INVALID_ENGINE_KIND "no %s %s device is available" #define VERBOSE_INVALID_ENGINE_IDX \ - "%zu %s devices are available but %zu was queried" + "%zu %s devices are available but device index %zu was queried" #define VERBOSE_INVALID_ACC_MODE "bad accumulation mode %s" #define VERBOSE_NULL_ARG "one of the mandatory arguments is nullptr" #define VERBOSE_BAD_ENGINE_KIND "bad engine kind" @@ -64,6 +65,9 @@ #define VERBOSE_INCONSISTENT_DIM "dimension %s:%d is inconsistent with %s:%d" #define VERBOSE_INCONSISTENT_NDIMS \ "tensors %s and %s have inconsistent number of dimensions" +// TODO: replace the version above with the version below. +#define VERBOSE_INCONSISTENT_NDIMS_WITH_VALS \ + "tensors %s and %s have inconsistent number of dimensions (%d) and (%d)" #define VERBOSE_INCONSISTENT_DT "tensors %s and %s have inconsistent datatypes" #define VERBOSE_INCONSISTENT_MDS "inconsistent %s and %s mds" #define VERBOSE_INCONSISTENT_ALPHA_BETA \ @@ -71,6 +75,7 @@ #define VERBOSE_INCONSISTENT_PRB "problem is not mathematically consistent" #define VERBOSE_BAD_NDIMS "%s has a bad number of dimensions %d" #define VERBOSE_BAD_DIM "bad dimension %s:%d" +#define VERBOSE_OUT_OF_RANGE_DIMS "out-of-range dimensions for %s" #define VERBOSE_UNSUPPORTED_ISA "unsupported isa" #define VERBOSE_UNSUPPORTED_DT "unsupported datatype" @@ -103,6 +108,8 @@ #define VERBOSE_WS_MISMATCH \ "workspace mismatch between forward and backward primitive " \ "descriptors" +#define VERBOSE_TENSOR_FORMAT_MISMATCH \ + "memory formats for %s and %s tensors do not match" #define VERBOSE_WS_INIT "workspace initialization failed" #define VERBOSE_SCRATCHPAD_INIT "scratchpad initialization unsuccessful" @@ -116,14 +123,16 @@ #define VERBOSE_IMPL_HEURISTIC_FAIL "heuristic fail: %s" #define VERBOSE_1x1CONV_HEURISTIC_FAIL "heuristic fail for 1x1 convolution: %s" #define VERBOSE_SCRATCHPAD_LIMIT "scratchpad memory limit exceeded" -#define VERBOSE_PRIMITIVE_CREATION_FAIL "failed to create nested primitive %s" +#define VERBOSE_PRIMITIVE_CREATION_FAIL "failed to create nested %s primitive" #define VERBOSE_DESC_CREATION_FAIL "failed to create %s descriptor" #define VERBOSE_SHAPE_RESTRICTION "failed shape restrictions" #define VERBOSE_INCOMPATIBLE_GEMM_FMT "incompatible gemm format" #define VERBOSE_DEVICE_CTX_MISMATCH "device not found in the given context" +#define VERBOSE_MISSING_OCL_DEVICE "%s OpenCL device not found" #define VERBOSE_INVALID_PLATFORM "unsupported %s platform (expected %s got %s)" #define VERBOSE_ENGINE_CREATION_FAIL "failed to create %s engine with index %zu" +#define VERBOSE_KERNEL_CREATION_FAIL "failed to create %s kernel" #define VERBOSE_DETERMINISTIC_FAIL "failed to run kernel deterministically" #define VERBOSE_SKIP_PRIMITIVE_IMPL \ "skipping or dispatching to another implementation" diff --git a/src/common/z_magic.hpp b/src/common/z_magic.hpp index 9baae4c8bab..597954e2844 100644 --- a/src/common/z_magic.hpp +++ b/src/common/z_magic.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2022 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ #define PRAGMA_MACRO(x) PRAGMA_MACRo(x) #endif -#define UNUSED(x) ((void)x) +#define UNUSED(x) ((void)(x)) #define MAYBE_UNUSED(x) UNUSED(x) #if defined(_WIN32) && !defined(__GNUC__) diff --git a/src/cpu/CMakeLists.txt b/src/cpu/CMakeLists.txt index 17ad1e4a59e..ab791ee7b2c 100644 --- a/src/cpu/CMakeLists.txt +++ b/src/cpu/CMakeLists.txt @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2019-2024 Intel Corporation +# Copyright 2019-2025 Intel Corporation # Copyright 2020 Arm Ltd. and affiliates # Copyright 2021 FUJITSU LIMITED # @@ -22,6 +22,7 @@ file(GLOB_RECURSE SOURCES_EXTRA ${CMAKE_CURRENT_SOURCE_DIR}/matmul/*.[ch]pp ${CMAKE_CURRENT_SOURCE_DIR}/reorder/*.[ch]pp ${CMAKE_CURRENT_SOURCE_DIR}/rnn/*.[ch]pp + ${CMAKE_CURRENT_SOURCE_DIR}/ukernel/*.[ch]pp ) foreach(SOURCE_FILE ${SOURCES_EXTRA}) @@ -130,6 +131,7 @@ set(OBJ_LIB ${LIB_PACKAGE_NAME}_cpu) add_library(${OBJ_LIB} OBJECT ${SOURCES}) set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS $) +enable_conditional_compilation4(${OBJ_LIB}) if (DNNL_TARGET_ARCH STREQUAL "X64") add_subdirectory(x64) @@ -137,6 +139,9 @@ endif() if (DNNL_TARGET_ARCH STREQUAL "AARCH64") add_subdirectory(aarch64) endif() +if (DNNL_USE_ACL) + add_subdirectory(acl) +endif() if (DNNL_TARGET_ARCH STREQUAL "PPC64") add_subdirectory(ppc64) endif() diff --git a/src/cpu/README.md b/src/cpu/README.md index 75668c15c82..7641f9e825b 100644 --- a/src/cpu/README.md +++ b/src/cpu/README.md @@ -44,7 +44,9 @@ architecture. Hence, for portability reasons [`cpu/platform.hpp`](platform.hpp) header file provides a set of helpers macros that could help conditionally enable or disable parts of code. There the following macros defined: - `DNNL_X64` is 1 on x64 architecture; +- `DNNL_X86` is 1 on x86 architecture; - `DNNL_AARCH64` is 1 on Arm AArch64 architecture; +- `DNNL_ARM` is 1 on Arm 32 architecture; - `DNNL_PPC64` is 1 on OpenPOWER / IBM Power architecture; - `DNNL_S390X` is 1 on IBMz / s390x architecture; - `DNNL_RV64` is 1 on RISC-V architecture; diff --git a/src/cpu/aarch64/CMakeLists.txt b/src/cpu/aarch64/CMakeLists.txt index 432a00bc70a..32eec64988c 100644 --- a/src/cpu/aarch64/CMakeLists.txt +++ b/src/cpu/aarch64/CMakeLists.txt @@ -20,21 +20,6 @@ file(GLOB_RECURSE SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/*.[ch]pp ) -file(GLOB XBYAK_AARCH64_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/xbyak_aarch64/src/xbyak_aarch64_impl.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/xbyak_aarch64/src/util_impl.cpp - ) - -list(REMOVE_ITEM SOURCES ${XBYAK_AARCH64_FILES}) - -if(NOT DNNL_AARCH64_USE_ACL) - file(GLOB_RECURSE ACL_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/acl_*.[ch] - ${CMAKE_CURRENT_SOURCE_DIR}/acl_*.[ch]pp - ) - list(REMOVE_ITEM SOURCES ${ACL_FILES}) -endif() - # If the runtime is not THREADPOOL remove threadpool_scheduler sources. if(NOT DNNL_CPU_RUNTIME STREQUAL "THREADPOOL") list(APPEND ACL_THREADPOOL_FILES @@ -48,5 +33,6 @@ set(OBJ_LIB ${LIB_PACKAGE_NAME}_cpu_aarch64) add_library(${OBJ_LIB} OBJECT ${SOURCES}) set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS $) +enable_conditional_compilation4(${OBJ_LIB}) -add_subdirectory(xbyak_aarch64) +add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/xbyak_aarch64 xbyak_aarch64) diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp deleted file mode 100644 index 37a3d6c3d98..00000000000 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ /dev/null @@ -1,183 +0,0 @@ -/******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_AARCH64_ACL_CONVOLUTION_UTILS_HPP -#define CPU_AARCH64_ACL_CONVOLUTION_UTILS_HPP - -#include -#include "acl_post_ops.hpp" -#include "acl_utils.hpp" -#include "arm_compute/runtime/experimental/operators/CpuDepthwiseConv2d.h" -#include "cpu/cpu_convolution_pd.hpp" -#include -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -template -struct acl_obj_t { - ConvOp conv; - arm_compute::experimental::MemoryRequirements aux_mem_req; -}; - -struct acl_conv_conf_t { - bool with_bias; - bool fast_math; - // If this is true, the result of the convolution goes into a temporarily - // allocated ACL tensor to be accumulated into the oneDNN dst during postops - bool use_dst_acc_for_sum; - // Tells that the selected algorithm is Winograd. This is needed because the - // algorithm can be set to algorithm::convolution_auto and later on we need to - // skip fixed-format protocol as ACL Winograd does not support it. - bool alg_winograd; - arm_compute::TensorInfo src_tensor_info; - arm_compute::TensorInfo wei_tensor_info; - arm_compute::TensorInfo bia_tensor_info; - arm_compute::TensorInfo dst_tensor_info; - - arm_compute::PadStrideInfo padstride_info; - arm_compute::Size2D dilation_info; - // Additional information about the weights not included in wei_tensor_info - arm_compute::WeightsInfo weights_info; - // Note: this will default to not enabled, and will do nothing - arm_compute::ActivationLayerInfo act_info; -}; - -namespace acl_convolution_utils { - -status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, - const primitive_attr_t &attr); - -} // namespace acl_convolution_utils - -// Keys are anonymous with local linkage. So deduce the type automagically. -using conv_key_t = decltype(memory_tracking::names::key_gemm_tmp_buffer); - -template -status_t init_scratchpad(op_t &conv, memory_tracking::registrar_t &scratchpad, - const std::map &conv_keys, engine_t *engine, - post_ops_t &post_ops, dnnl::impl::post_ops_t &attr_post_ops, - arm_compute::ActivationLayerInfo &act_info, bool &use_dst_acc_for_sum, - const dnnl::impl::memory_desc_t &dst_md) { - - // Book temp mem. - const auto aux_mem_req = conv.workspace(); - for (const auto &key : conv_keys) { - const auto id = key.first; - if (aux_mem_req[id].size > 0) { - scratchpad.book(key.second, aux_mem_req[id].size, 1, - aux_mem_req[id].alignment, aux_mem_req[id].alignment); - } - } - - CHECK(post_ops.init(engine, attr_post_ops, dst_md, act_info)); - use_dst_acc_for_sum = post_ops.has_sum(); - - if (use_dst_acc_for_sum) { - const memory_desc_wrapper dst_d(&dst_md); - scratchpad.book(memory_tracking::names::key_generic_acc, dst_d.nelems(), - dst_d.data_type_size()); - } - - return status::success; -} - -template -status_t execute_forward_conv_acl(const exec_ctx_t &ctx, - conv_obj_t *acl_conv_obj, const conv_pd_t *pd, - const std::map &conv_keys) { - - auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); - auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); - - // import_memory() and free() methods do not allocate/free any additional - // memory, only acquire/release pointers. - arm_compute::Tensor src_tensor; - arm_compute::Tensor wei_tensor; - arm_compute::Tensor bia_tensor = nullptr; - arm_compute::Tensor dst_tensor; - - auto const acp = pd->acp_; - src_tensor.allocator()->init(acp.src_tensor_info); - wei_tensor.allocator()->init(acp.wei_tensor_info); - dst_tensor.allocator()->init(acp.dst_tensor_info); - - src_tensor.allocator()->import_memory(const_cast(src_base)); - wei_tensor.allocator()->import_memory(const_cast(wei_base)); - - const auto scratchpad = ctx.get_scratchpad_grantor(); - - // If we have an unfused sum post op, put the result in a scratchpad tensor. - // Result will be summed to the dst during acl_post_ops.execute - auto dst_base = acp.use_dst_acc_for_sum - ? scratchpad.get(memory_tracking::names::key_generic_acc) - : CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); - dst_tensor.allocator()->import_memory(dst_base); - - if (acp.with_bias) { - auto bia_base = CTX_IN_MEM(const bia_data_t *, DNNL_ARG_BIAS); - bia_tensor.allocator()->init(acp.bia_tensor_info); - bia_tensor.allocator()->import_memory( - const_cast(bia_base)); - } - - // Constness of the weight tensor matters for depthwise conv in ACL. - // Otherwise, it will package the weights more often than needed, as - // it will expect the weights to change within the duration of the run - // func. - arm_compute::ITensorPack pack; - pack.add_tensor(arm_compute::TensorType::ACL_SRC_0, &src_tensor); - pack.add_const_tensor(arm_compute::TensorType::ACL_SRC_1, &wei_tensor); - pack.add_const_tensor(arm_compute::TensorType::ACL_SRC_2, &bia_tensor); - pack.add_tensor(arm_compute::TensorType::ACL_DST, &dst_tensor); - - // Get temp workspaces. - const auto aux_mem = acl_conv_obj->aux_mem_req; - - // Hold onto tmp tensors while we need pack. - std::vector tmp_tensors(aux_mem.size()); - for (const auto &key : conv_keys) { - const auto id = key.first; - if (aux_mem[id].size > 0) { - const auto info = arm_compute::TensorInfo( - arm_compute::TensorShape(aux_mem[id].size), 1, - arm_compute::DataType::U8); - auto buffer = scratchpad.get(key.second); - tmp_tensors[id].allocator()->init(info, aux_mem[id].alignment); - tmp_tensors[id].allocator()->import_memory(buffer); - pack.add_tensor(aux_mem[id].slot, &tmp_tensors[id]); - } - } - - acl_conv_obj->conv.run(pack); - - void *dst = dst_tensor.buffer(); - pd->post_ops.execute(ctx, dst); - - return status::success; -} - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl - -#endif // CPU_AARCH64_ACL_CONVOLUTION_UTILS_HPP diff --git a/src/cpu/aarch64/acl_reorder.cpp b/src/cpu/aarch64/acl_reorder.cpp index 061751b5557..73e38c0c4bb 100644 --- a/src/cpu/aarch64/acl_reorder.cpp +++ b/src/cpu/aarch64/acl_reorder.cpp @@ -19,7 +19,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_reorder_fwd_t::execute_forward(const exec_ctx_t &ctx) const { // Lock here is needed because resource_mapper does not support @@ -46,7 +46,7 @@ status_t acl_reorder_fwd_t::execute_forward(const exec_ctx_t &ctx) const { return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_reorder.hpp b/src/cpu/aarch64/acl_reorder.hpp index e586ed4e304..617053841be 100644 --- a/src/cpu/aarch64/acl_reorder.hpp +++ b/src/cpu/aarch64/acl_reorder.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Arm Ltd. and affiliates +* Copyright 2023-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,19 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_REORDER_HPP -#define CPU_AARCH64_ACL_REORDER_HPP +#ifndef CPU_ACL_REORDER_HPP +#define CPU_ACL_REORDER_HPP #include "arm_compute/core/Types.h" #include "common/utils.hpp" -#include "cpu/aarch64/acl_utils.hpp" +#include "cpu/acl/acl_utils.hpp" #include "cpu/aarch64/cpu_isa_traits.hpp" #include "cpu/reorder/cpu_reorder_pd.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_reorder_obj_t { arm_compute::NEReorderLayer reorder; @@ -95,12 +95,12 @@ struct acl_reorder_fwd_t : public primitive_t { if (!ok) return status::unimplemented; - int mask = -1; - bool is_set = false; - CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set)); - const memory_desc_wrapper input_d(src_md); - if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0) - return status::unimplemented; + if (!attr->scales_.has_default_values(DNNL_ARG_DST)) { + int mask = attr->scales_.get_mask(DNNL_ARG_DST); + const memory_desc_wrapper input_d(src_md); + if (input_d.has_runtime_dims_or_strides() && mask > 0) + return status::unimplemented; + } // Create and check primitive descriptor auto _pd = make_unique_pd(attr, src_engine->kind(), src_md, @@ -131,7 +131,7 @@ struct acl_reorder_fwd_t : public primitive_t { if (dst_tag == format_tag::BA4b4a || dst_tag == format_tag::Acdb4a || dst_tag == format_tag::Ab4a) { _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4; - } else if (mayiuse(sve_256) + } else if (aarch64::mayiuse(aarch64::sve_256) && (dst_tag == format_tag::BA8b4a || dst_tag == format_tag::Acdb8a || dst_tag == format_tag::Ab8a)) { @@ -147,13 +147,17 @@ struct acl_reorder_fwd_t : public primitive_t { switch (src_md->ndims) { case 2: { if (src_tag == format_tag::ab - && dst_md->data_type == data_type::bf16) { // bf16 + && dst_md->data_type == data_type::bf16 + && utils::one_of(dst_tag, format_tag::BA8b4a, + format_tag::BA4b4a)) { // bf16 acl_tensor_shape_in = arm_compute::TensorShape( src_md->dims[0], src_md->dims[1]); acl_tensor_shape_out = arm_compute::TensorShape( dst_md->padded_dims[0], dst_md->padded_dims[1]); } else if (src_tag == format_tag::ba - && dst_md->data_type == data_type::f32) { // f32 + && dst_md->data_type == data_type::f32 + && !utils::one_of(dst_tag, format_tag::BA8b4a, + format_tag::BA4b4a)) { // f32 acl_tensor_shape_in = arm_compute::TensorShape( src_md->dims[1], src_md->dims[0]); acl_tensor_shape_out = arm_compute::TensorShape( @@ -239,9 +243,9 @@ struct acl_reorder_fwd_t : public primitive_t { }; // acl_reorder_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_REORDER_HPP +#endif // CPU_ACL_REORDER_HPP diff --git a/src/cpu/aarch64/acl_softmax.cpp b/src/cpu/aarch64/acl_softmax.cpp deleted file mode 100644 index 976b33665d2..00000000000 --- a/src/cpu/aarch64/acl_softmax.cpp +++ /dev/null @@ -1,52 +0,0 @@ -/******************************************************************************* -* Copyright 2021-2022 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "cpu/aarch64/acl_softmax.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -status_t acl_softmax_fwd_t::execute_forward(const exec_ctx_t &ctx) const { - - // Lock here is needed because resource_mapper does not support - // concurrent multithreaded access. - std::lock_guard _lock {this->mtx}; - - auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); - auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); - - // Retrieve primitive resource and configured Compute Library objects - auto *acl_resource - = ctx.get_resource_mapper()->get(this); - acl_softmax_obj_t &acl_obj = acl_resource->get_acl_obj(); - - acl_obj.src_tensor.allocator()->import_memory(const_cast(src)); - acl_obj.dst_tensor.allocator()->import_memory(dst); - - acl_obj.softmax->run(); - - acl_obj.src_tensor.allocator()->free(); - acl_obj.dst_tensor.allocator()->free(); - - return status::success; -} - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl diff --git a/src/cpu/aarch64/acl_softmax.hpp b/src/cpu/aarch64/acl_softmax.hpp deleted file mode 100644 index 020e6ca5ab0..00000000000 --- a/src/cpu/aarch64/acl_softmax.hpp +++ /dev/null @@ -1,240 +0,0 @@ -/******************************************************************************* -* Copyright 2021-2024 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_AARCH64_ACL_SOFTMAX_HPP -#define CPU_AARCH64_ACL_SOFTMAX_HPP - -#include "cpu/cpu_softmax_pd.hpp" - -#include "cpu/aarch64/acl_utils.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -struct acl_softmax_obj_t { - std::unique_ptr softmax; - arm_compute::Tensor src_tensor; - arm_compute::Tensor dst_tensor; -}; - -struct acl_softmax_conf_t { - arm_compute::TensorInfo src_info; - arm_compute::TensorInfo dst_info; - float beta; - int32_t axis; - bool is_logsoftmax; -}; - -struct acl_softmax_resource_t : public resource_t { - acl_softmax_resource_t() - : acl_obj_(utils::make_unique()) {} - - status_t configure(const acl_softmax_conf_t &asp) { - if (!acl_obj_) return status::out_of_memory; - - // Init Compute Library tensors based on info from descriptor - acl_obj_->src_tensor.allocator()->init(asp.src_info); - acl_obj_->dst_tensor.allocator()->init(asp.dst_info); - - if (asp.is_logsoftmax) { - auto logsoftmax - = std::make_unique(); - // clang-format off - logsoftmax->configure( - &acl_obj_->src_tensor, - &acl_obj_->dst_tensor, - asp.beta, - asp.axis); - // clang-format on - acl_obj_->softmax = std::move(logsoftmax); - } else { - auto softmax = std::make_unique(); - // clang-format off - softmax->configure( - &acl_obj_->src_tensor, - &acl_obj_->dst_tensor, - asp.beta, - asp.axis); - // clang-format on - acl_obj_->softmax = std::move(softmax); - } - - return status::success; - } - - acl_softmax_obj_t &get_acl_obj() const { return *acl_obj_; } - - DNNL_DISALLOW_COPY_AND_ASSIGN(acl_softmax_resource_t); - -private: - std::unique_ptr acl_obj_; -}; // acl_softmax_resource_t - -struct acl_softmax_fwd_t : public primitive_t { - struct pd_t : public cpu_softmax_fwd_pd_t { - using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; - - DECLARE_COMMON_PD_T("acl", acl_softmax_fwd_t); - - status_t init(engine_t *engine) { - - bool ok = is_fwd() - && set_default_formats() == status::success - // ACL only supports matching src/dst (this must come after - // set_default_formats() to handle format_kind::any) - && *src_md() == *dst_md() - && utils::one_of( - src_md()->data_type, data_type::f32, data_type::f16) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - // Get memory desc to find sizes and dims - const memory_desc_wrapper src_d(src_md()); - const data_type_t data_type = src_d.data_type(); - - // ACL only supports plain tensors, can be permuted but not blocked - if (!src_d.is_plain()) return status::unimplemented; - - // Guards against a 0-sized dimension - if (src_d.has_zero_dim()) return status::unimplemented; - - // No scaling - asp_.beta = 1; - - asp_.is_logsoftmax = is_logsoftmax(); - - // The strides give us the in memory inner size - dim_t inner_size_ = src_d.blocking_desc().strides[axis()]; - - dim_t axis_size_ = axis_size(); - - // The outer size is any left-over dimensions not inner or on the axis - dim_t outer_size_ = src_d.nelems() / (inner_size_ * axis_size_); - - // In this context, NHWC tells ACL that the logical and physical - // dimensions are the same - arm_compute::DataLayout acl_layout = arm_compute::DataLayout::NHWC; - - const arm_compute::DataType acl_data_t - = acl_utils::get_acl_data_t(data_type); - - const int threads = dnnl_get_max_threads(); - if (inner_size_ == 1) { - // A rough empirical heuristic created by fitting a polynomial - // of the tensor sizes and thread count to the run time of the - // ref and ACL softmax. This variable is greater than zero when - // ref is faster, and less than zero when ACL is faster. We can - // interpret the constant term as the constant overhead - // associated with calling the external library and the negative - // coefficient on total_size as ACL being faster at processing - // each element - double acl_ref_performance_diff = 1 + 0.005 * outer_size_ - - 0.0027 * axis_size_ - * std::ceil(double(outer_size_) / threads); - if (threads > 1 || outer_size_ > 1) { - // Using threads within ACL adds another constant overhead - acl_ref_performance_diff += 17; - } - if (acl_ref_performance_diff > 0) return status::unimplemented; - - // If the inner size is 1, we can get rid of the dimension. - // This stops ACL doing a unnecessary permute - arm_compute::TensorShape acl_tensor_shape - = arm_compute::TensorShape(axis_size_, outer_size_); - asp_.axis = 0; - - asp_.src_info = arm_compute::TensorInfo( - acl_tensor_shape, 1, acl_data_t, acl_layout); - asp_.dst_info = arm_compute::TensorInfo( - acl_tensor_shape, 1, acl_data_t, acl_layout); - } else { - // A rough empirical heuristic, see comment above - // The only difference here is that ACL does a reorder, and so - // is considerably better - double acl_ref_performance_diff = 1 + 0.005 * outer_size_ - - 0.01 * inner_size_ * axis_size_ - * std::ceil(double(outer_size_) / threads); - if (threads > 1 || outer_size_ > 1) { - // Using threads within ACL adds another constant overhead - acl_ref_performance_diff += 17; - } - - if (acl_ref_performance_diff > 0) return status::unimplemented; - - // Irrespective of the input dimensions, we construct a tensor - // with dimensions such that softmax can be applied over the - // middle axis (1), with the correct stride and vector length. - arm_compute::TensorShape acl_tensor_shape - = arm_compute::TensorShape( - inner_size_, axis_size_, outer_size_); - asp_.axis = 1; - - asp_.src_info = arm_compute::TensorInfo( - acl_tensor_shape, 1, acl_data_t, acl_layout); - asp_.dst_info = arm_compute::TensorInfo( - acl_tensor_shape, 1, acl_data_t, acl_layout); - } - - // Validate manually to check for return status - if (asp_.is_logsoftmax) { - ACL_CHECK_VALID(arm_compute::NELogSoftmaxLayer::validate( - &asp_.src_info, &asp_.dst_info, asp_.beta, asp_.axis)); - } else { - ACL_CHECK_VALID(arm_compute::NESoftmaxLayer::validate( - &asp_.src_info, &asp_.dst_info, asp_.beta, asp_.axis)); - } - - return status::success; - } - - acl_softmax_conf_t asp_; - }; // pd_t - - acl_softmax_fwd_t(const pd_t *apd) : primitive_t(apd) {} - - status_t create_resource( - engine_t *engine, resource_mapper_t &mapper) const override { - if (mapper.has_resource(this)) return status::success; - - auto r = utils::make_unique(); - if (!r) return status::out_of_memory; - - // Configure the resource based on information from primitive descriptor - auto st = r->configure(pd()->asp_); - if (st == status::success) { mapper.add(this, std::move(r)); } - - return st; - } - - status_t execute(const exec_ctx_t &ctx) const override { - return execute_forward(ctx); - } - -private: - // To guard the const execute_forward, the mutex must be 'mutable' - mutable std::mutex mtx; - status_t execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } -}; // acl_softmax_fwd_t - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl - -#endif diff --git a/src/cpu/aarch64/acl_thread.cpp b/src/cpu/aarch64/acl_thread.cpp deleted file mode 100644 index 1b098629ab5..00000000000 --- a/src/cpu/aarch64/acl_thread.cpp +++ /dev/null @@ -1,120 +0,0 @@ -/******************************************************************************* -* Copyright 2022-2023 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "cpu/aarch64/acl_thread.hpp" -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL -#include "cpu/aarch64/acl_threadpool_scheduler.hpp" -#endif -#include "cpu/aarch64/acl_benchmark_scheduler.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -namespace acl_thread_utils { - -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP -void acl_thread_bind() { - static std::once_flag flag_once; - // The threads in Compute Library are bound for the cores 0..max_threads-1 - // dnnl_get_max_threads() returns OMP_NUM_THREADS - const int max_threads = dnnl_get_max_threads(); - // arm_compute::Scheduler does not support concurrent access thus a - // workaround here restricts it to only one call - std::call_once(flag_once, [&]() { - arm_compute::Scheduler::get().set_num_threads(max_threads); - }); -} -// Swap BenchmarkScheduler for default ACL scheduler builds (i.e. CPPScheduler, OMPScheduler) -void acl_set_benchmark_scheduler_default() { - static std::once_flag flag_once; - arm_compute::IScheduler *_real_scheduler = &arm_compute::Scheduler::get(); - std::shared_ptr benchmark_scheduler - = std::make_unique(*_real_scheduler); - // set Benchmark scheduler in ACL - std::call_once(flag_once, [&]() { - arm_compute::Scheduler::set( - std::static_pointer_cast( - benchmark_scheduler)); - }); -} -#endif - -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL -void acl_set_tp_scheduler() { - static std::once_flag flag_once; - // Create threadpool scheduler - std::shared_ptr threadpool_scheduler - = std::make_unique(); - // set CUSTOM scheduler in ACL - std::call_once(flag_once, - [&]() { arm_compute::Scheduler::set(threadpool_scheduler); }); -} - -void acl_set_threadpool_num_threads() { - using namespace dnnl::impl::threadpool_utils; - static std::once_flag flag_once; - threadpool_interop::threadpool_iface *tp = get_active_threadpool(); - // Check active threadpool - bool is_main = get_active_threadpool() == tp; - if (is_main) { - // Set num threads based on threadpool size - const int num_threads = (tp) ? dnnl_get_max_threads() : 1; - std::call_once(flag_once, [&]() { - arm_compute::Scheduler::get().set_num_threads(num_threads); - }); - } -} -// Swap BenchmarkScheduler for custom scheduler builds (i.e. ThreadPoolScheduler) -void acl_set_tp_benchmark_scheduler() { - static std::once_flag flag_once; - // Create threadpool scheduler - std::unique_ptr threadpool_scheduler - = std::make_unique(); - arm_compute::IScheduler *_real_scheduler = nullptr; - _real_scheduler = threadpool_scheduler.release(); - // Create benchmark scheduler and set TP as real scheduler - std::shared_ptr benchmark_scheduler - = std::make_unique(*_real_scheduler); - std::call_once(flag_once, - [&]() { arm_compute::Scheduler::set(benchmark_scheduler); }); -} -#endif - -void set_acl_threading() { -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP - acl_thread_bind(); - if (get_verbose(verbose_t::profile_externals)) { - acl_set_benchmark_scheduler_default(); - } -#endif -#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL - if (get_verbose(verbose_t::profile_externals)) { - acl_set_tp_benchmark_scheduler(); - } else { - acl_set_tp_scheduler(); - } - -#endif -} - -} // namespace acl_thread_utils - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl diff --git a/src/cpu/aarch64/acl_winograd_convolution.cpp b/src/cpu/aarch64/acl_winograd_convolution.cpp deleted file mode 100644 index da015388d64..00000000000 --- a/src/cpu/aarch64/acl_winograd_convolution.cpp +++ /dev/null @@ -1,139 +0,0 @@ -/******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "acl_winograd_convolution.hpp" -#include "common/memory_tracking.hpp" -#include "common/utils.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -namespace { -using data_t = prec_traits::type; - -// Keys are anonymous. So deduce the type automagically. -using conv_key_t = decltype(memory_tracking::names::key_gemm_tmp_buffer); - -// Map: [slot , key] -const std::map wino_conv_keys - = {{0, conv_key_t::key_gemm_asm_tmp_buffer}, - {1, conv_key_t::key_gemm_pretranspose_b}, - {2, conv_key_t::key_gemm_pretranspose}, - {3, conv_key_t::key_gemm_interleaved_lhs}, - {4, conv_key_t::key_gemm_pretransposed_rhs}, - {5, conv_key_t::key_gemm_transposed_1xwrhs}, - {6, conv_key_t::key_gemm_tmp_buffer}, - {7, conv_key_t::key_conv_permuted_outputs}, - {8, conv_key_t::key_conv_permuted_inputs}, - {9, conv_key_t::key_wino_workspace}, - {10, conv_key_t::key_wino_transformed_weights}, - {11, conv_key_t::key_conv_permuted_weights}}; -} // namespace - -status_t acl_wino_convolution_fwd_t::pd_t::init(engine_t *engine) { - using namespace data_type; - const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) - && attr()->has_default_values( - primitive_attr_t::skip_mask_t::post_ops, f16); - const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) - && attr()->has_default_values( - primitive_attr_t::skip_mask_t::post_ops, f32); - bool ok = is_fwd() - && utils::one_of(desc()->alg_kind, alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && utils::one_of(true, is_fp16_ok, is_fp32_ok) - && !has_zero_dim_memory(); - - ok = ok && DNNL_CPU_THREADING_RUNTIME != DNNL_RUNTIME_THREADPOOL; - if (!ok) return status::unimplemented; - - CHECK(init_conf()); - - set_default_alg_kind(alg_kind::convolution_winograd); - - Op conv; - conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info, - acp_.with_bias ? &acp_.bia_tensor_info : nullptr, - &acp_.dst_tensor_info, acp_.padstride_info, acp_.act_info, - true); // to support 5x5, 7x7 filter shapes in addition to 3x3 - - auto scratchpad = scratchpad_registry().registrar(); - const auto aux_mem = conv.workspace(); - return init_scratchpad(conv, scratchpad, wino_conv_keys, engine, post_ops, - attr_.post_ops_, acp_.act_info, acp_.use_dst_acc_for_sum, dst_md_); -} - -status_t acl_wino_convolution_fwd_t::init(engine_t *engine) { - auto acp = pd()->acp_; - acl_obj_->conv.configure(&acp.src_tensor_info, &acp.wei_tensor_info, - acp.with_bias ? &acp.bia_tensor_info : nullptr, - &acp.dst_tensor_info, acp.padstride_info, acp.act_info, - true); // to support 5x5, 7x7 filter shapes in addition to 3x3 - - acl_obj_->aux_mem_req = acl_obj_->conv.workspace(); - return status::success; -} - -status_t acl_wino_convolution_fwd_t::pd_t::init_conf() { - - // Under these conditions, fallback to faster GEMM-based convolution - // unless the user explicitly specifies Winograd algorithm - if (utils::one_of(true, src_md_.dims[2] > 112, // ih - src_md_.dims[3] > 112, // iw - src_md_.dims[1] < 64, // ic - dst_md_.dims[1]<64, // oc - dnnl_get_max_threads()> 28) - && desc()->alg_kind == alg_kind::convolution_auto) { - return status::unimplemented; - } - - // General Compute Library checks, memory tags are also set there - acp_.alg_winograd = true; - CHECK(acl_convolution_utils::acl_init_conf( - acp_, src_md_, weights_md_, dst_md_, bias_md_, *desc(), *attr())); - - const bool shape_ok - // only unit strides allowed - = (acp_.padstride_info.stride() == std::pair {1, 1}) - // Note: Compute Library supports arbitrary padding for wino kernels - // but we only allow small padding to be consistent with oneDNN - && (acp_.padstride_info.pad().first <= 1) // padding left/right - && (acp_.padstride_info.pad().second <= 1) // padding top/bottom - // only non-dilated convolutions allowed - && (acp_.dilation_info == arm_compute::Size2D(1, 1)); - - ACL_CHECK_SUPPORT(!shape_ok, "shape not supported by winograd kernels"); - - // Validate convolution manually to check for return status - ACL_CHECK_VALID(Op::validate(&acp_.src_tensor_info, &acp_.wei_tensor_info, - acp_.with_bias ? &acp_.bia_tensor_info : nullptr, - &acp_.dst_tensor_info, acp_.padstride_info, acp_.act_info, - true)); // enable_fast_math flag in ACL Winograd - - return status::success; -} - -status_t acl_wino_convolution_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - return execute_forward_conv_acl, pd_t, data_t>( - ctx, acl_obj_.get(), pd(), wino_conv_keys); -} -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl diff --git a/src/cpu/aarch64/acl_winograd_convolution.hpp b/src/cpu/aarch64/acl_winograd_convolution.hpp deleted file mode 100644 index 15b015757ea..00000000000 --- a/src/cpu/aarch64/acl_winograd_convolution.hpp +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP -#define CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP - -#include "cpu/cpu_convolution_pd.hpp" - -#include "acl_convolution_utils.hpp" -#include "arm_compute/runtime/experimental/operators/CpuWinogradConv2d.h" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -struct acl_wino_convolution_fwd_t : public primitive_t { - using Op = arm_compute::experimental::op::CpuWinogradConv2d; - - struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), acp_() {} - - DECLARE_COMMON_PD_T( - "wino:acl", acl_wino_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); - - status_t init(engine_t *engine); - - acl_conv_conf_t acp_; - acl_post_ops_t post_ops; - - private: - status_t init_conf(); - }; - - acl_wino_convolution_fwd_t(const pd_t *apd) - : primitive_t(apd), acl_obj_(std::make_unique>()) {} - - status_t init(engine_t *engine) override; - - status_t execute(const exec_ctx_t &ctx) const override { - return execute_forward(ctx); - } - -private: - status_t execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - std::unique_ptr> acl_obj_; -}; // acl_wino_convolution_fwd_t - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl - -#endif // CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP diff --git a/src/cpu/aarch64/brgemm/brgemm.cpp b/src/cpu/aarch64/brgemm/brgemm.cpp index 6ed6cc59597..94e1c73fd3b 100644 --- a/src/cpu/aarch64/brgemm/brgemm.cpp +++ b/src/cpu/aarch64/brgemm/brgemm.cpp @@ -1,6 +1,7 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * Copyright 2023-2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -170,8 +171,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa, if (brg == nullptr) return status::invalid_arguments; if (transA || transB) return status::unimplemented; - brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, alpha, - beta, LDA, LDB, LDC, M, N, K, strides); + CHECK(brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, + alpha, beta, LDA, LDB, LDC, M, N, K, strides)); if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments; bool ldx_check = (brg->is_row_major()) ? (LDA < K) @@ -197,8 +198,8 @@ status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa, if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f) return status::unimplemented; - brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, alpha, - beta, LDA, LDC, M, N, strides); + CHECK(brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, + alpha, beta, LDA, LDC, M, N, strides)); const bool ldx_check = (LDA < N || LDC < N); if (ldx_check) return status::invalid_arguments; @@ -290,41 +291,52 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr, const auto &src_scales = attr->scales_.get(DNNL_ARG_SRC); const auto &wei_scales = attr->scales_.get(DNNL_ARG_WEIGHTS); - brg->with_scales = !src_scales.has_default_values() - || !wei_scales.has_default_values() + const bool has_src_scales = !src_scales.has_default_values(); + const bool has_wei_scales = !wei_scales.has_default_values(); + brg->with_scales = has_src_scales || has_wei_scales || brg->with_weights_scale_adjust; if (brg->with_scales) { // Note. the current version supports only two different output scale // types: - // 1) common (mask_ = 0) + // 1) common (mask = 0) // 2) per_n_dim_scale - broadcast across n dimension; // for convolution and inner product promitives it corresponds - // to "per_oc" mask_ = 1 << 1; for matmul - to - // mask_ = (1 << (ndims - 1))), where ndims is number of + // to "per_oc" mask = 1 << 1; for matmul - to + // mask = (1 << (ndims - 1))), where ndims is number of // dimensions for original matmul problem - // So if wei_scales.mask_ != 0 (not common) it's assumed here that scale - // type is per_n_dim_scale and driver which calls brgemm kernel checked - // that mask has correct value for this case - brg->is_oc_scale = wei_scales.mask_ != 0; + // So if wei_scales.get_mask() > 0 (not common) it's assumed here that + // scale type is per_n_dim_scale and driver which calls brgemm kernel + // checked that mask has correct value for this case + brg->is_oc_scale = wei_scales.get_mask() > 0; } const auto &dst_scales = attr->scales_.get(DNNL_ARG_DST); - brg->with_dst_scales = !dst_scales.has_default_values(); - const bool scales_ok = src_scales.mask_ == 0 && dst_scales.mask_ == 0 + const bool has_dst_scales = !dst_scales.has_default_values(); + brg->with_dst_scales = has_dst_scales; + const bool scales_ok + = IMPLICATION(has_src_scales, src_scales.get_mask() == 0) + && IMPLICATION(has_dst_scales, dst_scales.get_mask() == 0) && attr->scales_.has_default_values( {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}); if (!scales_ok) return status::unimplemented; auto init_zp_type = [&](brgemm_broadcast_t &zp_type, int mem_arg) -> status_t { - auto zero_points = attr->zero_points_; - - // common zero point type is supported for now - if (!zero_points.common(mem_arg)) return status::unimplemented; + const auto &zp = attr->zero_points_; + // Always init a default value; + zp_type = brgemm_broadcast_t::none; + + if (!zp.has_default_values(mem_arg)) { + int mask = zp.get_mask(mem_arg); + if (mask == 0) { + zp_type = brgemm_broadcast_t::per_tensor; + } else if (mask == (1 << 1)) { + zp_type = brgemm_broadcast_t::per_n; + } else { + return status::unimplemented; + } + } - zp_type = zero_points.has_default_values(mem_arg) - ? brgemm_broadcast_t::none - : brgemm_broadcast_t::per_tensor; return status::success; }; @@ -416,6 +428,11 @@ status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) { return status::success; } +status_t brgemm_desc_finalize(brgemm_t *brg) { + // TODO: implement functionality here similar to corresponding one in x64 + return status::success; +} + status_t brgemm_kernel_create( brgemm_kernel_t **brg_kernel, const brgemm_t &brg) { if (!brg_kernel) return status::invalid_arguments; @@ -512,11 +529,13 @@ int brgemm_cmp(const brgemm_t &lhs, const brgemm_t &rhs) { CMP_BRGEMM_FIELD(brgattr.hint_prfB.dist2); CMP_BRGEMM_FIELD(brgattr.hint_prfC.dist1); CMP_BRGEMM_FIELD(brgattr.hint_prfC.dist2); - CMP_BRGEMM_FIELD(brgattr.wary_tail_read); + CMP_BRGEMM_FIELD(brgattr.wary_A_k_tail_read); + CMP_BRGEMM_FIELD(brgattr.extendable_k); CMP_BRGEMM_FIELD(brgattr.generate_skip_accumulation); CMP_BRGEMM_FIELD(brgattr.bd_mask_level); CMP_BRGEMM_FIELD(brgattr.use_uker); CMP_BRGEMM_FIELD(brgattr.use_interleave_stores); + CMP_BRGEMM_FIELD(brgattr.b_is_vnni); CMP_BRGEMM_FIELD(brgattr.fpmath_mode); CMP_BRGEMM_FIELD(brgattr.LDA2); CMP_BRGEMM_FIELD(brgattr.LDB2); diff --git a/src/cpu/aarch64/brgemm/brgemm.hpp b/src/cpu/aarch64/brgemm/brgemm.hpp index f6531f5ff64..64ae821a1c5 100644 --- a/src/cpu/aarch64/brgemm/brgemm.hpp +++ b/src/cpu/aarch64/brgemm/brgemm.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * Copyright 2023 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -121,6 +121,11 @@ status_t DNNL_API brgemm_desc_set_postops(brgemm_t *brg, status_t DNNL_API brgemm_desc_set_attr( brgemm_t *brg, const brgemm_attr_t &brgattr); +/// Finalize BRGEMM descriptor. +/// +/// @param brg Output BRGEMM descriptor +status_t DNNL_API brgemm_desc_finalize(brgemm_t *brg); + /// Generates a BRGEMM kernel based on descriptor /// /// @param brg_kernel Output BRGEMM kernel diff --git a/src/cpu/aarch64/brgemm/brgemm_types.hpp b/src/cpu/aarch64/brgemm/brgemm_types.hpp index d6eb16cd6ff..0c5485ce8c7 100644 --- a/src/cpu/aarch64/brgemm/brgemm_types.hpp +++ b/src/cpu/aarch64/brgemm/brgemm_types.hpp @@ -133,7 +133,8 @@ struct DNNL_API brgemm_attr_t { = brgemm_kernel_prefetching_t::brgemm_prf_default; brgemm_prf_t hint_prfA, hint_prfB, hint_prfC; - bool wary_tail_read; + bool wary_A_k_tail_read {false}; + bool extendable_k {false}; bool generate_skip_accumulation; // Value of bd_mask_level specifies how bd_mask is used in brgemm kernel // 0 – bd_mask is not used @@ -147,6 +148,7 @@ struct DNNL_API brgemm_attr_t { // interleave stores or not bool use_interleave_stores; impl::fpmath_mode_t fpmath_mode = fpmath_mode::strict; + bool b_is_vnni {false}; // Second level leading dimension describing distance between 16-line // blocks in case of blocked layout. Used to calculate address of next // bd block. By default are equal to regular leading dimension parameters diff --git a/src/cpu/aarch64/brgemm/brgemm_utils.cpp b/src/cpu/aarch64/brgemm/brgemm_utils.cpp index 109436db6bf..c517d9a0856 100644 --- a/src/cpu/aarch64/brgemm/brgemm_utils.cpp +++ b/src/cpu/aarch64/brgemm/brgemm_utils.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2022-2023 Intel Corporation * Copyright 2023-2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,15 +48,18 @@ impl::data_type_t get_accum_datatype(brgemm_t *brg) { return brg->is_int8 ? data_type::s32 : data_type::f32; } -void init_kernel_datatype( +status_t init_kernel_datatype( brgemm_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) { - assert(dt_a != data_type::undef && dt_b != data_type::undef); + if (!(dt_a != data_type::undef && dt_b != data_type::undef)) + return status::unimplemented; brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8) && utils::one_of(dt_b, data_type::u8, data_type::s8); brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16); brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32); brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b); - assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16); + if (!(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16)) + return status::unimplemented; + return status::success; } void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha, @@ -88,7 +92,7 @@ void maybe_try_bf32(brgemm_t *brg) { // } -void set_isa_impl(brgemm_t *brg) { +status_t set_isa_impl(brgemm_t *brg) { auto is_isa_ok = [&](cpu_isa_t isa) { return mayiuse(isa) && // maybe IMPLICATION(brg->isa_user != isa_undef, @@ -96,19 +100,14 @@ void set_isa_impl(brgemm_t *brg) { one_of(brg->isa_user, isa_undef, isa); }; - if (brg->is_bf32) { - assert(!"unsupported case"); - } else if (brg->is_f32) { - brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512, - is_isa_ok(sve_256), sve_256); - } else if (brg->is_bf16) { - assert(!"unsupported case"); - } else if (brg->is_f16) { - assert(!"unsupported case"); - } else if (brg->is_int8) { + if (brg->is_bf32 || brg->is_bf16 || brg->is_f16) { + return status::unimplemented; + } else if (brg->is_f32 || brg->is_int8) { brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512, is_isa_ok(sve_256), sve_256); + return status::success; } + return status::success; } void set_brg_vmm(brgemm_t *brg) { @@ -187,7 +186,7 @@ inline size_t data_type_vnni_granularity(data_type_t data_type) { } status_t brgemm_blocking(brgemm_t *brg) { - set_isa_impl(brg); + CHECK(set_isa_impl(brg)); if (brg->isa_impl == isa_undef) return status::unimplemented; assert(!brg->is_dgmm); // should not be called from brdgmm set_brg_vmm(brg); @@ -296,10 +295,11 @@ status_t brdgmm_blocking(brgemm_t *brg) { return status::success; } -void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, - impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, - float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, - dim_t N, dim_t K, const brgemm_strides_t *strides, bool is_bf32) { +status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, + brgemm_batch_kind_t type, impl::data_type_t dt_a, + impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta, + dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K, + const brgemm_strides_t *strides, bool is_bf32) { init_common_conf(brg, type, alpha, beta, strides); @@ -307,7 +307,7 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->dt_a = brg->is_row_major() ? dt_a : dt_b; brg->dt_b = brg->is_row_major() ? dt_b : dt_a; - init_kernel_datatype(brg, brg->dt_a, brg->dt_b); + CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b)); brg->dt_c = get_accum_datatype(brg); brg->dt_d = brg->dt_c; @@ -319,7 +319,7 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->typesize_D = types::data_type_size(brg->dt_d); brg->isa_user = isa; - set_isa_impl(brg); + CHECK(set_isa_impl(brg)); brg->is_bf32 = false; brg->has_int8_vnni = true; @@ -352,11 +352,13 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->rd_step = has_no_vnni_compute_instruction ? 1 : data_type_vnni_granularity(brg->dt_b); + return status::success; } -void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, - impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, - float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N, +status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, + brgemm_batch_kind_t type, impl::data_type_t dt_a, + impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta, + dim_t LDA, dim_t LDC, dim_t M, dim_t N, const brgemm_strides_t *strides) { init_common_conf(brg, type, alpha, beta, strides); @@ -365,7 +367,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->dt_a = dt_a; brg->dt_b = dt_b; - init_kernel_datatype(brg, brg->dt_a, brg->dt_b); + CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b)); brg->dt_c = get_accum_datatype(brg); brg->dt_d = brg->dt_c; @@ -394,6 +396,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->bcast_dim = M; brg->load_dim = N; + return status::success; } } // namespace brgemm_utils @@ -402,4 +405,4 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, } // namespace impl } // namespace dnnl -//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s \ No newline at end of file +//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/aarch64/brgemm/brgemm_utils.hpp b/src/cpu/aarch64/brgemm/brgemm_utils.hpp index 485b5fde961..563a5d734ac 100644 --- a/src/cpu/aarch64/brgemm/brgemm_utils.hpp +++ b/src/cpu/aarch64/brgemm/brgemm_utils.hpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2022 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,20 +45,21 @@ status_t brdgmm_blocking(brgemm_t *brg); * having to depend on BRGeMM's API. An additional feature is that this * function can be modified depending on needs without requiring changes * at the API level. */ -void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, - impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, - float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, - dim_t N, dim_t K, const brgemm_strides_t *strides = nullptr, - bool is_bf32 = false); +status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, + brgemm_batch_kind_t type, impl::data_type_t dt_a, + impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta, + dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K, + const brgemm_strides_t *strides = nullptr, bool is_bf32 = false); /* The purpose of this function is to enable initialization of brgemm values * and then call additional functions like blocking heuristics without * having to depend on BRDGeMM's API. An additional feature is that this * function can be modified depending on needs without requiring changes * at the API level. */ -void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, - impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, - float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N, +status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, + brgemm_batch_kind_t type, impl::data_type_t dt_a, + impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta, + dim_t LDA, dim_t LDC, dim_t M, dim_t N, const brgemm_strides_t *strides = nullptr); } // namespace brgemm_utils diff --git a/src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp index 087fb52935a..b3f02816761 100644 --- a/src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -766,10 +767,38 @@ void jit_brgemm_kernel_t::read_params() { void jit_brgemm_kernel_t::zero_accumulators(int bd_block2, bool is_bdb_tail, int ld_block2, bool is_ld_tail, bool skip_accumulation) { int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; + const bool need_to_apply_beta = brg.beta != 0.f; for_(int bd = 0; bd < bd_block; bd++) for (int ld = 0; ld < ld_block2; ld++) { auto zmm = accm(ld_block2, bd, ld); - eor(zmm.d, zmm.d, zmm.d); + // This part is moved here from apply_alpha_beta function so that fadd instruction can be avoided. + // This is also required only when K is blocked. + if (need_to_apply_beta) { + const bool is_tail = is_ld_tail && ld + 1 == ld_block2; + const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask; + + const int offset = C_offset(bd, ld); + + int base_offset = 0; + auto x_addr = reg_aux_C; + + if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) { + add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0); + base_offset = offset; + x_addr = reg_tmp_; + } + LD_MUL_VL(ld1w, zmm.s, k_mask, x_addr, offset - base_offset, 4); + + const bool need_init_beta_vmm = brg.beta != 1.f; + auto vmm_beta = z_tail_mask(); + if (need_init_beta_vmm) { + auto wreg_tmp = WReg(reg_tmp_gpr.getIdx()); + mov_imm(wreg_tmp, float2int(static_cast(brg.beta))); + dup(vmm_beta.s, wreg_tmp); + fmul(zmm.s, zmm.s, vmm_beta.s); + } + } else + eor(zmm.d, zmm.d, zmm.d); } } @@ -790,58 +819,6 @@ void jit_brgemm_kernel_t::apply_alpha_beta( if (dq2ps_required) { scvtf(vmm.s, P_ALL_ONE / T_m, vmm.s); } if (apply_alpha) { fmul(vmm.s, vmm.s, vmm_alpha.s); } } - - if (brg.beta == 0.f) return; - const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required; - const bool need_init_beta_vmm = brg.beta != 1.f; - auto vmm_prev_dst = z_tmp_1(); - auto vmm_beta = z_tail_mask(); - if (need_init_beta_vmm) { - auto wreg_tmp = WReg(reg_tmp_gpr.getIdx()); - mov_imm(wreg_tmp, float2int(static_cast(brg.beta))); - dup(vmm_beta.s, wreg_tmp); - } - - int base_offset = 0; - auto x_addr = reg_aux_C; - for_(int bd = 0; bd < bd_block; bd++) - for (int ld = 0; ld < ld_block2; ld++) { - const bool is_tail = is_ld_tail && ld + 1 == ld_block2; - const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask; - auto vmm = accm(ld_block2, bd, ld); - if (use_vadd_for_beta) { - if (brg.is_int8) { - assert(!"unsupported\n"); - } else { - ZRegS z_masked = vmm.s; - ZRegS z(vmm.getIdx()); - - const int offset = C_offset(bd, ld); - - if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) { - add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0); - base_offset = offset; - x_addr = reg_tmp_; - } - LD_MUL_VL(ld1w, vmm_prev_dst.s, k_mask, x_addr, - offset - base_offset, 4); - if (is_ld_tail) { - movprfx(z_masked, k_mask / T_z, z); - fadd(z_masked, k_mask / T_m, vmm_prev_dst.s); - } else { - fadd(z_masked, z_masked, vmm_prev_dst.s); - } - } - } else { - add_imm(X_DEFAULT_ADDR, reg_aux_C, C_offset(bd, ld), X_TMP_0); - ld1w(vmm_prev_dst.s, k_mask / T_z, ptr(X_DEFAULT_ADDR)); - if (brg.beta == 1.f) { - fadd(vmm.s, vmm.s, vmm_prev_dst.s); - } else { - fmla(vmm.s, P_ALL_ONE / T_m, vmm_prev_dst.s, vmm_beta.s); - } - } - } } void jit_brgemm_kernel_t::apply_post_ops( @@ -1414,7 +1391,8 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, || brg.zp_type_a != brgemm_broadcast_t::none); if (brg.req_cal_comp_pads || comp_vpad) assert(!"unsupported\n"); - bool maybe_load_bytes = (rows_for_rd_tail > 0 || brg.brgattr.wary_tail_read) + bool maybe_load_bytes + = (rows_for_rd_tail > 0 || brg.brgattr.wary_A_k_tail_read) && is_rd_tail && rd_tail_size != 0 && (brg.is_bf16 || brg.is_int8); if (n_bcast_1_load) { for (int rd = 0; rd < rd_loop; rd += brg.rd_step) { @@ -1424,7 +1402,7 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, auto rows_by_load_bytes = have_to_load_bytes ? rows_for_rd_tail : 0; for (int bd = bd_b; bd < bd_e && !is_emdbd; bd++) { const auto bd_by_load_bytes = (bd >= bd_e - rows_by_load_bytes - || brg.brgattr.wary_tail_read); + || brg.brgattr.wary_A_k_tail_read); broadcast(bcst(bd), A_offset(bd, rd), have_to_load_bytes && bd_by_load_bytes, brg.dt_a); } @@ -1464,7 +1442,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, int base_offset = 0; for (int rd = 0; rd < rd_loop; rd += brg.rd_step) { - int prefetch_count_B = 0; for (int ld = 0; ld < ld_block2; ld++) { const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE; if (brg.dt_b == data_type::f16) { @@ -1492,17 +1469,11 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, if (!is_emdbd) { const auto bd_by_load_bytes = (bd >= bd_e - rows_by_load_bytes - || brg.brgattr.wary_tail_read); + || brg.brgattr.wary_A_k_tail_read); broadcast(bcst(), A_offset(bd, rd), have_to_load_bytes && bd_by_load_bytes, brg.dt_a); } - if (prefetch_count_B < ld_block2) { - add_imm(X_DEFAULT_ADDR, reg_aux_B, - B_offset(prefetch_count_B++, rd) - + brg.LDB * brg.rd_block * brg.typesize_B, - X_TMP_0); - prfm(PLDL1KEEP, ptr(X_DEFAULT_ADDR)); - } + //The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced. for (int ld = 0; ld < ld_block2; ld++) { auto zmm = accm(ld_block2, bd, ld); if (is_emdbd) { @@ -1876,7 +1847,7 @@ void jit_brgemm_kernel_t::bdb_loop() { } void jit_brgemm_kernel_t::generate() { - size_t simd_w_; + size_t simd_w_ = 0; switch (brg.isa_impl) { case sve_512: simd_w_ = cpu_isa_traits::vlen / sizeof(float); @@ -1884,7 +1855,10 @@ void jit_brgemm_kernel_t::generate() { case sve_256: simd_w_ = cpu_isa_traits::vlen / sizeof(float); break; - default: assert(!"unsupported isa"); + default: { + assert(!"unsupported isa"); + return; + } } preamble(); if (simd_w_ != cpu_sveLen / sizeof(float)) { @@ -1935,7 +1909,8 @@ brgemm_attr_t::brgemm_attr_t() , hint_innermost_loop(brgemm_ld_loop_innermost) , hint_loop_order(brgemm_kernel_loop_order_t::brgemm_lo_default) , hint_prefetching(brgemm_kernel_prefetching_t::brgemm_prf_default) - , wary_tail_read(true) + , wary_A_k_tail_read(true) + , extendable_k(false) , generate_skip_accumulation(false) , bd_mask_level(0) , use_uker(false) diff --git a/src/cpu/aarch64/cpu_isa_traits.hpp b/src/cpu/aarch64/cpu_isa_traits.hpp index 64a6368b654..8bf338cf08b 100644 --- a/src/cpu/aarch64/cpu_isa_traits.hpp +++ b/src/cpu/aarch64/cpu_isa_traits.hpp @@ -31,8 +31,8 @@ #define XBYAK_USE_MMAP_ALLOCATOR #endif -#include "cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64.h" -#include "cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_util.h" +#include "xbyak_aarch64/xbyak_aarch64/xbyak_aarch64.h" +#include "xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_util.h" namespace dnnl { namespace impl { diff --git a/src/cpu/aarch64/cpu_reducer.cpp b/src/cpu/aarch64/cpu_reducer.cpp index 1e1c947dc96..4361e3c0c21 100644 --- a/src/cpu/aarch64/cpu_reducer.cpp +++ b/src/cpu/aarch64/cpu_reducer.cpp @@ -99,7 +99,7 @@ using namespace Xbyak_aarch64; template struct reducer_2d_driver_t : public jit_generator { - using data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; reducer_2d_driver_t(int n_src, size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst) @@ -122,7 +122,7 @@ template struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t { DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t) - using data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; void operator()( data_t *dst, const data_t *srcs, size_t ny, size_t nx) override { @@ -134,7 +134,7 @@ struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t { const int vlen = cpu_isa_traits::vlen; const int typesize - = sizeof(typename dnnl::impl::prec_traits::type); + = sizeof(typename dnnl::impl::prec_traits_t::type); XReg reg_dst = abi_param1; XReg reg_src = abi_param2; XReg reg_ny = abi_param3; diff --git a/src/cpu/aarch64/cpu_reducer.hpp b/src/cpu/aarch64/cpu_reducer.hpp index 0ccbd446948..7e6566c32cc 100644 --- a/src/cpu/aarch64/cpu_reducer.hpp +++ b/src/cpu/aarch64/cpu_reducer.hpp @@ -169,7 +169,7 @@ struct reducer_2d_driver_t; */ template struct cpu_reducer_t { - typedef typename prec_traits::type data_t; + typedef typename prec_traits_t::type data_t; struct conf_t { conf_t() = default; @@ -249,7 +249,7 @@ struct cpu_reducer_t { template struct cpu_reducer_2d_t { - typedef typename prec_traits::type data_t; + typedef typename prec_traits_t::type data_t; struct conf_t { conf_t() = default; @@ -334,7 +334,7 @@ struct cpu_reducer_2d_t { /** simple 1d accumulator: y[:] += x[:] */ template struct cpu_accumulator_1d_t { - typedef typename prec_traits::type data_t; + typedef typename prec_traits_t::type data_t; cpu_accumulator_1d_t(); ~cpu_accumulator_1d_t(); diff --git a/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp b/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp index 435f12b16f1..00163cbecaa 100644 --- a/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp +++ b/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2019-2023 Intel Corporation -* Copyright 2021-2023 FUJITSU LIMITED +* Copyright 2021-2024 FUJITSU LIMITED * Copyright 2022 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -475,9 +475,87 @@ void jit_uni_eltwise_injector_f32::elu_compute_vector_fwd( h->mov(vmm_src, p_mask / T_m, vmm_aux3); } +template +void jit_uni_eltwise_injector_f32< + isa>::tanh_polynomial_approx_compute_vector_fwd(const TRegS &vmm_src) { + + if (!utils::one_of(isa, sve_512)) return; + + using namespace Xbyak_aarch64::util; + + const int tanh_n_polynomials = 32; + + // Register mapping + TRegS vmm_dst = vmm_aux1, vmm_src_shift = vmm_aux1, vmm_coeff = vmm_aux1, + vmm_pol = vmm_aux2, vmm_indices = vmm_aux3, vmm_tmp = vmm_aux3, + vmm_src_pos = vmm_aux4, vmm_sign = vmm_aux4; + + const auto &mask = PReg(6); // avoid pred regs used in *conv_kernel* + + // Helper function to gather polynomial coefficients + auto gather_coefficient = [&](TRegS vmm_coeff, int coeff_idx, + TRegS vmm_pol_idx) { + h->add_imm(h->X_TMP_1, x_table, + table_off(tanh_pol_table, coeff_idx * tanh_n_polynomials), + h->X_TMP_0); + h->ld1w(ZRegS(IDX(vmm_coeff)), p_all, + ptr(h->X_TMP_1, ZRegS(IDX(vmm_pol_idx)), SXTW)); + }; + + // because tanh(x) = -tanh(-x), we extract sign to make x postive + // and reapply sign at the end + h->fabs(vmm_src_pos, p_all / T_z, vmm_src); + + // Compute indices for the table lookup + h->sub(ZRegS(IDX(vmm_indices)), ZRegS(IDX(vmm_src_pos)), + ZRegS(IDX(table_val(tanh_idx_bias, z_tmp)))); + h->and_(ZRegD(IDX(vmm_indices)), ZRegD(IDX(vmm_indices)), + ZRegD(IDX(table_val(tanh_idx_mask, z_tmp)))); + h->lsr(ZRegD(IDX(vmm_indices)), ZRegD(IDX(vmm_indices)), 20); + + // Argument reduction + h->and_(ZRegD(IDX(vmm_src_shift)), ZRegD(IDX(vmm_src_pos)), + ZRegD(IDX(table_val(tanh_idx_mask, z_tmp)))); + h->fsub(vmm_src_pos, vmm_src_pos, ZRegS(IDX(vmm_src_shift))); + + gather_coefficient(vmm_pol, 6, vmm_indices); + for (int deg = 5; deg >= 0; --deg) { + gather_coefficient(vmm_coeff, deg, vmm_indices); + h->fmad(vmm_pol, p_all / T_m, vmm_src_pos, vmm_coeff); + } + + // Restore src_pos + h->fabs(vmm_src_pos, p_all / T_z, vmm_src); + + // Now Blend the results + // [saturation_ubound; +inf] : return +/- 1 + table_val(one, vmm_dst); + + // [linear_ubound; saturation_lbound] : return +/- P(x) + table_val(tanh_saturation_lbound, vmm_tmp); + h->fcmgt(PRegS(IDX(mask)), p_all / T_z, vmm_tmp, vmm_src_pos); + h->sel(vmm_dst, mask / T_m, vmm_pol, vmm_dst); + + // [0; linear_ubound] : return x + table_val(tanh_linear_ubound, vmm_tmp); + h->fcmgt(PRegS(IDX(mask)), p_all / T_z, vmm_tmp, vmm_src_pos); + h->sel(vmm_dst, mask / T_m, vmm_src_pos, vmm_dst); + + // Reapply sign and return + h->and_(ZRegD(IDX(vmm_sign)), ZRegD(IDX(vmm_src)), + ZRegD(IDX(table_val(sign_mask, z_tmp)))); + h->eor(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_dst)), ZRegD(IDX(vmm_sign))); +} + template void jit_uni_eltwise_injector_f32::tanh_compute_vector_fwd( const TRegS &vmm_src) { + + if (utils::one_of(isa, sve_512)) { + tanh_polynomial_approx_compute_vector_fwd(vmm_src); + return; + } + // tanh(x) = x(1 + (-1/3)x^2) for |x| < tanh_range // tanh(x) = 1 - 2/(1 + exp(2 x)) for otherwise @@ -918,10 +996,87 @@ void jit_uni_eltwise_injector_f32::log_compute_vector_fwd( } h->L(exitL); } +template +void jit_uni_eltwise_injector_f32< + isa>::gelu_erf_minimax_approx_compute_vector_fwd(const TRegS &vmm_src) { + if (isa != sve_512) { // TODO: change this condition based on cpu id. + return; + } + + // register mapping + TRegS vmm_pol = vmm_aux0; + TRegS vmm_src_pos = vmm_aux1; + TRegS vmm_indices = vmm_aux2; + TRegS vmm_tmp = vmm_aux3; // this is for immediate read after write + + auto gather_coefficient + = [&](TRegS vmm_coeff, int coeff_idx, TRegS vmm_pol_idx) { + // we actually have 25 polynomials but pad to avoid unaligned accesses/ + int gelu_erf_n_polynomials = 32; + h->add_imm(h->X_TMP_1, x_table, + table_off(gelu_erf_minimax_pol, + coeff_idx * gelu_erf_n_polynomials), + h->X_TMP_0); + h->ld1w(ZRegS(IDX(vmm_coeff)), p_all / T_z, + ptr(h->X_TMP_1, ZRegS(IDX(vmm_pol_idx)), SXTW)); + }; + + // we use the erf function symmetry erf(-x) = -erf(x) + // So we make x positive, we will reapply the sign after erf evaluation + h->fabs(vmm_src_pos, p_all / T_z, vmm_src); + + // Compute indices for table lookup + h->add(vmm_indices, vmm_src_pos, + ZRegS(IDX(table_val(gelu_erf_idx_bias, z_tmp, 0)))); + + // An arithmetic shift is needed to properly map denormals to + // their polynomial. we shift by 21 as we use 2 bits of mantissa + // for indexing. + h->asr(ZRegS(IDX(vmm_indices)), ZRegS(IDX(vmm_indices)), 21); + + // Apply special rules + h->smax(vmm_indices, p_all / T_z, + ZRegS(IDX(table_val(gelu_erf_one, z_tmp)))); + h->smin(vmm_indices, p_all / T_z, + ZRegS(IDX(table_val(gelu_erf_twenty_four, z_tmp)))); + + // We have to check + // index = x_pos > rbound ? 23 : index; + // for erf to return -1/1 when we should. + h->fcmlt(p_mask.s, p_all / T_z, vmm_src_pos, + ZRegS(IDX(table_val(gelu_erf_rbound, z_tmp)))); + h->sel(vmm_indices, p_mask, vmm_indices, + ZRegS(IDX(table_val(gelu_erf_twenty_three, z_tmp)))); + + // Adjusting indices + h->mul(ZRegS(IDX(vmm_indices)), sizeof(float)); + + // Evaluate the polynomial + gather_coefficient(vmm_pol, 5, vmm_indices); + for (int deg = 4; deg >= 0; --deg) { + gather_coefficient(vmm_tmp, deg, vmm_indices); + h->fmad(vmm_pol, p_all / T_z, vmm_src_pos, vmm_tmp); + } + // Set the sign of vmm_pol properly + h->mov(ZRegD(IDX(vmm_tmp)), ZRegD(IDX(vmm_src))); + h->and_(ZRegD(IDX(vmm_tmp)), ZRegD(IDX(vmm_tmp)), + ZRegD(IDX(table_val(sign_mask, z_tmp)))); + h->eor(ZRegD(IDX(vmm_pol)), p_all / T_z, ZRegD(IDX(vmm_tmp))); + + // Compute the final output + h->fadd(vmm_pol, vmm_pol, ZRegS(IDX(table_val(one, z_tmp)))); + h->fmul(vmm_src, p_all / T_z, vmm_pol); + h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(half, z_tmp)))); +} template void jit_uni_eltwise_injector_f32::gelu_erf_compute_vector_fwd( const TRegS &vmm_src) { + + if (isa == sve_512) { // TODO: consider performance improvement for lower ISA + gelu_erf_minimax_approx_compute_vector_fwd(vmm_src); + return; + } // Here we approximate erf(x) using the expression by // Abramowitz and Stegun from ``Handbook of Mathematical // Functions'' @@ -1657,9 +1812,248 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { {bwd_mish_max_x_for_equation_f, {0x41b17217, true}}}; // tanh(x) constants for four interval approximation - static const table_t tanh_consts { - {tanh_range, {0x3d4ccccd, true}}, + // and for polynomial approximation + static const table_t tanh_consts {{tanh_range, {0x3d4ccccd, true}}, {tanh_m1d3, {0xbeaaaaab, true}}, + {tanh_idx_bias, {0x39800000, true}}, + {tanh_idx_mask, {0xffc00000, true}}, + {tanh_linear_ubound, {0x39ddb3d7, true}}, + {tanh_saturation_lbound, {0x41102cb3, true}}}; + + // tanh(x) polynomial approximation + // For each coefficient, there is 32 entries + static const table_t tanh_polynomial_table { + // coefficients of degree 0 + {tanh_pol_table, {0x00000000, false}}, + {tanh_pol_table, {0x39bfffff, false}}, + {tanh_pol_table, {0x39ffffff, false}}, + {tanh_pol_table, {0x3a3ffffe, false}}, + {tanh_pol_table, {0x3a7ffffb, false}}, + {tanh_pol_table, {0x3abffff7, false}}, + {tanh_pol_table, {0x3affffeb, false}}, + {tanh_pol_table, {0x3b3fffdc, false}}, + {tanh_pol_table, {0x3b7fffab, false}}, + {tanh_pol_table, {0x3bbfff70, false}}, + {tanh_pol_table, {0x3bfffeab, false}}, + {tanh_pol_table, {0x3c3ffdc0, false}}, + {tanh_pol_table, {0x3c7ffaab, false}}, + {tanh_pol_table, {0x3cbff701, false}}, + {tanh_pol_table, {0x3cffeaad, false}}, + {tanh_pol_table, {0x3d3fdc08, false}}, + {tanh_pol_table, {0x3d7faacd, false}}, + {tanh_pol_table, {0x3dbf7081, false}}, + {tanh_pol_table, {0x3dfeacc9, false}}, + {tanh_pol_table, {0x3e3dc7fd, false}}, + {tanh_pol_table, {0x3e7acbf5, false}}, + {tanh_pol_table, {0x3eb77a9f, false}}, + {tanh_pol_table, {0x3eec9a9f, false}}, + {tanh_pol_table, {0x3f22991f, false}}, + {tanh_pol_table, {0x3f42f7d6, false}}, + {tanh_pol_table, {0x3f67b7cc, false}}, + {tanh_pol_table, {0x3f76ca83, false}}, + {tanh_pol_table, {0x3f7ebbe9, false}}, + {tanh_pol_table, {0x3f7fd40c, false}}, + {tanh_pol_table, {0x3f7fff32, false}}, + {tanh_pol_table, {0x3f7ffffc, false}}, + {tanh_pol_table, {0x3f800000, false}}, + // coefficients of degree 1 + {tanh_pol_table, {0x3f800000, false}}, + {tanh_pol_table, {0x3f800018, false}}, + {tanh_pol_table, {0x3f7fffe8, false}}, + {tanh_pol_table, {0x3f7fffda, false}}, + {tanh_pol_table, {0x3f7fffdc, false}}, + {tanh_pol_table, {0x3f7fffdc, false}}, + {tanh_pol_table, {0x3f7fffac, false}}, + {tanh_pol_table, {0x3f7fff70, false}}, + {tanh_pol_table, {0x3f7ffeec, false}}, + {tanh_pol_table, {0x3f7ffdc0, false}}, + {tanh_pol_table, {0x3f7ffbed, false}}, + {tanh_pol_table, {0x3f7ff704, false}}, + {tanh_pol_table, {0x3f7feff5, false}}, + {tanh_pol_table, {0x3f7fdbca, false}}, + {tanh_pol_table, {0x3f7fbfff, false}}, + {tanh_pol_table, {0x3f7f7041, false}}, + {tanh_pol_table, {0x3f7f009b, false}}, + {tanh_pol_table, {0x3f7dc36c, false}}, + {tanh_pol_table, {0x3f7c0aa8, false}}, + {tanh_pol_table, {0x3f7734b8, false}}, + {tanh_pol_table, {0x3f70a4de, false}}, + {tanh_pol_table, {0x3f5f1fd8, false}}, + {tanh_pol_table, {0x3f495493, false}}, + {tanh_pol_table, {0x3f18b9ec, false}}, + {tanh_pol_table, {0x3ed706cb, false}}, + {tanh_pol_table, {0x3e390b06, false}}, + {tanh_pol_table, {0x3d90b11f, false}}, + {tanh_pol_table, {0x3c21a053, false}}, + {tanh_pol_table, {0x3aaf7fdb, false}}, + {tanh_pol_table, {0x37ccc1a3, false}}, + {tanh_pol_table, {0x355c6733, false}}, + {tanh_pol_table, {0x00000000, false}}, + // coefficients of degree 2 + {tanh_pol_table, {0x00000000, false}}, + {tanh_pol_table, {0xbe4e0ff1, false}}, + {tanh_pol_table, {0x3d25b1b1, false}}, + {tanh_pol_table, {0x3d6b6dab, false}}, + {tanh_pol_table, {0x3c9fb1d5, false}}, + {tanh_pol_table, {0xbabff06f, false}}, + {tanh_pol_table, {0x3c07b3f6, false}}, + {tanh_pol_table, {0xbb3fc1bc, false}}, + {tanh_pol_table, {0x3a9f5921, false}}, + {tanh_pol_table, {0xbbbf06f2, false}}, + {tanh_pol_table, {0xbbb0f402, false}}, + {tanh_pol_table, {0xbc47db9e, false}}, + {tanh_pol_table, {0xbc73d5e7, false}}, + {tanh_pol_table, {0xbca25bda, false}}, + {tanh_pol_table, {0xbcfca780, false}}, + {tanh_pol_table, {0xbd40e07c, false}}, + {tanh_pol_table, {0xbd7dab03, false}}, + {tanh_pol_table, {0xbdbe4a0f, false}}, + {tanh_pol_table, {0xbdfb14a5, false}}, + {tanh_pol_table, {0xbe36cc8d, false}}, + {tanh_pol_table, {0xbe6bd102, false}}, + {tanh_pol_table, {0xbe9fe7c5, false}}, + {tanh_pol_table, {0xbeba0f10, false}}, + {tanh_pol_table, {0xbec206a8, false}}, + {tanh_pol_table, {0xbea3c388, false}}, + {tanh_pol_table, {0xbe277d62, false}}, + {tanh_pol_table, {0xbd8b7960, false}}, + {tanh_pol_table, {0xbc209f49, false}}, + {tanh_pol_table, {0xbaad44ca, false}}, + {tanh_pol_table, {0xb7c6eeac, false}}, + {tanh_pol_table, {0xb663aa41, false}}, + {tanh_pol_table, {0x00000000, false}}, + // coefficients of degree 3 + {tanh_pol_table, {0x00000000, false}}, + {tanh_pol_table, {0x45b3ae96, false}}, + {tanh_pol_table, {0xc414eb20, false}}, + {tanh_pol_table, {0xc450e02e, false}}, + {tanh_pol_table, {0xc3152b4e, false}}, + {tanh_pol_table, {0xbead2f56, false}}, + {tanh_pol_table, {0xc2162e02, false}}, + {tanh_pol_table, {0xbeb4bd5a, false}}, + {tanh_pol_table, {0xc11a59a4, false}}, + {tanh_pol_table, {0xbed2f507, false}}, + {tanh_pol_table, {0xc020d32c, false}}, + {tanh_pol_table, {0x3dd0f506, false}}, + {tanh_pol_table, {0xbf2a75e2, false}}, + {tanh_pol_table, {0xbff950e3, false}}, + {tanh_pol_table, {0xbed47334, false}}, + {tanh_pol_table, {0xbe809b8c, false}}, + {tanh_pol_table, {0xbeb64532, false}}, + {tanh_pol_table, {0xbe961a5b, false}}, + {tanh_pol_table, {0xbe9b63ac, false}}, + {tanh_pol_table, {0xbea0d4b2, false}}, + {tanh_pol_table, {0xbe828a77, false}}, + {tanh_pol_table, {0xbe378612, false}}, + {tanh_pol_table, {0xbdc20908, false}}, + {tanh_pol_table, {0x3d2d3957, false}}, + {tanh_pol_table, {0x3dd46e89, false}}, + {tanh_pol_table, {0x3db3f629, false}}, + {tanh_pol_table, {0x3d2c5e7b, false}}, + {tanh_pol_table, {0x3bd20403, false}}, + {tanh_pol_table, {0x3a59dfae, false}}, + {tanh_pol_table, {0x3770af45, false}}, + {tanh_pol_table, {0x372cc014, false}}, + {tanh_pol_table, {0x00000000, false}}, + // coefficients of degree 4 + {tanh_pol_table, {0x00000000, false}}, + {tanh_pol_table, {0xcc981a1b, false}}, + {tanh_pol_table, {0x4a7edd3d, false}}, + {tanh_pol_table, {0x4ab1007c, false}}, + {tanh_pol_table, {0x48fedd9c, false}}, + {tanh_pol_table, {0x41a557b5, false}}, + {tanh_pol_table, {0x477ee32a, false}}, + {tanh_pol_table, {0x422557f5, false}}, + {tanh_pol_table, {0x45ff3ce4, false}}, + {tanh_pol_table, {0x42a55641, false}}, + {tanh_pol_table, {0x446e0867, false}}, + {tanh_pol_table, {0xc33dc19a, false}}, + {tanh_pol_table, {0x42915214, false}}, + {tanh_pol_table, {0x43af4fad, false}}, + {tanh_pol_table, {0x4110fe88, false}}, + {tanh_pol_table, {0xc1099b75, false}}, + {tanh_pol_table, {0x3fc8a8dc, false}}, + {tanh_pol_table, {0xbfbeaef5, false}}, + {tanh_pol_table, {0xbe365aad, false}}, + {tanh_pol_table, {0x3f4d9652, false}}, + {tanh_pol_table, {0x3ddfa08f, false}}, + {tanh_pol_table, {0x3e34e9b8, false}}, + {tanh_pol_table, {0x3e2d07a6, false}}, + {tanh_pol_table, {0x3dc63567, false}}, + {tanh_pol_table, {0x3cdaeb78, false}}, + {tanh_pol_table, {0xbcd17537, false}}, + {tanh_pol_table, {0xbc92829c, false}}, + {tanh_pol_table, {0xbb43ab99, false}}, + {tanh_pol_table, {0xb9b471dd, false}}, + {tanh_pol_table, {0xb6baad5a, false}}, + {tanh_pol_table, {0xb78bafc7, false}}, + {tanh_pol_table, {0x00000000, false}}, + // coefficients of degree 5 + {tanh_pol_table, {0x00000000, false}}, + {tanh_pol_table, {0x52f688d5, false}}, + {tanh_pol_table, {0xd0505c72, false}}, + {tanh_pol_table, {0xd08f98e3, false}}, + {tanh_pol_table, {0xce505cc9, false}}, + {tanh_pol_table, {0xc7162b8a, false}}, + {tanh_pol_table, {0xcc5061d6, false}}, + {tanh_pol_table, {0xc7162bdf, false}}, + {tanh_pol_table, {0xca50b37f, false}}, + {tanh_pol_table, {0xc7162a3a, false}}, + {tanh_pol_table, {0xc8422086, false}}, + {tanh_pol_table, {0x471a714e, false}}, + {tanh_pol_table, {0xc5ece1f1, false}}, + {tanh_pol_table, {0xc70e3d90, false}}, + {tanh_pol_table, {0xc3eba94a, false}}, + {tanh_pol_table, {0x43e0c424, false}}, + {tanh_pol_table, {0xc21f4552, false}}, + {tanh_pol_table, {0x42217cc8, false}}, + {tanh_pol_table, {0x405e7dc4, false}}, + {tanh_pol_table, {0xc10dd401, false}}, + {tanh_pol_table, {0x3e96b602, false}}, + {tanh_pol_table, {0xbd1a6d2f, false}}, + {tanh_pol_table, {0xbd393883, false}}, + {tanh_pol_table, {0xbd674682, false}}, + {tanh_pol_table, {0xbd310016, false}}, + {tanh_pol_table, {0xb961e269, false}}, + {tanh_pol_table, {0x3ba32495, false}}, + {tanh_pol_table, {0x3a7680d5, false}}, + {tanh_pol_table, {0x38b3173c, false}}, + {tanh_pol_table, {0x35a9deea, false}}, + {tanh_pol_table, {0x375c3f2a, false}}, + {tanh_pol_table, {0x00000000, false}}, + // coefficients of degree 6 + {tanh_pol_table, {0x00000000, false}}, + {tanh_pol_table, {0xd8995ed1, false}}, + {tanh_pol_table, {0x558285ea, false}}, + {tanh_pol_table, {0x55b2cd69, false}}, + {tanh_pol_table, {0x53028625, false}}, + {tanh_pol_table, {0x4bc9991f, false}}, + {tanh_pol_table, {0x5082898a, false}}, + {tanh_pol_table, {0x4b4999b3, false}}, + {tanh_pol_table, {0x4e02c07c, false}}, + {tanh_pol_table, {0x4ac99764, false}}, + {tanh_pol_table, {0x4b72c822, false}}, + {tanh_pol_table, {0xca40c0e1, false}}, + {tanh_pol_table, {0x489413e4, false}}, + {tanh_pol_table, {0x49b12224, false}}, + {tanh_pol_table, {0x46134c4e, false}}, + {tanh_pol_table, {0xc60c2d57, false}}, + {tanh_pol_table, {0x43c83910, false}}, + {tanh_pol_table, {0xc3c872d1, false}}, + {tanh_pol_table, {0xc186bc9e, false}}, + {tanh_pol_table, {0x42325bc3, false}}, + {tanh_pol_table, {0xbf2ffa4a, false}}, + {tanh_pol_table, {0x3d9a203c, false}}, + {tanh_pol_table, {0xbc545a43, false}}, + {tanh_pol_table, {0xbae08fee, false}}, + {tanh_pol_table, {0x3c80225d, false}}, + {tanh_pol_table, {0x3b1fd1df, false}}, + {tanh_pol_table, {0xba36b9d1, false}}, + {tanh_pol_table, {0xb91de544, false}}, + {tanh_pol_table, {0xb71f100f, false}}, + {tanh_pol_table, {0xb408e2ed, false}}, + {tanh_pol_table, {0xb685fec8, false}}, + {tanh_pol_table, {0x00000000, false}}, }; // soft_relu(x) constants @@ -1703,6 +2097,215 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { {gelu_erf_pol, {0xbfba00e3, true}}, // p4 = -1.453152027f {gelu_erf_pol, {0x3f87dc22, true}}, // p5 = 1.061405429f }; + // gelu_erf(x) constants for direct erf approximation (formula defined) + static const table_t gelu_erf_minimax_consts { + {gelu_erf_idx_bias, {0xc21fffff, true}}, + {gelu_erf_rbound, {0x40b15cee, true}}, + {gelu_erf_one, {0x00000001, true}}, + {gelu_erf_twenty_three, {0x00000017, true}}, + {gelu_erf_twenty_four, {0x00000018, true}}, + }; + // gelu_erf(x) minimax polynomials for piecewise approximaxtion + static const table_t gelu_erf_minimax_polynomial { + // coefficients of degree 0 + {gelu_erf_minimax_pol, {0xa6f2cb94, false}}, // -0x1.e59728p-50 + {gelu_erf_minimax_pol, {0x32827792, false}}, // 0x1.04ef24p-26 + {gelu_erf_minimax_pol, {0x3381cc0c, false}}, // 0x1.039818p-24 + {gelu_erf_minimax_pol, {0x34523d4a, false}}, // 0x1.a47a94p-23 + {gelu_erf_minimax_pol, {0x351ac44d, false}}, // 0x1.35889ap-21 + {gelu_erf_minimax_pol, {0x35f36d88, false}}, // 0x1.e6db1p-20 + {gelu_erf_minimax_pol, {0x36ee8229, false}}, // 0x1.dd0452p-18 + {gelu_erf_minimax_pol, {0x37b8a3bb, false}}, // 0x1.714776p-16 + {gelu_erf_minimax_pol, {0x3867a213, false}}, // 0x1.cf4426p-15 + {gelu_erf_minimax_pol, {0x3940033b, false}}, // 0x1.800676p-13 + {gelu_erf_minimax_pol, {0x3a2a5a1d, false}}, // 0x1.54b43ap-11 + {gelu_erf_minimax_pol, {0x3ae35863, false}}, // 0x1.c6b0c6p-10 + {gelu_erf_minimax_pol, {0x3b7828f2, false}}, // 0x1.f051e4p-9 + {gelu_erf_minimax_pol, {0x3c08b14b, false}}, // 0x1.116296p-7 + {gelu_erf_minimax_pol, {0x3c515ed3, false}}, // 0x1.a2bda6p-7 + {gelu_erf_minimax_pol, {0xbb503236, false}}, // -0x1.a0646cp-9 + {gelu_erf_minimax_pol, {0xbd8d8e5e, false}}, // -0x1.1b1cbcp-4 + {gelu_erf_minimax_pol, {0xbe8abcd9, false}}, // -0x1.1579b2p-2 + {gelu_erf_minimax_pol, {0xbf0c19a2, false}}, // -0x1.183344p-1 + {gelu_erf_minimax_pol, {0xbeccb328, false}}, // -0x1.99665p-2 + {gelu_erf_minimax_pol, {0x3e176ced, false}}, // 0x1.2ed9dap-3 + {gelu_erf_minimax_pol, {0x3f470d99, false}}, // 0x1.8e1b32p-1 + {gelu_erf_minimax_pol, {0x3f7abb28, false}}, // 0x1.f5765p-1 + {gelu_erf_minimax_pol, {0x3f800000, false}}, // 0x1p0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + // coefficients of degree 1 + {gelu_erf_minimax_pol, {0x3f4c422a, false}}, // 0x1.988454p-1 + {gelu_erf_minimax_pol, {0x3f4c421f, false}}, // 0x1.98843ep-1 + {gelu_erf_minimax_pol, {0x3f4c4207, false}}, // 0x1.98840ep-1 + {gelu_erf_minimax_pol, {0x3f4c41cb, false}}, // 0x1.988396p-1 + {gelu_erf_minimax_pol, {0x3f4c413b, false}}, // 0x1.988276p-1 + {gelu_erf_minimax_pol, {0x3f4c3fad, false}}, // 0x1.987f5ap-1 + {gelu_erf_minimax_pol, {0x3f4c3a2f, false}}, // 0x1.98745ep-1 + {gelu_erf_minimax_pol, {0x3f4c2d40, false}}, // 0x1.985a8p-1 + {gelu_erf_minimax_pol, {0x3f4c146a, false}}, // 0x1.9828d4p-1 + {gelu_erf_minimax_pol, {0x3f4bc341, false}}, // 0x1.978682p-1 + {gelu_erf_minimax_pol, {0x3f4ad08c, false}}, // 0x1.95a118p-1 + {gelu_erf_minimax_pol, {0x3f48f8cf, false}}, // 0x1.91f19ep-1 + {gelu_erf_minimax_pol, {0x3f45fac7, false}}, // 0x1.8bf58ep-1 + {gelu_erf_minimax_pol, {0x3f404e07, false}}, // 0x1.809c0ep-1 + {gelu_erf_minimax_pol, {0x3f3b980f, false}}, // 0x1.77301ep-1 + {gelu_erf_minimax_pol, {0x3f48dff3, false}}, // 0x1.91bfe6p-1 + {gelu_erf_minimax_pol, {0x3f78b21b, false}}, // 0x1.f16436p-1 + {gelu_erf_minimax_pol, {0x3fbb0704, false}}, // 0x1.760e08p0 + {gelu_erf_minimax_pol, {0x40019c32, false}}, // 0x1.033864p1 + {gelu_erf_minimax_pol, {0x3fe536d6, false}}, // 0x1.ca6dacp0 + {gelu_erf_minimax_pol, {0x3f81331e, false}}, // 0x1.02663cp0 + {gelu_erf_minimax_pol, {0x3e6c8684, false}}, // 0x1.d90d08p-3 + {gelu_erf_minimax_pol, {0x3c98f936, false}}, // 0x1.31f26cp-6 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x3f800000, false}}, // 0x1p0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + // coefficients of degree 2 + {gelu_erf_minimax_pol, {0xb62173f4, false}}, // -0x1.42e7e8p-19 + {gelu_erf_minimax_pol, {0x3735e4cf, false}}, // 0x1.6bc99ep-17 + {gelu_erf_minimax_pol, {0x37f2ff89, false}}, // 0x1.e5ff12p-16 + {gelu_erf_minimax_pol, {0x388c23be, false}}, // 0x1.18477cp-14 + {gelu_erf_minimax_pol, {0x3917535c, false}}, // 0x1.2ea6b8p-13 + {gelu_erf_minimax_pol, {0x39ab2ab0, false}}, // 0x1.56556p-12 + {gelu_erf_minimax_pol, {0x3a60fadb, false}}, // 0x1.c1f5b6p-11 + {gelu_erf_minimax_pol, {0x3af9b960, false}}, // 0x1.f372cp-10 + {gelu_erf_minimax_pol, {0x3b6e5491, false}}, // 0x1.dca922p-9 + {gelu_erf_minimax_pol, {0x3c0a4ec5, false}}, // 0x1.149d8ap-7 + {gelu_erf_minimax_pol, {0x3ca5aa8c, false}}, // 0x1.4b5518p-6 + {gelu_erf_minimax_pol, {0x3d2138d9, false}}, // 0x1.4271b2p-5 + {gelu_erf_minimax_pol, {0x3d8737d4, false}}, // 0x1.0e6fa8p-4 + {gelu_erf_minimax_pol, {0x3ddfb660, false}}, // 0x1.bf6ccp-4 + {gelu_erf_minimax_pol, {0x3e0f27ab, false}}, // 0x1.1e4f56p-3 + {gelu_erf_minimax_pol, {0x3d94004b, false}}, // 0x1.280096p-4 + {gelu_erf_minimax_pol, {0xbe0efdeb, false}}, // -0x1.1dfbd6p-3 + {gelu_erf_minimax_pol, {0xbf1d96c3, false}}, // -0x1.3b2d86p-1 + {gelu_erf_minimax_pol, {0xbf89db58, false}}, // -0x1.13b6bp0 + {gelu_erf_minimax_pol, {0xbf6d9897, false}}, // -0x1.db312ep-1 + {gelu_erf_minimax_pol, {0xbef69fb8, false}}, // -0x1.ed3f7p-2 + {gelu_erf_minimax_pol, {0xbdc4f8a8, false}}, // -0x1.89f15p-4 + {gelu_erf_minimax_pol, {0xbbde6422, false}}, // -0x1.bcc844p-8 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + // coefficients of degree 3 + {gelu_erf_minimax_pol, {0xbe081a19, false}}, // -0x1.103432p-3 + {gelu_erf_minimax_pol, {0xbe084570, false}}, // -0x1.108aep-3 + {gelu_erf_minimax_pol, {0xbe08639b, false}}, // -0x1.10c736p-3 + {gelu_erf_minimax_pol, {0xbe089837, false}}, // -0x1.11306ep-3 + {gelu_erf_minimax_pol, {0xbe08f409, false}}, // -0x1.11e812p-3 + {gelu_erf_minimax_pol, {0xbe09ab95, false}}, // -0x1.13572ap-3 + {gelu_erf_minimax_pol, {0xbe0b66d0, false}}, // -0x1.16cdap-3 + {gelu_erf_minimax_pol, {0xbe0e400a, false}}, // -0x1.1c8014p-3 + {gelu_erf_minimax_pol, {0xbe124df8, false}}, // -0x1.249bfp-3 + {gelu_erf_minimax_pol, {0xbe1bde02, false}}, // -0x1.37bc04p-3 + {gelu_erf_minimax_pol, {0xbe2f19c9, false}}, // -0x1.5e3392p-3 + {gelu_erf_minimax_pol, {0xbe4931bf, false}}, // -0x1.92637ep-3 + {gelu_erf_minimax_pol, {0xbe685fbc, false}}, // -0x1.d0bf78p-3 + {gelu_erf_minimax_pol, {0xbe89c95f, false}}, // -0x1.1392bep-2 + {gelu_erf_minimax_pol, {0xbe96cbca, false}}, // -0x1.2d9794p-2 + {gelu_erf_minimax_pol, {0xbe8044aa, false}}, // -0x1.008954p-2 + {gelu_erf_minimax_pol, {0xbe0550f2, false}}, // -0x1.0aa1e4p-3 + {gelu_erf_minimax_pol, {0x3dcfd6a1, false}}, // 0x1.9fad42p-4 + {gelu_erf_minimax_pol, {0x3e94c826, false}}, // 0x1.29904cp-2 + {gelu_erf_minimax_pol, {0x3e79345f, false}}, // 0x1.f268bep-3 + {gelu_erf_minimax_pol, {0x3decec91, false}}, // 0x1.d9d922p-4 + {gelu_erf_minimax_pol, {0x3ca46568, false}}, // 0x1.48cadp-6 + {gelu_erf_minimax_pol, {0x3aa1e00a, false}}, // 0x1.43c014p-10 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + // coefficients of degree 4 + {gelu_erf_minimax_pol, {0xba3d61db, false}}, // -0x1.7ac3b6p-11 + {gelu_erf_minimax_pol, {0x39f097a3, false}}, // 0x1.e12f46p-12 + {gelu_erf_minimax_pol, {0x3a5845dc, false}}, // 0x1.b08bb8p-11 + {gelu_erf_minimax_pol, {0x3ab1fa35, false}}, // 0x1.63f46ap-10 + {gelu_erf_minimax_pol, {0x3b0cefb8, false}}, // 0x1.19df7p-9 + {gelu_erf_minimax_pol, {0x3b653ab6, false}}, // 0x1.ca756cp-9 + {gelu_erf_minimax_pol, {0x3bcae527, false}}, // 0x1.95ca4ep-8 + {gelu_erf_minimax_pol, {0x3c221712, false}}, // 0x1.442e24p-7 + {gelu_erf_minimax_pol, {0x3c6c5840, false}}, // 0x1.d8b08p-7 + {gelu_erf_minimax_pol, {0x3cc0a703, false}}, // 0x1.814e06p-6 + {gelu_erf_minimax_pol, {0x3d1dcc19, false}}, // 0x1.3b9832p-5 + {gelu_erf_minimax_pol, {0x3d63656d, false}}, // 0x1.c6cadap-5 + {gelu_erf_minimax_pol, {0x3d955907, false}}, // 0x1.2ab20ep-4 + {gelu_erf_minimax_pol, {0x3dbf9910, false}}, // 0x1.7f322p-4 + {gelu_erf_minimax_pol, {0x3dd53f69, false}}, // 0x1.aa7ed2p-4 + {gelu_erf_minimax_pol, {0x3db7dcef, false}}, // 0x1.6fb9dep-4 + {gelu_erf_minimax_pol, {0x3d639ebe, false}}, // 0x1.c73d7cp-5 + {gelu_erf_minimax_pol, {0xba6ede48, false}}, // -0x1.ddbc9p-11 + {gelu_erf_minimax_pol, {0xbd22be69, false}}, // -0x1.457cd2p-5 + {gelu_erf_minimax_pol, {0xbd041cf1, false}}, // -0x1.0839e2p-5 + {gelu_erf_minimax_pol, {0xbc64f5ab, false}}, // -0x1.c9eb56p-7 + {gelu_erf_minimax_pol, {0xbb097a32, false}}, // -0x1.12f464p-9 + {gelu_erf_minimax_pol, {0xb8ebf380, false}}, // -0x1.d7e7p-14 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + // coefficients of degree 5 + {gelu_erf_minimax_pol, {0x3cb7d80c, false}}, // 0x1.6fb018p-6 + {gelu_erf_minimax_pol, {0x3c9b6050, false}}, // 0x1.36c0ap-6 + {gelu_erf_minimax_pol, {0x3c978d11, false}}, // 0x1.2f1a22p-6 + {gelu_erf_minimax_pol, {0x3c92e850, false}}, // 0x1.25d0ap-6 + {gelu_erf_minimax_pol, {0x3c8d058b, false}}, // 0x1.1a0b16p-6 + {gelu_erf_minimax_pol, {0x3c848454, false}}, // 0x1.0908a8p-6 + {gelu_erf_minimax_pol, {0x3c6cd623, false}}, // 0x1.d9ac46p-7 + {gelu_erf_minimax_pol, {0x3c4c824b, false}}, // 0x1.990496p-7 + {gelu_erf_minimax_pol, {0x3c2a7935, false}}, // 0x1.54f26ap-7 + {gelu_erf_minimax_pol, {0x3be0b390, false}}, // 0x1.c1672p-8 + {gelu_erf_minimax_pol, {0x3b0651ac, false}}, // 0x1.0ca358p-9 + {gelu_erf_minimax_pol, {0xbb232f53, false}}, // -0x1.465ea6p-9 + {gelu_erf_minimax_pol, {0xbbd42fa0, false}}, // -0x1.a85f4p-8 + {gelu_erf_minimax_pol, {0xbc2c5366, false}}, // -0x1.58a6ccp-7 + {gelu_erf_minimax_pol, {0xbc492c9e, false}}, // -0x1.92593cp-7 + {gelu_erf_minimax_pol, {0xbc2a7aa6, false}}, // -0x1.54f54cp-7 + {gelu_erf_minimax_pol, {0xbbd55d04, false}}, // -0x1.aaba08p-8 + {gelu_erf_minimax_pol, {0xba823a76, false}}, // -0x1.0474ecp-10 + {gelu_erf_minimax_pol, {0x3b102aa8, false}}, // 0x1.20555p-9 + {gelu_erf_minimax_pol, {0x3ae25a7e, false}}, // 0x1.c4b4fcp-10 + {gelu_erf_minimax_pol, {0x3a31f792, false}}, // 0x1.63ef24p-11 + {gelu_erf_minimax_pol, {0x38b84375, false}}, // 0x1.7086eap-14 + {gelu_erf_minimax_pol, {0x3689bb5a, false}}, // 0x1.1376b4p-18 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + {gelu_erf_minimax_pol, {0x00000000, false}}, // 0 padd + }; // This object takes care about which constants and polynomials to include. struct need_t { @@ -1775,12 +2378,14 @@ void jit_uni_eltwise_injector_f32::register_table_entries() { if (need.exp()) push_entries_of(exp_consts2); if (need.mish()) push_entries_of(mish_consts); if (need.tanh()) push_entries_of(tanh_consts); + if (need.tanh()) push_entries_of(tanh_polynomial_table); if (need.soft_relu()) push_entries_of(soft_relu_consts); if (need.soft_relu()) push_entries_of(soft_relu_polynomial); if (need.gelu_tanh()) push_entries_of(gelu_tanh_consts); if (need.gelu_erf()) push_entries_of(gelu_erf_consts); if (need.gelu_erf()) push_entries_of(gelu_erf_polynomial); - + if (need.gelu_erf()) push_entries_of(gelu_erf_minimax_consts); + if (need.gelu_erf()) push_entries_of(gelu_erf_minimax_polynomial); // Now that we registered the entries, we set the offsets. No // entries should be registered after this point. This allows to // expect the same order when injecting the table entries in diff --git a/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp b/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp index 7301d99d567..355f877ccb2 100644 --- a/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp +++ b/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2019-2023 Intel Corporation -* Copyright 2021-2023 FUJITSU LIMITED +* Copyright 2021-2024 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -215,6 +215,7 @@ struct jit_uni_eltwise_injector_f32 { void relu_zero_ns_compute_vector_fwd(const TRegS &vmm_src); void elu_compute_vector_fwd(const TRegS &vmm_src); void tanh_compute_vector_fwd(const TRegS &vmm_src); + void tanh_polynomial_approx_compute_vector_fwd(const TRegS &vmm_src); void square_compute_vector_fwd(const TRegS &vmm_src); void abs_compute_vector_fwd(const TRegS &vmm_src); void sqrt_compute_vector_fwd(const TRegS &vmm_src); @@ -277,12 +278,23 @@ struct jit_uni_eltwise_injector_f32 { bwd_mish_max_x_for_equation_f, tanh_range, // tanh(x) = x - x^3/3 for |x| < tanh_range tanh_m1d3, // -1/3 + tanh_idx_bias, // bias applied during index computation + tanh_idx_mask, // mask applied to extract index + tanh_linear_ubound, // arg below which tanh(x) = x + tanh_saturation_lbound, // arg after which tanh(x) = 1.f + tanh_pol_table, // table of polynomial coefficients soft_relu_one_twenty_six, // 126.f soft_relu_mantissa_sign_mask, // mask for mantissa bits and sign soft_relu_pol, // see correspondent table for float values gelu_tanh_fitting_const, // 0.044715f gelu_tanh_fitting_const_times_three, // 0.134145f gelu_tanh_sqrt_two_over_pi, // sqrtf(2.f/pi) = 0.797884f + gelu_erf_idx_bias, // bias applied to compute table index + gelu_erf_rbound, // upper bound at which we clamp erf at 1 + gelu_erf_one, // just the integer value 1, used for index clamping + gelu_erf_twenty_three, // just the integer value 23, used for index clamping + gelu_erf_twenty_four, // just the integer value 24, used for index clamping + gelu_erf_minimax_pol, // see correspondent table for float values gelu_erf_approx_const, // 0.3275911f - implementation based for approx gelu_erf_one_over_sqrt_two, // 1.f / sqrtf(2.f) gelu_erf_one_over_sqrt_pi, // 1.f / sqrtf(pi) = 0.564190f diff --git a/src/cpu/aarch64/jit_brdgmm_dw_conv.cpp b/src/cpu/aarch64/jit_brdgmm_dw_conv.cpp index 24e018aef02..226864baad2 100644 --- a/src/cpu/aarch64/jit_brdgmm_dw_conv.cpp +++ b/src/cpu/aarch64/jit_brdgmm_dw_conv.cpp @@ -108,7 +108,7 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) { // const auto isa = sve_512; auto skip_mask = skip_mask_t::post_ops; - if (is_int8) skip_mask |= skip_mask_t::scales_runtime; + if (is_int8) skip_mask |= skip_mask_t::scales; bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) && one_of(true, is_f32, is_int8) && (isa != isa_undef) @@ -200,7 +200,7 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) { const auto &wei_scales = attr_.scales_.get(DNNL_ARG_WEIGHTS); jcp.with_scale = !src_scales.has_default_values() || !wei_scales.has_default_values(); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; const bool scales_ok = attr_scales_ok({DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}); diff --git a/src/cpu/aarch64/jit_brdgmm_dw_conv.hpp b/src/cpu/aarch64/jit_brdgmm_dw_conv.hpp index 830c9e56bfd..61d6a726fcf 100644 --- a/src/cpu/aarch64/jit_brdgmm_dw_conv.hpp +++ b/src/cpu/aarch64/jit_brdgmm_dw_conv.hpp @@ -34,15 +34,13 @@ namespace aarch64 { template struct brdgmm_dw_convolution_fwd_t : public primitive_t { struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brdgmm_dw:", jcp_.isa, ""), brdgmm_dw_convolution_fwd_t); status_t init(engine_t *engine); - jit_brdgmm_conv_conf_t jcp_; + jit_brdgmm_conv_conf_t jcp_ = utils::zero(); std::vector bcps_; std::vector batches_; std::vector bs_; diff --git a/src/cpu/aarch64/jit_brgemm_1x1_conv.cpp b/src/cpu/aarch64/jit_brgemm_1x1_conv.cpp index d9e8e49d3d0..808b6685b19 100644 --- a/src/cpu/aarch64/jit_brgemm_1x1_conv.cpp +++ b/src/cpu/aarch64/jit_brgemm_1x1_conv.cpp @@ -54,8 +54,8 @@ status_t brgemm_1x1_convolution_fwd_t::pd_t::init(engine_t *engine) { using skip_mask_t = primitive_attr_t::skip_mask_t; auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt - | skip_mask_t::zero_points_runtime; - if (one_of(src_type, u8, s8)) skip_mask |= skip_mask_t::scales_runtime; + | skip_mask_t::zero_points; + if (one_of(src_type, u8, s8)) skip_mask |= skip_mask_t::scales; bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(src_type, wei_type, data_type::undef, dst_type, @@ -115,6 +115,7 @@ status_t brgemm_1x1_convolution_fwd_t::pd_t::init(engine_t *engine) { brg.with_weights_scale_adjust = jcp_.scale_adjust_factor != 1.0f; CHECK(brgemm_desc_set_postops( &brg, attr(), &dst_md_, LDD, jcp_.bia_dt)); + CHECK(brgemm_desc_finalize(&brg)); brgs_->insert(brg_idx, brg); } diff --git a/src/cpu/aarch64/jit_brgemm_1x1_conv.hpp b/src/cpu/aarch64/jit_brgemm_1x1_conv.hpp index 7843d14d7a0..20e698c4c61 100644 --- a/src/cpu/aarch64/jit_brgemm_1x1_conv.hpp +++ b/src/cpu/aarch64/jit_brgemm_1x1_conv.hpp @@ -43,11 +43,7 @@ namespace aarch64 { template struct brgemm_1x1_convolution_fwd_t : public primitive_t { struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) - , with_sum(false) - , sum_scale(0) {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgconv_1x1:", isa, ""), brgemm_1x1_convolution_fwd_t); @@ -55,13 +51,13 @@ struct brgemm_1x1_convolution_fwd_t : public primitive_t { status_t init(engine_t *engine); std::shared_ptr brgs_; - bool with_sum; - float sum_scale; + bool with_sum = false; + float sum_scale = 0.f; bool need_postwork; int ic_chunks; - jit_brgemm_conv_conf_t jcp_; + jit_brgemm_conv_conf_t jcp_ = utils::zero(); protected: bool arg_scales_ok() const { @@ -70,12 +66,20 @@ struct brgemm_1x1_convolution_fwd_t : public primitive_t { return attr_scales_ok(supported_args); } bool zero_points_ok() const { - // Only common zero points are supported -> mask should only be 0 - int mask_src = 0, mask_dst = 0; - attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); - attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); - return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) - && mask_src == 0 && mask_dst == 0; + const auto &zp = attr()->zero_points_; + + if (!zp.has_default_values(DNNL_ARG_SRC)) { + int mask_src = zp.get_mask(DNNL_ARG_SRC); + const bool ok = mask_src == 0; + if (!ok) return false; + } + if (!zp.has_default_values(DNNL_ARG_DST)) { + int mask_dst = zp.get_mask(DNNL_ARG_DST); + const bool ok = mask_dst == 0; + if (!ok) return false; + } + + return zp.has_default_values(DNNL_ARG_WEIGHTS); } }; diff --git a/src/cpu/aarch64/jit_brgemm_conv.cpp b/src/cpu/aarch64/jit_brgemm_conv.cpp index c649e0cb690..36b126c9cd3 100644 --- a/src/cpu/aarch64/jit_brgemm_conv.cpp +++ b/src/cpu/aarch64/jit_brgemm_conv.cpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation -* Copyright 2024 FUJITSU LIMITED +* Copyright 2024-2025 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,8 +43,8 @@ using namespace jit_uni_brgemm_conv_comp_pad_kernel; #define ndims_pick(v5, v4, v3) \ ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0) -template -void brgemm_convolution_fwd_t::pd_t::init_batch(int icc, +template +void brgemm_convolution_fwd_t::pd_t::init_batch(int icc, const char *src_base, const char *wei_base, int n_ic_blocks, int ic_block_s, int iid_b, int iih_b, int iiw_b, const dim_t *const __restrict kw_top_vpads, @@ -117,8 +117,8 @@ void brgemm_convolution_fwd_t::pd_t::init_batch(int icc, } } -template -inline void brgemm_convolution_fwd_t::pd_t::get_A_B(int icc, +template +inline void brgemm_convolution_fwd_t::pd_t::get_A_B(int icc, const char *src_base, const char *wei_base, int ic_block_s, int iid_b, int iih_b, int iiw_b, int kd_b, int kh_b, const void *&ptrA, const void *&ptrB) const { @@ -147,10 +147,9 @@ inline void brgemm_convolution_fwd_t::pd_t::get_A_B(int icc, ptrB = wei_base_kh + wei_kw * wei_kw_offset; } -template -status_t brgemm_convolution_fwd_t::pd_t::add_brg_descriptor( - int vM, int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, - int kh_e) { +template +status_t brgemm_convolution_fwd_t::pd_t::add_brg_descriptor(int vM, + int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, int kh_e) { const auto src_type = src_md(0)->data_type; const auto wei_type = weights_md(0)->data_type; @@ -265,7 +264,7 @@ status_t brgemm_convolution_fwd_t::pd_t::add_brg_descriptor( brgattr.hint_expected_B_size = 0; brgattr.hint_expected_C_size = 0; - brgattr.wary_tail_read = false; + brgattr.wary_A_k_tail_read = false; brgattr.bd_mask_level = jcp_.use_M_mask; brgattr.max_top_vpad = jcp_.max_vpad; @@ -280,14 +279,15 @@ status_t brgemm_convolution_fwd_t::pd_t::add_brg_descriptor( brg.with_weights_scale_adjust = jcp_.scale_adjust_factor != 1.0f; CHECK(brgemm_desc_set_postops(&brg, attr(), &dst_md_, LDD, jcp_.bia_dt)); + CHECK(brgemm_desc_finalize(&brg)); + brgemm_descriptors_->insert(brg_idx, brg, bd_mask, stoffs); return status::success; } -template -status_t brgemm_convolution_fwd_t::pd_t::init( - engine_t *engine) { +template +status_t brgemm_convolution_fwd_t::pd_t::init(engine_t *engine) { using namespace data_type; using namespace utils; brgemm_descriptors_ @@ -304,15 +304,15 @@ status_t brgemm_convolution_fwd_t::pd_t::init( // executing 'use_inversion == true' as FWD. This can only work if the // diff_src_desc and diff_dst_desc are defined in the aforementioned. const convolution_desc_t &cd = *desc(); - if (use_inversion + if (cd.use_inversion && one_of(true, types::is_zero_md(&cd.diff_src_desc), types::is_zero_md(&cd.diff_dst_desc))) return status::unimplemented; using skip_mask_t = primitive_attr_t::skip_mask_t; auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt - | skip_mask_t::zero_points_runtime; - if (is_int8) skip_mask |= skip_mask_t::scales_runtime; + | skip_mask_t::zero_points; + if (is_int8) skip_mask |= skip_mask_t::scales; bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) && IMPLICATION(is_int8, @@ -334,6 +334,8 @@ status_t brgemm_convolution_fwd_t::pd_t::init( // For exec_base it makes sense to use unrolled kernel only if // there is no padding by width. // 2. For exec_trans block by kw is always KW + // 3. 'false' is used intentionally to disable the condition, ensuring that + // the assert fails only when jcp_.use_uker is true, regardless of exec_type. assert(IMPLICATION(jcp_.use_uker, false && one_of(jcp_.exec_type, exec_base, exec_trans))); assert(IMPLICATION(jcp_.use_interleave_stores, jcp_.use_uker)); @@ -533,13 +535,12 @@ status_t brgemm_convolution_fwd_t::pd_t::init( return status::success; } -template -brgemm_convolution_fwd_t::brgemm_convolution_fwd_t( - const pd_t *apd) +template +brgemm_convolution_fwd_t::brgemm_convolution_fwd_t(const pd_t *apd) : primitive_t(apd), bias_d(pd()->weights_md(1)) {} -template -void brgemm_convolution_fwd_t::get_kw_range( +template +void brgemm_convolution_fwd_t::get_kw_range( int ow, int &kw_s, int &kw_full_s, int &kw_full_f, int &kw_f) const { // This function needed for exec_base only const auto _pd = pd(); @@ -568,8 +569,8 @@ void brgemm_convolution_fwd_t::get_kw_range( if (kw_full_f == -1) kw_full_s = kw_full_f = kw_f; } -template -inline void brgemm_convolution_fwd_t::get_ow_range( +template +inline void brgemm_convolution_fwd_t::get_ow_range( int ow, int kw, int &ow_s, int &ow_f) const { // This function needed for exec_base only const auto _pd = pd(); @@ -600,9 +601,9 @@ inline void brgemm_convolution_fwd_t::get_ow_range( ow_f = nstl::min(nstl::max(ow_f, ow_s), ow + M); } -template -status_t brgemm_convolution_fwd_t::add_brg_kernel(int M, - int i_N, int i_K, int i_init, int kd_b, int kd_e, int kh_b, int kh_e) { +template +status_t brgemm_convolution_fwd_t::add_brg_kernel(int M, int i_N, int i_K, + int i_init, int kd_b, int kd_e, int kh_b, int kh_e) { if (M <= 0) return status::success; const auto _pd = pd(); const auto &jcp = _pd->jcp_; @@ -621,8 +622,8 @@ status_t brgemm_convolution_fwd_t::add_brg_kernel(int M, return status::success; } -template -status_t brgemm_convolution_fwd_t::add_po_kernel( +template +status_t brgemm_convolution_fwd_t::add_po_kernel( brgemm_t *bcfg, int ker_idx, bool is_init) { if (!bcfg) return status::success; const auto _pd = pd(); @@ -639,8 +640,8 @@ status_t brgemm_convolution_fwd_t::add_po_kernel( return status::success; } -template -void brgemm_convolution_fwd_t::add_po_kernels( +template +void brgemm_convolution_fwd_t::add_po_kernels( int i_N, int init_bcast_dim, int po_bcast_dim) { const auto _pd = pd(); const auto &jcp = _pd->jcp_; @@ -674,10 +675,10 @@ void brgemm_convolution_fwd_t::add_po_kernels( } } } -template -int brgemm_convolution_fwd_t::get_comp_ker_idx( - const int kd_b, const int kd_e, const int kh_b, const int kh_e, - const int kw_b, const int kw_e) const { +template +int brgemm_convolution_fwd_t::get_comp_ker_idx(const int kd_b, + const int kd_e, const int kh_b, const int kh_e, const int kw_b, + const int kw_e) const { const auto _pd = pd(); const auto &jcp = _pd->jcp_; @@ -694,11 +695,10 @@ int brgemm_convolution_fwd_t::get_comp_ker_idx( return -1; } -template -inline int brgemm_convolution_fwd_t::get_comp_offset( - const int g, const int ocb, const int ow, const int kd_b, - const int kd_e, const int kh_b, const int kh_e, const int kw_b, - const int kw_e) const { +template +inline int brgemm_convolution_fwd_t::get_comp_offset(const int g, + const int ocb, const int ow, const int kd_b, const int kd_e, + const int kh_b, const int kh_e, const int kw_b, const int kw_e) const { const auto _pd = pd(); const auto &jcp = _pd->jcp_; @@ -712,8 +712,8 @@ inline int brgemm_convolution_fwd_t::get_comp_offset( : (g * jcp.nb_oc + ocb) * jcp.oc_block; } -template -status_t brgemm_convolution_fwd_t::init(engine_t *engine) { +template +status_t brgemm_convolution_fwd_t::init(engine_t *engine) { const auto _pd = pd(); const auto &jcp = _pd->jcp_; @@ -1052,8 +1052,8 @@ status_t brgemm_convolution_fwd_t::init(engine_t *engine) { return status::success; } -template -struct brgemm_convolution_fwd_t::brgemm_thread_ctx_t { +template +struct brgemm_convolution_fwd_t::brgemm_thread_ctx_t { brgemm_thread_ctx_t(brgemm_exec_ctx_t &brgemm_ctx_, int ithr_, brgemm_batch_element_t *__restrict brg_batch_, char *c_buffer_, char *wsp_tile_) @@ -1080,9 +1080,8 @@ struct brgemm_convolution_fwd_t::brgemm_thread_ctx_t { const float *dst_scales {nullptr}; }; -template -status_t brgemm_convolution_fwd_t::execute( - const exec_ctx_t &ctx) const { +template +status_t brgemm_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { const auto _pd = pd(); const auto &jcp = _pd->jcp_; @@ -1264,8 +1263,8 @@ status_t brgemm_convolution_fwd_t::execute( return status::success; } -template -status_t brgemm_convolution_fwd_t::cal_compensation( +template +status_t brgemm_convolution_fwd_t::cal_compensation( const char *__restrict weights, int32_t *src_zp_buffer, int32_t *s8s8_comp_buffer) const { const auto _pd = pd(); @@ -1330,8 +1329,8 @@ status_t brgemm_convolution_fwd_t::cal_compensation( return status::success; } -template -void brgemm_convolution_fwd_t::perform_outwork( +template +void brgemm_convolution_fwd_t::perform_outwork( const brgemm_thread_ctx_t &btc, char *dst_base, const char *bias_w, int ow, int g_oc, bool is_oc_tail, int ker_ow_s, int ker_ow_f, int kd_l, int kh_l, bool maybe_do_init, bool do_postwork, @@ -1415,8 +1414,8 @@ void brgemm_convolution_fwd_t::perform_outwork( } } -template -inline void brgemm_convolution_fwd_t::call_brgemm_kernel( +template +inline void brgemm_convolution_fwd_t::call_brgemm_kernel( const brgemm_thread_ctx_t &btc, const brgemm_kernel_t *brg_ker, int batch_size, char *ptr_C, char *ptr_D, const char *bias_w, int g_oc, bool do_postops, int comp_ker_offs, bool do_only_comp) const { @@ -1465,8 +1464,8 @@ inline void brgemm_convolution_fwd_t::call_brgemm_kernel( ptr_C, static_cast(btc.wsp_tile)); } -template -void brgemm_convolution_fwd_t::maybe_conv_inp(int ithr, +template +void brgemm_convolution_fwd_t::maybe_conv_inp(int ithr, const char *__restrict src, char *__restrict inp_buffer, uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int odb, int ohb, int owb, int last_g, int last_n, int last_icc, int last_odb, @@ -1646,9 +1645,8 @@ void brgemm_convolution_fwd_t::maybe_conv_inp(int ithr, char *ptr_D; \ int kd_b(0), kd_e(0), kh_b(0), kh_e(0), k_l(0), iiw_b(0); -template -void brgemm_convolution_fwd_t::ker_base( - brgemm_thread_ctx_t &btc) const { +template +void brgemm_convolution_fwd_t::ker_base(brgemm_thread_ctx_t &btc) const { const auto _pd = pd(); const auto &jcp = _pd->jcp_; @@ -1797,8 +1795,8 @@ void brgemm_convolution_fwd_t::ker_base( } } -template -void brgemm_convolution_fwd_t::ker_trans( +template +void brgemm_convolution_fwd_t::ker_trans( brgemm_thread_ctx_t &btc, char *inp_buffer) const { const auto _pd = pd(); @@ -1922,9 +1920,8 @@ void brgemm_convolution_fwd_t::ker_trans( } } -template -void brgemm_convolution_fwd_t::ker_vpad( - brgemm_thread_ctx_t &btc) const { +template +void brgemm_convolution_fwd_t::ker_vpad(brgemm_thread_ctx_t &btc) const { const auto _pd = pd(); const auto &jcp = _pd->jcp_; diff --git a/src/cpu/aarch64/jit_brgemm_conv.hpp b/src/cpu/aarch64/jit_brgemm_conv.hpp index 2f476a2552a..dedcf753be2 100644 --- a/src/cpu/aarch64/jit_brgemm_conv.hpp +++ b/src/cpu/aarch64/jit_brgemm_conv.hpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation -* Copyright 2024 FUJITSU LIMITED +* Copyright 2024-2025 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,18 +41,13 @@ namespace impl { namespace cpu { namespace aarch64 { -template +template struct brgemm_convolution_fwd_t : public primitive_t { struct brgemm_thread_ctx_t; struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::hint_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) - , with_sum(false) {} - - ~pd_t() = default; + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; // ------- DECLARE_COMMON_PD_t ----- DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgconv:", jcp_.isa, ""), @@ -63,8 +58,8 @@ struct brgemm_convolution_fwd_t : public primitive_t { int brgs_sz_; std::shared_ptr brgemm_descriptors_; - bool with_sum; - jit_brgemm_conv_conf_t jcp_; + bool with_sum = false; + jit_brgemm_conv_conf_t jcp_ = utils::zero(); int ic_chunks; bool need_postwork; @@ -122,7 +117,7 @@ struct brgemm_convolution_fwd_t : public primitive_t { } inline int maybe_invert(int k, int K) const { - return use_inversion ? K - 1 - k : k; + return desc()->use_inversion ? K - 1 - k : k; }; void init_batch(int icc, const char *src_base, const char *wei_base, @@ -149,12 +144,20 @@ struct brgemm_convolution_fwd_t : public primitive_t { } bool zero_points_ok() const { - // Only common zero points are supported -> mask should only be 0 - int mask_src = 0, mask_dst = 0; - attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); - attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); - return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) - && mask_src == 0 && mask_dst == 0; + const auto &zp = attr()->zero_points_; + + if (!zp.has_default_values(DNNL_ARG_SRC)) { + int mask_src = zp.get_mask(DNNL_ARG_SRC); + const bool ok = mask_src == 0; + if (!ok) return false; + } + if (!zp.has_default_values(DNNL_ARG_DST)) { + int mask_dst = zp.get_mask(DNNL_ARG_DST); + const bool ok = mask_dst == 0; + if (!ok) return false; + } + + return zp.has_default_values(DNNL_ARG_WEIGHTS); } int KD, KH, KW, EXT_KD, EXT_KH, EXT_KW, KS, KD_BLOCK, KH_BLOCK, @@ -207,7 +210,7 @@ struct brgemm_convolution_fwd_t : public primitive_t { } inline int maybe_invert_range(int k, int k_inv, int K) const { - return use_inversion ? K - k_inv : k; + return pd()->desc()->use_inversion ? K - k_inv : k; }; void get_kw_range( diff --git a/src/cpu/aarch64/jit_brgemm_conv_bwd.cpp b/src/cpu/aarch64/jit_brgemm_conv_bwd.cpp new file mode 100644 index 00000000000..79210d804b0 --- /dev/null +++ b/src/cpu/aarch64/jit_brgemm_conv_bwd.cpp @@ -0,0 +1,185 @@ +/******************************************************************************* +* Copyright 2025 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/dnnl_thread.hpp" +#include "common/nstl.hpp" +#include "common/primitive_desc_iterator.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/aarch64/jit_brgemm_1x1_conv.hpp" +#include "cpu/aarch64/jit_brgemm_conv_bwd.hpp" +#include "cpu/cpu_convolution_pd.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +namespace { +status_t weights_axes_permutation( + memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) { + int perm[DNNL_MAX_NDIMS] {}; // bwd conv to fwd conv weight permutation + for (int d = 0; d < DNNL_MAX_NDIMS; ++d) + perm[d] = d; + nstl::swap(perm[0 + with_groups], perm[1 + with_groups]); + + return memory_desc_permute_axes(*o_md, *i_md, perm); +} + +status_t fwd_conv_desc_create( + convolution_desc_t *fwd_conv_d, const convolution_desc_t *bwd_conv_d) { + // create a new weights descriptor with OC and IC transposed; + // spatial inversion is handled by inverting indices on-the-fly + memory_desc_t fwd_weights_md; + const memory_desc_t &bwd_weights_md = bwd_conv_d->weights_desc; + const bool with_groups + = bwd_weights_md.ndims == bwd_conv_d->diff_src_desc.ndims + 1; + CHECK(weights_axes_permutation( + &fwd_weights_md, &bwd_weights_md, with_groups)); + + // create a fwd convolution descriptor with padding adjusted + // to the perspective of backward propagation, namely: + // - left padding replaced by left overflow + // - right padding replaced by right overflow + const int ndims_spatial = bwd_conv_d->diff_src_desc.ndims - 2; + dims_t overflow_l; + dims_t overflow_r; + dim_t ks = 1; + for (int i = 0; i < ndims_spatial; i++) { + VDISPATCH_CONV_IC(bwd_conv_d->strides[i] == 1, + VERBOSE_UNSUPPORTED_FEATURE, + "only unit strides are allowed for bwd-to-fwd conversion"); + const dim_t K + = bwd_weights_md.dims[bwd_weights_md.ndims - ndims_spatial + i]; + ks *= K; + const dim_t D = bwd_conv_d->dilates[i]; + const dim_t PL = bwd_conv_d->padding[0][i]; // left padding + const dim_t PR = bwd_conv_d->padding[1][i]; // right padding + constexpr dim_t S = 1; + // the following relations hold for unit stride only + overflow_l[i] = ((K - 1) * (D + 1) - PL) / S; + overflow_r[i] = ((K - 1) * (D + 1) - PR) / S; + } + CHECK(conv_desc_init(fwd_conv_d, prop_kind::forward_training, + alg_kind::convolution_direct, &bwd_conv_d->diff_dst_desc, + &fwd_weights_md, &bwd_conv_d->bias_desc, &bwd_conv_d->diff_src_desc, + bwd_conv_d->strides, bwd_conv_d->dilates, overflow_l, overflow_r)); + + // HACK: Set diff_src_desc and diff_dst_desc as a signal to the primitive + // descriptor cache that we are using the bwd-via-fwd version of + // fwd conv and thus need a separate cache entry. Only needed for + // non-1x1 convs due to spatial inversion of weights. This assumes + // that external users only use the API to create conv descs, and + // relies on common/convolution.cpp only setting the expected mem descs. + // TODO: Pass this information via attributes or integrate the bwd-via-fwd + // method directly into fwd conv implementations. + const bool with_spatial_inversion = ks > 1; + if (with_spatial_inversion) { + fwd_conv_d->diff_src_desc = fwd_conv_d->src_desc; + fwd_conv_d->diff_dst_desc = fwd_conv_d->dst_desc; + } + // Note: internal field to hint this conv is created from deconv. + fwd_conv_d->use_inversion = true; + return status::success; +} +} // namespace + +template +status_t brgemm_convolution_bwd_t::pd_t::init(engine_t *engine) { + using namespace data_type; + using namespace utils; + + VDISPATCH_CONV(is_bwd_d(), VERBOSE_BAD_PROPKIND); + VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct), + VERBOSE_BAD_ALGORITHM); + VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); + VDISPATCH_CONV(attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + + convolution_desc_t fwd_conv_d = convolution_desc_t(); + CHECK(fwd_conv_desc_create(&fwd_conv_d, desc())); + + primitive_desc_iterator_t it(engine, + reinterpret_cast(&fwd_conv_d), attr(), nullptr); + if (!it.is_initialized()) return status::out_of_memory; + + while (++it != it.end()) { + fwd_pd_ = *it; + using fwd_1x1_conv_pd_t = + typename brgemm_1x1_convolution_fwd_t::pd_t; + const auto pd_1x1 = dynamic_cast((*it).get()); + if (pd_1x1 != nullptr) { + break; // 1x1 implementation found + } + + using fwd_conv_pd_t = typename brgemm_convolution_fwd_t::pd_t; + + const auto pd = dynamic_cast((*it).get()); + if (pd != nullptr) { + break; // non-1x1 implementation found + } + } + + VDISPATCH_CONV(it != it.end(), "Implementation wasn't found"); + + if (weights_md_.format_kind == format_kind::any) + CHECK(weights_axes_permutation( + &weights_md_, fwd_pd_->weights_md(), with_groups())); + if (diff_src_md_.format_kind == format_kind::any) + diff_src_md_ = *fwd_pd_->dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *fwd_pd_->src_md(); + if (bias_md_.format_kind == format_kind::any) + bias_md_ = *fwd_pd_->weights_md(1); + + init_name(); + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book( + memory_tracking::names::key_nested, fwd_pd_->scratchpad_registry()); + + return status::success; +} + +template +status_t brgemm_convolution_bwd_t::init(engine_t *engine) { + return pd()->fwd_pd_->create_primitive(fwd_p_, engine); +} + +template +status_t brgemm_convolution_bwd_t::execute(const exec_ctx_t &ctx) const { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC); + conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); + conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); + if (pd()->with_bias()) conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS); + + exec_ctx_t fwd_ctx(ctx, std::move(conv_args)); + + nested_scratchpad_t ns(ctx, memory_tracking::names::key_nested, fwd_p_); + fwd_ctx.set_scratchpad_grantor(ns.grantor()); + return fwd_p_->execute(fwd_ctx); +} + +template struct brgemm_convolution_bwd_t; +template struct brgemm_convolution_bwd_t; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/aarch64/jit_brgemm_conv_bwd.hpp b/src/cpu/aarch64/jit_brgemm_conv_bwd.hpp new file mode 100644 index 00000000000..01498b291a6 --- /dev/null +++ b/src/cpu/aarch64/jit_brgemm_conv_bwd.hpp @@ -0,0 +1,76 @@ +/******************************************************************************* +* Copyright 2025 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_JIT_BRGEMM_CONV_BWD_HPP +#define CPU_AARCH64_JIT_BRGEMM_CONV_BWD_HPP + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive.hpp" +#include "common/utils.hpp" + +#include "cpu/aarch64/jit_brgemm_conv.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +template +struct brgemm_convolution_bwd_t : public primitive_t { + + struct pd_t : public cpu_convolution_bwd_data_pd_t { + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; + + DECLARE_COMMON_PD_T(name_.c_str(), brgemm_convolution_bwd_t); + + status_t init(engine_t *engine); + + std::shared_ptr fwd_pd_; + + private: + std::string name_ = JIT_IMPL_NAME_HELPER("brg_conv_bwd:", isa, ""); + + void init_name() { + name_.append("+"); + name_.append(fwd_pd_->name()); + } + }; + + brgemm_convolution_bwd_t(const pd_t *apd) : primitive_t(apd) {}; + + ~brgemm_convolution_bwd_t() override = default; + + status_t init(engine_t *engine) override; + + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { + return static_cast(primitive_t::pd().get()); + } + std::shared_ptr fwd_p_; +}; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif + +// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/aarch64/jit_brgemm_conv_comp_pad_kernel.hpp b/src/cpu/aarch64/jit_brgemm_conv_comp_pad_kernel.hpp index 0472aafb91a..96f86c2084a 100644 --- a/src/cpu/aarch64/jit_brgemm_conv_comp_pad_kernel.hpp +++ b/src/cpu/aarch64/jit_brgemm_conv_comp_pad_kernel.hpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2022-2023 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,15 +44,14 @@ struct jit_uni_brgemm_conv_comp_pad_kernel_t : public jit_generator { using XReg = const Xbyak_aarch64::XReg; - jit_uni_brgemm_conv_comp_pad_kernel_t( - const jit_brgemm_conv_conf_t &ajcp); + jit_uni_brgemm_conv_comp_pad_kernel_t(const jit_brgemm_conv_conf_t &ajcp); ~jit_uni_brgemm_conv_comp_pad_kernel_t() = default; protected: static constexpr bool is_ymm_ = true; - jit_brgemm_conv_conf_t jcp_; + jit_brgemm_conv_conf_t jcp_ = utils::zero(); const int inp_dsz_; const int out_dsz_; const size_t nb_ic_; diff --git a/src/cpu/aarch64/jit_brgemm_conv_utils.cpp b/src/cpu/aarch64/jit_brgemm_conv_utils.cpp index b93db5c423d..d10662b96ce 100644 --- a/src/cpu/aarch64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/aarch64/jit_brgemm_conv_utils.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -725,9 +726,9 @@ status_t brg_blocking_t::estimate_brgemm_ur() { const float alpha = 1.0; const float beta = 0.0; brgemm_t brg; - brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt, + CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt, brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr, - is_bf32); + is_bf32)); CHECK(brgemm_utils::brgemm_blocking(&brg)); ur = brg.bd_block; ur_block = brg.bd_block; @@ -771,9 +772,9 @@ status_t brg_blocking_t::get_brgemm_ur( * rnd_up(oc, oc_block) * wei_dsz; const auto strides_ptr = (brg_type == brgemm_strd) ? &brg_strides : nullptr; - brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt, - wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB, LDC, - vM, vN, vK, strides_ptr, is_bf32); + CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, + src_dt, wei_dt, brgemm_row_major, alpha, vbeta, LDA, + LDB, LDC, vM, vN, vK, strides_ptr, is_bf32)); CHECK(brgemm_utils::brgemm_blocking(&brg)); brgemm_attr_t brgattr; @@ -1758,19 +1759,23 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, const int prelu_ind = p.find(primitive_kind::prelu); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + const auto &zp = attr.zero_points_; jcp.src_zero_point = get_zp_type(attr, DNNL_ARG_SRC) != brgemm_broadcast_t::none; jcp.dst_zero_point = get_zp_type(attr, DNNL_ARG_DST) != brgemm_broadcast_t::none; - const bool has_zero_points = jcp.src_zero_point || jcp.dst_zero_point; - const bool params_ok - = IMPLICATION(has_zero_points, utils::one_of(jcp.src_dt, u8, s8)) - && IMPLICATION( - jcp.src_zero_point, attr.zero_points_.common(DNNL_ARG_SRC)) - && IMPLICATION( - jcp.dst_zero_point, attr.zero_points_.common(DNNL_ARG_DST)); - if (!params_ok) return status::unimplemented; + VDISPATCH_CONV_IC(IMPLICATION(jcp.src_zero_point || jcp.dst_zero_point, + utils::one_of(jcp.src_dt, s8, u8)), + VERBOSE_UNSUPPORTED_ZP_CFG); + + VDISPATCH_CONV_IC( + IMPLICATION(jcp.src_zero_point, zp.get_mask(DNNL_ARG_SRC) == 0), + VERBOSE_UNSUPPORTED_ZP_CFG); + + VDISPATCH_CONV_IC( + IMPLICATION(jcp.dst_zero_point, zp.get_mask(DNNL_ARG_DST) == 0), + VERBOSE_UNSUPPORTED_ZP_CFG); jcp.nthr = nthreads; jcp.kh_sets = 1; @@ -1992,7 +1997,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.with_scales = !src_scales.has_default_values() || !wei_scales.has_default_values() || jcp.scale_adjust_factor != 1.0f; - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; // disables the shape with small ic but large spatial // or specific large spatial shapes for int8 conv @@ -2189,7 +2194,7 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.with_scales = !src_scales.has_default_values() || !wei_scales.has_default_values() || jcp.scale_adjust_factor != 1.0f; - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; // enable ununroll_bd_loop for big shapes to reduce kernel sizes jcp.ununroll_bd_loop diff --git a/src/cpu/aarch64/jit_brgemm_post_ops.hpp b/src/cpu/aarch64/jit_brgemm_post_ops.hpp index 2809e1813b6..5aed828a582 100644 --- a/src/cpu/aarch64/jit_brgemm_post_ops.hpp +++ b/src/cpu/aarch64/jit_brgemm_post_ops.hpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2020-2023 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -196,7 +197,7 @@ struct jit_brgemm_kernel_diff_bias_t : public jit_generator { } void generate() override { - size_t simd_w_; + size_t simd_w_ = 0; switch (brg_.isa_impl) { case sve_512: simd_w_ = cpu_isa_traits::vlen / sizeof(float); @@ -204,7 +205,10 @@ struct jit_brgemm_kernel_diff_bias_t : public jit_generator { case sve_256: simd_w_ = cpu_isa_traits::vlen / sizeof(float); break; - default: assert(!"unsupported isa"); + default: { + assert(!"unsupported isa"); + return; + } } preamble(); if (simd_w_ != cpu_sveLen / sizeof(float)) { @@ -321,8 +325,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator { const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); // per_oc: conv: 1 << 0, (1 << 1) + (1 << 0) (with groups) // per_oc: ip: 1 << 0 - is_oc_scale_ - = utils::one_of(wei_scales.mask_, 1 << 0, (1 << 1) + (1 << 0)); + is_oc_scale_ = utils::one_of( + wei_scales.get_mask(), 1 << 0, (1 << 1) + (1 << 0)); LDD_ = brg.LDD; inp_dt_ = brg.dt_c; @@ -850,7 +854,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator { } void generate() override { - size_t simd_w_; + size_t simd_w_ = 0; switch (brg.isa_impl) { case sve_512: simd_w_ = cpu_isa_traits::vlen / sizeof(float); @@ -858,7 +862,10 @@ struct jit_brgemm_kernel_post_ops : public jit_generator { case sve_256: simd_w_ = cpu_isa_traits::vlen / sizeof(float); break; - default: assert(!"unsupported isa"); + default: { + assert(!"unsupported isa"); + return; + } } preamble(); if (simd_w_ != cpu_sveLen / sizeof(float)) { diff --git a/src/cpu/aarch64/jit_primitive_conf.hpp b/src/cpu/aarch64/jit_primitive_conf.hpp index ef223f20aab..22af4a66fa9 100644 --- a/src/cpu/aarch64/jit_primitive_conf.hpp +++ b/src/cpu/aarch64/jit_primitive_conf.hpp @@ -36,6 +36,7 @@ enum conv_version_t { ver_unused, ver_fma, ver_sve_512, + ver_sve_256, }; enum conv_loop_order_t { diff --git a/src/cpu/aarch64/jit_sve_1x1_conv_kernel.cpp b/src/cpu/aarch64/jit_sve_1x1_conv_kernel.cpp new file mode 100644 index 00000000000..5bb7da464ea --- /dev/null +++ b/src/cpu/aarch64/jit_sve_1x1_conv_kernel.cpp @@ -0,0 +1,1398 @@ +/******************************************************************************* +* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2024 FUJITSU LIMITED +* Copyright 2024-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#include +#include + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory.hpp" +#include "common/memory_tracking.hpp" +#include "common/nstl.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/aarch64/cpu_barrier.hpp" +#include "cpu/platform.hpp" + +#include "cpu/aarch64/injectors/injector_utils.hpp" +#include "cpu/aarch64/injectors/jit_uni_binary_injector.hpp" +#include "cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/aarch64/jit_sve_1x1_conv_kernel.hpp" +#include "cpu/aarch64/jit_uni_1x1_conv_utils.hpp" + +#define GET_OFF(field) \ + static_cast(offsetof(jit_1x1_conv_call_s, field)) + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +using namespace dnnl::impl::format_tag; +using namespace dnnl::impl::prop_kind; +using namespace dnnl::impl::utils; + +template +jit_sve_1x1_conv_kernel::jit_sve_1x1_conv_kernel( + const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, + const memory_desc_t &dst_md) + : jcp(ajcp), attr_(attr) { + if (jcp.with_eltwise || jcp.with_binary) { + using namespace binary_injector; + static constexpr bool preserve_gpr = true; + static constexpr bool preserve_vmm = false; + static constexpr size_t helper_vmm_idx = 31; + const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; + static constexpr bool use_exact_tail_scalar_bcast = true; + + const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, + x14, x15, x13, preserve_gpr, preserve_vmm, + GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), + memory_desc_wrapper(dst_md), tail_size, k_load_dim_mask, + use_exact_tail_scalar_bcast}; + const static_params_t static_params { + this->param1, rhs_arg_static_params}; + + postops_injector_ = utils::make_unique< + injector::jit_uni_postops_injector_t>( + this, jcp.post_ops, static_params); + } +} + +template +void jit_sve_1x1_conv_kernel::bcast_loop(int load_loop_blk) { + + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + ldr(reg_bcast_loop_iter, ptr(X_SP, reg_bcast_loop_work_offt)); + + Label bcast_loop; + Label bcast_loop_tail; + Label large_tail; + + cmp_imm(reg_bcast_loop_iter, jcp.bcast_block, reg_tmp_imm); + b(LT, bcast_loop_tail); + + L(bcast_loop); + { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + if (i + 1 == num_substeps) L(large_tail); + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add_imm(aux1_reg_bcast_data, aux1_reg_bcast_data, + jcp.bcast_loop_bcast_substep, reg_tmp_imm); + add_imm(aux_reg_output_data, aux_reg_output_data, + jcp.bcast_loop_output_substep, reg_tmp_imm); + } else { + add_imm(aux1_reg_bcast_data, aux1_reg_bcast_data, + jcp.bcast_loop_bcast_step + - (num_substeps - 1) + * jcp.bcast_loop_bcast_substep, + reg_tmp_imm); + add_imm(aux_reg_output_data, aux_reg_output_data, + jcp.bcast_loop_output_step + - (num_substeps - 1) + * jcp.bcast_loop_output_substep, + reg_tmp_imm); + } + subs_imm(reg_bcast_loop_iter, reg_bcast_loop_iter, jcp.ur, + reg_tmp_imm); + } + cmp_imm(reg_bcast_loop_iter, jcp.bcast_block, reg_tmp_imm); + b(GE, bcast_loop); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + if (jcp.ur_tail >= jcp.ur) { + cmp_imm(reg_bcast_loop_iter, jcp.ur, reg_tmp_imm); + b(GE, large_tail); + } + if (jcp.ur_tail % jcp.ur) { + cmp(reg_bcast_loop_iter, 0); + b(LE, bcast_loop_tail_out); + reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur, 0, true); + L(bcast_loop_tail_out); + } + } +} + +template +Xbyak_aarch64::XReg jit_sve_1x1_conv_kernel::output_ptr( + const bool is_out_layout_nxc, const int i_load, const int i_ur, + Xbyak_aarch64::XReg addr) { + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + int i_load_shift = is_out_layout_nxc + ? jcp.load_block + : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block; + int i_ur_shift = is_out_layout_nxc ? jcp.load_dim : jcp.load_block; + int offset = (i_load * i_load_shift + i_ur * i_ur_shift) + * jcp.typesize_out; + EVEX_compress_addr(addr, X_TMP_0, aux_reg_output_data, offset); + } else { + int offset = jcp.typesize_out * jcp.load_block * i_ur; + mov(X_TMP_0, i_load); + mul(X_TMP_0, reg_output_stride, X_TMP_0); + add_imm(X_TMP_1, X_TMP_0, offset, X_TMP_2); + add(addr, aux_reg_output_data, X_TMP_1); + } + return addr; +} + +static int vreg_accum_idx( + const int load_loop_blk, const int i_load, const int i_ur) { + return (i_ur * load_loop_blk + i_load); +} + +template +static void iterate(const int load_loop_blk, const int ur, const bool mask_tail, + const F &fun) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_tail && i_load + 1 == load_loop_blk; + for (int i_ur = 0; i_ur < ur; ++i_ur) + fun(mask_flag, i_load, i_ur); + } +} +template +static void iterate(const int load_loop_blk, const int ur, const F &fun) { + iterate(load_loop_blk, ur, false, fun); +} + +template +void jit_sve_1x1_conv_kernel::apply_postops( + const bool is_out_layout_nxc, const int load_loop_blk, const int ur) { + injector_utils::vmm_index_set_t vmm_idxs; + if (jcp.with_binary) { + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; + const auto mask_tail = jcp.oc_without_padding % jcp.load_block; + iterate(load_loop_blk, ur, mask_tail, + [&](const bool mask_flag, const int i_load, const int i_ur) { + const auto vmm_idx + = vreg_accum_idx(load_loop_blk, i_load, i_ur); + vmm_idxs.emplace(vmm_idx); + + rhs_arg_params.vmm_idx_to_out_reg.emplace( + vmm_idx, aux_reg_output_data); + rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, + get_output_offset(is_out_layout_nxc, i_load, i_ur)); + if (mask_flag) + rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); + }); + + ldr(abi_param1, ptr(X_SP, reg_abi_param1_backup)); + + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + } else { + iterate(load_loop_blk, ur, + [&](const bool, const int i_load, const int i_ur) { + vmm_idxs.emplace( + vreg_accum_idx(load_loop_blk, i_load, i_ur)); + }); + postops_injector_->compute_vector_range(vmm_idxs); + } +} + +template +void jit_sve_1x1_conv_kernel::reduce_loop( + int load_loop_blk, int ur, int substep, bool wraparound) { + + const bool out_layout_nxc = is_out_layout_nxc(jcp); + const bool load_layout_nxc = is_load_layout_nxc(jcp); + const bool bcast_layout_nxc = is_bcast_layout_nxc(jcp); + const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block; + const int load_dim_tail = jcp.load_dim % jcp.load_block; + + auto vreg_load + = [=](int i_load) { return ZReg(ur * load_loop_blk + i_load); }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return ZReg(vreg_accum_idx(load_loop_blk, i_load, i_ur)); + }; + + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(X_DEFAULT_ADDR, X_TMP_0, reg_bias_data, + jcp.typesize_out * jcp.oc_block * i_load); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast, + const Xbyak_aarch64::XReg addr, + const Xbyak_aarch64::XReg tmp) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + int offt; + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + const int reduce_mul = bcast_layout_nxc ? jcp.reduce_dim + : jcp.reduce_loop_unroll; + offt = (i_reduce == jcp.reduce_loop_unroll) + ? (jcp.bcast_dim + i_ur) * reduce_mul + : i_ur * reduce_mul + i_reduce; + } else { + int rmul = bcast_layout_nxc ? jcp.ic : jcp.ic_block; + offt = i_reduce * rmul + i_ur; + } + return EVEX_compress_addr( + addr, tmp, aux_reg_bcast_data, jcp.typesize_in * offt, bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load, + const Xbyak_aarch64::XReg addr, + const Xbyak_aarch64::XReg tmp) { + int offt; + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + int lmul = jcp.load_block + * (load_layout_nxc ? 1 + : utils::rnd_up( + jcp.reduce_dim, jcp.reduce_block)); + int rmul = load_layout_nxc ? jcp.load_dim : jcp.load_block; + offt = i_load * lmul + u0 * rmul; + return EVEX_compress_addr(addr, tmp, aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + jcp.typesize_in * offt); + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_bias + && one_of(jcp.prop_kind, forward_training, forward_inference)) { + tst(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + b(EQ, init_zero); + + for (int i_load = 0; i_load < load_loop_blk; i_load++) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto vreg_acc = vreg_accum(i_load, i_ur); + if (i_load + 1 == load_loop_blk && load_dim_tail) + ld1w(vreg_acc.s, k_load_dim_mask / T_z, + ptr(bias_ptr(i_load))); + else + ld1w(vreg_acc.s, P_ALL_ONE / T_z, + ptr(bias_ptr(i_load))); + } + b(init_done); + } + + L(init_zero); + + /* Zero clear */ + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + eor(r.d, r.d, r.d); + } + L(init_done); + }; + + auto store = [=]() { + Label store_noadd; + if (!jcp.with_sum) { + tst(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + b(NE, store_noadd); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + auto r = vreg_accum(i_load, i_ur).s; + if (i_load + 1 == load_loop_blk && load_dim_tail) + ld1w(zreg_tmp.s, k_load_dim_mask / T_z, + ptr(output_ptr(out_layout_nxc, i_load, i_ur, + X_DEFAULT_ADDR))); + else + ld1w(zreg_tmp.s, P_ALL_ONE / T_z, + ptr(output_ptr(out_layout_nxc, i_load, i_ur, + X_DEFAULT_ADDR))); + fadd(r, r, zreg_tmp.s); + } + + L(store_noadd); + if (jcp.with_eltwise || jcp.with_binary) { + Label store_nopostops; + tst(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + b(EQ, store_nopostops); + + apply_postops(out_layout_nxc, load_loop_blk, ur); + + L(store_nopostops); + } + + auto store_output = [=](bool output_is_aligned) { + const auto mask_flag = load_dim_tail; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + auto vreg_acc = vreg_accum(i_load, i_ur); + // for nxc_layout-bwd_w, weights are still padded and the + // output_ptr here can be uninitialized scratchpad. + // To ensure final output (after reduction) is zero-padded, + // here we zero-pad output by omitting the mask. + if (jcp.prop_kind != backward_weights + && (i_load + 1 == load_loop_blk && mask_flag)) { + st1w(vreg_acc.s, k_load_dim_mask / T_z, + ptr(output_ptr(out_layout_nxc, i_load, i_ur, + X_DEFAULT_ADDR))); + } else { + st1w(vreg_acc.s, P_ALL_ONE / T_z, + ptr(output_ptr(out_layout_nxc, i_load, i_ur, + X_DEFAULT_ADDR))); + } + } + } + }; + + Label unaligned_store, end_store; + tst(aux_reg_output_data, cpu_isa_traits::vlen - 1); + b(NE, unaligned_store); + store_output(true); + b(end_store); + L(unaligned_store); + { store_output(false); } + L(end_store); + }; + + auto fma_block = [=](bool last_block) { + const int i_reduce_end = reduce_dim_tail && last_block + ? reduce_dim_tail + : jcp.reduce_loop_unroll; + + for (int i_reduce = 0; i_reduce < i_reduce_end; i_reduce++) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + auto vreg = vreg_load(i_load); + if (i_load + 1 == load_loop_blk && load_dim_tail) + ld1w(vreg.s, k_load_dim_mask / T_z, + ptr(load_ptr(i_reduce, i_load, X_DEFAULT_ADDR, + X_TMP_0))); + else + ld1w(vreg.s, P_ALL_ONE / T_z, + ptr(load_ptr(i_reduce, i_load, X_DEFAULT_ADDR, + X_TMP_0))); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (jcp.expl_bcast && load_loop_blk > 1) { + ldr(W_TMP_0, + ptr(bcast_ptr(i_reduce, i_ur, false, X_DEFAULT_ADDR, + X_TMP_1))); + dup(vreg_bcast.s, W_TMP_0); + } + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + auto vreg_acc = vreg_accum(i_load, i_ur); + if (i_load + 1 == load_loop_blk && load_dim_tail) { + ld1rw(zreg_tmp.s, P_ALL_ONE, + ptr(bcast_ptr(i_reduce, i_ur, true, + X_DEFAULT_ADDR, X_TMP_0))); + fmla(vreg_acc.s, k_load_dim_mask / T_m, + vreg_load(i_load).s, zreg_tmp.s); + } else if (jcp.expl_bcast && load_loop_blk > 1) { + fmla(vreg_acc.s, P_ALL_ONE / T_m, vreg_load(i_load).s, + vreg_bcast.s); + } else { + ld1rw(zreg_tmp.s, P_ALL_ONE, + ptr(bcast_ptr(i_reduce, i_ur, true, + X_DEFAULT_ADDR, X_TMP_0))); + fmla(vreg_acc.s, P_ALL_ONE / T_m, vreg_load(i_load).s, + zreg_tmp.s); + } + } + } + } + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + subs_imm(reduce_loop_iter, reduce_loop_iter, jcp.reduce_loop_unroll, + reg_tmp_imm); + b(LE, reduce_loop_tail); + + L(reduce_loop); + { + fma_block(false); + add_imm(aux_reg_bcast_data, aux_reg_bcast_data, + jcp.reduce_loop_bcast_step, reg_tmp_imm); + add_imm(aux_reg_load_data, aux_reg_load_data, jcp.reduce_loop_load_step, + reg_tmp_imm); + subs_imm(reduce_loop_iter, reduce_loop_iter, jcp.reduce_loop_unroll, + reg_tmp_imm); + b(GT, reduce_loop); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +template +void jit_sve_1x1_conv_kernel::generate() { + preamble(); + + sub_imm(X_SP, X_SP, stack_space_needed, X_TMP_0); + if (jcp.with_binary) { + const auto zeroed_reg = x15; + eor(zeroed_reg, zeroed_reg, zeroed_reg); + str(zeroed_reg, ptr(X_SP, reg_binary_post_op_acc_off)); + str(param1, ptr(X_SP, reg_abi_param1_backup)); + } + + /* Pointers indicate weight, input, and output data */ + ldr(reg_bcast_data, ptr(abi_param1, GET_OFF(bcast_data))); // Input + ldr(reg_load_data, ptr(abi_param1, GET_OFF(load_data))); // Weight + ldr(reg_output_data, ptr(abi_param1, GET_OFF(output_data))); // Output + + /* Pointer indicates bias data if the layer has bias option */ + if (jcp.with_bias) ldr(reg_bias_data, ptr(abi_param1, GET_OFF(bias_data))); + + /* Get workloads of each loop */ + ldr(reg_load_loop_work, ptr(abi_param1, GET_OFF(load_dim))); + ldr(reg_bcast_loop_work, ptr(abi_param1, GET_OFF(bcast_dim))); + str(reg_bcast_loop_work, ptr(X_SP, reg_bcast_loop_work_offt)); + ldr(reg_reduce_loop_work, ptr(abi_param1, GET_OFF(reduce_dim))); + + /* A flag for controlling reduce loop */ + ldr(reg_reduce_pos_flag, ptr(abi_param1, GET_OFF(first_last_flag))); + if (jcp.prop_kind == backward_weights) + ldr(reg_output_stride, ptr(param1, GET_OFF(output_stride))); + + const int load_dim_tail + = (one_of(jcp.prop_kind, forward_training, forward_inference) + ? jcp.oc_without_padding + : jcp.load_dim) + % jcp.load_block; + if (load_dim_tail) { + const WReg w_tmp(reg_load_dim_tail_mask.getIdx()); + mov_imm(w_tmp, (1 << load_dim_tail) - 1); + st1w(zreg_tmp1.s, P_ALL_ONE / T_z, ptr(X_TRANSLATOR_STACK, -1, MUL_VL)); + index(zreg_tmp.s, 0, 1); + mov(zreg_tmp1.s, 1); + lsl(zreg_tmp1.s, P_ALL_ONE / T_m, zreg_tmp.s); + dup(zreg_tmp.s, w_tmp); + and_(zreg_tmp.d, zreg_tmp.d, zreg_tmp1.d); + cmpne(k_load_dim_tail_mask.s, P_ALL_ONE, zreg_tmp.s, 0); + ldr(zreg_tmp1, ptr(X_TRANSLATOR_STACK, -1, MUL_VL)); + } + + auto load_loop_body = [=](int load_loop_blk) { + if (load_dim_tail) { + eor(k_load_dim_mask.b, P_ALL_ONE / T_z, k_load_dim_mask.b, + k_load_dim_mask.b); + not_(k_load_dim_mask.b, P_ALL_ONE / T_z, k_load_dim_mask.b); + } + subs_imm(reg_load_loop_work, reg_load_loop_work, + load_loop_blk * jcp.load_loop_iter_step, reg_tmp_imm); + if (load_dim_tail) { + Label no_update_mask; + b(GE, no_update_mask); + mov(k_load_dim_mask.b, k_load_dim_tail_mask.b); + L(no_update_mask); + } + bcast_loop(load_loop_blk); + add_imm(reg_load_data, reg_load_data, + load_loop_blk * jcp.load_loop_load_step, reg_tmp_imm); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add_imm(reg_bias_data, reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_out, + reg_tmp_imm); + add_imm(reg_output_data, reg_output_data, + load_loop_blk * jcp.load_block * jcp.typesize_out + * (is_out_layout_nxc(jcp) + ? 1 + : (jcp.with_dw_conv + ? jcp.ow + : jcp.bcast_dim)), + reg_tmp_imm); + if (jcp.with_binary) { + const auto oc_off_oprnd = aux_reg_load_data; + ldr(oc_off_oprnd, ptr(X_SP, reg_binary_post_op_acc_off)); + add_imm(oc_off_oprnd, oc_off_oprnd, + jcp.load_block * load_loop_blk, X_TMP_0); + str(oc_off_oprnd, ptr(X_SP, reg_binary_post_op_acc_off)); + } + break; + case backward_data: + add_imm(reg_output_data, reg_output_data, + load_loop_blk * jcp.load_block * jcp.typesize_out + * (is_out_layout_nxc(jcp) ? 1 : jcp.bcast_dim), + reg_tmp_imm); + break; + case backward_weights: + for (int i_load = 0; i_load < load_loop_blk; i_load++) + add(reg_output_data, reg_output_data, reg_output_stride); + break; + default: assert(!"invalid prop_kind"); + } + }; + + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + + Label load_loop_blk[7]; + + // with an implicit load_loop_block {6, 5, 4, 3, 2, 1} + static const int ur_cases_fma_embd_bcast[] = {2, 4, 5, 8, 14, 32}; + static const int ur_cases_fma_expl_bcast[] = {2, 5, 6, 9, 14, 32}; + + const int size_ur_cases_fma = jcp.expl_bcast + ? sizeof(ur_cases_fma_expl_bcast) + : sizeof(ur_cases_fma_embd_bcast); + + const int *ur_cases_fma = jcp.expl_bcast ? ur_cases_fma_expl_bcast + : ur_cases_fma_embd_bcast; + const int *ur_cases = ur_cases_fma; + const int num_ur_cases = size_ur_cases_fma / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) { + cmp_imm(reg_load_loop_work, simd_w * (label_idx + 1), reg_tmp_imm); + b(LE, load_loop_blk[label_idx]); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) { + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + b(LE, load_loop_blk[num_ur_cases]); + } + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp_imm(reg_load_loop_work, 2 * label_idx * simd_w, + reg_tmp_imm); + b(EQ, load_loop_blk[label_idx - 1]); + } + cmp_imm(reg_load_loop_work, label_idx * simd_w, reg_tmp_imm); + b(GT, load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx >= 0; --idx) { + cmp_imm(reg_load_loop_work, simd_w * (idx + 1), reg_tmp_imm); + b(GE, load_loop_blk[idx]); + } + if (ur_idx < num_ur_cases - 2) { + cmp_imm(reg_load_loop_work, simd_w, reg_tmp_imm); + b(LE, load_loop_blk[0]); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add_imm(X_SP, X_SP, stack_space_needed, X_TMP_0); + + postamble(); + if (jcp.with_eltwise) postops_injector_->prepare_table(); +} + +template +status_t jit_sve_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + + /* arch check */ + if (!mayiuse(isa_)) { return status::unimplemented; } + jcp.isa = isa_; + + if (!everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(), + dst_d.data_type())) { + return status::unimplemented; + } + + jcp.nthr = nthreads; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int ndims = src_d.ndims(); + /* Forward_[training, inference], backward_[data, weight] */ + jcp.prop_kind = cd.prop_kind; + + /* Check group option */ + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + /* Batchsize */ + jcp.mb = src_d.dims()[0]; + /* Channel */ + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.oc = jcp.oc_without_padding; + jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; + jcp.ic = jcp.ic_without_padding; + /* D, H, W */ + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + /* Kernel size */ + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + /* padding params */ + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + /* stride params */ + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + /* bias info */ + jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, + format_kind::undef, cd.diff_bias_desc.format_kind) + != format_kind::undef; + + /* Spatials */ + jcp.os = jcp.od * jcp.oh * jcp.ow; + jcp.is = jcp.id * jcp.ih * jcp.iw; + + /* Depthwise conv check */ + const auto &post_ops = attr.post_ops_; + const int dw_conv_ind = post_ops.find(primitive_kind::convolution); + jcp.with_dw_conv = dw_conv_ind != -1; + if (jcp.with_dw_conv) { return status::unimplemented; } + + /* Post operation check */ + // Using dw_conv_ind as upper-bound below, as post-ops after it will be + // handled in depthwise convolution. + const int eltwise_ind + = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + if (dst_d.data_type() == data_type::s32) { + return status::unimplemented; + } + } + + const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); + jcp.with_sum = sum_ind != -1; + + const int binary_ind + = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); + jcp.with_binary = binary_ind != -1; + + if (dw_conv_ind >= 0) { + // dw_conv and post_ops after it are handled externally, so skip them + jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), + post_ops.entry_.cbegin() + dw_conv_ind); + } else { + jcp.post_ops = post_ops; + } + + /* Data format check */ + const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); + bool is_data_layout_nxc; + format_tag_t required_dat_tag; + + switch (isa_) { + case sve_512: { + const auto dat_tag_nCx16c + = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); + is_data_layout_nxc + = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); + required_dat_tag + = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; + break; + } + case sve_256: { + const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); + jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); + is_data_layout_nxc + = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); + required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; + break; + } + default: break; + } + /* Channel padding check */ + bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + /* Input and output must be multiple of simd_w */ + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + using namespace injector; + + static constexpr bool sum_at_pos_0_only = true; + static constexpr bool sum_requires_scale_one = true; + static constexpr bool sum_requires_zp_zero = true; + const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(jcp.isa, + {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + sum_requires_scale_one, sum_requires_zp_zero)); + if (!post_ops_ok_) { return status::unimplemented; } + + bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == required_dat_tag + && jcp.dst_tag == required_dat_tag + && IMPLICATION(!is_data_layout_nxc, + jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0) + && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1 + && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1 && jcp.ow == jcp.iw + && jcp.oh == jcp.ih && jcp.od == jcp.id; // enforce rpad=0 + if (!args_ok) { return status::unimplemented; } + + /* Channel blocking size is simd_w */ + jcp.ic_block = jcp.oc_block = simd_w; + + switch (isa_) { + case sve_512: { + jcp.ver = ver_sve_512; + break; + } + case sve_256: { + jcp.ver = ver_sve_256; + break; + } + default: break; + } + + if (everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(), + dst_d.data_type())) { + const int is_bwd_d = jcp.prop_kind == backward_data; + + /* Set weight data layout tag */ + format_tag_t wei_tag; + switch (isa_) { + case sve_512: { + wei_tag = with_groups + ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i, + gOIhw16i16o, gIOhw16o16i, gOIdhw16i16o, + gIOdhw16o16i) + : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, + OIhw16i16o, IOhw16o16i, OIdhw16i16o, + IOdhw16o16i); + break; + } + case sve_256: { + wei_tag = with_groups + ? pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gIOw8o8i, + gOIhw8i8o, gIOhw8o8i, gOIdhw8i8o, gIOdhw8o8i) + : pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, IOw8o8i, + OIhw8i8o, IOhw8o8i, OIdhw8i8o, IOdhw8o8i); + break; + } + default: break; + } + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + + if (jcp.wei_tag != wei_tag) return status::unimplemented; + + // jcp.fma_step = 1; + jcp.typesize_in = sizeof(prec_traits_t::type); + jcp.typesize_out = sizeof(prec_traits_t::type); + } else { + // TODO: currently, only support fp32; + return status::unimplemented; + } + + /* once all the formats are set, check the padding consistency */ + + if (!is_data_layout_nxc) { + args_ok = true && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) { return status::unimplemented; } + } + + // TODO: Optimize bellow params + const int SMALL_SPATIAL = 10; + const int BIG_SPATIAL = 65; + const int BIG_REDUCE_DIM = 1024; + const int BIG_LOAD_DIM = (jcp.reduce_dim >= 512) ? 256 : 512; + + int load_blocking {0}; + int load_blocking_max {0}; + int bcast_blocking {0}; + int bcast_blocking_max {0}; + int reduce_blocking {0}; + int reduce_blocking_max {0}; + + jcp.load_grp_count = 1; + + // TODO: mov check funcs into platform files + const int L1_capacity + = platform::get_per_core_cache_size(1) / sizeof(float); + const int L2_size = platform::get_per_core_cache_size(2) / sizeof(float); + const int L2_capacity = (L2_size * 3) / 4; + + /* FWD, BWD data */ + + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + /* Forward */ + if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur); + jcp.reduce_dim = jcp.ic; // src channel + jcp.reduce_block = jcp.ic_block; // src simd_w + + jcp.load_dim = jcp.oc; // dst channel + jcp.load_block = jcp.oc_block; // dst simd_W + + jcp.bcast_dim = jcp.is; // src H*W + } else { + /* Backward data */ + jcp.reduce_dim = jcp.oc; // src channel + jcp.reduce_block = jcp.oc_block; // src simd_w + + jcp.load_dim = jcp.ic; // dst channel + jcp.load_block = jcp.ic_block; // dst simd_w + + jcp.bcast_dim = jcp.os; // src H*W + } + /* # of consecutive channel elements */ + jcp.reduce_loop_unroll = jcp.reduce_block; + + /* Offset to move to the next 16 input channel elements with the same H*W position */ + jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll + * (is_data_layout_nxc ? 1 : jcp.bcast_dim) * jcp.typesize_in; + + /* Offset: 16o*16i (filter) */ + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + + /* Offset: I/16 * 16o */ + jcp.load_loop_load_step + = (utils::rnd_up(jcp.reduce_dim, jcp.reduce_block)) + * jcp.load_block * jcp.typesize_in; + + /* adjusting registry blocking */ + int max_regs, min_regs, size_threshold; + + /* spatial : H*D of dst */ + const int spatial + = (one_of(jcp.prop_kind, forward_training, forward_inference)) + ? jcp.od * jcp.oh // forward + : jcp.id * jcp.ih; // backward + + if ((8 * jcp.mb) / jcp.nthr >= 1 + // NHWC perf: RN50 mb=1 + || (is_data_layout_nxc && jcp.mb == 1)) { + max_regs = 9; // max # of ur_w + min_regs = 6; // min # of ur_w + size_threshold = 14; + jcp.expl_bcast = true; + + /* + * H*D of dst > SMALL_SPATIAL + */ + if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM + && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL + && jcp.reduce_dim < 256) { + max_regs = 6; + min_regs = 5; + } + } else { + max_regs = 30; + min_regs = 9; + size_threshold = 14; + jcp.expl_bcast = false; + jcp.use_vmovntps = true; + } + jcp.ur = 1; + + for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { + /* + * H*D of dst >= size_threshold, (H*D of dst) % ur_w == 0 + * or + * H*D of dst < size_threshold, (H*W of dst) % ur_w == 0 + */ + if ((spatial >= size_threshold && spatial % ur_w == 0) + || (spatial < size_threshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + + if (jcp.ur == 1) { + // If ur = 1, then min(max_regs, H*W of dst) + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i--) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) break; + } + } + } + jcp.bcast_block = jcp.ur; // block size of bcast (input data) + /* Number of steps for the dst address to output, used in bcast_loop() */ + jcp.bcast_loop_output_step = jcp.ur * jcp.typesize_out + * (is_data_layout_nxc ? jcp.load_dim : jcp.load_block); + jcp.bcast_loop_output_substep = -1; // unused + + /* Number of steps for the src address to be broadcasted in bcast_loop() */ + jcp.bcast_loop_bcast_step = jcp.ur * jcp.typesize_in + * (is_data_layout_nxc ? jcp.reduce_dim : jcp.reduce_block); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_iter_step = jcp.load_block; + + if (jcp.prop_kind == backward_data) + jcp.loop_order = loop_lbr; + else + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + int nb_load = div_up(jcp.load_dim, jcp.load_block); + if (is_data_layout_nxc) { + reduce_blocking = jcp.reduce_dim; + } else if (jcp.expl_bcast) { + if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL + && spatial < BIG_SPATIAL) { + reduce_blocking = nstl::min(jcp.reduce_dim, 80); + } else if (spatial > SMALL_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 512); + else + reduce_blocking = nstl::min(jcp.reduce_dim, 256); + } else { + reduce_blocking = nb_reduce; + if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + else if (spatial > SMALL_SPATIAL + && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 8; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + // Check input data cache aliasing. + // For other ISA constants may be updated. + // 64 * 1024 is chosen due to 1MB L2 16-way cache. + // 7 is empirical value. It is about half of 16. + // So we leave about half of the set for other data - weights, dst + int way_size = (16 * 1024) / jcp.typesize_in; + int max_hits = 7; + if (!is_data_layout_nxc + && jcp.bcast_dim * reduce_blocking > way_size * max_hits) { + int nrb = reduce_blocking / simd_w; + int sp = jcp.bcast_dim; + int wl = way_size / simd_w; + for (int start_off = 0; start_off < jcp.ur; start_off++) { + for (int off = start_off, hits = 0; off < sp * nrb; off += wl) { + if (off % sp >= jcp.ur || ++hits < max_hits) continue; + int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp); + reduce_blocking + = nstl::min(reduce_blocking, max_r_blocking); + break; + } + } + } + + if (reduce_blocking < jcp.reduce_dim) { + if (jcp.prop_kind == backward_data) + jcp.loop_order = reduce_src ? loop_lbr : loop_rlb; + else + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + } + load_blocking = jcp.load_dim; + + /* Number of weight elements to be loaded for dest */ + int load_size = jcp.load_dim * jcp.reduce_dim; + /* Number of elements to be broadcasted from src */ + auto bcast_size + = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; + + /* 12 cores per CMG */ + if (jcp.nthr <= 12 && jcp.mb < jcp.nthr + && nb_load * nb_bcast > jcp.nthr) { + // Some heuristic here + float calc_koef = 0.01, best_cost = FLT_MAX; + int n_lgc = jcp.nthr; + float ratio = (float)load_size / (float)bcast_size; + int best_lgc = ratio > 1 ? n_lgc : 1; + auto calc_job_cost = [&](int lb, int tg, float mem_k) { + int bb_size = jcp.mb * div_up(nb_bcast, tg); + float calc_size = (float)(bb_size * jcp.ur) + * (lb * jcp.load_block) * jcp.reduce_dim; + float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) + * jcp.reduce_dim; + return calc_koef * calc_size + mem_k * mem_size; + }; + for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { + lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; + int min_lb = nb_load / lgc; + int max_lb = div_up(nb_load, lgc); + int min_tg = jcp.nthr / lgc; + int max_tg = div_up(jcp.nthr, lgc); + // Some heuristic here + float mem_koef = (max_tg == 1) ? 1.f : 1.3f; + float job_cost = 0.; + if (jcp.nthr % lgc < nb_load % lgc) { + job_cost = calc_job_cost(max_lb, min_tg, mem_koef); + } else { + auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); + auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); + job_cost = nstl::max(job_cost1, job_cost2); + } + + if (job_cost < best_cost) { + best_lgc = lgc; + best_cost = job_cost; + } + } + jcp.load_grp_count = best_lgc; + load_blocking + = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } else { + jcp.load_grp_count + = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider(jcp.nthr, jcp.load_grp_count, + 2 * jcp.load_grp_count, false); + } + if (jcp.expl_bcast && jcp.bcast_dim <= 64 && load_size >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= 49 && jcp.mb <= jcp.nthr + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); + load_blocking = jcp.load_block; + } + + auto get_thr_eff = [=](int load_chunk, int nthr) { + int lgc = div_up(nb_load, load_chunk); + int thr_per_grp = div_up(nthr, lgc); + int bcast_per_thr + = div_up(jcp.mb * nb_bcast, thr_per_grp) * jcp.bcast_block; + int load_per_thr = load_chunk * simd_w; + float data_norm = (bcast_per_thr + load_per_thr) / 2.f; + float data_eff + = (bcast_per_thr * load_per_thr) / (data_norm * data_norm); + float thr_eff_over_grp + = (float)nstl::max(1, nthr / lgc) / div_up(nthr, lgc); + float thr_eff_in_grp = ((float)jcp.mb * nb_bcast) + / rnd_up(jcp.mb * nb_bcast, thr_per_grp); + float thr_eff = thr_eff_over_grp * thr_eff_in_grp; + float load_eff = (float)nb_load / rnd_up(nb_load, lgc); + float overall_eff = data_eff + thr_eff + load_eff; + return overall_eff; + }; + + auto get_load_chunk = [=](int nthr) { + float best_eff = -1.0f; + int best_lgc = 1; + float eff; + + for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) { + int lgc = div_up(nb_load, load_chunk); + if (lgc > nthr) continue; + eff = get_thr_eff(load_chunk, nthr); + if (eff > best_eff) { + best_eff = eff; + best_lgc = lgc; + } + } + return best_lgc; + }; + + /* adjust the thread decomposition + * to improve the thr_eff for small problem size + * the threshold 8192 is empirical + * TODO: Threshold can be increase for init stride > 1*/ + if (sizeof(float) * bcast_size < 8192 && jcp.mb < jcp.nthr + && nb_load * nb_bcast < jcp.nthr) { + float best_thr_eff = -1.0f; + float thr_eff = -1.0f; + int overall_lgc = jcp.load_grp_count; + int lgc = 1; + int best_nthr = jcp.nthr; + int end_nthr = with_groups ? jcp.ngroups : 1; + for (int nthr = jcp.nthr / 2; nthr >= end_nthr; nthr--) { + lgc = get_load_chunk(nthr); + thr_eff = get_thr_eff(lgc, nthr); + if (best_thr_eff < thr_eff) { + best_thr_eff = thr_eff; + overall_lgc = lgc; + best_nthr = nthr; + } + } + jcp.nthr = best_nthr; + jcp.load_grp_count = overall_lgc; + load_blocking + = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } + + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(jcp.nthr, jcp.load_grp_count)) + * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking + - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + // NHWC perf + if (is_data_layout_nxc) bcast_blocking = jcp.bcast_block; + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur; + + } else if (jcp.prop_kind == backward_weights) { /* BWD weight */ + + jcp.reduce_dim = jcp.is; + + jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true); + if (jcp.reduce_dim % jcp.reduce_block != 0) + jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false); + if (jcp.reduce_block > 256) { jcp.reduce_block = 1; } + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + if (jcp.reduce_block <= 19 && + // maskrcnn optimization for nxc; don't reduce ur when ocb<=1 + !(is_data_layout_nxc && jcp.load_dim <= jcp.load_block)) { + // if reduce_block is big then generated JIT code may be big + // for small values of ur because reduce_loop_unroll = reduce_block + jcp.ur = jcp.bcast_block / 2; + jcp.expl_bcast = true; + } else { + jcp.ur = jcp.bcast_block; + jcp.expl_bcast = false; + } + + jcp.ur_tail = jcp.bcast_dim % jcp.bcast_block; + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step = jcp.typesize_in * jcp.reduce_loop_unroll + * (is_data_layout_nxc ? jcp.ic : jcp.ic_block); + jcp.reduce_loop_load_step = jcp.typesize_in * jcp.reduce_loop_unroll + * (is_data_layout_nxc ? jcp.oc : jcp.oc_block); + + jcp.bcast_loop_output_step + = jcp.oc_block * jcp.ic_block * jcp.typesize_out; + jcp.bcast_loop_output_substep + = jcp.oc_block * jcp.ur * jcp.typesize_out; + jcp.bcast_loop_bcast_step = jcp.ic_block + * (is_data_layout_nxc ? 1 + : utils::rnd_up( + jcp.reduce_dim, jcp.reduce_block)) + * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in; + + jcp.load_loop_load_step = jcp.typesize_in * jcp.oc_block + * (is_data_layout_nxc ? 1 : jcp.os); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + balance(jcp); + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + load_blocking = best_divider(load_blocking, 16, load_blocking, false); + load_blocking *= jcp.load_block; + + load_blocking_max = load_blocking; + assert(IMPLICATION( + !is_data_layout_nxc, jcp.load_dim % load_blocking == 0)); + + int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + int min_bcast_blocking = 5; + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + bcast_blocking = best_divider( + bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(IMPLICATION( + !is_data_layout_nxc, jcp.bcast_dim % bcast_blocking == 0)); + + // for reduction balance + if (is_data_layout_nxc && jcp.reduce_dim >= BIG_SPATIAL * BIG_SPATIAL + && jcp.load_dim >= BIG_LOAD_DIM / 2) { + reduce_blocking = rnd_up(nstl::min(jcp.ow, 256), jcp.reduce_block); + } else { + int max_reduce_blocking + = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim); + int min_reduce_blocking = nstl::min( + L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); + reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking, + max_reduce_blocking, true); + reduce_blocking + = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block), + jcp.reduce_block); + } + + reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); + } else { + return status::unimplemented; + } + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + + if (!is_data_layout_nxc) { + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + } + + assert(jcp.bcast_block % jcp.ur == 0); + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = utils::div_up(load_blocking, jcp.load_block); + jcp.nb_load_blocking_max = utils::div_up(load_blocking_max, jcp.load_block); + jcp.nb_reduce_blocking = utils::div_up(reduce_blocking, jcp.reduce_block); + jcp.nb_reduce_blocking_max + = utils::div_up(reduce_blocking_max, jcp.reduce_block); + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + return status::success; +} +template +void jit_sve_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + + using namespace dnnl::impl::memory_tracking::names; + + // Fox nxc layout bias is padded only for bwd_wb direction, as bias + // reduction kernels can't handle tails yet. + if (jcp.with_bias && jcp.prop_kind != backward_data + && (jcp.oc != jcp.oc_without_padding // blocked layout + || (jcp.prop_kind == backward_weights // nxc layout + && jcp.oc % jcp.oc_block != 0))) { + + const size_t nelems_padded_bias + = jcp.ngroups * utils::rnd_up(jcp.oc, jcp.oc_block); + scratchpad.book( + key_conv_padded_bias, nelems_padded_bias, jcp.typesize_out); + } + + if (jcp.prop_kind == backward_weights) { + const size_t wei_size = (size_t)jcp.ngroups + * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block); + scratchpad.book(key_conv_wei_reduction, wei_size * (jcp.nthr_mb - 1), + jcp.typesize_out); + } +} + +/* BWD W*/ +template +void jit_sve_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp) { + int nthreads = jcp.nthr; + // initialize jcp reduction threading properties + jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; + if (nthreads < jcp.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + const int nb_load = div_up(jcp.load_dim, jcp.load_block); + const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + jcp.nthr_g = jcp.ngroups; + const int nthr = nthreads / jcp.nthr_g; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level + * optimizer tries to minimize memory consumption. few notes: (n1) + * unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + int bcast_koeff = 1; + int load_koeff = 1; + int output_koeff = 12; + return 0 + + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_bcast, nthr_ic_b) + * jcp.ic_block * jcp.reduce_block / jcp.stride_h + / jcp.stride_w /* (n1) */ + + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) + * jcp.oc_block * jcp.reduce_block + + (size_t)output_koeff /* (n2) */ + * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.oc_block; + }; + + int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; + auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); + for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); + for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); + auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + jcp.nthr_mb = nthr_mb; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + } + } + if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) + jcp.nthr_mb = nstl::min(jcp.mb, nthreads); + + jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; + assert(jcp.nthr <= nthreads); +} + +template struct jit_sve_1x1_conv_kernel; +template struct jit_sve_1x1_conv_kernel; +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/jit_sve_1x1_conv_kernel.hpp b/src/cpu/aarch64/jit_sve_1x1_conv_kernel.hpp new file mode 100644 index 00000000000..5bfd5db50de --- /dev/null +++ b/src/cpu/aarch64/jit_sve_1x1_conv_kernel.hpp @@ -0,0 +1,205 @@ +/******************************************************************************* +* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2024 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_JIT_SVE_1x1_CONV_KERNEL_HPP +#define CPU_AARCH64_JIT_SVE_1x1_CONV_KERNEL_HPP + +#include "common/c_types_map.hpp" +#include "common/memory_tracking.hpp" + +#include "cpu/aarch64/injectors/jit_uni_postops_injector.hpp" +#include "cpu/aarch64/jit_generator.hpp" +#include "cpu/aarch64/jit_op_imm_check.hpp" +#include "cpu/aarch64/jit_primitive_conf.hpp" + +using namespace Xbyak_aarch64; + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +/* Get vector offsets, ofs / VL(eg VL: 512bits = 64Bytes ) */ +#define VL64_OFS(ofs) (ofs >> cpu_isa_traits::vlen_shift) + +template +struct jit_sve_1x1_conv_kernel : public jit_generator { + jit_sve_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp, + const primitive_attr_t &attr, const memory_desc_t &dst_md); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sve_1x1_conv_kernel) + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + +private: + using reg64_t = const XReg; + + /* Flags and loop variables */ + reg64_t reg_reduce_pos_flag = x1; + reg64_t reduce_loop_iter = x2; + reg64_t reg_bcast_loop_iter = x3; + reg64_t reg_relu_ns = x20; // For forward + reg64_t reg_output_stride = x20; // For backward + + /* Pointer */ + reg64_t reg_bcast_data = x5; // Input + reg64_t reg_load_data = x6; // Weight + reg64_t reg_output_data = x7; // Output + reg64_t reg_bias_data = x8; // bias + reg64_t aux1_reg_bcast_data = x9; + reg64_t aux_reg_output_data = x10; + reg64_t aux_reg_bcast_data = x11; + reg64_t aux_reg_load_data = x12; + reg64_t reg_prev_bcast_addr + = x13; // Input: The reg keeps addr accessed by previous ldr inst + reg64_t reg_prev_out_addr + = x14; // Output: The reg keeps addr accessed by previous ldr or str inst + + /* Workload */ + reg64_t reg_load_loop_work = x15; + reg64_t reg_reduce_loop_work = x16; + reg64_t reg_bcast_loop_work = x17; + + /* Temporay registers */ + reg64_t reg_tmp_imm = x27; // tmp for add_imm + reg64_t reg_tmp_ofs = x19; // tmp reg to calc bwd wei offset in out_load + + reg64_t reg_load_dim_tail_mask = aux_reg_load_data; + + std::unique_ptr> + postops_injector_; + + constexpr static int isa_simd_width_ + = cpu_isa_traits::vlen / sizeof(float); + + ZReg vreg_bcast = ZReg(31); + PReg k_load_dim_mask = p2; + PReg k_load_dim_tail_mask = p3; + ZReg zreg_tmp = ZReg(31); + ZReg zreg_tmp1 = ZReg(30); + + constexpr static int reg64_size_ = sizeof(int64_t); + constexpr static int reg_bcast_loop_work_offt = 0; + constexpr static int reg_binary_post_op_acc_off = 1 * reg64_size_; + constexpr static int reg_abi_param1_backup = 2 * reg64_size_; + constexpr static int stack_space_needed = 3 * reg64_size_; + + template + Xbyak_aarch64::XReg EVEX_compress_addr(const Xbyak_aarch64::XReg &addr, + const Xbyak_aarch64::XReg &x_tmp, Xbyak_aarch64::XReg base, + T raw_offt, bool bcast = false) { + + assert(raw_offt <= INT_MAX); + auto offt = static_cast(raw_offt); + + add_imm(addr, base, offt, x_tmp); + if (bcast) { + // addr is the same as addr when bcast is false. + } + return addr; + } + + void prefetch( + const std::string prfop, int level, reg64_t in, long long int ofs) { + bool for_load = false; + if (prfop == "LD") { + for_load = true; + } else if (prfop == "ST") { + for_load = false; + } else { + assert(!"invalid prfop"); + } + + bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false; + if (cacheline_aligned == true) { + Prfop op; + switch (level) { + case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break; + case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break; + case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break; + default: assert(!"invalid prfop"); break; + } + + if (prfm_imm_check(ofs)) { + prfm(op, ptr(in, static_cast(ofs))); + } else { + add_imm(reg_tmp_ofs, in, ofs, reg_tmp_imm); + prfm(op, ptr(reg_tmp_ofs)); + } + } else { + PrfopSve op_sve; + switch (level) { + case 1: + op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE; + break; + case 2: + op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE; + break; + case 3: + op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE; + break; + default: assert(!"invalid prfop"); break; + } + + if (prfw_imm_check(ofs)) { + prfw(op_sve, P_ALL_ONE, + ptr(in, static_cast(VL64_OFS(ofs)))); + } else { + add_imm(reg_tmp_ofs, in, ofs, reg_tmp_imm); + prfw(op_sve, P_ALL_ONE, ptr(reg_tmp_ofs)); + } + } + } + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate() override; + static void balance(jit_1x1_conv_conf_t &jcp); + + inline size_t get_output_offset( + const bool is_out_layout_nxc, const int i_load, const int i_ur) { + const size_t i_load_shift = is_out_layout_nxc + ? jcp.load_block + : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block; + const size_t i_ur_shift + = is_out_layout_nxc ? jcp.load_dim : jcp.load_block; + return jcp.typesize_out * (i_load * i_load_shift + i_ur * i_ur_shift); + } + + Xbyak_aarch64::XReg output_ptr(const bool out_layout_nxc, const int i_load, + const int i_ur, Xbyak_aarch64::XReg addr); + void apply_postops(const bool is_out_layout_nxc, const int load_loop_blk, + const int ur); +}; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/aarch64/jit_sve_1x1_convolution.cpp b/src/cpu/aarch64/jit_sve_1x1_convolution.cpp new file mode 100644 index 00000000000..065863d93c8 --- /dev/null +++ b/src/cpu/aarch64/jit_sve_1x1_convolution.cpp @@ -0,0 +1,1057 @@ +/******************************************************************************* +* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2024 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/aarch64/jit_generator.hpp" + +#include "cpu/aarch64/jit_sve_1x1_convolution.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +using namespace dnnl::impl::status; +using namespace dnnl::impl::memory_tracking::names; +using namespace dnnl::impl::utils; + +#define data_blk_off(f, n, c, d, h, w) \ + ((ndims == 3) ? (f).blk_off(n, c, w) \ + : ((ndims == 4) ? (f).blk_off(n, c, h, w) \ + : (f).blk_off(n, c, d, h, w))) +/* convolution forward */ + +template +void jit_sve_1x1_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + const auto &jcp = kernel_->jcp; + auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); + auto weights_dw = CTX_IN_MEM( + const wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); + auto bias_dw = CTX_IN_MEM( + const dst_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + const auto post_ops_binary_rhs_arg_vec_dw = pd()->dw_conv_pd_ + ? binary_injector::prepare_binary_args( + pd()->dw_conv_pd_->jcp_.post_ops, ctx, + pd()->jcp_.post_ops.entry_.size() + 1) + : std::vector {}; + + auto scratchpad = ctx.get_scratchpad_grantor(); + + if (pd()->wants_padded_bias()) { + auto padded_bias + = scratchpad.template get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, + dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), + post_ops_binary_rhs_arg_vec_dw.data()); + }); + + if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); +} + +template +void jit_sve_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, const wei_data_t *weights_dw, + const dst_data_t *bias_dw, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad, + const void *post_ops_binary_rhs_arg_vec, + const void *post_ops_binary_rhs_arg_vec_dw) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper dw_weights_d( + pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); + const memory_desc_wrapper dw_bias_d( + pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = pd()->rtus_.reduce_src_ + ? scratchpad.get(key_conv_rtus_space) + : nullptr; + + const int ndims = src_d.ndims(); + const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + auto rp = typename rtus_driver_t::call_params_t(); + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + + // override some constants for fused dw_conv + const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; + const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; + const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; + const int nb_bcast_blocking_max + = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; + const int nb_load_blocking = jcp.nb_load_blocking; + const int nb_load_blocking_max = jcp.with_dw_conv + ? jcp.nb_load_blocking + : jcp.nb_load_blocking_max; + const bool is_dst_layout_nxc = utils::one_of( + jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); + const bool is_src_layout_nxc = utils::one_of( + jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); + + // Begin: declare Variables needed for dw conv. + memory_tracking::grantor_t dw_scratchpad( + scratchpad, memory_tracking::names::prefix_fusion); + dst_data_t *pbuf; + size_t row_offset; + const int nb_buffer = jcp.nb_load_blocking; + std::vector addrs; + // End + + auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, + int &bcast_step, int &od, int &oh, int &ow, + int &id, int &ih, int &iw) { + int osb {0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); + bcast_step = step( + nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + od = os / (jcp.oh * jcp.ow); + int os_2d = os % (jcp.oh * jcp.ow); + oh = os_2d / jcp.ow; + ow = os_2d % jcp.ow; + + id = od * stride_d; + ih = oh * stride_h; + iw = ow * stride_w; + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int ocb_end, int &load_step) { + load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); + const auto max_oc + = nstl::min(ocb_end * jcp.oc_block, jcp.oc_without_padding); + p.load_dim = this_block_size( + ocb * jcp.oc_block, max_oc, load_step * jcp.oc_block); + }; + + auto init_reduce = [&](int icb) { + const int nb_ic_blocking_step + = nstl::min(icb + nb_ic_blocking, nb_ic) - icb; + p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking_step >= nb_ic ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size( + icb * jcp.ic_block, jcp.ic, nb_ic_blocking_step * jcp.ic_block); + rp.icb = p.reduce_dim; + }; + + auto ker_1x1 = [&](int ocb, int ocb_start, int icb, int n, int g, int od, + int oh, int ow, int id, int ih, int iw) { + const int oc_off_idx = is_dst_layout_nxc + ? g * jcp.oc + ocb * jcp.oc_block + : g * nb_oc + ocb; + const size_t dst_off = data_blk_off(dst_d, n, oc_off_idx, od, oh, ow); + + p.output_data = jcp.with_dw_conv + ? pbuf + (oh % pd()->dw_conv_pd_->jcp_.kh) * row_offset + : &dst[dst_off]; + p.bias_data = bias + ? &bias[oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block)] + : nullptr; + + p.load_data + = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + const int ic_off_idx = is_src_layout_nxc + ? g * jcp.ic + icb * jcp.ic_block + : g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + (is_src_layout_nxc ? ic_off_idx + : jcp.is * ic_off_idx * jcp.ic_block); + if (ocb == ocb_start) { + rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); + (*rtus_driver_)(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); + + p.oc_l_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); + p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; + p.dst_orig = dst; + + (*kernel_)(&p); + }; + auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, + int ocb_end) { + if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; + + if (jcp.loop_order == loop_rlb) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, ocb_end, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, + ow {0}, id {0}, ih {0}, iw {0}; + init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, + ow, id, ih, iw); + ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, + iw); + iwork += bcast_step; + } + ocb += load_step; + } + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, ocb_end, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, + id {0}, ih {0}, iw {0}; + init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, + id, ih, iw); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, + iw); + } + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, + id {0}, ih {0}, iw {0}; + init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, + id, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, ocb_end, load_step); + ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, + iw); + ocb += load_step; + } + iwork += bcast_step; + } + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, + id {0}, ih {0}, iw {0}; + init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, + ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, ocb_end, load_step); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, + iw); + } + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } + }; + + auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { + auto &jcp_dw = pd()->dw_conv_pd_->jcp_; + int oh_1x1 = nstl::max(dw_oh * jcp_dw.stride_h - jcp_dw.t_pad, 0); + + for (int i = 0; i < jcp_dw.kh; ++i) + addrs[i] = pbuf + ((oh_1x1++) % jcp_dw.kh) * row_offset; + + const auto ocb_end = ocb_start + load_step; + const auto wch_stride = (is_src_layout_nxc ? 1 : jcp_dw.iw) + * jcp_dw.nb_ch_blocking * jcp_dw.ch_block; + const int dil_h = jcp_dw.dilate_h + 1; + const int str_h = jcp_dw.stride_h; + const int ch_num = jcp_dw.nb_ch_blocking; + const int ow = 0; + const int kw = 0; + + for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw.nb_ch_blocking) { + + const int i_t_overflow + = nstl::max(0, (int)(jcp_dw.t_pad - dw_oh * str_h)); + const int i_b_overflow + = nstl::max(jcp_dw.ih, + (int)(dw_oh * str_h + (jcp_dw.kh - 1) * dil_h + - jcp_dw.t_pad + 1)) + - jcp_dw.ih; + + const int kh = div_up(i_t_overflow, dil_h); + const int kh_padding = jcp_dw.kh - div_up(i_t_overflow, dil_h) + - div_up(i_b_overflow, dil_h); + + jit_conv_call_s par_conv_dw; + + par_conv_dw.src = addrs.data(); + + const size_t ch_step = is_dst_layout_nxc + ? jcp_dw.ch_block + : dst_d.blk_off(0, 1, 0, 0); + par_conv_dw.dst + = &dst[dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step]; + + par_conv_dw.filt + = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)]; + if (bias) + par_conv_dw.bias + = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw.ch_block)]; + + par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding); + + par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw.nb_ch) - ch) + * jcp_dw.ch_block; + + par_conv_dw.oc_l_off = ch * jcp_dw.ch_block; + par_conv_dw.post_ops_binary_rhs_arg_vec + = post_ops_binary_rhs_arg_vec_dw; + par_conv_dw.dst_orig = dst; + + (*kernel_dw_)(&par_conv_dw); + + for (int i = 0; i < jcp_dw.kh; ++i) + addrs[i] += wch_stride; + } + }; + + auto conv_dw = [&]() { + // Set variables + auto dw_conv_buffer + = dw_scratchpad.get(key_fusion_inout_buffer); + auto &jcp_dw = pd()->dw_conv_pd_->jcp_; + + const auto dw_conv_buffer_size_ + = (size_t)jcp_dw.kh * jcp.ow * nb_buffer * jcp.oc_block; + pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; + row_offset = dw_conv_buffer_size_ / jcp_dw.kh; + addrs.resize(jcp_dw.kh); + + int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; + balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw.oh, bcast_start, + bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); + + while (ocb_start < ocb_end) { + int load_step; + init_load(ocb_start, ocb_end, load_step); + + int oh_1x1 = 0; + auto bcast_iter = bcast_start; + while (bcast_iter < bcast_end) { + int n {0}, g {0}, oh_dw {0}; + nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, + jcp_dw.oh); + if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary + const int oh_1x1_range = oh_dw * jcp_dw.stride_h - jcp_dw.t_pad; + const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); + const int oh_1x1_end + = nstl::min(oh_1x1_range + jcp_dw.kh, jcp.oh); + oh_1x1 = nstl::max( + oh_1x1_begin, oh_1x1); // Skip rows computed previously + + // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh + const int bcast_start_1x1 + = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; + const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; + + conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, + ocb_start + load_step); + oh_1x1 = oh_1x1_end; + ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); + + bcast_iter += nb_bcast_blocking; + } + ocb_start += load_step; + } + }; + + if (jcp.with_dw_conv) { + conv_dw(); + } else { + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, + ocb_start, ocb_end, jcp.load_grp_count); + + conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end); + } +} + +template struct jit_sve_1x1_convolution_fwd_t; +template struct jit_sve_1x1_convolution_fwd_t; + +/* convolution backward wtr data */ +template +void jit_sve_1x1_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = pd()->rtus_.reduce_src_ + ? ctx.get_scratchpad_grantor().template get( + key_conv_rtus_space) + : nullptr; + + const int ndims = diff_src_d.ndims(); + + assert(jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1); + + const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = typename rtus_driver_t::call_params_t(); + + int bcast_start {0}, bcast_end {0}, icb_start {0}, icb_end {0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, + icb_start, icb_end, jcp.load_grp_count); + + bool reduce_outer + = (jcp.loop_order == loop_rbl || jcp.loop_order == loop_rlb); + int nboc_outer = reduce_outer ? nb_oc : 1; + int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1; + + int nboc_inner = reduce_outer ? 1 : nb_oc; + int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking; + const int max_ic = nstl::min(icb_end * jcp.ic_block, jcp.ic); + + for (int ocb_outer = 0; ocb_outer < nboc_outer; + ocb_outer += ocb_outer_step) { + size_t cur_ocb_outer + = nstl::min(ocb_outer + ocb_outer_step, nboc_outer) + - ocb_outer; + + int load_step = 0; + for (int icb = icb_start; icb < icb_end; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size( + icb * jcp.ic_block, max_ic, load_step * jcp.ic_block); + rp.icb = p.load_dim; + int bcast_step; + for (int iwork = bcast_start; iwork < bcast_end; + iwork += bcast_step) { + int n {0}, g {0}, osb {0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size( + os, jcp.os, bcast_step * os_block); + rp.os = p.bcast_dim; + const int od = os / (jcp.oh * jcp.ow); + const int os_2d = os % (jcp.oh * jcp.ow); + const int oh = os_2d / jcp.ow; + const int ow = os_2d % jcp.ow; + const int id = od * stride_d; + const int ih = oh * stride_h; + const int iw = ow * stride_w; + rp.iw_start = iw; + const bool is_dsrc_layout_nxc + = utils::one_of(jcp.src_tag, format_tag::nwc, + format_tag::nhwc, format_tag::ndhwc); + const int ic_off_idx = is_dsrc_layout_nxc + ? g * jcp.ic + icb * jcp.ic_block + : g * nb_ic + icb; + rp.src = diff_src + + data_blk_off( + diff_src_d, n, ic_off_idx, id, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb_inner = 0; ocb_inner < nboc_inner; + ocb_inner += ocb_inner_step) { + int cur_ocb_inner + = nstl::min(ocb_inner + ocb_inner_step, + nboc_inner) + - ocb_inner; + + int ocb = reduce_outer ? ocb_outer : ocb_inner; + int nb_oc_blocking_step + = reduce_outer ? cur_ocb_outer : cur_ocb_inner; + const bool is_ddst_layout_nxc + = utils::one_of(jcp.dst_tag, format_tag::nwc, + format_tag::nhwc, format_tag::ndhwc); + const int oc_off_idx = is_ddst_layout_nxc + ? g * jcp.oc + ocb * jcp.oc_block + : g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off( + diff_dst_d, n, oc_off_idx, od, oh, ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, + jcp.oc, nb_oc_blocking_step * jcp.oc_block); + + (*kernel_)(&p); + } + if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp); + } + } + } + }); +} + +template struct jit_sve_1x1_convolution_bwd_data_t; +template struct jit_sve_1x1_convolution_bwd_data_t; + +/* convolution backward wtr weights */ + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template +status_t jit_sve_1x1_convolution_bwd_weights_t::init(engine_t *engine) { + + CHECK(safe_ptr_assign(kernel_, + new jit_sve_1x1_conv_kernel( + pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); + CHECK(safe_ptr_assign( + acc_ker_, new cpu_accumulator_1d_t())); + CHECK(safe_ptr_assign(reducer_bias_, + new cpu_reducer_t(pd()->reducer_bia_conf_))); + CHECK(kernel_->create_kernel()); + CHECK(acc_ker_->create_kernel()); + CHECK(reducer_bias_->create_kernel()); + + CHECK(init_rtus_driver(this)); + return status::success; +} +template +void jit_sve_1x1_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) + const { + auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + const auto scratchpad = ctx.get_scratchpad_grantor(); + auto rtus_space = pd()->rtus_.reduce_src_ + ? scratchpad.get(key_conv_rtus_space) + : NULL; + const bool is_bias_padded + = pd()->with_bias() && jcp.oc_without_padding % jcp.oc_block != 0; + + data_t *diff_bias = is_bias_padded + ? scratchpad.get(key_conv_padded_bias) + : diff_bias_in; + auto wei_reduction = scratchpad.get(key_conv_wei_reduction); + + const int ndims = src_d.ndims(); + const int wei_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) + * rnd_up(jcp.ic, jcp.ic_block); + + simple_barrier::ctx_t reduction_barrier; + simple_barrier::ctx_init(&reduction_barrier); + + const auto reducer_bia_scratchpad + = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia); + auto rb = this->reducer_bias_.get(); + rb->init(reducer_bia_scratchpad); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + + const int sp_nb = jcp.nb_reduce; + const int mb_sp_work = jcp.mb * sp_nb; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + const bool is_src_layout_nxc = utils::one_of( + jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); + + const bool is_ddst_layout_nxc = utils::one_of( + jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); + + auto maybe_zero_icpad = [&](const int g_start, const int g_end, + const int ocb_start, const int ocb_end) { + // write zeros to IC padded region. + const int ic_tail = jcp.ic_without_padding % jcp.ic_block; + if (is_ddst_layout_nxc && ic_tail != 0) { + for_(int g = g_start; g < g_end; ++g) + for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) { + const int z_icb = nb_ic - 1; + const size_t off = wht_blk_off(diff_weights_d, g, z_ocb, z_icb) + + ic_tail * jcp.oc_block; + data_t *z_wei = diff_weights + off; + const int zero_work + = (nb_ic * jcp.ic_block - jcp.ic_without_padding) + * jcp.oc_block; + PRAGMA_OMP_SIMD() + for (int o = 0; o < zero_work; ++o) { + z_wei[o] = 0; + } + } + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + + const int ithr_ic_b = ithr % jcp.nthr_ic_b; + const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; + const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; + const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; + + /* reduction dimension */ + int mb_sp_b_start {0}, mb_sp_b_end {0}; + balance211( + mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, mb_sp_b_end); + + /* independent dimensions */ + int g_start {0}, oc_b_start {0}, ic_b_start {0}; + int g_end {0}, oc_b_end {0}, ic_b_end {0}; + + balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); + balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end); + balance211( + jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end); + + const int g_work = g_end - g_start; + const int oc_b_work = oc_b_end - oc_b_start; + const int ic_b_work = ic_b_end - ic_b_start; + const bool cache_aliasing + = (jcp.ic * jcp.ngroups * sizeof(float)) % 1024 == 0; + int reduce_step = jcp.nb_reduce_blocking; + int reduce_step_max = jcp.nb_reduce_blocking_max; + if (is_src_layout_nxc && cache_aliasing) { + // Experiments show 4 is a magic number with the tested shapes. + // TODO: maybe tune for shapes with sp_dim%4 != 0 + reduce_step = nstl::min(4, reduce_step); + reduce_step_max = reduce_step; + } + + data_t *diff_wei = ithr_mb == 0 + ? diff_weights + : wei_reduction + (ithr_mb - 1) * wei_size; + + int sp_b_step = 0; + for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; + mb_sp_b += sp_b_step) { + int img {0}, sp_b {0}; + nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); + sp_b_step = step(reduce_step, + nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), + reduce_step_max); + + for (int g = g_start; g < g_end; ++g) { + int load_step = 0; + int bcast_step = 0; + for (int ic_b = ic_b_start; ic_b < ic_b_end; + ic_b += bcast_step) { + if (is_src_layout_nxc && cache_aliasing) { + bcast_step = ic_b_work; + } else { + bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, + jcp.nb_bcast_blocking_max); + } + + for (int oc_b = oc_b_start; oc_b < oc_b_end; + oc_b += load_step) { + load_step = step(nb_oc_blocking, oc_b_end - oc_b, + jcp.nb_load_blocking_max); + const int _ic_b = g * nb_ic + ic_b; + const int oc_off_idx = is_ddst_layout_nxc + ? g * jcp.oc + oc_b * jcp.oc_block + : g * nb_oc + oc_b; + + data_t *store_to; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + store_to = diff_wei + off; + + const int ic_off_idx + = (is_src_layout_nxc ? jcp.ic_block : 1) + * _ic_b; + const data_t *diff_src + = &src[src_d.blk_off(img, ic_off_idx)]; + + int sp_b_end = sp_b + sp_b_step; + const data_t *pdiff_dst = &diff_dst[diff_dst_d.blk_off( + img, oc_off_idx)]; + const data_t *local_src = diff_src; + + auto p = jit_1x1_conv_call_s(); + auto rp = typename rtus_driver_t::call_params_t(); + p.output_stride = utils::rnd_up(jcp.ic, jcp.ic_block) + * jcp.oc_block * jcp.typesize_out; + + p.load_dim = this_block_size(oc_b * jcp.oc_block, + jcp.oc, load_step * jcp.oc_block); + + p.bcast_dim = this_block_size(ic_b * jcp.ic_block, + jcp.ic, bcast_step * jcp.ic_block); + rp.icb = p.bcast_dim; + p.output_data = store_to; + + p.reduce_dim = sp_b_step * jcp.reduce_block; + rp.os = p.reduce_dim; + p.first_last_flag = 0 + | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST + : 0) + | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); + + int sp = sp_b * jcp.reduce_block; + int oc_mult + = is_ddst_layout_nxc ? jcp.oc : jcp.oc_block; + p.load_data = pdiff_dst + sp * oc_mult; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = oh * stride_h; + const int iw = ow * stride_w; + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + sp * jcp.ic_block; + + if (ndims == 3) + rp.src = local_src + + iw * src_d.blocking_desc().strides[2]; + else + rp.src = local_src + + ih * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + (*rtus_driver_)(&rp); + + p.bcast_data = rp.ws; + } else { + int ic_mult + = is_src_layout_nxc ? jcp.ic : jcp.ic_block; + p.bcast_data = local_src + sp * ic_mult; + } + + (*kernel_)(&p); + } + } + } + } + + if (ithr_mb == 0 && ic_b_end >= jcp.nb_bcast) { + maybe_zero_icpad(g_start, g_end, oc_b_start, oc_b_end); + } + + /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */ + if (dnnl_thr_syncable() && jcp.nthr_mb > 1) { + simple_barrier::barrier(&reduction_barrier, jcp.nthr); + const int work = g_work * oc_b_work * ic_b_work; + int start {0}, end {0}; + balance211(work, jcp.nthr_mb, ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + int w = start; + int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0}; + nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, + oc_b_work, sub_ic_b_start, ic_b_work); + while (w < end) { + const int g = g_start + sub_g_start; + const int oc_b = oc_b_start + sub_oc_b_start; + const int ic_b = ic_b_start + sub_ic_b_start; + const int ic_to_accumulate + = nstl::min(end - w, ic_b_work - sub_ic_b_start) + * jcp.ic_block; + const int acc_size + = this_block_size(ic_b * jcp.ic_block, + jcp.ic_without_padding, ic_to_accumulate) + * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + data_t *d = diff_weights + off; + data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, g_work, + sub_oc_b_start, oc_b_work, sub_ic_b_start, + ic_b_work); + } + } + } + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start {0}, img_end {0}; + + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start {0}, ocb_start {0}; + nd_iterator_init( + b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const int oc_off_idx = is_ddst_layout_nxc + ? g * jcp.oc + ocb * jcp.oc_block + : g * jcp.nb_load + ocb; + const data_t *d_dst + = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)]; + + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + const int sp_shift = is_ddst_layout_nxc ? jcp.ngroups * jcp.oc + : jcp.oc_block; + const auto max_oc = this_block_size( + ocb * jcp.oc_block, jcp.oc, jcp.oc_block); + if (img == img_start) + for (int o = 0; o < jcp.oc_block; ++o) + d_bias[o] = 0.; + + for (int os = 0; os < jcp.os; ++os) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < max_oc; ++o) + d_bias[o] += d_dst[o]; + d_dst += sp_shift; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load); + } + } + + if (dnnl_thr_syncable()) + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + if (dnnl_thr_syncable()) { + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + ker(ithr, jcp.nthr); + if (pd()->with_bias()) ker_bias(ithr, jcp.nthr); + }); + } else { + parallel(jcp.nthr, [&](int ithr, int nthr) { ker(ithr, nthr); }); + if (jcp.nthr_mb > 1) + parallel(jcp.nthr, [&](int ithr, int nthr) { + assert(nthr == jcp.nthr); + + const int ithr_ic_b = ithr % jcp.nthr_ic_b; + const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; + const int ithr_g + = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; + const int ithr_mb + = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; + + /* independent dimensions */ + int g_start {0}, oc_b_start {0}, ic_b_start {0}; + int g_end {0}, oc_b_end {0}, ic_b_end {0}; + + balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); + balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, + oc_b_end); + balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, + ic_b_end); + + const int g_work = g_end - g_start; + const int oc_b_work = oc_b_end - oc_b_start; + const int ic_b_work = ic_b_end - ic_b_start; + + const int work = g_work * oc_b_work * ic_b_work; + int start {0}, end {0}; + balance211(work, jcp.nthr_mb, ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + int w = start; + int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0}; + nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, + oc_b_work, sub_ic_b_start, ic_b_work); + while (w < end) { + const int g = g_start + sub_g_start; + const int oc_b = oc_b_start + sub_oc_b_start; + const int ic_b = ic_b_start + sub_ic_b_start; + const int ic_to_accumulate + = nstl::min(end - w, ic_b_work - sub_ic_b_start) + * jcp.ic_block; + const int acc_size + = this_block_size(ic_b * jcp.ic_block, + jcp.ic_without_padding, + ic_to_accumulate) + * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + data_t *d = diff_weights + off; + data_t *s + = wei_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, g_work, + sub_oc_b_start, oc_b_work, sub_ic_b_start, + ic_b_work); + } + } + }); + if (pd()->with_bias()) { + parallel(jcp.nthr, + [&](int ithr, int nthr) { ker_bias(ithr, nthr); }); + parallel(jcp.nthr, [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + MAYBE_UNUSED(nthr); + if (rb->balancer().ithr_njobs(ithr) == 0) return; + rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad); + }); + } + } + + /* TODO: put this in ker_bias */ + if (is_bias_padded) { + assert(IMPLICATION(!is_ddst_layout_nxc, jcp.ngroups == 1)); + const int padded_stride = rnd_up(jcp.oc, jcp.oc_block); + const int stride = jcp.oc_without_padding; + for (int g = 0; g < jcp.ngroups; ++g) { + utils::array_copy(diff_bias_in + g * stride, + diff_bias + g * padded_stride, stride); + } + } +} + +template struct jit_sve_1x1_convolution_bwd_weights_t; +template struct jit_sve_1x1_convolution_bwd_weights_t; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/jit_sve_1x1_convolution.hpp b/src/cpu/aarch64/jit_sve_1x1_convolution.hpp new file mode 100644 index 00000000000..fd0a19d94c5 --- /dev/null +++ b/src/cpu/aarch64/jit_sve_1x1_convolution.hpp @@ -0,0 +1,664 @@ +/******************************************************************************* +* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2024 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_JIT_SVE_1X1_CONVOLUTION_HPP +#define CPU_AARCH64_JIT_SVE_1X1_CONVOLUTION_HPP + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive.hpp" +#include "common/primitive_hashing.hpp" +#include "common/utils.hpp" + +#include "cpu/cpu_convolution_pd.hpp" +#include "cpu/dw_convolution_utils.hpp" +#include "cpu/platform.hpp" + +#include "cpu/aarch64/cpu_reducer.hpp" +#include "cpu/aarch64/jit_sve_1x1_conv_kernel.hpp" +#include "cpu/aarch64/jit_uni_1x1_conv_utils.hpp" +#include "cpu/aarch64/jit_uni_dw_convolution.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +template +struct jit_sve_1x1_convolution_fwd_t : public primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; + + pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { + if (copy(other) != status::success) is_initialized_ = false; + } + + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", isa_, ""), + jit_sve_1x1_convolution_fwd_t); + + status_t init(engine_t *engine) { + using namespace utils; + + bool ok = true && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, dst_type) + && !has_zero_dim_memory() && set_default_formats() + && attr_.set_default_formats(dst_md(0)) == status::success; + if (!ok) { return status::unimplemented; } + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + CHECK(jit_sve_1x1_conv_kernel::init_conf(jcp_, *conv_d, + *src_d, *weights_md(), *dst_md(), *attr(), + dnnl_get_max_threads(), rtus_.reduce_src_)); + if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); + + auto scratchpad = scratchpad_registry().registrar(); + jit_sve_1x1_conv_kernel::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad, jcp_.nthr); + + return status::success; + } + + const memory_desc_t *dst_md( + int index = 0, bool user_input = false) const override { + return jcp_.with_dw_conv + ? dw_conv_pd_->dst_md(index, user_input) + : cpu_convolution_fwd_pd_t::dst_md(index, user_input); + } + + const memory_desc_t *arg_md( + int arg, bool user_input = false) const override { + if (jcp_.with_dw_conv) { + switch (arg) { + case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_SRC: + return cpu_convolution_fwd_pd_t::dst_md(0, user_input); + case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: + return dw_conv_pd_->weights_md(0); + case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: + return dw_conv_pd_->weights_md(1); + default: break; + } + } + return convolution_fwd_pd_t::arg_md(arg, user_input); + } + + arg_usage_t arg_usage(int arg) const override { + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) + && attr_post_op_dw_inputs() > 1) + return arg_usage_t::input; + + return convolution_fwd_pd_t::arg_usage(arg); + } + + jit_1x1_conv_conf_t jcp_ = utils::zero(); + reduce_to_unit_stride_t rtus_ = utils::zero(); + using dw_pd_t = jit_sve_512_dw_convolution_fwd_t::pd_t; + std::unique_ptr dw_conv_pd_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const memory_desc_wrapper src_d(&src_md_); + const memory_desc_wrapper dst_d(&dst_md_); + + const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); + format_tag_t dat_tag, wei_tag; + + switch (isa_) { + case sve_512: { + const auto dat_tag_nCx16c = utils::pick( + ndims() - 3, nCw16c, nChw16c, nCdhw16c); + const auto curr_src_tag = src_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c); + const auto curr_dst_tag = dst_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c); + const auto is_data_layout_nxc + = IMPLICATION(curr_src_tag != dat_tag_nxc, + src_d.format_kind() == format_kind::any) + && IMPLICATION(curr_dst_tag != dat_tag_nxc, + dst_d.format_kind() == format_kind::any) + && utils::one_of( + dat_tag_nxc, curr_src_tag, curr_dst_tag); + dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; + wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, + OIdhw16i16o, gOIdhw16i16o); + break; + } + case sve_256: { + const auto dat_tag_nCx8c + = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + const auto curr_src_tag = src_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx8c); + const auto curr_dst_tag = dst_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx8c); + const auto is_data_layout_nxc + = IMPLICATION(curr_src_tag != dat_tag_nxc, + src_d.format_kind() == format_kind::any) + && IMPLICATION(curr_dst_tag != dat_tag_nxc, + dst_d.format_kind() == format_kind::any) + && utils::one_of( + dat_tag_nxc, curr_src_tag, curr_dst_tag); + dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; + wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, OIdhw8i8o, + gOIdhw8i8o); + break; + } + default: break; + } + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + status_t copy(const pd_t &other) { + jcp_ = other.jcp_; + rtus_ = other.rtus_; + if (other.dw_conv_pd_) { + dw_conv_pd_.reset(other.dw_conv_pd_->clone()); + if (!dw_conv_pd_) return status::out_of_memory; + } + return status::success; + } + + status_t depthwise_po_init(engine_t *engine) { + + using namespace memory_tracking; + auto &jcp_1x1 = jcp_; + primitive_attr_t attr_1x1(*attr()); + if (!attr_1x1.is_initialized()) return status::out_of_memory; + const auto &src_md = dst_md_; + const memory_desc_wrapper src_d(src_md); + const auto nthr = dnnl_get_max_threads(); + auto l2_cache = platform::get_per_core_cache_size(2) * nthr; + + // Note: A robust fusion implementation would be to check if both + // 1x1 conv and dw conv that are considered here for fusion are + // optimal independently. This would require creating a new + // primitive_desc through primitive_iterator & check if they match. + // Due to concern that these creations and/or checks could be heavy, + // for 1x1: Check that no better ISA is available. + // for dw: Always fuse with same ISA. + // Caveat: May be a better dw conv exists. + + // TODO: Add a check if better ISA exists following above note. + bool ok = true + && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) + // TODO: Below may be further tuned. + && (l2_cache * 2 < src_d.size()) + // load_grp_count check can be redundant due to l2 check + // above. Adding it explicitly as the current driver doesn't + // work if this condition fails. + && (jcp_1x1.load_grp_count < 2); + if (!ok) return status::unimplemented; + + int dw_po_index + = attr_1x1.post_ops_.find(primitive_kind::convolution); + convolution_desc_t cd_dw; + primitive_attr_t attr_dw; + CHECK(get_depthwise_conv_desc( + cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); + + // The code below doesn't work because currently it requires `jcp_` + // member which is not available from the common interface. In turn, + // this means the common pd creation interface through an iterator + // can't be used and a specific convolution implementation's pd is + // required here. It restricts the usage of inherited + // `convolution_pd_t` constructor. + // ANCHOR: USING_INHERITED_IS_IMPOSSIBLE. + // + // ```cpp + // primitive_desc_iterator_t it( + // engine, (op_desc_t *)&cd_dw, &attr_dw, nullptr); + // if (!it.is_initialized()) return status::out_of_memory; + // while (++it != it.end()) { + // dw_conv_pd_ = *it; + // break; + // } + // VDISPATCH_CONV_IC(dw_conv_pd_, "dw_conv_pd hasn't been created"); + // ``` + // + // ```compiler output + // error: ‘using element_type = struct dnnl::impl::primitive_desc_t’ + // {aka ‘struct dnnl::impl::primitive_desc_t’} has no member named + // ‘jcp_’ + // auto &jcp_dw = dw_conv_pd_->jcp_; + // ^~~~ + // ``` + // + // TODO: figure out the way to initialize fused conv through a + // normal interface without hacks accessing specific members. + CHECK(safe_ptr_assign( + dw_conv_pd_, new dw_pd_t(&cd_dw, &attr_dw, nullptr))); + CHECK(dw_conv_pd_->init(engine)); + auto &jcp_dw = dw_conv_pd_->jcp_; + + ok = true + && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) + && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) + && IMPLICATION( + jcp_dw.ow_block, jcp_dw.ow_block == jcp_dw.ow); + if (!ok) return status::unimplemented; + + assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); + assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); + assert(IMPLICATION( + dw_conv_pd_->weights_md(1)->data_type != data_type::undef, + dw_conv_pd_->weights_md(1)->format_kind + != format_kind::any)); + + jcp_dw.is_fused_conv = true; + // TODO: Support/experiment arbitary oc_work in dw conv. + // Until then we keep oc_work perfectly divisible. + while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) + --jcp_1x1.nb_load_blocking; + jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; + + while (jcp_1x1.nb_load_blocking % jcp_dw.nb_ch_blocking != 0) + --jcp_dw.nb_ch_blocking; + + jcp_dw.dw_conv_buffer_oc + = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; + + const auto dat_tag_nxc = utils::pick(ndims() - 3, format_tag::nwc, + format_tag::nhwc, format_tag::ndhwc); + const bool is_data_nxc = utils::everyone_is( + dat_tag_nxc, jcp_1x1.src_tag, jcp_1x1.dst_tag); + if (!is_data_nxc) + jcp_1x1.bcast_loop_output_step = jcp_1x1.ur * jcp_1x1.load_block + * jcp_1x1.typesize_out; + + registrar_t scratchpad(scratchpad_registry_); + registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); + + size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw.kh * jcp_dw.iw + * jcp_dw.dw_conv_buffer_oc; + assert(dw_conv_buffer_size_); + dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, + dw_conv_buffer_size_, + types::data_type_size(dw_conv_pd_->src_md()->data_type)); + + jit_uni_dw_conv_fwd_kernel::init_scratchpad( + dw_scratchpad, jcp_dw); + + return status::success; + } + }; + + template + friend status_t init_rtus_driver(conv_t *self); + + jit_sve_1x1_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + typedef typename prec_traits_t::type src_data_t; + typedef typename prec_traits_t::type wei_data_t; + typedef typename prec_traits_t::type dst_data_t; + + status_t init(engine_t *engine) override { + CHECK(safe_ptr_assign(kernel_, + new jit_sve_1x1_conv_kernel( + pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); + CHECK(kernel_->create_kernel()); + + if (pd()->jcp_.with_dw_conv) { + CHECK(safe_ptr_assign( + kernel_dw_, new dw_conv_kernel_t(pd()->dw_conv_pd_->jcp_))); + CHECK(kernel_dw_->create_kernel()); + } + CHECK(init_rtus_driver(this)); + return status::success; + } + + status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, const wei_data_t *weights_dw, + const dst_data_t *bias_dw, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad, + const void *post_ops_binary_rhs_arg_vec, + const void *post_ops_binary_rhs_arg_vec_dw) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + std::unique_ptr> kernel_; + std::unique_ptr> rtus_driver_; + using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32; + std::unique_ptr kernel_dw_; +}; + +using jit_sve_256_1x1_convolution_fwd_f32_t + = jit_sve_1x1_convolution_fwd_t; +using jit_sve_512_1x1_convolution_fwd_f32_t + = jit_sve_1x1_convolution_fwd_t; + +template +struct jit_sve_1x1_convolution_bwd_data_t : public primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; + + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", isa_, ""), + jit_sve_1x1_convolution_bwd_data_t); + + status_t init(engine_t *engine) { + bool ok = true && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, + data_type::undef, diff_dst_type, data_type::undef) + && attr()->has_default_values() && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_sve_1x1_conv_kernel::init_conf(jcp_, + *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr(), dnnl_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_sve_1x1_conv_kernel::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad, jcp_.nthr); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_ = utils::zero(); + reduce_to_unit_stride_t rtus_ = utils::zero(); + + protected: + bool set_default_formats() { + using namespace format_tag; + + const memory_desc_wrapper diff_src_d(&diff_src_md_); + const memory_desc_wrapper diff_dst_d(&diff_dst_md_); + + const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); + format_tag_t dat_tag, wei_tag; + + switch (isa_) { + case sve_512: { + const auto dat_tag_nCx16c = utils::pick( + ndims() - 3, nCw16c, nChw16c, nCdhw16c); + const auto curr_src_tag = diff_src_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c); + const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c); + const auto is_data_layout_nxc + = IMPLICATION(curr_src_tag != dat_tag_nxc, + diff_src_d.format_kind() + == format_kind::any) + && IMPLICATION(curr_dst_tag != dat_tag_nxc, + diff_dst_d.format_kind() + == format_kind::any) + && utils::one_of( + dat_tag_nxc, curr_src_tag, curr_dst_tag); + dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; + wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i, + IOdhw16o16i, gIOdhw16o16i); + break; + } + case sve_256: { + const auto dat_tag_nCx8c + = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + const auto curr_src_tag = diff_src_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx8c); + const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx8c); + const auto is_data_layout_nxc + = IMPLICATION(curr_src_tag != dat_tag_nxc, + diff_src_d.format_kind() + == format_kind::any) + && IMPLICATION(curr_dst_tag != dat_tag_nxc, + diff_dst_d.format_kind() + == format_kind::any) + && utils::one_of( + dat_tag_nxc, curr_src_tag, curr_dst_tag); + dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; + wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + IOw8o8i, gIOw8o8i, IOhw8o8i, gIOhw8o8i, IOdhw8o8i, + gIOdhw8o8i); + break; + } + default: break; + } + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend status_t init_rtus_driver(conv_t *self); + + jit_sve_1x1_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} + + typedef typename prec_traits_t::type diff_dst_data_t; + typedef typename prec_traits_t::type wei_data_t; + typedef typename prec_traits_t::type diff_src_data_t; + + status_t init(engine_t *engine) override { + CHECK(safe_ptr_assign(kernel_, + new jit_sve_1x1_conv_kernel( + pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); + CHECK(kernel_->create_kernel()); + CHECK(init_rtus_driver(this)); + return status::success; + } + + status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::unique_ptr> kernel_; + std::unique_ptr> rtus_driver_; +}; +using jit_sve_256_1x1_convolution_bwd_data_f32_t + = jit_sve_1x1_convolution_bwd_data_t; + +using jit_sve_512_1x1_convolution_bwd_data_f32_t + = jit_sve_1x1_convolution_bwd_data_t; + +/* Backward weight */ +template +struct jit_sve_1x1_convolution_bwd_weights_t : public primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + using cpu_convolution_bwd_weights_pd_t:: + cpu_convolution_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", isa_, ""), + jit_sve_1x1_convolution_bwd_weights_t); + + status_t init(engine_t *engine) { + bool ok = true && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && attr()->has_default_values() && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) { return status::unimplemented; } + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_sve_1x1_conv_kernel::init_conf(jcp_, + *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr(), dnnl_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_sve_1x1_conv_kernel::init_scratchpad(scratchpad, jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + rtus_prepare_space_info(this, scratchpad, jcp_.nthr); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_ = utils::zero(); + typename cpu_reducer_t::conf_t reducer_bia_conf_; + reduce_to_unit_stride_t rtus_ = utils::zero(); + + protected: + bool set_default_formats() { + using namespace format_tag; + + const memory_desc_wrapper src_d(&src_md_); + const memory_desc_wrapper diff_dst_d(&diff_dst_md_); + + const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); + + format_tag_t dat_tag, wei_tag; + + switch (isa_) { + case sve_512: { + auto dat_tag_nCx16c = utils::pick( + ndims() - 3, nCw16c, nChw16c, nCdhw16c); + const auto curr_src_tag = src_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c); + const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c); + const auto is_data_layout_nxc + = IMPLICATION(curr_src_tag != dat_tag_nxc, + src_d.format_kind() == format_kind::any) + && IMPLICATION(curr_dst_tag != dat_tag_nxc, + diff_dst_d.format_kind() + == format_kind::any) + && utils::one_of( + dat_tag_nxc, curr_src_tag, curr_dst_tag); + + dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; + wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, + OIdhw16i16o, gOIdhw16i16o); + break; + } + case sve_256: { + const auto dat_tag_nCx8c + = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + const auto curr_src_tag = src_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx8c); + const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx8c); + const auto is_data_layout_nxc + = IMPLICATION(curr_src_tag != dat_tag_nxc, + src_d.format_kind() == format_kind::any) + && IMPLICATION(curr_dst_tag != dat_tag_nxc, + diff_dst_d.format_kind() + == format_kind::any) + && utils::one_of( + dat_tag_nxc, curr_src_tag, curr_dst_tag); + + dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; + wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, OIdhw8i8o, + gOIdhw8i8o); + break; + } + default: break; + } + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, jcp_.mb, + max_buffer_size, true)); + } + } + }; + + template + friend status_t init_rtus_driver(conv_t *self); + + jit_sve_1x1_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} + + typedef typename prec_traits_t::type data_t; + + status_t init(engine_t *engine) override; + + status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + std::unique_ptr> kernel_; + std::unique_ptr> acc_ker_; + std::unique_ptr> reducer_bias_; + // std::unique_ptr trans_kernel_; + std::unique_ptr> rtus_driver_; +}; + +using jit_sve_256_1x1_convolution_bwd_weights_t + = jit_sve_1x1_convolution_bwd_weights_t; + +using jit_sve_512_1x1_convolution_bwd_weights_t + = jit_sve_1x1_convolution_bwd_weights_t; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/aarch64/jit_sve_512_1x1_conv_kernel.cpp b/src/cpu/aarch64/jit_sve_512_1x1_conv_kernel.cpp deleted file mode 100644 index 827b8904633..00000000000 --- a/src/cpu/aarch64/jit_sve_512_1x1_conv_kernel.cpp +++ /dev/null @@ -1,1333 +0,0 @@ -/******************************************************************************* -* Copyright 2021-2023 Intel Corporation -* Copyright 2021-2023 FUJITSU LIMITED -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "common/c_types_map.hpp" -#include "common/dnnl_thread.hpp" -#include "common/memory.hpp" -#include "common/memory_tracking.hpp" -#include "common/nstl.hpp" -#include "common/type_helpers.hpp" -#include "common/utils.hpp" - -#include "cpu/aarch64/cpu_barrier.hpp" -#include "cpu/platform.hpp" - -#include "cpu/aarch64/injectors/injector_utils.hpp" -#include "cpu/aarch64/injectors/jit_uni_binary_injector.hpp" -#include "cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp" -#include "cpu/aarch64/jit_sve_512_1x1_conv_kernel.hpp" -#include "cpu/aarch64/jit_uni_1x1_conv_utils.hpp" - -#define GET_OFF(field) \ - static_cast(offsetof(jit_1x1_conv_call_s, field)) - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -using namespace dnnl::impl::format_tag; -using namespace dnnl::impl::prop_kind; -using namespace dnnl::impl::utils; - -jit_sve_512_1x1_conv_kernel::jit_sve_512_1x1_conv_kernel( - const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, - const memory_desc_t &dst_md) - : jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { - using namespace binary_injector; - static constexpr bool preserve_gpr = true; - static constexpr bool preserve_vmm = false; - static constexpr size_t helper_vmm_idx = 31; - const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; - static constexpr bool use_exact_tail_scalar_bcast = true; - - const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - x14, x15, x13, preserve_gpr, preserve_vmm, - GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), - memory_desc_wrapper(dst_md), tail_size, k_load_dim_mask, - use_exact_tail_scalar_bcast}; - const static_params_t static_params { - this->param1, rhs_arg_static_params}; - - postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); - } -} - -void jit_sve_512_1x1_conv_kernel::bcast_loop(int load_loop_blk) { - - mov(aux1_reg_bcast_data, reg_bcast_data); - mov(aux_reg_bcast_data, reg_bcast_data); - mov(aux_reg_output_data, reg_output_data); - ldr(reg_bcast_loop_iter, ptr(X_SP, reg_bcast_loop_work_offt)); - - Label bcast_loop; - Label bcast_loop_tail; - Label large_tail; - - cmp_imm(reg_bcast_loop_iter, jcp.bcast_block, reg_tmp_imm); - b(LT, bcast_loop_tail); - - L(bcast_loop); - { - assert(jcp.bcast_block % jcp.ur == 0); - int num_substeps = jcp.bcast_block / jcp.ur; - assert(num_substeps > 0 && num_substeps < 10); - for (int i = 0; i < num_substeps; i++) { - if (i + 1 == num_substeps) L(large_tail); - reduce_loop(load_loop_blk, jcp.ur, i, false); - if (i < num_substeps - 1) { - add_imm(aux1_reg_bcast_data, aux1_reg_bcast_data, - jcp.bcast_loop_bcast_substep, reg_tmp_imm); - add_imm(aux_reg_output_data, aux_reg_output_data, - jcp.bcast_loop_output_substep, reg_tmp_imm); - } else { - add_imm(aux1_reg_bcast_data, aux1_reg_bcast_data, - jcp.bcast_loop_bcast_step - - (num_substeps - 1) - * jcp.bcast_loop_bcast_substep, - reg_tmp_imm); - add_imm(aux_reg_output_data, aux_reg_output_data, - jcp.bcast_loop_output_step - - (num_substeps - 1) - * jcp.bcast_loop_output_substep, - reg_tmp_imm); - } - subs_imm(reg_bcast_loop_iter, reg_bcast_loop_iter, jcp.ur, - reg_tmp_imm); - } - cmp_imm(reg_bcast_loop_iter, jcp.bcast_block, reg_tmp_imm); - b(GE, bcast_loop); - } - - L(bcast_loop_tail); - if (jcp.ur_tail) { - Label bcast_loop_tail_out; - if (jcp.ur_tail >= jcp.ur) { - cmp_imm(reg_bcast_loop_iter, jcp.ur, reg_tmp_imm); - b(GE, large_tail); - } - if (jcp.ur_tail % jcp.ur) { - cmp(reg_bcast_loop_iter, 0); - b(LE, bcast_loop_tail_out); - reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur, 0, true); - L(bcast_loop_tail_out); - } - } -} - -Xbyak_aarch64::XReg jit_sve_512_1x1_conv_kernel::output_ptr( - const bool is_out_layout_nxc, const int i_load, const int i_ur, - Xbyak_aarch64::XReg addr) { - if (one_of(jcp.prop_kind, forward_training, forward_inference, - backward_data)) { - int i_load_shift = is_out_layout_nxc - ? jcp.load_block - : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block; - int i_ur_shift = is_out_layout_nxc ? jcp.load_dim : jcp.load_block; - int offset = (i_load * i_load_shift + i_ur * i_ur_shift) - * jcp.typesize_out; - EVEX_compress_addr(addr, X_TMP_0, aux_reg_output_data, offset); - } else { - int offset = jcp.typesize_out * jcp.load_block * i_ur; - mov(X_TMP_0, i_load); - mul(X_TMP_0, reg_output_stride, X_TMP_0); - add_imm(X_TMP_1, X_TMP_0, offset, X_TMP_2); - add(addr, aux_reg_output_data, X_TMP_1); - } - return addr; -} - -static int vreg_accum_idx( - const int load_loop_blk, const int i_load, const int i_ur) { - return (i_ur * load_loop_blk + i_load); -} - -template -static void iterate(const int load_loop_blk, const int ur, const bool mask_tail, - const F &fun) { - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - const bool mask_flag = mask_tail && i_load + 1 == load_loop_blk; - for (int i_ur = 0; i_ur < ur; ++i_ur) - fun(mask_flag, i_load, i_ur); - } -} -template -static void iterate(const int load_loop_blk, const int ur, const F &fun) { - iterate(load_loop_blk, ur, false, fun); -} - -void jit_sve_512_1x1_conv_kernel::apply_postops( - const bool is_out_layout_nxc, const int load_loop_blk, const int ur) { - injector_utils::vmm_index_set_t vmm_idxs; - if (jcp.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; - const auto mask_tail = jcp.oc_without_padding % jcp.load_block; - iterate(load_loop_blk, ur, mask_tail, - [&](const bool mask_flag, const int i_load, const int i_ur) { - const auto vmm_idx - = vreg_accum_idx(load_loop_blk, i_load, i_ur); - vmm_idxs.emplace(vmm_idx); - - rhs_arg_params.vmm_idx_to_out_reg.emplace( - vmm_idx, aux_reg_output_data); - rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, - get_output_offset(is_out_layout_nxc, i_load, i_ur)); - if (mask_flag) - rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); - }); - - ldr(abi_param1, ptr(X_SP, reg_abi_param1_backup)); - - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); - } else { - iterate(load_loop_blk, ur, - [&](const bool, const int i_load, const int i_ur) { - vmm_idxs.emplace( - vreg_accum_idx(load_loop_blk, i_load, i_ur)); - }); - postops_injector_->compute_vector_range(vmm_idxs); - } -} - -void jit_sve_512_1x1_conv_kernel::reduce_loop( - int load_loop_blk, int ur, int substep, bool wraparound) { - - const bool out_layout_nxc = is_out_layout_nxc(jcp); - const bool load_layout_nxc = is_load_layout_nxc(jcp); - const bool bcast_layout_nxc = is_bcast_layout_nxc(jcp); - const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block; - const int load_dim_tail = jcp.load_dim % jcp.load_block; - - auto vreg_load - = [=](int i_load) { return ZReg(ur * load_loop_blk + i_load); }; - - auto vreg_accum = [=](int i_load, int i_ur) { - return ZReg(vreg_accum_idx(load_loop_blk, i_load, i_ur)); - }; - - auto bias_ptr = [=](int i_load) { - return EVEX_compress_addr(X_DEFAULT_ADDR, X_TMP_0, reg_bias_data, - jcp.typesize_out * jcp.oc_block * i_load); - }; - - auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast, - const Xbyak_aarch64::XReg addr, - const Xbyak_aarch64::XReg tmp) { - assert(i_ur < jcp.ur); - assert(i_reduce <= jcp.reduce_loop_unroll); - int offt; - if (one_of(jcp.prop_kind, forward_training, forward_inference, - backward_data)) { - assert(jcp.reduce_loop_unroll == jcp.reduce_block); - const int reduce_mul = bcast_layout_nxc ? jcp.reduce_dim - : jcp.reduce_loop_unroll; - offt = (i_reduce == jcp.reduce_loop_unroll) - ? (jcp.bcast_dim + i_ur) * reduce_mul - : i_ur * reduce_mul + i_reduce; - } else { - int rmul = bcast_layout_nxc ? jcp.ic : jcp.ic_block; - offt = i_reduce * rmul + i_ur; - } - return EVEX_compress_addr( - addr, tmp, aux_reg_bcast_data, jcp.typesize_in * offt, bcast); - }; - - auto load_ptr = [=](int i_reduce, int i_load, - const Xbyak_aarch64::XReg addr, - const Xbyak_aarch64::XReg tmp) { - int offt; - int u0 = i_reduce % jcp.reduce_loop_unroll; - int u1 = i_reduce / jcp.reduce_loop_unroll; - int lmul = jcp.load_block - * (load_layout_nxc ? 1 - : utils::rnd_up( - jcp.reduce_dim, jcp.reduce_block)); - int rmul = load_layout_nxc ? jcp.load_dim : jcp.load_block; - offt = i_load * lmul + u0 * rmul; - return EVEX_compress_addr(addr, tmp, aux_reg_load_data, - u1 * jcp.reduce_loop_load_step + jcp.typesize_in * offt); - }; - - auto init = [=]() { - Label init_done; - Label init_zero; - - if (jcp.with_bias - && one_of(jcp.prop_kind, forward_training, forward_inference)) { - tst(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - b(EQ, init_zero); - - for (int i_load = 0; i_load < load_loop_blk; i_load++) - for (int i_ur = 0; i_ur < ur; ++i_ur) { - auto vreg_acc = vreg_accum(i_load, i_ur); - if (i_load + 1 == load_loop_blk && load_dim_tail) - ld1w(vreg_acc.s, k_load_dim_mask / T_z, - ptr(bias_ptr(i_load))); - else - ldr(vreg_acc, ptr(bias_ptr(i_load))); - } - b(init_done); - } - - L(init_zero); - - /* Zero clear */ - for (int i_load = 0; i_load < load_loop_blk; ++i_load) - for (int i_ur = 0; i_ur < ur; ++i_ur) { - auto r = vreg_accum(i_load, i_ur); - eor(r.d, r.d, r.d); - } - L(init_done); - }; - - auto store = [=]() { - Label store_noadd; - if (!jcp.with_sum) { - tst(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - b(NE, store_noadd); - } - - for (int i_ur = 0; i_ur < ur; ++i_ur) - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - auto r = vreg_accum(i_load, i_ur).s; - if (i_load + 1 == load_loop_blk && load_dim_tail) - ld1w(zreg_tmp.s, k_load_dim_mask / T_z, - ptr(output_ptr(out_layout_nxc, i_load, i_ur, - X_DEFAULT_ADDR))); - else - ldr(zreg_tmp, - ptr(output_ptr(out_layout_nxc, i_load, i_ur, - X_DEFAULT_ADDR))); - fadd(r, r, zreg_tmp.s); - } - - L(store_noadd); - if (jcp.with_eltwise || jcp.with_binary) { - Label store_nopostops; - tst(reg_reduce_pos_flag, FLAG_REDUCE_LAST); - b(EQ, store_nopostops); - - apply_postops(out_layout_nxc, load_loop_blk, ur); - - L(store_nopostops); - } - - auto store_output = [=](bool output_is_aligned) { - const auto mask_flag = load_dim_tail; - for (int i_ur = 0; i_ur < ur; ++i_ur) { - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - auto vreg_acc = vreg_accum(i_load, i_ur); - // for nxc_layout-bwd_w, weights are still padded and the - // output_ptr here can be uninitialized scratchpad. - // To ensure final output (after reduction) is zero-padded, - // here we zero-pad output by omitting the mask. - if (jcp.prop_kind != backward_weights - && (i_load + 1 == load_loop_blk && mask_flag)) { - st1w(vreg_acc.s, k_load_dim_mask / T_z, - ptr(output_ptr(out_layout_nxc, i_load, i_ur, - X_DEFAULT_ADDR))); - } else { - str(vreg_acc, - ptr(output_ptr(out_layout_nxc, i_load, i_ur, - X_DEFAULT_ADDR))); - } - } - } - }; - - Label unaligned_store, end_store; - tst(aux_reg_output_data, cpu_isa_traits::vlen - 1); - b(NE, unaligned_store); - store_output(true); - b(end_store); - L(unaligned_store); - { store_output(false); } - L(end_store); - }; - - auto fma_block = [=](bool last_block) { - const int i_reduce_end = reduce_dim_tail && last_block - ? reduce_dim_tail - : jcp.reduce_loop_unroll; - - for (int i_reduce = 0; i_reduce < i_reduce_end; i_reduce++) { - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - auto vreg = vreg_load(i_load); - if (i_load + 1 == load_loop_blk && load_dim_tail) - ld1w(vreg.s, k_load_dim_mask / T_z, - ptr(load_ptr(i_reduce, i_load, X_DEFAULT_ADDR, - X_TMP_0))); - else - ldr(vreg, - ptr(load_ptr(i_reduce, i_load, X_DEFAULT_ADDR, - X_TMP_0))); - } - - for (int i_ur = 0; i_ur < ur; ++i_ur) { - if (jcp.expl_bcast && load_loop_blk > 1) { - ldr(W_TMP_0, - ptr(bcast_ptr(i_reduce, i_ur, false, X_DEFAULT_ADDR, - X_TMP_1))); - dup(vreg_bcast.s, W_TMP_0); - } - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - auto vreg_acc = vreg_accum(i_load, i_ur); - if (i_load + 1 == load_loop_blk && load_dim_tail) { - ld1rw(zreg_tmp.s, P_ALL_ONE, - ptr(bcast_ptr(i_reduce, i_ur, true, - X_DEFAULT_ADDR, X_TMP_0))); - fmla(vreg_acc.s, k_load_dim_mask / T_m, - vreg_load(i_load).s, zreg_tmp.s); - } else if (jcp.expl_bcast && load_loop_blk > 1) { - fmla(vreg_acc.s, P_ALL_ONE / T_m, vreg_load(i_load).s, - vreg_bcast.s); - } else { - ld1rw(zreg_tmp.s, P_ALL_ONE, - ptr(bcast_ptr(i_reduce, i_ur, true, - X_DEFAULT_ADDR, X_TMP_0))); - fmla(vreg_acc.s, P_ALL_ONE / T_m, vreg_load(i_load).s, - zreg_tmp.s); - } - } - } - } - }; - - Label reduce_loop; - Label reduce_loop_tail; - - mov(aux_reg_load_data, reg_load_data); - - mov(aux_reg_bcast_data, aux1_reg_bcast_data); - init(); - - mov(reduce_loop_iter, reg_reduce_loop_work); - subs_imm(reduce_loop_iter, reduce_loop_iter, jcp.reduce_loop_unroll, - reg_tmp_imm); - b(LE, reduce_loop_tail); - - L(reduce_loop); - { - fma_block(false); - add_imm(aux_reg_bcast_data, aux_reg_bcast_data, - jcp.reduce_loop_bcast_step, reg_tmp_imm); - add_imm(aux_reg_load_data, aux_reg_load_data, jcp.reduce_loop_load_step, - reg_tmp_imm); - subs_imm(reduce_loop_iter, reduce_loop_iter, jcp.reduce_loop_unroll, - reg_tmp_imm); - b(GT, reduce_loop); - } - - L(reduce_loop_tail); - fma_block(true); - - store(); -} - -void jit_sve_512_1x1_conv_kernel::generate() { - preamble(); - - sub_imm(X_SP, X_SP, stack_space_needed, X_TMP_0); - if (jcp.with_binary) { - const auto zeroed_reg = x15; - eor(zeroed_reg, zeroed_reg, zeroed_reg); - str(zeroed_reg, ptr(X_SP, reg_binary_post_op_acc_off)); - str(param1, ptr(X_SP, reg_abi_param1_backup)); - } - - /* Pointers indicate weight, input, and output data */ - ldr(reg_bcast_data, ptr(abi_param1, GET_OFF(bcast_data))); // Input - ldr(reg_load_data, ptr(abi_param1, GET_OFF(load_data))); // Weight - ldr(reg_output_data, ptr(abi_param1, GET_OFF(output_data))); // Output - - /* Pointer indicates bias data if the layer has bias option */ - if (jcp.with_bias) ldr(reg_bias_data, ptr(abi_param1, GET_OFF(bias_data))); - - /* Get workloads of each loop */ - ldr(reg_load_loop_work, ptr(abi_param1, GET_OFF(load_dim))); - ldr(reg_bcast_loop_work, ptr(abi_param1, GET_OFF(bcast_dim))); - str(reg_bcast_loop_work, ptr(X_SP, reg_bcast_loop_work_offt)); - ldr(reg_reduce_loop_work, ptr(abi_param1, GET_OFF(reduce_dim))); - - /* A flag for controlling reduce loop */ - ldr(reg_reduce_pos_flag, ptr(abi_param1, GET_OFF(first_last_flag))); - if (jcp.prop_kind == backward_weights) - ldr(reg_output_stride, ptr(param1, GET_OFF(output_stride))); - - const int load_dim_tail - = (one_of(jcp.prop_kind, forward_training, forward_inference) - ? jcp.oc_without_padding - : jcp.load_dim) - % jcp.load_block; - if (load_dim_tail) { - const WReg w_tmp(reg_load_dim_tail_mask.getIdx()); - mov_imm(w_tmp, (1 << load_dim_tail) - 1); - str(zreg_tmp1, ptr(X_TRANSLATOR_STACK, -1, MUL_VL)); - index(zreg_tmp.s, 0, 1); - mov(zreg_tmp1.s, 1); - lsl(zreg_tmp1.s, P_ALL_ONE / T_m, zreg_tmp.s); - dup(zreg_tmp.s, w_tmp); - and_(zreg_tmp.d, zreg_tmp.d, zreg_tmp1.d); - cmpne(k_load_dim_tail_mask.s, P_ALL_ONE, zreg_tmp.s, 0); - ldr(zreg_tmp1, ptr(X_TRANSLATOR_STACK, -1, MUL_VL)); - } - - auto load_loop_body = [=](int load_loop_blk) { - if (load_dim_tail) { - eor(k_load_dim_mask.b, P_ALL_ONE / T_z, k_load_dim_mask.b, - k_load_dim_mask.b); - not_(k_load_dim_mask.b, P_ALL_ONE / T_z, k_load_dim_mask.b); - } - subs_imm(reg_load_loop_work, reg_load_loop_work, - load_loop_blk * jcp.load_loop_iter_step, reg_tmp_imm); - if (load_dim_tail) { - Label no_update_mask; - b(GE, no_update_mask); - mov(k_load_dim_mask.b, k_load_dim_tail_mask.b); - L(no_update_mask); - } - bcast_loop(load_loop_blk); - add_imm(reg_load_data, reg_load_data, - load_loop_blk * jcp.load_loop_load_step, reg_tmp_imm); - switch (jcp.prop_kind) { - case forward_training: - case forward_inference: - add_imm(reg_bias_data, reg_bias_data, - load_loop_blk * jcp.load_block * jcp.typesize_out, - reg_tmp_imm); - add_imm(reg_output_data, reg_output_data, - load_loop_blk * jcp.load_block * jcp.typesize_out - * (is_out_layout_nxc(jcp) - ? 1 - : (jcp.with_dw_conv - ? jcp.ow - : jcp.bcast_dim)), - reg_tmp_imm); - if (jcp.with_binary) { - const auto oc_off_oprnd = aux_reg_load_data; - ldr(oc_off_oprnd, ptr(X_SP, reg_binary_post_op_acc_off)); - add_imm(oc_off_oprnd, oc_off_oprnd, - jcp.load_block * load_loop_blk, X_TMP_0); - str(oc_off_oprnd, ptr(X_SP, reg_binary_post_op_acc_off)); - } - break; - case backward_data: - add_imm(reg_output_data, reg_output_data, - load_loop_blk * jcp.load_block * jcp.typesize_out - * (is_out_layout_nxc(jcp) ? 1 : jcp.bcast_dim), - reg_tmp_imm); - break; - case backward_weights: - for (int i_load = 0; i_load < load_loop_blk; i_load++) - add(reg_output_data, reg_output_data, reg_output_stride); - break; - default: assert(!"invalid prop_kind"); - } - }; - - const int simd_w = cpu_isa_traits::vlen / sizeof(float); - - Label load_loop_blk[7]; - - // with an implicit load_loop_block {6, 5, 4, 3, 2, 1} - static const int ur_cases_fma_embd_bcast[] = {2, 4, 5, 8, 14, 32}; - static const int ur_cases_fma_expl_bcast[] = {2, 5, 6, 9, 14, 32}; - - const int size_ur_cases_fma = jcp.expl_bcast - ? sizeof(ur_cases_fma_expl_bcast) - : sizeof(ur_cases_fma_embd_bcast); - - const int *ur_cases_fma = jcp.expl_bcast ? ur_cases_fma_expl_bcast - : ur_cases_fma_embd_bcast; - const int *ur_cases = ur_cases_fma; - const int num_ur_cases = size_ur_cases_fma / sizeof(*ur_cases); - - for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { - int label_idx = num_ur_cases - ur_idx - 1; - if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) { - cmp_imm(reg_load_loop_work, simd_w * (label_idx + 1), reg_tmp_imm); - b(LE, load_loop_blk[label_idx]); - } - } - - for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { - int label_idx = num_ur_cases - ur_idx - 1; - if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) { - L(load_loop_blk[label_idx]); - { - if (label_idx == 0) { - cmp(reg_load_loop_work, 0); - b(LE, load_loop_blk[num_ur_cases]); - } - load_loop_body(label_idx + 1); - if (label_idx - 1 > 0) { - cmp_imm(reg_load_loop_work, 2 * label_idx * simd_w, - reg_tmp_imm); - b(EQ, load_loop_blk[label_idx - 1]); - } - cmp_imm(reg_load_loop_work, label_idx * simd_w, reg_tmp_imm); - b(GT, load_loop_blk[label_idx]); - } - for (int idx = label_idx - 1; idx >= 0; --idx) { - cmp_imm(reg_load_loop_work, simd_w * (idx + 1), reg_tmp_imm); - b(GE, load_loop_blk[idx]); - } - if (ur_idx < num_ur_cases - 2) { - cmp_imm(reg_load_loop_work, simd_w, reg_tmp_imm); - b(LE, load_loop_blk[0]); - } - } - } - L(load_loop_blk[num_ur_cases]); - - add_imm(X_SP, X_SP, stack_space_needed, X_TMP_0); - - postamble(); - if (jcp.with_eltwise) postops_injector_->prepare_table(); -} - -status_t jit_sve_512_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr, int nthreads, bool reduce_src) { - - /* arch check */ - if (!mayiuse(sve_512)) return status::unimplemented; - - if (!everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(), - dst_d.data_type())) - return status::unimplemented; - - jcp.nthr = nthreads; - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - const int simd_w = cpu_isa_traits::vlen / sizeof(float); - const int ndims = src_d.ndims(); - /* Forward_[training, inference], backward_[data, weight] */ - jcp.prop_kind = cd.prop_kind; - - /* Check group option */ - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - /* Batchsize */ - jcp.mb = src_d.dims()[0]; - /* Channel */ - jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; - jcp.oc = jcp.oc_without_padding; - jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; - jcp.ic = jcp.ic_without_padding; - /* D, H, W */ - jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; - jcp.iw = src_d.dims()[ndims - 1]; - jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; - jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; - jcp.ow = dst_d.dims()[ndims - 1]; - /* Kernel size */ - jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - /* padding params */ - jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; - jcp.l_pad = cd.padding[0][ndims - 3]; - /* stride params */ - jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; - jcp.stride_w = cd.strides[ndims - 3]; - /* bias info */ - jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, - format_kind::undef, cd.diff_bias_desc.format_kind) - != format_kind::undef; - - /* Spatials */ - jcp.os = jcp.od * jcp.oh * jcp.ow; - jcp.is = jcp.id * jcp.ih * jcp.iw; - - /* Depthwise conv check */ - const auto &post_ops = attr.post_ops_; - const int dw_conv_ind = post_ops.find(primitive_kind::convolution); - jcp.with_dw_conv = dw_conv_ind != -1; - if (jcp.with_dw_conv) return status::unimplemented; - - /* Post operation check */ - // Using dw_conv_ind as upper-bound below, as post-ops after it will be - // handled in depthwise convolution. - const int eltwise_ind - = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) { - if (dst_d.data_type() == data_type::s32) return status::unimplemented; - } - - const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); - jcp.with_sum = sum_ind != -1; - - const int binary_ind - = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); - jcp.with_binary = binary_ind != -1; - - if (dw_conv_ind >= 0) { - // dw_conv and post_ops after it are handled externally, so skip them - jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), - post_ops.entry_.cbegin() + dw_conv_ind); - } else { - jcp.post_ops = post_ops; - } - - /* Data format check */ - const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); - const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); - bool is_data_layout_nxc - = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); - auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; - - /* Channel padding check */ - bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1 - && src_d.data_type() == data_type::f32; - - /* Input and output must be multiple of simd_w */ - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - using namespace injector; - - static constexpr bool sum_at_pos_0_only = true; - static constexpr bool sum_requires_scale_one = true; - static constexpr bool sum_requires_zp_zero = true; - const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(jcp.isa, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, - sum_requires_scale_one, sum_requires_zp_zero)); - if (!post_ops_ok_) return status::unimplemented; - - bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == required_dat_tag - && jcp.dst_tag == required_dat_tag - && IMPLICATION(!is_data_layout_nxc, - jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0) - && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0 - && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1 - && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1 && jcp.ow == jcp.iw - && jcp.oh == jcp.ih && jcp.od == jcp.id; // enforce rpad=0 - if (!args_ok) return status::unimplemented; - - /* Channel blocking size is simd_w */ - jcp.ic_block = jcp.oc_block = simd_w; - - jcp.ver = ver_sve_512; - if (everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(), - dst_d.data_type())) { - const int is_bwd_d = jcp.prop_kind == backward_data; - /* Set weight data layout tag */ - format_tag_t wei_tag = with_groups - ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i, - gOIhw16i16o, gIOhw16o16i, gOIdhw16i16o, gIOdhw16o16i) - : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, - OIhw16i16o, IOhw16o16i, OIdhw16i16o, IOdhw16o16i); - - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - if (jcp.wei_tag != wei_tag) return status::unimplemented; - - // jcp.fma_step = 1; - jcp.typesize_in = sizeof(prec_traits::type); - jcp.typesize_out = sizeof(prec_traits::type); - } else { - // TODO: currently, only support fp32 - return status::unimplemented; - } - - /* once all the formats are set, check the padding consistency */ - if (!is_data_layout_nxc) { - args_ok = true && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= dst_d.padded_dims()[1] - && jcp.ic <= weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; - if (!args_ok) return status::unimplemented; - } - // TODO: Optimize bellow params - const int SMALL_SPATIAL = 10; - const int BIG_SPATIAL = 65; - const int BIG_REDUCE_DIM = 1024; - const int BIG_LOAD_DIM = (jcp.reduce_dim >= 512) ? 256 : 512; - - int load_blocking {0}; - int load_blocking_max {0}; - int bcast_blocking {0}; - int bcast_blocking_max {0}; - int reduce_blocking {0}; - int reduce_blocking_max {0}; - - jcp.load_grp_count = 1; - - // TODO: mov check funcs into platform files - const int L1_capacity - = platform::get_per_core_cache_size(1) / sizeof(float); - const int L2_size = platform::get_per_core_cache_size(2) / sizeof(float); - const int L2_capacity = (L2_size * 3) / 4; - - /* FWD, BWD data */ - if (one_of(jcp.prop_kind, forward_training, forward_inference, - backward_data)) { - - if (one_of(jcp.prop_kind, forward_training, forward_inference)) { - /* Forward */ - if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur); - jcp.reduce_dim = jcp.ic; // src channel - jcp.reduce_block = jcp.ic_block; // src simd_w - - jcp.load_dim = jcp.oc; // dst channel - jcp.load_block = jcp.oc_block; // dst simd_W - - jcp.bcast_dim = jcp.is; // src H*W - } else { - /* Backward data */ - jcp.reduce_dim = jcp.oc; // src channel - jcp.reduce_block = jcp.oc_block; // src simd_w - - jcp.load_dim = jcp.ic; // dst channel - jcp.load_block = jcp.ic_block; // dst simd_w - - jcp.bcast_dim = jcp.os; // src H*W - } - - /* # of consecutive channel elements */ - jcp.reduce_loop_unroll = jcp.reduce_block; - - /* Offset to move to the next 16 input channel elements with the same H*W position */ - jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll - * (is_data_layout_nxc ? 1 : jcp.bcast_dim) * jcp.typesize_in; - - /* Offset: 16o*16i (filter) */ - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; - - /* Offset: I/16 * 16o */ - jcp.load_loop_load_step - = (utils::rnd_up(jcp.reduce_dim, jcp.reduce_block)) - * jcp.load_block * jcp.typesize_in; - - /* adjusting registry blocking */ - int max_regs, min_regs, size_threshold; - - /* spatial : H*D of dst */ - const int spatial - = (one_of(jcp.prop_kind, forward_training, forward_inference)) - ? jcp.od * jcp.oh // forward - : jcp.id * jcp.ih; // backward - - if ((8 * jcp.mb) / jcp.nthr >= 1 - // NHWC perf: RN50 mb=1 - || (is_data_layout_nxc && jcp.mb == 1)) { - max_regs = 9; // max # of ur_w - min_regs = 6; // min # of ur_w - size_threshold = 14; - jcp.expl_bcast = true; - - /* - * H*D of dst > SMALL_SPATIAL - */ - if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM - && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL - && jcp.reduce_dim < 256) { - max_regs = 6; - min_regs = 5; - } - } else { - max_regs = 30; - min_regs = 9; - size_threshold = 14; - jcp.expl_bcast = false; - jcp.use_vmovntps = true; - } - jcp.ur = 1; - - for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { - /* - * H*D of dst >= size_threshold, (H*D of dst) % ur_w == 0 - * or - * H*D of dst < size_threshold, (H*W of dst) % ur_w == 0 - */ - if ((spatial >= size_threshold && spatial % ur_w == 0) - || (spatial < size_threshold && jcp.os % ur_w == 0)) { - jcp.ur = ur_w; - break; - } - } - - if (jcp.ur == 1) { - // If ur = 1, then min(max_regs, H*W of dst) - jcp.ur = nstl::min(max_regs, jcp.os); - int os_tail = jcp.os % max_regs; - for (int i = max_regs; i >= min_regs; i--) { - int i_tail = jcp.os % i; - if (i_tail > os_tail || i_tail == 0) { - jcp.ur = i; - os_tail = i_tail; - if (i_tail == 0) break; - } - } - } - jcp.bcast_block = jcp.ur; // block size of bcast (input data) - /* Number of steps for the dst address to output, used in bcast_loop() */ - jcp.bcast_loop_output_step = jcp.ur * jcp.typesize_out - * (is_data_layout_nxc ? jcp.load_dim : jcp.load_block); - jcp.bcast_loop_output_substep = -1; // unused - - /* Number of steps for the src address to be broadcasted in bcast_loop() */ - jcp.bcast_loop_bcast_step = jcp.ur * jcp.typesize_in - * (is_data_layout_nxc ? jcp.reduce_dim : jcp.reduce_block); - jcp.bcast_loop_bcast_substep = -1; // unused - - jcp.load_loop_iter_step = jcp.load_block; - - if (jcp.prop_kind == backward_data) - jcp.loop_order = loop_lbr; - else - jcp.loop_order = reduce_src ? loop_blr : loop_lbr; - - int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - int nb_load = div_up(jcp.load_dim, jcp.load_block); - if (is_data_layout_nxc) { - reduce_blocking = jcp.reduce_dim; - } else if (jcp.expl_bcast) { - if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL - && spatial < BIG_SPATIAL) { - reduce_blocking = nstl::min(jcp.reduce_dim, 80); - } else if (spatial > SMALL_SPATIAL) - reduce_blocking = nstl::min(jcp.reduce_dim, 512); - else - reduce_blocking = nstl::min(jcp.reduce_dim, 256); - } else { - reduce_blocking = nb_reduce; - if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) - reduce_blocking = 16; - else if (spatial > SMALL_SPATIAL - && jcp.reduce_dim >= BIG_REDUCE_DIM) - reduce_blocking = 8; - reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); - reduce_blocking *= jcp.reduce_block; - } - - // Check input data cache aliasing. - // For other ISA constants may be updated. - // 64 * 1024 is chosen due to 1MB L2 16-way cache. - // 7 is empirical value. It is about half of 16. - // So we leave about half of the set for other data - weights, dst - int way_size = (16 * 1024) / jcp.typesize_in; - int max_hits = 7; - if (!is_data_layout_nxc - && jcp.bcast_dim * reduce_blocking > way_size * max_hits) { - int nrb = reduce_blocking / simd_w; - int sp = jcp.bcast_dim; - int wl = way_size / simd_w; - for (int start_off = 0; start_off < jcp.ur; start_off++) { - for (int off = start_off, hits = 0; off < sp * nrb; off += wl) { - if (off % sp >= jcp.ur || ++hits < max_hits) continue; - int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp); - reduce_blocking - = nstl::min(reduce_blocking, max_r_blocking); - break; - } - } - } - - if (reduce_blocking < jcp.reduce_dim) { - if (jcp.prop_kind == backward_data) - jcp.loop_order = reduce_src ? loop_lbr : loop_rlb; - else - jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; - } - load_blocking = jcp.load_dim; - - /* Number of weight elements to be loaded for dest */ - int load_size = jcp.load_dim * jcp.reduce_dim; - /* Number of elements to be broadcasted from src */ - auto bcast_size - = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; - - /* 12 cores per CMG */ - if (jcp.nthr <= 12 && jcp.mb < jcp.nthr - && nb_load * nb_bcast > jcp.nthr) { - // Some heuristic here - float calc_koef = 0.01, best_cost = FLT_MAX; - int n_lgc = jcp.nthr; - float ratio = (float)load_size / (float)bcast_size; - int best_lgc = ratio > 1 ? n_lgc : 1; - auto calc_job_cost = [&](int lb, int tg, float mem_k) { - int bb_size = jcp.mb * div_up(nb_bcast, tg); - float calc_size = (float)(bb_size * jcp.ur) - * (lb * jcp.load_block) * jcp.reduce_dim; - float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) - * jcp.reduce_dim; - return calc_koef * calc_size + mem_k * mem_size; - }; - for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { - lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; - int min_lb = nb_load / lgc; - int max_lb = div_up(nb_load, lgc); - int min_tg = jcp.nthr / lgc; - int max_tg = div_up(jcp.nthr, lgc); - // Some heuristic here - float mem_koef = (max_tg == 1) ? 1.f : 1.3f; - float job_cost = 0.; - if (jcp.nthr % lgc < nb_load % lgc) { - job_cost = calc_job_cost(max_lb, min_tg, mem_koef); - } else { - auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); - auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); - job_cost = nstl::max(job_cost1, job_cost2); - } - - if (job_cost < best_cost) { - best_lgc = lgc; - best_cost = job_cost; - } - } - jcp.load_grp_count = best_lgc; - load_blocking - = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; - } else { - jcp.load_grp_count - = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast); - jcp.load_grp_count = best_divider(jcp.nthr, jcp.load_grp_count, - 2 * jcp.load_grp_count, false); - } - if (jcp.expl_bcast && jcp.bcast_dim <= 64 && load_size >= L2_size) { - jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); - } else if (jcp.bcast_dim <= 49 && jcp.mb <= jcp.nthr - && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { - jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); - load_blocking = jcp.load_block; - } - - auto get_thr_eff = [=](int load_chunk, int nthr) { - int lgc = div_up(nb_load, load_chunk); - int thr_per_grp = div_up(nthr, lgc); - int bcast_per_thr - = div_up(jcp.mb * nb_bcast, thr_per_grp) * jcp.bcast_block; - int load_per_thr = load_chunk * simd_w; - float data_norm = (bcast_per_thr + load_per_thr) / 2.f; - float data_eff - = (bcast_per_thr * load_per_thr) / (data_norm * data_norm); - float thr_eff_over_grp - = (float)nstl::max(1, nthr / lgc) / div_up(nthr, lgc); - float thr_eff_in_grp = ((float)jcp.mb * nb_bcast) - / rnd_up(jcp.mb * nb_bcast, thr_per_grp); - float thr_eff = thr_eff_over_grp * thr_eff_in_grp; - float load_eff = (float)nb_load / rnd_up(nb_load, lgc); - float overall_eff = data_eff + thr_eff + load_eff; - return overall_eff; - }; - - auto get_load_chunk = [=](int nthr) { - float best_eff = -1.0f; - int best_lgc = 1; - float eff; - - for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) { - int lgc = div_up(nb_load, load_chunk); - if (lgc > nthr) continue; - eff = get_thr_eff(load_chunk, nthr); - if (eff > best_eff) { - best_eff = eff; - best_lgc = lgc; - } - } - return best_lgc; - }; - - /* adjust the thread decomposition - * to improve the thr_eff for small problem size - * the threshold 8192 is empirical - * TODO: Threshold can be increase for init stride > 1*/ - if (sizeof(float) * bcast_size < 8192 && jcp.mb < jcp.nthr - && nb_load * nb_bcast < jcp.nthr) { - float best_thr_eff = -1.0f; - float thr_eff = -1.0f; - int overall_lgc = jcp.load_grp_count; - int lgc = 1; - int best_nthr = jcp.nthr; - int end_nthr = with_groups ? jcp.ngroups : 1; - for (int nthr = jcp.nthr / 2; nthr >= end_nthr; nthr--) { - lgc = get_load_chunk(nthr); - thr_eff = get_thr_eff(lgc, nthr); - if (best_thr_eff < thr_eff) { - best_thr_eff = thr_eff; - overall_lgc = lgc; - best_nthr = nthr; - } - } - jcp.nthr = best_nthr; - jcp.load_grp_count = overall_lgc; - load_blocking - = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; - } - - bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, - div_up(jcp.nthr, jcp.load_grp_count)) - * jcp.bcast_block; - bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); - bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); - - int space_for_bcast = (L2_capacity - /* kernel_size - */ - 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking - - 3 * 1024); - if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2; - - int bcast_in_cache - = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); - bcast_blocking = nstl::min( - bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); - // NHWC perf - if (is_data_layout_nxc) bcast_blocking = jcp.bcast_block; - - load_blocking_max = load_blocking; - bcast_blocking_max = bcast_blocking * 3 / 2; - reduce_blocking_max = reduce_blocking; - - jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur; - - } else if (jcp.prop_kind == backward_weights) { /* BWD weight */ - - jcp.reduce_dim = jcp.is; - - jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true); - if (jcp.reduce_dim % jcp.reduce_block != 0) - jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false); - if (jcp.reduce_block > 256) { jcp.reduce_block = 1; } - - jcp.load_dim = jcp.oc; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.ic; - jcp.bcast_block = jcp.ic_block; - - if (jcp.reduce_block <= 19 && - // maskrcnn optimization for nxc; don't reduce ur when ocb<=1 - !(is_data_layout_nxc && jcp.load_dim <= jcp.load_block)) { - // if reduce_block is big then generated JIT code may be big - // for small values of ur because reduce_loop_unroll = reduce_block - jcp.ur = jcp.bcast_block / 2; - jcp.expl_bcast = true; - } else { - jcp.ur = jcp.bcast_block; - jcp.expl_bcast = false; - } - - jcp.ur_tail = jcp.bcast_dim % jcp.bcast_block; - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step = jcp.typesize_in * jcp.reduce_loop_unroll - * (is_data_layout_nxc ? jcp.ic : jcp.ic_block); - jcp.reduce_loop_load_step = jcp.typesize_in * jcp.reduce_loop_unroll - * (is_data_layout_nxc ? jcp.oc : jcp.oc_block); - - jcp.bcast_loop_output_step - = jcp.oc_block * jcp.ic_block * jcp.typesize_out; - jcp.bcast_loop_output_substep - = jcp.oc_block * jcp.ur * jcp.typesize_out; - jcp.bcast_loop_bcast_step = jcp.ic_block - * (is_data_layout_nxc ? 1 - : utils::rnd_up( - jcp.reduce_dim, jcp.reduce_block)) - * jcp.typesize_in; - jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in; - - jcp.load_loop_load_step = jcp.typesize_in * jcp.oc_block - * (is_data_layout_nxc ? 1 : jcp.os); - jcp.load_loop_iter_step = jcp.oc_block; - - /* --- */ - balance(jcp); - - load_blocking = div_up(jcp.load_dim, jcp.load_block); - load_blocking = best_divider(load_blocking, 16, load_blocking, false); - load_blocking *= jcp.load_block; - - load_blocking_max = load_blocking; - assert(IMPLICATION( - !is_data_layout_nxc, jcp.load_dim % load_blocking == 0)); - - int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); - int min_bcast_blocking = 5; - - bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); - bcast_blocking = best_divider( - bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); - bcast_blocking *= jcp.bcast_block; - bcast_blocking_max = bcast_blocking; - assert(IMPLICATION( - !is_data_layout_nxc, jcp.bcast_dim % bcast_blocking == 0)); - - // for reduction balance - if (is_data_layout_nxc && jcp.reduce_dim >= BIG_SPATIAL * BIG_SPATIAL - && jcp.load_dim >= BIG_LOAD_DIM / 2) { - reduce_blocking = rnd_up(nstl::min(jcp.ow, 256), jcp.reduce_block); - } else { - int max_reduce_blocking - = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim); - int min_reduce_blocking = nstl::min( - L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); - reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking, - max_reduce_blocking, true); - reduce_blocking - = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block), - jcp.reduce_block); - } - - reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); - } else - return status::unimplemented; - - assert(load_blocking); - assert(load_blocking_max); - assert(bcast_blocking); - assert(bcast_blocking_max); - assert(reduce_blocking); - assert(reduce_blocking_max); - - if (!is_data_layout_nxc) { - assert(load_blocking % jcp.load_block == 0); - assert(reduce_blocking % jcp.reduce_block == 0); - assert(load_blocking_max % jcp.load_block == 0); - assert(reduce_blocking_max % jcp.reduce_block == 0); - assert(jcp.reduce_dim % jcp.reduce_block == 0); - } - - assert(jcp.bcast_block % jcp.ur == 0); - - jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; - jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; - jcp.nb_load_blocking = utils::div_up(load_blocking, jcp.load_block); - jcp.nb_load_blocking_max = utils::div_up(load_blocking_max, jcp.load_block); - jcp.nb_reduce_blocking = utils::div_up(reduce_blocking, jcp.reduce_block); - jcp.nb_reduce_blocking_max - = utils::div_up(reduce_blocking_max, jcp.reduce_block); - - jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); - jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - - return status::success; -} - -void jit_sve_512_1x1_conv_kernel::init_scratchpad( - memory_tracking::registrar_t &scratchpad, - const jit_1x1_conv_conf_t &jcp) { - - using namespace dnnl::impl::memory_tracking::names; - - // Fox nxc layout bias is padded only for bwd_wb direction, as bias - // reduction kernels can't handle tails yet. - if (jcp.with_bias && jcp.prop_kind != backward_data - && (jcp.oc != jcp.oc_without_padding // blocked layout - || (jcp.prop_kind == backward_weights // nxc layout - && jcp.oc % jcp.oc_block != 0))) { - - const size_t nelems_padded_bias - = jcp.ngroups * utils::rnd_up(jcp.oc, jcp.oc_block); - scratchpad.book( - key_conv_padded_bias, nelems_padded_bias, jcp.typesize_out); - } - - if (jcp.prop_kind == backward_weights) { - const size_t wei_size = (size_t)jcp.ngroups - * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block); - scratchpad.book(key_conv_wei_reduction, wei_size * (jcp.nthr_mb - 1), - jcp.typesize_out); - } -} - -/* BWD W*/ -void jit_sve_512_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp) { - int nthreads = jcp.nthr; - // initialize jcp reduction threading properties - jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; - if (nthreads < jcp.ngroups) { - /* simplification... fortunately it doesn't hurt much */ - return; - } - const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - const int nb_load = div_up(jcp.load_dim, jcp.load_block); - const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - - jcp.nthr_g = jcp.ngroups; - const int nthr = nthreads / jcp.nthr_g; - - auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { - /* calculate per thread memory cost (read/write). high level - * optimizer tries to minimize memory consumption. few notes: (n1) - * unclear why, but that essentially helps first convolution... - * (n2) assuming the reduction over minibatch is always there: - * - instead of 8 it should be 5 here (write ~= 2 read): - * kernel: temporal workspace 1 write - * reduction: 1 read from workspace and 1 write to the diff_wei - * - but experiments showed 8 works better than 5 or 6... */ - int bcast_koeff = 1; - int load_koeff = 1; - int output_koeff = 12; - return 0 - + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) - * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_bcast, nthr_ic_b) - * jcp.ic_block * jcp.reduce_block / jcp.stride_h - / jcp.stride_w /* (n1) */ - + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) - * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) - * jcp.oc_block * jcp.reduce_block - + (size_t)output_koeff /* (n2) */ - * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) - * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.oc_block; - }; - - int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; - auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); - - /* step 1: find the best thread distribution with lowest memory cost */ - const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); - for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { - const int nthr_par = nthr / nthr_mb; - const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); - for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { - nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); - auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); - if (mem_cost <= best_mem_cost) { - best_mem_cost = mem_cost; - jcp.nthr_mb = nthr_mb; - jcp.nthr_oc_b = nthr_oc_b; - jcp.nthr_ic_b = nthr_ic_b; - } - } - } - if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) - jcp.nthr_mb = nstl::min(jcp.mb, nthreads); - - jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; - assert(jcp.nthr <= nthreads); -} - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl diff --git a/src/cpu/aarch64/jit_sve_512_1x1_conv_kernel.hpp b/src/cpu/aarch64/jit_sve_512_1x1_conv_kernel.hpp deleted file mode 100644 index 2d41be54911..00000000000 --- a/src/cpu/aarch64/jit_sve_512_1x1_conv_kernel.hpp +++ /dev/null @@ -1,204 +0,0 @@ -/******************************************************************************* -* Copyright 2021-2023 Intel Corporation -* Copyright 2021-2023 FUJITSU LIMITED -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_AARCH64_JIT_SVE_1x1_CONV_KERNEL_HPP -#define CPU_AARCH64_JIT_SVE_1x1_CONV_KERNEL_HPP - -#include "common/c_types_map.hpp" -#include "common/memory_tracking.hpp" - -#include "cpu/aarch64/injectors/jit_uni_postops_injector.hpp" -#include "cpu/aarch64/jit_generator.hpp" -#include "cpu/aarch64/jit_op_imm_check.hpp" -#include "cpu/aarch64/jit_primitive_conf.hpp" - -using namespace Xbyak_aarch64; - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -/* Get vector offsets, ofs / VL(VL: 512bits = 64Bytes) */ -#define VL64_OFS(ofs) ((ofs) >> 6) - -struct jit_sve_512_1x1_conv_kernel : public jit_generator { - jit_sve_512_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp, - const primitive_attr_t &attr, const memory_desc_t &dst_md); - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sve_512_1x1_conv_kernel) - - static status_t init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, const primitive_attr_t &attr, - int nthreads, bool reduce_src); - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_1x1_conv_conf_t &jcp); - - jit_1x1_conv_conf_t jcp; - const primitive_attr_t &attr_; - -private: - using reg64_t = const XReg; - - /* Flags and loop variables */ - reg64_t reg_reduce_pos_flag = x1; - reg64_t reduce_loop_iter = x2; - reg64_t reg_bcast_loop_iter = x3; - reg64_t reg_relu_ns = x20; // For forward - reg64_t reg_output_stride = x20; // For backward - - /* Pointer */ - reg64_t reg_bcast_data = x5; // Input - reg64_t reg_load_data = x6; // Weight - reg64_t reg_output_data = x7; // Output - reg64_t reg_bias_data = x8; // bias - reg64_t aux1_reg_bcast_data = x9; - reg64_t aux_reg_output_data = x10; - reg64_t aux_reg_bcast_data = x11; - reg64_t aux_reg_load_data = x12; - reg64_t reg_prev_bcast_addr - = x13; // Input: The reg keeps addr accessed by previous ldr inst - reg64_t reg_prev_out_addr - = x14; // Output: The reg keeps addr accessed by previous ldr or str inst - - /* Workload */ - reg64_t reg_load_loop_work = x15; - reg64_t reg_reduce_loop_work = x16; - reg64_t reg_bcast_loop_work = x17; - - /* Temporay registers */ - reg64_t reg_tmp_imm = x27; // tmp for add_imm - reg64_t reg_tmp_ofs = x19; // tmp reg to calc bwd wei offset in out_load - - reg64_t reg_load_dim_tail_mask = aux_reg_load_data; - - std::unique_ptr> - postops_injector_; - - constexpr static int isa_simd_width_ - = cpu_isa_traits::vlen / sizeof(float); - - ZReg vreg_bcast = ZReg(31); - PReg k_load_dim_mask = p2; - PReg k_load_dim_tail_mask = p3; - ZReg zreg_tmp = ZReg(31); - ZReg zreg_tmp1 = ZReg(30); - - constexpr static int reg64_size_ = sizeof(int64_t); - constexpr static int reg_bcast_loop_work_offt = 0; - constexpr static int reg_binary_post_op_acc_off = 1 * reg64_size_; - constexpr static int reg_abi_param1_backup = 2 * reg64_size_; - constexpr static int stack_space_needed = 3 * reg64_size_; - - template - Xbyak_aarch64::XReg EVEX_compress_addr(const Xbyak_aarch64::XReg &addr, - const Xbyak_aarch64::XReg &x_tmp, Xbyak_aarch64::XReg base, - T raw_offt, bool bcast = false) { - - assert(raw_offt <= INT_MAX); - auto offt = static_cast(raw_offt); - - add_imm(addr, base, offt, x_tmp); - if (bcast) { - // addr is the same as addr when bcast is false. - } - return addr; - } - - void prefetch( - const std::string prfop, int level, reg64_t in, long long int ofs) { - bool for_load = false; - if (prfop == "LD") { - for_load = true; - } else if (prfop == "ST") { - for_load = false; - } else { - assert(!"invalid prfop"); - } - - bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false; - if (cacheline_aligned == true) { - Prfop op; - switch (level) { - case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break; - case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break; - case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break; - default: assert(!"invalid prfop"); break; - } - - if (prfm_imm_check(ofs)) { - prfm(op, ptr(in, static_cast(ofs))); - } else { - add_imm(reg_tmp_ofs, in, ofs, reg_tmp_imm); - prfm(op, ptr(reg_tmp_ofs)); - } - } else { - PrfopSve op_sve; - switch (level) { - case 1: - op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE; - break; - case 2: - op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE; - break; - case 3: - op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE; - break; - default: assert(!"invalid prfop"); break; - } - - if (prfw_imm_check(ofs)) { - prfw(op_sve, P_ALL_ONE, - ptr(in, static_cast(VL64_OFS(ofs)))); - } else { - add_imm(reg_tmp_ofs, in, ofs, reg_tmp_imm); - prfw(op_sve, P_ALL_ONE, ptr(reg_tmp_ofs)); - } - } - } - - void bcast_loop(int load_loop_blk); - void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); - - void generate() override; - static void balance(jit_1x1_conv_conf_t &jcp); - - inline size_t get_output_offset( - const bool is_out_layout_nxc, const int i_load, const int i_ur) { - const size_t i_load_shift = is_out_layout_nxc - ? jcp.load_block - : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block; - const size_t i_ur_shift - = is_out_layout_nxc ? jcp.load_dim : jcp.load_block; - return jcp.typesize_out * (i_load * i_load_shift + i_ur * i_ur_shift); - } - - Xbyak_aarch64::XReg output_ptr(const bool out_layout_nxc, const int i_load, - const int i_ur, Xbyak_aarch64::XReg addr); - void apply_postops(const bool is_out_layout_nxc, const int load_loop_blk, - const int ur); -}; - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl - -#endif diff --git a/src/cpu/aarch64/jit_sve_512_1x1_convolution.cpp b/src/cpu/aarch64/jit_sve_512_1x1_convolution.cpp deleted file mode 100644 index 4d311bdb611..00000000000 --- a/src/cpu/aarch64/jit_sve_512_1x1_convolution.cpp +++ /dev/null @@ -1,1040 +0,0 @@ -/******************************************************************************* -* Copyright 2021-2023 Intel Corporation -* Copyright 2021-2023 FUJITSU LIMITED -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "common/c_types_map.hpp" -#include "common/dnnl_thread.hpp" -#include "common/type_helpers.hpp" -#include "common/utils.hpp" - -#include "cpu/aarch64/jit_generator.hpp" - -#include "cpu/aarch64/jit_sve_512_1x1_convolution.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -using namespace dnnl::impl::status; -using namespace dnnl::impl::memory_tracking::names; -using namespace dnnl::impl::utils; - -#define data_blk_off(f, n, c, d, h, w) \ - ((ndims == 3) ? (f).blk_off(n, c, w) \ - : ((ndims == 4) ? (f).blk_off(n, c, h, w) \ - : (f).blk_off(n, c, d, h, w))) -/* convolution forward */ - -template -void jit_sve_512_1x1_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { - const auto &jcp = kernel_->jcp; - auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); - auto weights_dw = CTX_IN_MEM( - const wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); - auto bias_dw = CTX_IN_MEM( - const dst_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); - const auto post_ops_binary_rhs_arg_vec - = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); - const auto post_ops_binary_rhs_arg_vec_dw = pd()->dw_conv_pd_ - ? binary_injector::prepare_binary_args( - pd()->dw_conv_pd_->jcp_.post_ops, ctx, - pd()->jcp_.post_ops.entry_.size() + 1) - : std::vector {}; - - auto scratchpad = ctx.get_scratchpad_grantor(); - - if (pd()->wants_padded_bias()) { - auto padded_bias - = scratchpad.template get(key_conv_padded_bias); - utils::array_copy(padded_bias, bias, jcp.oc_without_padding); - utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, - jcp.oc - jcp.oc_without_padding); - bias = padded_bias; - } - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, - dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), - post_ops_binary_rhs_arg_vec_dw.data()); - }); - - if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); -} - -template -void jit_sve_512_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, const int nthr, - const src_data_t *src, const wei_data_t *weights, - const dst_data_t *bias, const wei_data_t *weights_dw, - const dst_data_t *bias_dw, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad, - const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const { - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper dw_weights_d( - pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); - const memory_desc_wrapper dw_bias_d( - pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)); - - const auto &jcp = kernel_->jcp; - auto rtus_space = pd()->rtus_.reduce_src_ - ? scratchpad.get(key_conv_rtus_space) - : nullptr; - - const int ndims = src_d.ndims(); - const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - auto p = jit_1x1_conv_call_s(); - - auto rp = rtus_driver_t::call_params_t(); - const int nb_oc = jcp.nb_load; - const int nb_ic = jcp.nb_reduce; - const int nb_ic_blocking = jcp.nb_reduce_blocking; - - // override some constants for fused dw_conv - const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; - const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; - const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; - const int nb_bcast_blocking_max - = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; - const int nb_load_blocking = jcp.nb_load_blocking; - const int nb_load_blocking_max = jcp.with_dw_conv - ? jcp.nb_load_blocking - : jcp.nb_load_blocking_max; - const bool is_dst_layout_nxc = utils::one_of( - jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); - const bool is_src_layout_nxc = utils::one_of( - jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); - - // Begin: declare Variables needed for dw conv. - memory_tracking::grantor_t dw_scratchpad( - scratchpad, memory_tracking::names::prefix_fusion); - dst_data_t *pbuf; - size_t row_offset; - const int nb_buffer = jcp.nb_load_blocking; - std::vector addrs; - // End - - auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, - int &bcast_step, int &od, int &oh, int &ow, - int &id, int &ih, int &iw) { - int osb {0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); - bcast_step = step( - nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); - bcast_step = nstl::min(bcast_step, bcast_end - iwork); - - const int os = osb * os_block; - od = os / (jcp.oh * jcp.ow); - int os_2d = os % (jcp.oh * jcp.ow); - oh = os_2d / jcp.ow; - ow = os_2d % jcp.ow; - - id = od * stride_d; - ih = oh * stride_h; - iw = ow * stride_w; - rp.iw_start = iw; - - p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); - rp.os = p.bcast_dim; - }; - - auto init_load = [&](int ocb, int ocb_end, int &load_step) { - load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); - const auto max_oc - = nstl::min(ocb_end * jcp.oc_block, jcp.oc_without_padding); - p.load_dim = this_block_size( - ocb * jcp.oc_block, max_oc, load_step * jcp.oc_block); - }; - - auto init_reduce = [&](int icb) { - const int nb_ic_blocking_step - = nstl::min(icb + nb_ic_blocking, nb_ic) - icb; - p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0) - | (icb + nb_ic_blocking_step >= nb_ic ? FLAG_REDUCE_LAST : 0); - - p.reduce_dim = this_block_size( - icb * jcp.ic_block, jcp.ic, nb_ic_blocking_step * jcp.ic_block); - rp.icb = p.reduce_dim; - }; - - auto ker_1x1 = [&](int ocb, int ocb_start, int icb, int n, int g, int od, - int oh, int ow, int id, int ih, int iw) { - const int oc_off_idx = is_dst_layout_nxc - ? g * jcp.oc + ocb * jcp.oc_block - : g * nb_oc + ocb; - const size_t dst_off = data_blk_off(dst_d, n, oc_off_idx, od, oh, ow); - - p.output_data = jcp.with_dw_conv - ? pbuf + (oh % pd()->dw_conv_pd_->jcp_.kh) * row_offset - : &dst[dst_off]; - p.bias_data = bias - ? &bias[oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block)] - : nullptr; - - p.load_data - = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - const int ic_off_idx = is_src_layout_nxc - ? g * jcp.ic + icb * jcp.ic_block - : g * nb_ic + icb; - if (pd()->rtus_.reduce_src_) { - rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ - + (is_src_layout_nxc ? ic_off_idx - : jcp.is * ic_off_idx * jcp.ic_block); - if (ocb == ocb_start) { - rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); - (*rtus_driver_)(&rp); - } - p.bcast_data = rp.ws; - } else - p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); - - p.oc_l_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); - p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; - p.dst_orig = dst; - - (*kernel_)(&p); - }; - auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, - int ocb_end) { - if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; - - if (jcp.loop_order == loop_rlb) { - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, ocb_end, load_step); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, - ow {0}, id {0}, ih {0}, iw {0}; - init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, - ow, id, ih, iw); - ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, - iw); - iwork += bcast_step; - } - ocb += load_step; - } - } - } else if (jcp.loop_order == loop_lbr) { - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, ocb_end, load_step); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, - id {0}, ih {0}, iw {0}; - init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, - id, ih, iw); - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, - iw); - } - iwork += bcast_step; - } - ocb += load_step; - } - } else if (jcp.loop_order == loop_rbl) { - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, - id {0}, ih {0}, iw {0}; - init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, - id, ih, iw); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, ocb_end, load_step); - ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, - iw); - ocb += load_step; - } - iwork += bcast_step; - } - } - } else if (jcp.loop_order == loop_blr) { - int iwork = bcast_start; - while (iwork < bcast_end) { - int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, - id {0}, ih {0}, iw {0}; - init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, - ih, iw); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, ocb_end, load_step); - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, - iw); - } - ocb += load_step; - } - iwork += bcast_step; - } - } else { - assert(!"unsupported loop order"); - } - }; - - auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { - auto &jcp_dw = pd()->dw_conv_pd_->jcp_; - int oh_1x1 = nstl::max(dw_oh * jcp_dw.stride_h - jcp_dw.t_pad, 0); - - for (int i = 0; i < jcp_dw.kh; ++i) - addrs[i] = pbuf + ((oh_1x1++) % jcp_dw.kh) * row_offset; - - const auto ocb_end = ocb_start + load_step; - const auto wch_stride = (is_src_layout_nxc ? 1 : jcp_dw.iw) - * jcp_dw.nb_ch_blocking * jcp_dw.ch_block; - const int dil_h = jcp_dw.dilate_h + 1; - const int str_h = jcp_dw.stride_h; - const int ch_num = jcp_dw.nb_ch_blocking; - const int ow = 0; - const int kw = 0; - - for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw.nb_ch_blocking) { - - const int i_t_overflow - = nstl::max(0, (int)(jcp_dw.t_pad - dw_oh * str_h)); - const int i_b_overflow - = nstl::max(jcp_dw.ih, - (int)(dw_oh * str_h + (jcp_dw.kh - 1) * dil_h - - jcp_dw.t_pad + 1)) - - jcp_dw.ih; - - const int kh = div_up(i_t_overflow, dil_h); - const int kh_padding = jcp_dw.kh - div_up(i_t_overflow, dil_h) - - div_up(i_b_overflow, dil_h); - - jit_conv_call_s par_conv_dw; - - par_conv_dw.src = addrs.data(); - - const size_t ch_step = is_dst_layout_nxc - ? jcp_dw.ch_block - : dst_d.blk_off(0, 1, 0, 0); - par_conv_dw.dst - = &dst[dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step]; - - par_conv_dw.filt - = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)]; - if (bias) - par_conv_dw.bias - = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw.ch_block)]; - - par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding); - - par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw.nb_ch) - ch) - * jcp_dw.ch_block; - - par_conv_dw.oc_l_off = ch * jcp_dw.ch_block; - par_conv_dw.post_ops_binary_rhs_arg_vec - = post_ops_binary_rhs_arg_vec_dw; - par_conv_dw.dst_orig = dst; - - (*kernel_dw_)(&par_conv_dw); - - for (int i = 0; i < jcp_dw.kh; ++i) - addrs[i] += wch_stride; - } - }; - - auto conv_dw = [&]() { - // Set variables - auto dw_conv_buffer - = dw_scratchpad.get(key_fusion_inout_buffer); - auto &jcp_dw = pd()->dw_conv_pd_->jcp_; - - const auto dw_conv_buffer_size_ - = (size_t)jcp_dw.kh * jcp.ow * nb_buffer * jcp.oc_block; - pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; - row_offset = dw_conv_buffer_size_ / jcp_dw.kh; - addrs.resize(jcp_dw.kh); - - int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; - balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw.oh, bcast_start, - bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); - - while (ocb_start < ocb_end) { - int load_step; - init_load(ocb_start, ocb_end, load_step); - - int oh_1x1 = 0; - auto bcast_iter = bcast_start; - while (bcast_iter < bcast_end) { - int n {0}, g {0}, oh_dw {0}; - nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, - jcp_dw.oh); - if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary - const int oh_1x1_range = oh_dw * jcp_dw.stride_h - jcp_dw.t_pad; - const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); - const int oh_1x1_end - = nstl::min(oh_1x1_range + jcp_dw.kh, jcp.oh); - oh_1x1 = nstl::max( - oh_1x1_begin, oh_1x1); // Skip rows computed previously - - // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh - const int bcast_start_1x1 - = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; - const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; - - conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, - ocb_start + load_step); - oh_1x1 = oh_1x1_end; - ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); - - bcast_iter += nb_bcast_blocking; - } - ocb_start += load_step; - } - }; - - if (jcp.with_dw_conv) { - conv_dw(); - } else { - - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; - int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; - balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, - ocb_start, ocb_end, jcp.load_grp_count); - - conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end); - } -} - -template struct jit_sve_512_1x1_convolution_fwd_t; - -/* convolution backward wtr data */ -template -void jit_sve_512_1x1_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - - const auto &jcp = kernel_->jcp; - auto rtus_space = pd()->rtus_.reduce_src_ - ? ctx.get_scratchpad_grantor().template get( - key_conv_rtus_space) - : nullptr; - - const int ndims = diff_src_d.ndims(); - - assert(jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1); - - const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - - const int nb_ic = jcp.nb_load; - const int nb_oc = jcp.nb_reduce; - const int os_block = jcp.bcast_block; - const int nb_oc_blocking = jcp.nb_reduce_blocking; - - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - auto p = jit_1x1_conv_call_s(); - auto rp = rtus_driver_t::call_params_t(); - - int bcast_start {0}, bcast_end {0}, icb_start {0}, icb_end {0}; - balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, - icb_start, icb_end, jcp.load_grp_count); - - bool reduce_outer - = (jcp.loop_order == loop_rbl || jcp.loop_order == loop_rlb); - int nboc_outer = reduce_outer ? nb_oc : 1; - int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1; - - int nboc_inner = reduce_outer ? 1 : nb_oc; - int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking; - const int max_ic = nstl::min(icb_end * jcp.ic_block, jcp.ic); - - for (int ocb_outer = 0; ocb_outer < nboc_outer; - ocb_outer += ocb_outer_step) { - size_t cur_ocb_outer - = nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - - ocb_outer; - - int load_step = 0; - for (int icb = icb_start; icb < icb_end; icb += load_step) { - load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, - jcp.nb_load_blocking_max); - - p.load_dim = this_block_size( - icb * jcp.ic_block, max_ic, load_step * jcp.ic_block); - rp.icb = p.load_dim; - int bcast_step; - for (int iwork = bcast_start; iwork < bcast_end; - iwork += bcast_step) { - int n {0}, g {0}, osb {0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, - jcp.nb_bcast); - - bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, - jcp.nb_bcast_blocking_max); - bcast_step = nstl::min(bcast_step, bcast_end - iwork); - - const int os = osb * os_block; - p.bcast_dim = this_block_size( - os, jcp.os, bcast_step * os_block); - rp.os = p.bcast_dim; - const int od = os / (jcp.oh * jcp.ow); - const int os_2d = os % (jcp.oh * jcp.ow); - const int oh = os_2d / jcp.ow; - const int ow = os_2d % jcp.ow; - const int id = od * stride_d; - const int ih = oh * stride_h; - const int iw = ow * stride_w; - rp.iw_start = iw; - const bool is_dsrc_layout_nxc - = utils::one_of(jcp.src_tag, format_tag::nwc, - format_tag::nhwc, format_tag::ndhwc); - const int ic_off_idx = is_dsrc_layout_nxc - ? g * jcp.ic + icb * jcp.ic_block - : g * nb_ic + icb; - rp.src = diff_src - + data_blk_off( - diff_src_d, n, ic_off_idx, id, ih, iw); - if (pd()->rtus_.reduce_src_) { - rp.ws = rtus_space - + ithr * pd()->rtus_.space_per_thread_; - p.output_data = rp.ws; - } else - p.output_data = rp.src; - - for (int ocb_inner = 0; ocb_inner < nboc_inner; - ocb_inner += ocb_inner_step) { - int cur_ocb_inner - = nstl::min(ocb_inner + ocb_inner_step, - nboc_inner) - - ocb_inner; - - int ocb = reduce_outer ? ocb_outer : ocb_inner; - int nb_oc_blocking_step - = reduce_outer ? cur_ocb_outer : cur_ocb_inner; - const bool is_ddst_layout_nxc - = utils::one_of(jcp.dst_tag, format_tag::nwc, - format_tag::nhwc, format_tag::ndhwc); - const int oc_off_idx = is_ddst_layout_nxc - ? g * jcp.oc + ocb * jcp.oc_block - : g * nb_oc + ocb; - size_t diff_dst_off = data_blk_off( - diff_dst_d, n, oc_off_idx, od, oh, ow); - p.bcast_data = &diff_dst[diff_dst_off]; - - p.load_data = &weights[pd()->with_groups() - ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - - p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; - - p.reduce_dim = this_block_size(ocb * jcp.oc_block, - jcp.oc, nb_oc_blocking_step * jcp.oc_block); - - (*kernel_)(&p); - } - if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp); - } - } - } - }); -} - -template struct jit_sve_512_1x1_convolution_bwd_data_t; - -/* convolution backward wtr weights */ - -#define wht_blk_off(d, g, ...) \ - (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ - : (d).blk_off(__VA_ARGS__)) - -status_t jit_sve_512_1x1_convolution_bwd_weights_t ::init(engine_t *engine) { - - CHECK(safe_ptr_assign(kernel_, - new jit_sve_512_1x1_conv_kernel( - pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); - CHECK(safe_ptr_assign( - acc_ker_, new cpu_accumulator_1d_t())); - CHECK(safe_ptr_assign(reducer_bias_, - new cpu_reducer_t(pd()->reducer_bia_conf_))); - CHECK(kernel_->create_kernel()); - CHECK(acc_ker_->create_kernel()); - CHECK(reducer_bias_->create_kernel()); - - CHECK(init_rtus_driver(this)); - return status::success; -} - -void jit_sve_512_1x1_convolution_bwd_weights_t::execute_backward_weights( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS); - auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - - const auto &jcp = kernel_->jcp; - - const auto scratchpad = ctx.get_scratchpad_grantor(); - auto rtus_space = pd()->rtus_.reduce_src_ - ? scratchpad.get(key_conv_rtus_space) - : NULL; - const bool is_bias_padded - = pd()->with_bias() && jcp.oc_without_padding % jcp.oc_block != 0; - - data_t *diff_bias = is_bias_padded - ? scratchpad.get(key_conv_padded_bias) - : diff_bias_in; - auto wei_reduction = scratchpad.get(key_conv_wei_reduction); - - const int ndims = src_d.ndims(); - const int wei_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) - * rnd_up(jcp.ic, jcp.ic_block); - - simple_barrier::ctx_t reduction_barrier; - simple_barrier::ctx_init(&reduction_barrier); - - const auto reducer_bia_scratchpad - = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia); - auto rb = this->reducer_bias_.get(); - rb->init(reducer_bia_scratchpad); - - // TODO (Roma): remove this restriction - assert(jcp.stride_w == 1 && jcp.stride_h == 1); - - const int nb_ic = jcp.nb_bcast; - const int nb_ic_blocking = jcp.nb_bcast_blocking; - - const int nb_oc = jcp.nb_load; - const int nb_oc_blocking = jcp.nb_load_blocking; - - const int sp_nb = jcp.nb_reduce; - const int mb_sp_work = jcp.mb * sp_nb; - - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - const bool is_src_layout_nxc = utils::one_of( - jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); - - const bool is_ddst_layout_nxc = utils::one_of( - jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); - - auto maybe_zero_icpad = [&](const int g_start, const int g_end, - const int ocb_start, const int ocb_end) { - // write zeros to IC padded region. - const int ic_tail = jcp.ic_without_padding % jcp.ic_block; - if (is_ddst_layout_nxc && ic_tail != 0) { - for_(int g = g_start; g < g_end; ++g) - for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) { - const int z_icb = nb_ic - 1; - const size_t off = wht_blk_off(diff_weights_d, g, z_ocb, z_icb) - + ic_tail * jcp.oc_block; - data_t *z_wei = diff_weights + off; - const int zero_work - = (nb_ic * jcp.ic_block - jcp.ic_without_padding) - * jcp.oc_block; - PRAGMA_OMP_SIMD() - for (int o = 0; o < zero_work; ++o) { - z_wei[o] = 0; - } - } - } - }; - - auto ker = [&](const int ithr, const int nthr) { - assert(nthr == jcp.nthr); - - const int ithr_ic_b = ithr % jcp.nthr_ic_b; - const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; - const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; - const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; - - /* reduction dimension */ - int mb_sp_b_start {0}, mb_sp_b_end {0}; - balance211( - mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, mb_sp_b_end); - - /* independent dimensions */ - int g_start {0}, oc_b_start {0}, ic_b_start {0}; - int g_end {0}, oc_b_end {0}, ic_b_end {0}; - - balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); - balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end); - balance211( - jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end); - - const int g_work = g_end - g_start; - const int oc_b_work = oc_b_end - oc_b_start; - const int ic_b_work = ic_b_end - ic_b_start; - const bool cache_aliasing - = (jcp.ic * jcp.ngroups * sizeof(float)) % 1024 == 0; - int reduce_step = jcp.nb_reduce_blocking; - int reduce_step_max = jcp.nb_reduce_blocking_max; - if (is_src_layout_nxc && cache_aliasing) { - // Experiments show 4 is a magic number with the tested shapes. - // TODO: maybe tune for shapes with sp_dim%4 != 0 - reduce_step = nstl::min(4, reduce_step); - reduce_step_max = reduce_step; - } - - data_t *diff_wei = ithr_mb == 0 - ? diff_weights - : wei_reduction + (ithr_mb - 1) * wei_size; - - int sp_b_step = 0; - for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; - mb_sp_b += sp_b_step) { - int img {0}, sp_b {0}; - nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); - sp_b_step = step(reduce_step, - nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), - reduce_step_max); - - for (int g = g_start; g < g_end; ++g) { - int load_step = 0; - int bcast_step = 0; - for (int ic_b = ic_b_start; ic_b < ic_b_end; - ic_b += bcast_step) { - if (is_src_layout_nxc && cache_aliasing) { - bcast_step = ic_b_work; - } else { - bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, - jcp.nb_bcast_blocking_max); - } - - for (int oc_b = oc_b_start; oc_b < oc_b_end; - oc_b += load_step) { - load_step = step(nb_oc_blocking, oc_b_end - oc_b, - jcp.nb_load_blocking_max); - const int _ic_b = g * nb_ic + ic_b; - const int oc_off_idx = is_ddst_layout_nxc - ? g * jcp.oc + oc_b * jcp.oc_block - : g * nb_oc + oc_b; - - data_t *store_to; - - const size_t off - = wht_blk_off(diff_weights_d, g, oc_b, ic_b); - store_to = diff_wei + off; - - const int ic_off_idx - = (is_src_layout_nxc ? jcp.ic_block : 1) - * _ic_b; - const data_t *diff_src - = &src[src_d.blk_off(img, ic_off_idx)]; - - int sp_b_end = sp_b + sp_b_step; - const data_t *pdiff_dst = &diff_dst[diff_dst_d.blk_off( - img, oc_off_idx)]; - const data_t *local_src = diff_src; - - auto p = jit_1x1_conv_call_s(); - auto rp = rtus_driver_t::call_params_t(); - p.output_stride = utils::rnd_up(jcp.ic, jcp.ic_block) - * jcp.oc_block * jcp.typesize_out; - - p.load_dim = this_block_size(oc_b * jcp.oc_block, - jcp.oc, load_step * jcp.oc_block); - - p.bcast_dim = this_block_size(ic_b * jcp.ic_block, - jcp.ic, bcast_step * jcp.ic_block); - rp.icb = p.bcast_dim; - p.output_data = store_to; - - p.reduce_dim = sp_b_step * jcp.reduce_block; - rp.os = p.reduce_dim; - p.first_last_flag = 0 - | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST - : 0) - | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); - - int sp = sp_b * jcp.reduce_block; - int oc_mult - = is_ddst_layout_nxc ? jcp.oc : jcp.oc_block; - p.load_data = pdiff_dst + sp * oc_mult; - - if (pd()->rtus_.reduce_src_) { - const int oh = sp / jcp.ow; - const int ow = sp % jcp.ow; - - const int ih = oh * stride_h; - const int iw = ow * stride_w; - rp.iw_start = iw; - - rp.ws = rtus_space - + ithr * pd()->rtus_.space_per_thread_ - + sp * jcp.ic_block; - - if (ndims == 3) - rp.src = local_src - + iw * src_d.blocking_desc().strides[2]; - else - rp.src = local_src - + ih * src_d.blocking_desc().strides[2] - + iw * src_d.blocking_desc().strides[3]; - (*rtus_driver_)(&rp); - - p.bcast_data = rp.ws; - } else { - int ic_mult - = is_src_layout_nxc ? jcp.ic : jcp.ic_block; - p.bcast_data = local_src + sp * ic_mult; - } - - (*kernel_)(&p); - } - } - } - } - - if (ithr_mb == 0 && ic_b_end >= jcp.nb_bcast) { - maybe_zero_icpad(g_start, g_end, oc_b_start, oc_b_end); - } - - /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */ - if (dnnl_thr_syncable() && jcp.nthr_mb > 1) { - simple_barrier::barrier(&reduction_barrier, jcp.nthr); - const int work = g_work * oc_b_work * ic_b_work; - int start {0}, end {0}; - balance211(work, jcp.nthr_mb, ithr_mb, start, end); - if (start == end) return; - - for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { - int w = start; - int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0}; - nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, - oc_b_work, sub_ic_b_start, ic_b_work); - while (w < end) { - const int g = g_start + sub_g_start; - const int oc_b = oc_b_start + sub_oc_b_start; - const int ic_b = ic_b_start + sub_ic_b_start; - const int ic_to_accumulate - = nstl::min(end - w, ic_b_work - sub_ic_b_start) - * jcp.ic_block; - const int acc_size - = this_block_size(ic_b * jcp.ic_block, - jcp.ic_without_padding, ic_to_accumulate) - * jcp.oc_block; - - const size_t off - = wht_blk_off(diff_weights_d, g, oc_b, ic_b); - data_t *d = diff_weights + off; - data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off; - - acc_ker_->accumulate(d, s, acc_size); - - nd_iterator_jump(w, end, sub_g_start, g_work, - sub_oc_b_start, oc_b_work, sub_ic_b_start, - ic_b_work); - } - } - } - }; - - auto ker_bias = [&](int ithr, int nthr) { - assert(nthr == rb->balancer().nthr_); - - const int b_job_start = rb->balancer().ithr_job_off(ithr); - const int b_njobs = rb->balancer().ithr_njobs(ithr); - - if (b_njobs == 0) return; - - /* reduction dimension */ - int img_start {0}, img_end {0}; - - balance211(jcp.mb, rb->balancer().nthr_per_group_, - rb->balancer().id_in_group(ithr), img_start, img_end); - - /* jobs */ - int g_start {0}, ocb_start {0}; - nd_iterator_init( - b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load); - - for (int img = img_start; img < img_end; ++img) { - int g = g_start, ocb = ocb_start; - for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { - const int oc_off_idx = is_ddst_layout_nxc - ? g * jcp.oc + ocb * jcp.oc_block - : g * jcp.nb_load + ocb; - const data_t *d_dst - = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)]; - - data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, - reducer_bia_scratchpad) - + b_job_loc * rb->balancer().job_size_; - const int sp_shift = is_ddst_layout_nxc ? jcp.ngroups * jcp.oc - : jcp.oc_block; - const auto max_oc = this_block_size( - ocb * jcp.oc_block, jcp.oc, jcp.oc_block); - if (img == img_start) - for (int o = 0; o < 16; ++o) - d_bias[o] = 0.; - - for (int os = 0; os < jcp.os; ++os) { - PRAGMA_OMP_SIMD() - for (int o = 0; o < max_oc; ++o) - d_bias[o] += d_dst[o]; - d_dst += sp_shift; - } - - nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load); - } - } - - if (dnnl_thr_syncable()) - rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); - }; - - if (dnnl_thr_syncable()) { - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - ker(ithr, jcp.nthr); - if (pd()->with_bias()) ker_bias(ithr, jcp.nthr); - }); - } else { - parallel(jcp.nthr, [&](int ithr, int nthr) { ker(ithr, nthr); }); - if (jcp.nthr_mb > 1) - parallel(jcp.nthr, [&](int ithr, int nthr) { - assert(nthr == jcp.nthr); - - const int ithr_ic_b = ithr % jcp.nthr_ic_b; - const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; - const int ithr_g - = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; - const int ithr_mb - = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; - - /* independent dimensions */ - int g_start {0}, oc_b_start {0}, ic_b_start {0}; - int g_end {0}, oc_b_end {0}, ic_b_end {0}; - - balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); - balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, - oc_b_end); - balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, - ic_b_end); - - const int g_work = g_end - g_start; - const int oc_b_work = oc_b_end - oc_b_start; - const int ic_b_work = ic_b_end - ic_b_start; - - const int work = g_work * oc_b_work * ic_b_work; - int start {0}, end {0}; - balance211(work, jcp.nthr_mb, ithr_mb, start, end); - if (start == end) return; - - for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { - int w = start; - int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0}; - nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, - oc_b_work, sub_ic_b_start, ic_b_work); - while (w < end) { - const int g = g_start + sub_g_start; - const int oc_b = oc_b_start + sub_oc_b_start; - const int ic_b = ic_b_start + sub_ic_b_start; - const int ic_to_accumulate - = nstl::min(end - w, ic_b_work - sub_ic_b_start) - * jcp.ic_block; - const int acc_size - = this_block_size(ic_b * jcp.ic_block, - jcp.ic_without_padding, - ic_to_accumulate) - * jcp.oc_block; - - const size_t off - = wht_blk_off(diff_weights_d, g, oc_b, ic_b); - data_t *d = diff_weights + off; - data_t *s - = wei_reduction + (thr_mb - 1) * wei_size + off; - - acc_ker_->accumulate(d, s, acc_size); - - nd_iterator_jump(w, end, sub_g_start, g_work, - sub_oc_b_start, oc_b_work, sub_ic_b_start, - ic_b_work); - } - } - }); - if (pd()->with_bias()) { - parallel(jcp.nthr, - [&](int ithr, int nthr) { ker_bias(ithr, nthr); }); - parallel(jcp.nthr, [&](int ithr, int nthr) { - assert(nthr == rb->balancer().nthr_); - MAYBE_UNUSED(nthr); - if (rb->balancer().ithr_njobs(ithr) == 0) return; - rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad); - }); - } - } - - /* TODO: put this in ker_bias */ - if (is_bias_padded) { - assert(IMPLICATION(!is_ddst_layout_nxc, jcp.ngroups == 1)); - const int padded_stride = rnd_up(jcp.oc, jcp.oc_block); - const int stride = jcp.oc_without_padding; - for (int g = 0; g < jcp.ngroups; ++g) { - utils::array_copy(diff_bias_in + g * stride, - diff_bias + g * padded_stride, stride); - } - } -} - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl diff --git a/src/cpu/aarch64/jit_sve_512_1x1_convolution.hpp b/src/cpu/aarch64/jit_sve_512_1x1_convolution.hpp deleted file mode 100644 index ac7b1f2b1d1..00000000000 --- a/src/cpu/aarch64/jit_sve_512_1x1_convolution.hpp +++ /dev/null @@ -1,534 +0,0 @@ -/******************************************************************************* -* Copyright 2021-2023 Intel Corporation -* Copyright 2021-2023 FUJITSU LIMITED -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_AARCH64_JIT_SVE_1X1_CONVOLUTION_HPP -#define CPU_AARCH64_JIT_SVE_1X1_CONVOLUTION_HPP - -#include "common/c_types_map.hpp" -#include "common/dnnl_thread.hpp" -#include "common/memory_tracking.hpp" -#include "common/primitive.hpp" -#include "common/primitive_hashing.hpp" -#include "common/utils.hpp" - -#include "cpu/cpu_convolution_pd.hpp" -#include "cpu/dw_convolution_utils.hpp" -#include "cpu/platform.hpp" - -#include "cpu/aarch64/cpu_reducer.hpp" -#include "cpu/aarch64/jit_sve_512_1x1_conv_kernel.hpp" -#include "cpu/aarch64/jit_uni_1x1_conv_utils.hpp" -#include "cpu/aarch64/jit_uni_dw_convolution.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -template -struct jit_sve_512_1x1_convolution_fwd_t : public primitive_t { - struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) - , jcp_() - , rtus_() {} - pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { - if (copy(other) != status::success) is_initialized_ = false; - } - - DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", sve_512, ""), - jit_sve_512_1x1_convolution_fwd_t); - - status_t init(engine_t *engine) { - using namespace utils; - - bool ok = true && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, wei_type, dst_type, dst_type, - data_type::undef) - && attr()->has_default_values( - primitive_attr_t::skip_mask_t::post_ops, dst_type) - && !has_zero_dim_memory() && set_default_formats() - && attr_.set_default_formats(dst_md(0)) == status::success; - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *src_d = src_md(); - rtus_prepare(this, conv_d, src_d, dst_md()); - - CHECK(jit_sve_512_1x1_conv_kernel::init_conf(jcp_, *conv_d, *src_d, - *weights_md(), *dst_md(), *attr(), dnnl_get_max_threads(), - rtus_.reduce_src_)); - if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); - - auto scratchpad = scratchpad_registry().registrar(); - jit_sve_512_1x1_conv_kernel::init_scratchpad(scratchpad, jcp_); - - rtus_prepare_space_info(this, scratchpad, jcp_.nthr); - - return status::success; - } - - const memory_desc_t *dst_md( - int index = 0, bool user_input = false) const override { - return jcp_.with_dw_conv - ? dw_conv_pd_->dst_md(index, user_input) - : cpu_convolution_fwd_pd_t::dst_md(index, user_input); - } - - const memory_desc_t *arg_md( - int arg, bool user_input = false) const override { - if (jcp_.with_dw_conv) { - switch (arg) { - case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_SRC: - return cpu_convolution_fwd_pd_t::dst_md(0, user_input); - case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: - return dw_conv_pd_->weights_md(0); - case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: - return dw_conv_pd_->weights_md(1); - default: break; - } - } - return convolution_fwd_pd_t::arg_md(arg, user_input); - } - - arg_usage_t arg_usage(int arg) const override { - if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) - return arg_usage_t::input; - - if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) - && attr_post_op_dw_inputs() > 1) - return arg_usage_t::input; - - return convolution_fwd_pd_t::arg_usage(arg); - } - - jit_1x1_conv_conf_t jcp_; - reduce_to_unit_stride_t rtus_; - using dw_pd_t = jit_sve_512_dw_convolution_fwd_t::pd_t; - std::unique_ptr dw_conv_pd_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - const memory_desc_wrapper src_d(&src_md_); - const memory_desc_wrapper dst_d(&dst_md_); - - const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); - const auto dat_tag_nCx16c - = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); - const auto curr_src_tag - = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); - const auto curr_dst_tag - = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); - const auto is_data_layout_nxc - = IMPLICATION(curr_src_tag != dat_tag_nxc, - src_d.format_kind() == format_kind::any) - && IMPLICATION(curr_dst_tag != dat_tag_nxc, - dst_d.format_kind() == format_kind::any) - && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); - auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; - auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), - OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o, - gOIdhw16i16o); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - status_t copy(const pd_t &other) { - jcp_ = other.jcp_; - rtus_ = other.rtus_; - if (other.dw_conv_pd_) { - dw_conv_pd_.reset(other.dw_conv_pd_->clone()); - if (!dw_conv_pd_) return status::out_of_memory; - } - return status::success; - } - - status_t depthwise_po_init(engine_t *engine) { - - using namespace memory_tracking; - auto &jcp_1x1 = jcp_; - primitive_attr_t attr_1x1(*attr()); - if (!attr_1x1.is_initialized()) return status::out_of_memory; - const auto &src_md = dst_md_; - const memory_desc_wrapper src_d(src_md); - const auto nthr = dnnl_get_max_threads(); - auto l2_cache = platform::get_per_core_cache_size(2) * nthr; - - // Note: A robust fusion implementation would be to check if both - // 1x1 conv and dw conv that are considered here for fusion are - // optimal independently. This would require creating a new - // primitive_desc through primitive_iterator & check if they match. - // Due to concern that these creations and/or checks could be heavy, - // for 1x1: Check that no better ISA is available. - // for dw: Always fuse with same ISA. - // Caveat: May be a better dw conv exists. - - // TODO: Add a check if better ISA exists following above note. - bool ok = true - && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) - // TODO: Below may be further tuned. - && (l2_cache * 2 < src_d.size()) - // load_grp_count check can be redundant due to l2 check - // above. Adding it explicitly as the current driver doesn't - // work if this condition fails. - && (jcp_1x1.load_grp_count < 2); - if (!ok) return status::unimplemented; - - int dw_po_index - = attr_1x1.post_ops_.find(primitive_kind::convolution); - convolution_desc_t cd_dw; - primitive_attr_t attr_dw; - CHECK(get_depthwise_conv_desc( - cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); - - CHECK(safe_ptr_assign( - dw_conv_pd_, new dw_pd_t(&cd_dw, &attr_dw, nullptr))); - CHECK(dw_conv_pd_->init(engine)); - auto &jcp_dw = dw_conv_pd_->jcp_; - - ok = true - && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) - && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) - && IMPLICATION( - jcp_dw.ow_block, jcp_dw.ow_block == jcp_dw.ow); - if (!ok) return status::unimplemented; - - assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); - assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); - assert(IMPLICATION( - dw_conv_pd_->weights_md(1)->data_type != data_type::undef, - dw_conv_pd_->weights_md(1)->format_kind - != format_kind::any)); - - jcp_dw.is_fused_conv = true; - // TODO: Support/experiment arbitary oc_work in dw conv. - // Until then we keep oc_work perfectly divisible. - while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) - --jcp_1x1.nb_load_blocking; - jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; - - while (jcp_1x1.nb_load_blocking % jcp_dw.nb_ch_blocking != 0) - --jcp_dw.nb_ch_blocking; - - jcp_dw.dw_conv_buffer_oc - = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; - - const auto dat_tag_nxc = utils::pick(ndims() - 3, format_tag::nwc, - format_tag::nhwc, format_tag::ndhwc); - const bool is_data_nxc = utils::everyone_is( - dat_tag_nxc, jcp_1x1.src_tag, jcp_1x1.dst_tag); - if (!is_data_nxc) - jcp_1x1.bcast_loop_output_step = jcp_1x1.ur * jcp_1x1.load_block - * jcp_1x1.typesize_out; - - registrar_t scratchpad(scratchpad_registry_); - registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); - - size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw.kh * jcp_dw.iw - * jcp_dw.dw_conv_buffer_oc; - assert(dw_conv_buffer_size_); - dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, - dw_conv_buffer_size_, - types::data_type_size(dw_conv_pd_->src_md()->data_type)); - - jit_uni_dw_conv_fwd_kernel::init_scratchpad(dw_scratchpad, jcp_dw); - - return status::success; - } - }; - - template - friend status_t init_rtus_driver(conv_t *self); - - jit_sve_512_1x1_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - - status_t init(engine_t *engine) override { - CHECK(safe_ptr_assign(kernel_, - new jit_sve_512_1x1_conv_kernel( - pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); - CHECK(kernel_->create_kernel()); - - if (pd()->jcp_.with_dw_conv) { - CHECK(safe_ptr_assign( - kernel_dw_, new dw_conv_kernel_t(pd()->dw_conv_pd_->jcp_))); - CHECK(kernel_dw_->create_kernel()); - } - - CHECK(init_rtus_driver(this)); - return status::success; - } - - status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - void execute_forward_thr(const int ithr, const int nthr, - const src_data_t *src, const wei_data_t *weights, - const dst_data_t *bias, const wei_data_t *weights_dw, - const dst_data_t *bias_dw, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad, - const void *post_ops_binary_rhs_arg_vec, - const void *post_ops_binary_rhs_arg_vec_dw) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - - std::unique_ptr kernel_; - std::unique_ptr> rtus_driver_; - using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32; - std::unique_ptr kernel_dw_; -}; - -using jit_sve_512_1x1_convolution_fwd_f32_t - = jit_sve_512_1x1_convolution_fwd_t; - -template -struct jit_sve_512_1x1_convolution_bwd_data_t : public primitive_t { - struct pd_t : public cpu_convolution_bwd_data_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) - , jcp_() - , rtus_() {} - DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", sve_512, ""), - jit_sve_512_1x1_convolution_bwd_data_t); - - status_t init(engine_t *engine) { - bool ok = true && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(diff_src_type, wei_type, - data_type::undef, diff_dst_type, data_type::undef) - && attr()->has_default_values() && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *diff_src_d = diff_src_md(); - rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); - - status_t status = jit_sve_512_1x1_conv_kernel::init_conf(jcp_, - *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), - *attr(), dnnl_get_max_threads(), rtus_.reduce_src_); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_sve_512_1x1_conv_kernel::init_scratchpad(scratchpad, jcp_); - - rtus_prepare_space_info(this, scratchpad, jcp_.nthr); - - return status::success; - } - - // TODO (Roma): structs conf header cleanup - jit_1x1_conv_conf_t jcp_; - reduce_to_unit_stride_t rtus_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - const memory_desc_wrapper diff_src_d(&diff_src_md_); - const memory_desc_wrapper diff_dst_d(&diff_dst_md_); - - const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); - const auto dat_tag_nCx16c - = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); - const auto curr_src_tag = diff_src_d.matches_one_of_tag( - dat_tag_nxc, dat_tag_nCx16c); - const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( - dat_tag_nxc, dat_tag_nCx16c); - const auto is_data_layout_nxc - = IMPLICATION(curr_src_tag != dat_tag_nxc, - diff_src_d.format_kind() == format_kind::any) - && IMPLICATION(curr_dst_tag != dat_tag_nxc, - diff_dst_d.format_kind() == format_kind::any) - && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); - auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; - auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), - IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i, IOdhw16o16i, - gIOdhw16o16i); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - template - friend status_t init_rtus_driver(conv_t *self); - - jit_sve_512_1x1_convolution_bwd_data_t(const pd_t *apd) - : primitive_t(apd) {} - - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type diff_src_data_t; - - status_t init(engine_t *engine) override { - CHECK(safe_ptr_assign(kernel_, - new jit_sve_512_1x1_conv_kernel( - pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); - CHECK(kernel_->create_kernel()); - CHECK(init_rtus_driver(this)); - return status::success; - } - - status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - std::unique_ptr kernel_; - std::unique_ptr> rtus_driver_; -}; - -using jit_sve_512_1x1_convolution_bwd_data_f32_t - = jit_sve_512_1x1_convolution_bwd_data_t; - -/* Backward weight */ -struct jit_sve_512_1x1_convolution_bwd_weights_t : public primitive_t { - struct pd_t : public cpu_convolution_bwd_weights_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) - , jcp_() - , rtus_() {} - - DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", sve_512, ""), - jit_sve_512_1x1_convolution_bwd_weights_t); - - status_t init(engine_t *engine) { - bool ok = true && desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && attr()->has_default_values() && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *src_d = src_md(); - rtus_prepare(this, conv_d, src_d, diff_dst_md()); - - status_t status = jit_sve_512_1x1_conv_kernel::init_conf(jcp_, - *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), - *attr(), dnnl_get_max_threads(), rtus_.reduce_src_); - if (status != status::success) return status; - - init_balancers(); - - auto scratchpad = scratchpad_registry().registrar(); - jit_sve_512_1x1_conv_kernel::init_scratchpad(scratchpad, jcp_); - - auto reducer_bia_scratchpad = memory_tracking::registrar_t( - scratchpad, memory_tracking::names::prefix_reducer_bia); - reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); - rtus_prepare_space_info(this, scratchpad, jcp_.nthr); - - return status::success; - } - - // TODO (Roma): structs conf header cleanup - jit_1x1_conv_conf_t jcp_; - cpu_reducer_t::conf_t reducer_bia_conf_; - reduce_to_unit_stride_t rtus_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - const memory_desc_wrapper src_d(&src_md_); - const memory_desc_wrapper diff_dst_d(&diff_dst_md_); - - const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); - const auto dat_tag_nCx16c - = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); - const auto curr_src_tag - = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); - const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( - dat_tag_nxc, dat_tag_nCx16c); - const auto is_data_layout_nxc - = IMPLICATION(curr_src_tag != dat_tag_nxc, - src_d.format_kind() == format_kind::any) - && IMPLICATION(curr_dst_tag != dat_tag_nxc, - diff_dst_d.format_kind() == format_kind::any) - && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); - - auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; - auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), - OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o, - gOIdhw16i16o); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - - private: - void init_balancers() { - const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; - if (with_bias()) { - reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, - jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, jcp_.mb, - max_buffer_size, true)); - } - } - }; - - template - friend status_t init_rtus_driver(conv_t *self); - - jit_sve_512_1x1_convolution_bwd_weights_t(const pd_t *apd) - : primitive_t(apd) {} - - typedef typename prec_traits::type data_t; - - status_t init(engine_t *engine) override; - - status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - - std::unique_ptr kernel_; - std::unique_ptr> acc_ker_; - std::unique_ptr> reducer_bias_; - // std::unique_ptr trans_kernel_; - std::unique_ptr> rtus_driver_; -}; - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl - -#endif diff --git a/src/cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.cpp b/src/cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.cpp index 593a933f104..0694fc72d6e 100644 --- a/src/cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.cpp +++ b/src/cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.cpp @@ -114,10 +114,11 @@ status_t _jit_sve_512_core_x8s8s32x_deconv_fwd_kernel::init_conf( if (jcp.is_depthwise && (!jcp.signed_input || is_3d)) return status::unimplemented; - if (!zero_points_valid(&attr)) return status::unimplemented; + VDISPATCH_DECONVOLUTION_IC( + zero_points_valid(&attr), VERBOSE_UNSUPPORTED_ZP_CFG); jcp.src_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_SRC); jcp.dst_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_DST); - jcp.zp_src_is_common = attr.zero_points_.common(DNNL_ARG_SRC); + jcp.zp_src_is_common = attr.zero_points_.get_mask(DNNL_ARG_SRC) == 0; format_tag_t dat_tag = utils::pick( ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); @@ -274,12 +275,8 @@ status_t _jit_sve_512_core_x8s8s32x_deconv_fwd_kernel::init_conf( //save post_ops desc for further usage jcp.post_ops = p; - const auto &oscales = attr.output_scales_; - jcp.is_oc_scale = oscales.mask_ == 1 << 1; - - // only common and per-oc-channel scales are supported - const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1); - if (!oscales_ok) return status::unimplemented; + // TODO: add proper scaling support. + jcp.is_oc_scale = false; jcp.dst_dt = dst_d.data_type(); jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; @@ -1416,7 +1413,8 @@ status_t jit_sve_512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d( const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; const int nb_groups = jcp.nb_ch; - DEFINE_SCALES_BUFFER(oscales); + // TODO: add support for scaling based on latest programming model. + DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS); const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); int32_t *compensation = (!jcp.signed_input) @@ -1514,7 +1512,8 @@ status_t jit_sve_512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d( size_t dst_h_stride = dst_d.blk_off(0, 0, 1); size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - DEFINE_SCALES_BUFFER(oscales); + // TODO: add support for scaling based on latest programming model. + DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS); const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); int32_t *compensation = (!jcp.signed_input) @@ -1675,7 +1674,8 @@ status_t jit_sve_512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_3d( size_t wht_kd_stride = wht_blk_off(weights_d, 0, 0, 0, 1); size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); - DEFINE_SCALES_BUFFER(oscales); + // TODO: add support for scaling based on latest programming model. + DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS); size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); int32_t *compensation = (!jcp.signed_input) diff --git a/src/cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.hpp b/src/cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.hpp index 2429ac43df5..db1188ea620 100644 --- a/src/cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.hpp +++ b/src/cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.hpp @@ -286,9 +286,8 @@ struct jit_sve_512_core_x8s8s32x_deconvolution_fwd_t : public primitive_t { weights_md(1)->data_type, f32, s32, s8, u8)) && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8) && desc()->accum_data_type == s32 - && attr()->has_default_values(skip_mask_t::oscale_runtime - | skip_mask_t::post_ops - | skip_mask_t::zero_points_runtime); + && attr()->has_default_values( + skip_mask_t::post_ops | skip_mask_t::zero_points); if (!ok) return status::unimplemented; CHECK(_jit_sve_512_core_x8s8s32x_deconv_fwd_kernel::init_conf(jcp_, @@ -302,7 +301,7 @@ struct jit_sve_512_core_x8s8s32x_deconvolution_fwd_t : public primitive_t { return status::success; } - jit_conv_conf_t jcp_; + jit_conv_conf_t jcp_ = utils::zero(); }; jit_sve_512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd) diff --git a/src/cpu/aarch64/jit_sve_512_x8s8s32x_conv_kernel.cpp b/src/cpu/aarch64/jit_sve_512_x8s8s32x_conv_kernel.cpp index c8a2314ae06..bb2442f694e 100644 --- a/src/cpu/aarch64/jit_sve_512_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/aarch64/jit_sve_512_x8s8s32x_conv_kernel.cpp @@ -1426,12 +1426,8 @@ status_t jit_sve_512_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, pick_loop_order(jcp, jcp.nthr); - const auto &oscales = attr.output_scales_; - jcp.is_oc_scale = oscales.mask_ == 1 << 1; - - // only common and per-oc-channel scales are supported - const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1); - if (!oscales_ok) return status::unimplemented; + // TODO: enable quantization. + jcp.is_oc_scale = false; jcp.wei_adj_scale = (weights_d.extra().flags & memory_extra_flags::scale_adjust) diff --git a/src/cpu/aarch64/jit_sve_512_x8s8s32x_convolution.cpp b/src/cpu/aarch64/jit_sve_512_x8s8s32x_convolution.cpp index e1c0d5ee562..e7795140983 100644 --- a/src/cpu/aarch64/jit_sve_512_x8s8s32x_convolution.cpp +++ b/src/cpu/aarch64/jit_sve_512_x8s8s32x_convolution.cpp @@ -61,7 +61,8 @@ jit_sve_512_x8s8s32x_convolution_fwd_t::execute_forward_1d( assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); - DEFINE_SCALES_BUFFER(oscales); + // TODO: add support for scaling based on latest programming model. + DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS); size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); @@ -174,7 +175,8 @@ jit_sve_512_x8s8s32x_convolution_fwd_t::execute_forward_2d( assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); - DEFINE_SCALES_BUFFER(oscales); + // TODO: add support for scaling based on latest programming model. + DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS); size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); @@ -319,7 +321,8 @@ status_t jit_sve_512_x8s8s32x_convolution_fwd_t(weights); @@ -408,7 +411,8 @@ jit_sve_512_x8s8s32x_convolution_fwd_t::execute_forward_3d( assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); - DEFINE_SCALES_BUFFER(oscales); + // TODO: add support for scaling based on latest programming model. + DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS); size_t offset = weights_d.size() - weights_d.additional_buffer_size(); auto w = const_cast(weights); diff --git a/src/cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp b/src/cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp index e6c73302079..60b66651359 100644 --- a/src/cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp +++ b/src/cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp @@ -36,9 +36,7 @@ namespace aarch64 { template struct jit_sve_512_x8s8s32x_convolution_fwd_t : public primitive_t { struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_int8:", sve_512, ""), jit_sve_512_x8s8s32x_convolution_fwd_t); @@ -53,9 +51,7 @@ struct jit_sve_512_x8s8s32x_convolution_fwd_t : public primitive_t { utils::one_of(bias_md_.data_type, data_type::f32, data_type::s32, data_type::s8, data_type::u8)) - && attr()->has_default_values( - smask_t::oscale_runtime | smask_t::post_ops, - dst_type) + && attr()->has_default_values(smask_t::post_ops, dst_type) && !has_zero_dim_memory(); if (!ok) return status::unimplemented; @@ -71,15 +67,15 @@ struct jit_sve_512_x8s8s32x_convolution_fwd_t : public primitive_t { return status; } - jit_conv_conf_t jcp_; + jit_conv_conf_t jcp_ = utils::zero(); }; jit_sve_512_x8s8s32x_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits_t::type src_data_t; + typedef typename prec_traits_t::type wei_data_t; + typedef typename prec_traits_t::type dst_data_t; status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, diff --git a/src/cpu/aarch64/jit_sve_conv_kernel.cpp b/src/cpu/aarch64/jit_sve_conv_kernel.cpp index 3acec02fe29..26cb17300af 100644 --- a/src/cpu/aarch64/jit_sve_conv_kernel.cpp +++ b/src/cpu/aarch64/jit_sve_conv_kernel.cpp @@ -1250,8 +1250,7 @@ void jit_sve_conv_bwd_data_kernel_f32::store_output(int ur_w) { auto out_load = [=](int aux_output_offset, int idx, int prev_ofs) { int ofs = aux_output_offset; - if ((VL_OFS(ofs, isa) < LDRMAX) && (VL_OFS(ofs, isa) >= (-1 * LDRMAX)) - && ((ofs & 0x3f) == 0)) { + if (ldr_imm_check(ofs) && (ofs % 64 == 0)) { add_imm(X_DEFAULT_ADDR, reg_src, ofs, X_TMP_0); ld1w(zreg_tmp(idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); } else { @@ -1273,8 +1272,7 @@ void jit_sve_conv_bwd_data_kernel_f32::store_output(int ur_w) { auto out_str = [=](int j, int k, int aux_output_offset, int prev_ofs) { int ofs = aux_output_offset; - if ((VL_OFS(ofs, isa) < LDRMAX) && (VL_OFS(ofs, isa) >= (-1 * LDRMAX)) - && ((ofs & 0x3f) == 0)) { + if (ldr_imm_check(ofs) && (ofs % 64 == 0)) { add_imm(X_DEFAULT_ADDR, reg_src, ofs, X_TMP_0); st1w(zreg_out(j, k).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); @@ -1415,8 +1413,7 @@ void jit_sve_conv_bwd_data_kernel_f32::compute_loop_fma( auto ker_load = [=](int i, int aux_kernel_offset) { int ofs = aux_kernel_offset; - if ((VL_OFS(ofs, isa) < LDRMAX) && (VL_OFS(ofs, isa) >= (-1 * LDRMAX)) - && ((ofs & 0x3f) == 0)) { + if (ldr_imm_check(ofs) && (ofs % 64 == 0)) { add_imm(X_DEFAULT_ADDR, aux_reg_ker, ofs, X_TMP_0); ld1w(zreg_ker(i).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); @@ -4467,4 +4464,4 @@ template struct jit_sve_conv_bwd_weights_kernel_f32; } // namespace aarch64 } // namespace cpu } // namespace impl -} // namespace dnnl \ No newline at end of file +} // namespace dnnl diff --git a/src/cpu/aarch64/jit_sve_convolution.hpp b/src/cpu/aarch64/jit_sve_convolution.hpp index 16b397406d7..628e94c87ed 100644 --- a/src/cpu/aarch64/jit_sve_convolution.hpp +++ b/src/cpu/aarch64/jit_sve_convolution.hpp @@ -39,9 +39,7 @@ template struct jit_sve_convolution_fwd_t : public primitive_t { struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", isa, ""), jit_sve_convolution_fwd_t); @@ -67,14 +65,14 @@ struct jit_sve_convolution_fwd_t : public primitive_t { return status; } - jit_conv_conf_t jcp_; + jit_conv_conf_t jcp_ = utils::zero(); }; jit_sve_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits_t::type src_data_t; + typedef typename prec_traits_t::type wei_data_t; + typedef typename prec_traits_t::type dst_data_t; status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, @@ -114,9 +112,7 @@ template struct jit_sve_convolution_bwd_data_t : public primitive_t { struct pd_t : public cpu_convolution_bwd_data_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", isa, ""), jit_sve_convolution_bwd_data_t); @@ -141,14 +137,14 @@ struct jit_sve_convolution_bwd_data_t : public primitive_t { return status::success; } - jit_conv_conf_t jcp_; + jit_conv_conf_t jcp_ = utils::zero(); }; jit_sve_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits_t::type diff_dst_data_t; + typedef typename prec_traits_t::type wei_data_t; + typedef typename prec_traits_t::type diff_src_data_t; status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, @@ -183,10 +179,8 @@ template struct jit_sve_convolution_bwd_weights_t : public primitive_t { struct pd_t : public cpu_convolution_bwd_weights_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) - , jcp_() {} + using cpu_convolution_bwd_weights_pd_t:: + cpu_convolution_bwd_weights_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", isa, ""), jit_sve_convolution_bwd_weights_t); @@ -218,7 +212,7 @@ struct jit_sve_convolution_bwd_weights_t : public primitive_t { return status; } - jit_conv_conf_t jcp_; + jit_conv_conf_t jcp_ = utils::zero(); typename cpu_reducer_t::conf_t reducer_bia_conf_; @@ -235,9 +229,9 @@ struct jit_sve_convolution_bwd_weights_t : public primitive_t { jit_sve_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type diff_weights_data_t; + typedef typename prec_traits_t::type src_data_t; + typedef typename prec_traits_t::type diff_dst_data_t; + typedef typename prec_traits_t::type diff_weights_data_t; status_t init(engine_t *engine) override; @@ -274,4 +268,4 @@ struct jit_sve_convolution_bwd_weights_t : public primitive_t { #endif -// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s \ No newline at end of file +// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/aarch64/jit_uni_1x1_conv_utils.hpp b/src/cpu/aarch64/jit_uni_1x1_conv_utils.hpp index 521a8c54dc7..8b52047dc14 100644 --- a/src/cpu/aarch64/jit_uni_1x1_conv_utils.hpp +++ b/src/cpu/aarch64/jit_uni_1x1_conv_utils.hpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation -* Copyright 2021-2023 FUJITSU LIMITED +* Copyright 2021-2024 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ struct reduce_to_unit_stride_t { /* 1x1-kernel does not support non-unit strides so far, so the idea is: * - for fwd or bwd_weights: to copy src to a scratch memory (with strides - * equal to 1) and then call the kernel + * equal to 1) and then call the kernel * - for bwd_data: reduce the problem to the one with unit stride by * performing computations in a scratch memory (with strides equal to 1) * and then copy the result to diff_src */ @@ -50,7 +50,6 @@ template inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d, const memory_desc_t *&src_d, const memory_desc_t *dst_d) { const int ndims = src_d->ndims; - bool rtus_applicable = utils::one_of(ndims, 3, 4); if (ndims == 3) rtus_applicable = rtus_applicable && conv_d->strides[0] != 1 @@ -182,13 +181,15 @@ struct rtus_driver_t : public jit_generator { ZReg res = ZReg(idx); if (is_nspc_) { switch (isa) { - case sve_512: res = ZReg(idx); break; + case sve_512: + case sve_256: res = ZReg(idx); break; default: assert(!"Not supported isa"); res = ZReg(idx); } return res; } switch (isa) { case sve_512: + case sve_256: switch (typesize) { case 4: res = ZReg(idx); break; default: @@ -202,7 +203,7 @@ struct rtus_driver_t : public jit_generator { reg_zero = Vmm(0, typesize); reg_v = Vmm(1, typesize); - vlen_ = reg_v.getBit() / 8; + vlen_ = cpu_isa_traits::vlen; vlen_shift_ = 0; int tvlen = is_nspc_ ? typesize_ : vlen_; @@ -217,7 +218,6 @@ struct rtus_driver_t : public jit_generator { void loop_is() { using namespace Xbyak_aarch64; - mov(reg_cur_src, reg_src); mov(reg_cur_iw, reg_iw_start); mov(reg_cur_os, reg_os); @@ -285,7 +285,7 @@ struct rtus_driver_t : public jit_generator { mov(reg_cur_src, reg_src); mov(reg_cur_iw, reg_iw_start); - if (isa == sve_512) { + if (isa == sve_256 || isa == sve_512) { and_(reg_icb_remainder, reg_icb, (vlen_ / typesize_) - 1); mov_imm(X_TMP_0, 0); whilelt(tail_mask.s, X_TMP_0, reg_icb_remainder); @@ -356,8 +356,9 @@ struct rtus_driver_t : public jit_generator { const size_t w_step_factor = ic_ * typesize_; const size_t max_load_store_bytes = typesize_ == 4 ? 32 : 16; - const size_t load_store_size - = isa == sve_512 ? vlen_ : max_load_store_bytes; + const size_t load_store_size = (isa == sve_256 || isa == sve_512) + ? vlen_ + : max_load_store_bytes; Label is_loop, ic_loop, ic_loop_tail, ic_loop_finish; L(is_loop); @@ -467,7 +468,7 @@ struct rtus_driver_t : public jit_generator { void generate() override { using namespace Xbyak_aarch64; - assert(isa == sve_512); + assert(isa == sve_256 || isa == sve_512); preamble(); #define READ_PARAM(what) \ diff --git a/src/cpu/aarch64/jit_uni_batch_normalization.hpp b/src/cpu/aarch64/jit_uni_batch_normalization.hpp index 3311f5b665a..7197ce2d815 100644 --- a/src/cpu/aarch64/jit_uni_batch_normalization.hpp +++ b/src/cpu/aarch64/jit_uni_batch_normalization.hpp @@ -42,10 +42,8 @@ struct driver_t; template struct jit_uni_batch_normalization_fwd_t : public primitive_t { struct pd_t : public cpu_batch_normalization_fwd_pd_t { - pd_t(const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_batch_normalization_fwd_pd_t:: + cpu_batch_normalization_fwd_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("bnorm_jit:", isa, ""), jit_uni_batch_normalization_fwd_t); @@ -70,10 +68,8 @@ struct jit_uni_batch_normalization_fwd_t : public primitive_t { template struct jit_uni_batch_normalization_bwd_t : public primitive_t { struct pd_t : public cpu_batch_normalization_bwd_pd_t { - pd_t(const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_batch_normalization_bwd_pd_t:: + cpu_batch_normalization_bwd_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("bnorm_jit:", isa, ""), jit_uni_batch_normalization_bwd_t); diff --git a/src/cpu/aarch64/jit_uni_batch_normalization_s8.hpp b/src/cpu/aarch64/jit_uni_batch_normalization_s8.hpp index 3b96d7ed5ed..3950ffcccfd 100644 --- a/src/cpu/aarch64/jit_uni_batch_normalization_s8.hpp +++ b/src/cpu/aarch64/jit_uni_batch_normalization_s8.hpp @@ -41,10 +41,8 @@ struct driver_t; template struct jit_uni_batch_normalization_s8_fwd_t : public primitive_t { struct pd_t : public cpu_batch_normalization_fwd_pd_t { - pd_t(const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_batch_normalization_fwd_pd_t:: + cpu_batch_normalization_fwd_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("bnorm_s8_jit:", isa, ""), jit_uni_batch_normalization_s8_fwd_t); diff --git a/src/cpu/aarch64/jit_uni_binary.cpp b/src/cpu/aarch64/jit_uni_binary.cpp index 2dd83c3b1a6..bfaf2b64b45 100644 --- a/src/cpu/aarch64/jit_uni_binary.cpp +++ b/src/cpu/aarch64/jit_uni_binary.cpp @@ -118,7 +118,7 @@ status_t jit_uni_binary_t::pd_t::init(engine_t *engine) { && data_format_supported(src0_md_, conf_.isa) && set_default_params() == status::success && !has_zero_dim_memory() && IMPLICATION(!conf_.is_i8, src0_md_ == dst_md_) && is_applicable() - && attr()->has_default_values(sm::post_ops | sm::scales_runtime) + && attr()->has_default_values(sm::post_ops | sm::scales) && attr_.set_default_formats(dst_md(0)) == status::success; if (!ok) return status::unimplemented; @@ -140,10 +140,8 @@ status_t jit_uni_binary_t::pd_t::init(engine_t *engine) { po, src0_md_, get_supported_postops_bcast_strategies()); conf_.op_type = get_op_type(src0_md_); assert(conf_.op_type != op_t::none); - conf_.do_scale_src0 = !attr()->scales_.get(DNNL_ARG_SRC_0).defined() - || !attr()->scales_.get(DNNL_ARG_SRC_0).has_default_values(); - conf_.do_scale_src1 = !attr()->scales_.get(DNNL_ARG_SRC_1).defined() - || !attr()->scales_.get(DNNL_ARG_SRC_1).has_default_values(); + conf_.do_scale_src0 = !attr()->scales_.has_default_values(DNNL_ARG_SRC_0); + conf_.do_scale_src1 = !attr()->scales_.has_default_values(DNNL_ARG_SRC_1); const auto sum_idx = po.find(primitive_kind::sum); conf_.do_sum = sum_idx != -1 && po.entry_[sum_idx].sum.scale != 0.f; conf_.with_eltwise = po.find(primitive_kind::eltwise) != -1; diff --git a/src/cpu/aarch64/jit_uni_dw_conv_kernel_f32.cpp b/src/cpu/aarch64/jit_uni_dw_conv_kernel_f32.cpp index 557d19166d4..f146c82c728 100644 --- a/src/cpu/aarch64/jit_uni_dw_conv_kernel_f32.cpp +++ b/src/cpu/aarch64/jit_uni_dw_conv_kernel_f32.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2021-2022 Intel Corporation * Copyright 2021-2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1099,9 +1100,14 @@ void jit_uni_dw_conv_bwd_weights_kernel_f32::generate() { const int simd_w_ = cpu_isa_traits::vlen / sizeof(float); preamble(); //TO DO : renaming predicate register (P_ALL_ONE) - if (simd_w_ != cpu_sveLen / sizeof(float)) + if (simd_w_ != cpu_sveLen / sizeof(float)) { set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1); - if (simd_w_ != 16 || simd_w_ != 8) assert(!"Unsupport: simd_w != 16, 8"); + } + + if (simd_w_ != 16 && simd_w_ != 8) { + assert(!"Unsupported: simd_w != 16, 8"); + } + ldr(reg_input_baddr, ptr(abi_param1, static_cast(offsetof(jit_dw_conv_call_s, input)))); diff --git a/src/cpu/aarch64/jit_uni_dw_convolution.hpp b/src/cpu/aarch64/jit_uni_dw_convolution.hpp index 74adbc4c7b0..de3739681e5 100644 --- a/src/cpu/aarch64/jit_uni_dw_convolution.hpp +++ b/src/cpu/aarch64/jit_uni_dw_convolution.hpp @@ -36,7 +36,9 @@ namespace aarch64 { template struct jit_uni_dw_convolution_fwd_t : public primitive_t { struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, + // Note: check `USING_INHERITED_IS_IMPOSSIBLE` comment in other files + // for details why this ctor can't be removed. + pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const typename pd_t::base_class *hint_fwd_pd) : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} @@ -69,15 +71,15 @@ struct jit_uni_dw_convolution_fwd_t : public primitive_t { return status::success; } - jit_conv_conf_t jcp_; + jit_conv_conf_t jcp_ = utils::zero(); }; jit_uni_dw_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type f32_data_t; - typedef typename prec_traits::type bf16_data_t; - typedef typename prec_traits::type data_t; - typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits_t::type f32_data_t; + typedef typename prec_traits_t::type bf16_data_t; + typedef typename prec_traits_t::type data_t; + typedef typename prec_traits_t::type dst_data_t; status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, @@ -105,9 +107,7 @@ template struct jit_uni_dw_convolution_bwd_data_t : public primitive_t { struct pd_t : public cpu_convolution_bwd_data_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""), jit_uni_dw_convolution_bwd_data_t); @@ -134,7 +134,7 @@ struct jit_uni_dw_convolution_bwd_data_t : public primitive_t { return status::success; } - jit_conv_conf_t jcp_; + jit_conv_conf_t jcp_ = utils::zero(); protected: bool set_default_formats() { @@ -158,9 +158,9 @@ struct jit_uni_dw_convolution_bwd_data_t : public primitive_t { jit_uni_dw_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type diff_src_data_t; - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits_t::type diff_src_data_t; + typedef typename prec_traits_t::type diff_dst_data_t; + typedef typename prec_traits_t::type wei_data_t; status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, @@ -191,10 +191,9 @@ template struct jit_uni_dw_convolution_bwd_weights_t : public primitive_t { struct pd_t : public cpu_convolution_bwd_weights_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) - , jcp_() {} + using cpu_convolution_bwd_weights_pd_t:: + cpu_convolution_bwd_weights_pd_t; + using jit_uni_dw_convolution_bwd_weights = jit_uni_dw_convolution_bwd_weights_t; @@ -229,7 +228,7 @@ struct jit_uni_dw_convolution_bwd_weights_t : public primitive_t { return status::success; } - jit_conv_conf_t jcp_; + jit_conv_conf_t jcp_ = utils::zero(); protected: bool set_default_formats() { @@ -253,11 +252,11 @@ struct jit_uni_dw_convolution_bwd_weights_t : public primitive_t { jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd); - typedef typename prec_traits::type bf16_data_t; - typedef typename prec_traits::type f32_data_t; - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type diff_weights_data_t; + typedef typename prec_traits_t::type bf16_data_t; + typedef typename prec_traits_t::type f32_data_t; + typedef typename prec_traits_t::type src_data_t; + typedef typename prec_traits_t::type diff_dst_data_t; + typedef typename prec_traits_t::type diff_weights_data_t; status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, diff --git a/src/cpu/aarch64/jit_uni_eltwise.hpp b/src/cpu/aarch64/jit_uni_eltwise.hpp index 97ee5ed4c11..9665a7e9c12 100644 --- a/src/cpu/aarch64/jit_uni_eltwise.hpp +++ b/src/cpu/aarch64/jit_uni_eltwise.hpp @@ -50,7 +50,7 @@ struct jit_uni_eltwise_fwd_t : public primitive_t { jit_uni_eltwise_fwd_t(const pd_t *apd); virtual ~jit_uni_eltwise_fwd_t(); - typedef typename prec_traits::type data_t; + typedef typename prec_traits_t::type data_t; status_t init(engine_t *engine) override; @@ -75,7 +75,7 @@ struct jit_uni_eltwise_bwd_t : public primitive_t { jit_uni_eltwise_bwd_t(const pd_t *apd); virtual ~jit_uni_eltwise_bwd_t(); - typedef typename prec_traits::type data_t; + typedef typename prec_traits_t::type data_t; status_t init(engine_t *engine) override; diff --git a/src/cpu/aarch64/jit_uni_eltwise_int.hpp b/src/cpu/aarch64/jit_uni_eltwise_int.hpp index 7f646a2275a..bb487ff0393 100644 --- a/src/cpu/aarch64/jit_uni_eltwise_int.hpp +++ b/src/cpu/aarch64/jit_uni_eltwise_int.hpp @@ -50,7 +50,7 @@ struct jit_uni_eltwise_int_fwd_t : public primitive_t { jit_uni_eltwise_int_fwd_t(const pd_t *apd); ~jit_uni_eltwise_int_fwd_t(); - typedef typename prec_traits::type data_t; + typedef typename prec_traits_t::type data_t; status_t init(engine_t *engine) override; diff --git a/src/cpu/aarch64/jit_uni_i8i8_pooling.cpp b/src/cpu/aarch64/jit_uni_i8i8_pooling.cpp index 6e14591b282..dfe28ec0475 100644 --- a/src/cpu/aarch64/jit_uni_i8i8_pooling.cpp +++ b/src/cpu/aarch64/jit_uni_i8i8_pooling.cpp @@ -128,8 +128,8 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { // thus we need to take into account ratio of sizes s32/i8 = 4 static constexpr data_type_t avg_proc_dt = data_type::s32; enum : int { - s32_to_i8_ratio = sizeof(typename prec_traits::type) - / sizeof(typename prec_traits::type), + s32_to_i8_ratio = sizeof(typename prec_traits_t::type) + / sizeof(typename prec_traits_t::type), max_num_ll = s32_to_i8_ratio, mmx_msk_base_reg = 3 }; diff --git a/src/cpu/aarch64/jit_uni_pooling.cpp b/src/cpu/aarch64/jit_uni_pooling.cpp index 46d26aeb977..e60f8fc4472 100644 --- a/src/cpu/aarch64/jit_uni_pooling.cpp +++ b/src/cpu/aarch64/jit_uni_pooling.cpp @@ -560,7 +560,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(jpp.post_ops, ctx); - using wsp_data_t = typename prec_traits::type; + using wsp_data_t = typename prec_traits_t::type; using namespace jit_uni_pooling_utils; const auto transpose_facade @@ -688,7 +688,7 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, const auto post_ops_binary_rhs_arg_vec = binary_injector::prepare_binary_args(jpp.post_ops, ctx); - using wsp_data_t = typename prec_traits::type; + using wsp_data_t = typename prec_traits_t::type; using namespace jit_uni_pooling_utils; static constexpr int first_ithr = 0; @@ -893,7 +893,7 @@ void jit_uni_pooling_bwd_t::execute_backward( const exec_ctx_t &ctx) const { using namespace jit_uni_pooling_utils; - using wsp_data_t = typename prec_traits::type; + using wsp_data_t = typename prec_traits_t::type; const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); @@ -1018,7 +1018,7 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( const auto &jpp = pd()->jpp_; - using wsp_data_t = typename prec_traits::type; + using wsp_data_t = typename prec_traits_t::type; using namespace jit_uni_pooling_utils; static constexpr int first_ithr = 0; diff --git a/src/cpu/aarch64/jit_uni_pooling.hpp b/src/cpu/aarch64/jit_uni_pooling.hpp index 8f6448c4bd7..ac854c75fce 100644 --- a/src/cpu/aarch64/jit_uni_pooling.hpp +++ b/src/cpu/aarch64/jit_uni_pooling.hpp @@ -82,7 +82,7 @@ struct jit_uni_pooling_fwd_t : public primitive_t { jit_uni_pooling_fwd_t &operator=(jit_uni_pooling_fwd_t &&) = default; ~jit_uni_pooling_fwd_t(); - using data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; status_t init(engine_t *engine) override; @@ -151,7 +151,7 @@ struct jit_uni_pooling_bwd_t : public primitive_t { jit_uni_pooling_bwd_t &operator=(jit_uni_pooling_bwd_t &&) = default; ~jit_uni_pooling_bwd_t(); - using data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; status_t init(engine_t *engine) override; diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp index 6d08a7f55a6..f2aa1f42d2c 100644 --- a/src/cpu/aarch64/jit_uni_reorder.cpp +++ b/src/cpu/aarch64/jit_uni_reorder.cpp @@ -1,7 +1,7 @@ /******************************************************************************* * Copyright 2018-2023 Intel Corporation * Copyright 2020-2024 FUJITSU LIMITED -* Copyright 2022-2024 Arm Ltd. and affiliates +* Copyright 2022-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -161,14 +161,28 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { static bool applicable(const prb_t &p) { using namespace data_type; + bool bf16_ok + = (mayiuse_bf16() && (p.itype == bf16) && (p.otype == bf16) + && !interim_f32_needed(p, false) && p.beta == 0.f) + || (p.itype != bf16 && p.otype != bf16) + || (p.itype == f32 && p.otype == bf16 && mayiuse_bf16() + && p.beta == 0.f) + || (p.itype == bf16 && p.otype == f32 && mayiuse_bf16() + && p.beta == 0.f); + + bool is_f16 = (p.itype == f16 || p.otype == f16); + bool f16_ok = (p.itype == f32 && p.otype == f16 && p.beta == 0.f) + || (p.itype == f16 && p.otype == f32 && p.beta == 0.f); + bool ok = true && p.ndims > 0 - && utils::one_of(p.itype, f32, s32, data_type::s8, u8) - && utils::one_of(p.otype, f32, bf16, s32, data_type::s8, u8) + && utils::one_of( + p.itype, f32, f16, bf16, s32, data_type::s8, u8) + && utils::one_of( + p.otype, f32, f16, bf16, s32, data_type::s8, u8) && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p) - && IMPLICATION( - p.otype == bf16, p.itype == f32 && mayiuse_bf16()); + && bf16_ok && IMPLICATION(is_f16, f16_ok); return ok; } @@ -272,7 +286,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { case f32: /* do nothing */ break; + case f16: cvt_v_f16_f32(startIdx, regNum); break; case s32: cvt_z_s32_f32(startIdx, regNum); break; + case bf16: cvt_v_bf16_fp32(startIdx, regNum); break; case data_type::s8: cvt_z_s8_s32(startIdx, regNum); cvt_z_s32_f32(startIdx, regNum); @@ -302,6 +318,12 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { cvt_z_s32_s8(startIdx, regNum); if (idt == u8) cvt_z_u8_s8(startIdx, regNum); break; + case data_type::bf16: + if (idt == f32) cvt_v_f32_bf16(startIdx, regNum); + break; + case data_type::f16: + if (idt == f32) cvt_v_f32_f16(startIdx, regNum); + break; case u8: if (idt == f32) cvt_z_f32_s32(startIdx, regNum); if (utils::one_of(idt, f32, s32)) @@ -614,6 +636,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { /* do nothing */ break; case s32: cvt_v_s32_f32(startIdx, regNum); break; + case bf16: cvt_v_bf16_fp32(startIdx, regNum); break; + case f16: cvt_v_f16_f32(startIdx, regNum); break; case data_type::s8: cvt_v_s8_s32(startIdx, regNum); cvt_v_s32_f32(startIdx, regNum); @@ -629,6 +653,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { auto cvt2odt = [=](const int startIdx, const int regNum, data_type_t odt, data_type_t idt) { switch (odt) { + case f32: + if (idt == bf16) cvt_v_bf16_fp32(startIdx, regNum); + if (idt == f16) cvt_v_f16_f32(startIdx, regNum); + break; case s32: if (idt == f32) cvt_v_f32_s32(startIdx, regNum); @@ -652,6 +680,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { case bf16: if (idt == f32) cvt_v_f32_bf16(startIdx, regNum); break; + case f16: + if (idt == f32) cvt_v_f32_f16(startIdx, regNum); + break; default: assert(!"unreachable"); } }; @@ -702,7 +733,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { const int load_tail_step = !can_load_xmm && can_store_xmm ? ur_step : load_step; - const bool interim_f32 = interim_f32_needed(); + const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_); const bool need_saturation = (utils::one_of(prb_.otype, u8, data_type::s8, s32) @@ -775,7 +806,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { // transposition on the fly const bool fast_return = prb_.src_scale_type != scale_type_t::MANY && prb_.dst_scale_type != scale_type_t::MANY - && prb_.beta == 0.f; + && prb_.beta == 0.f && !prb_.req_src_zp && !prb_.req_dst_zp; if (fast_return) { if (prb_.src_scale_type == scale_type_t::COMMON) for (int ur = 0; ur < reg_unroll; ur += load_step) @@ -1285,17 +1316,17 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { } } - bool interim_f32_needed() { + static bool interim_f32_needed(const prb_t &prb, bool compensation_needed) { using namespace data_type; - - return utils::one_of(f32, prb_.itype, prb_.otype) - || prb_.src_scale_type != scale_type_t::NONE - || prb_.dst_scale_type != scale_type_t::NONE || prb_.beta != 0.f - || ((prb_.req_src_zp || prb_.req_dst_zp) - ? !(prb_.itype == s32 && prb_.otype == s32) + bool ret = utils::one_of(f32, prb.itype, prb.otype) + || prb.src_scale_type != scale_type_t::NONE + || prb.dst_scale_type != scale_type_t::NONE || prb.beta != 0.f + || ((prb.req_src_zp || prb.req_dst_zp) + ? !(prb.itype == s32 && prb.otype == s32) : false) - || (prb_.itype != f32 && compensation_needed_) - || prb_.scale_adjust != 1.f; + || (prb.itype != f32 && compensation_needed) + || prb.scale_adjust != 1.f; + return ret; } void process_unroll_generic( @@ -1313,7 +1344,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { int curr = 0; // will switch between 0 and 1 - const bool interim_f32 = interim_f32_needed(); + const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_); if (prb_.req_src_zp) { add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0); @@ -1685,6 +1716,18 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { UNROLL_INST2(bfcvtn, VReg4H(i), VReg4S(i)); } + void cvt_v_bf16_fp32(const size_t startIdx, const size_t regNum) { + UNROLL_INST2(shll, VReg4S(i), VReg4H(i), 16); + } + + void cvt_v_f16_f32(const size_t startIdx, const size_t regNum) { + UNROLL_INST2(fcvtl, VReg4S(i), VReg4H(i)); + } + + void cvt_v_f32_f16(const size_t startIdx, const size_t regNum) { + UNROLL_INST2(fcvtn, VReg4H(i), VReg4S(i)); + } + void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) { cvt_z_b_s(startIdx, regNum); UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); @@ -2730,9 +2773,10 @@ static void prb_thread_kernel_balance( if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { DEBUG({ - printf("split: "); - prb_dump(prb); - printf("ndims_ker_max = %d\n", ndims_ker_max); + verbose_printf( + verbose_t::debuginfo, "split: %s\n", prb_dump(prb).c_str()); + verbose_printf(verbose_t::debuginfo, "ndims_ker_max = %d\n", + ndims_ker_max); }); } } @@ -2767,13 +2811,10 @@ status_t jit_uni_reorder_t::pd_t::init_scratchpad() { compensation_reduce_size); } - const memory_desc_wrapper input_d(src_md()); - int scales_mask = -1; - bool is_set = false; - CHECK(attr()->scales_.get(DNNL_ARG_DST, &scales_mask, &is_set)); - - if (is_set && scales_mask > 0) { - get_D_values(input_d, scales_mask, nullptr, &D_mask_, nullptr); + if (!attr()->scales_.has_default_values(DNNL_ARG_DST)) { + const memory_desc_wrapper input_d(src_md()); + int mask = attr()->scales_.get_mask(DNNL_ARG_DST); + get_D_values(input_d, mask, nullptr, &D_mask_, nullptr); if (D_mask_ > 1) { scratchpad.template book( memory_tracking::names::key_reorder_precomputed_dst_scales, @@ -2797,8 +2838,8 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, prb_block_for_cache(prb); DEBUG({ - printf("cache: "); - prb_dump(prb); + verbose_printf( + verbose_t::debuginfo, "cache: %s\n", prb_dump(prb).c_str()); }); int ndims_ker_max {}; @@ -2817,8 +2858,8 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, return status::unimplemented; DEBUG({ - printf("ker : "); - prb_dump(ker_desc.prb); + verbose_printf(verbose_t::debuginfo, "ker : %s\n", + prb_dump(ker_desc.prb).c_str()); }); auto _pd = make_unique_pd( @@ -3027,12 +3068,12 @@ void jit_uni_reorder_t::omp_driver(const char *in, char *out, out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); DEBUG({ - printf("prb : "); - tr::prb_dump(pd()->prb_); + verbose_printf(verbose_t::debuginfo, "prb : %s\n", + tr::prb_dump(pd()->prb_).c_str()); }); DEBUG({ - printf("ker : "); - tr::prb_dump(pd()->ker_desc_.prb); + verbose_printf(verbose_t::debuginfo, "ker : %s\n", + tr::prb_dump(pd()->ker_desc_.prb).c_str()); }); int ndims = pd()->prb_.ndims; @@ -3236,8 +3277,8 @@ status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, prb_tile_normalize(prb); DEBUG({ - printf("tile : "); - prb_dump(prb); + verbose_printf( + verbose_t::debuginfo, "tile : %s\n", prb_dump(prb).c_str()); }); if (!tr::jit_single_blk_kernel_t::applicable(prb)) { diff --git a/src/cpu/aarch64/jit_uni_reorder.hpp b/src/cpu/aarch64/jit_uni_reorder.hpp index 4587fd82e21..83ac55ed855 100644 --- a/src/cpu/aarch64/jit_uni_reorder.hpp +++ b/src/cpu/aarch64/jit_uni_reorder.hpp @@ -149,8 +149,8 @@ void prb_node_swap(prb_t &p, int d0, int d1); * to the right if d0 > d1 */ void prb_node_move(prb_t &p, int d0, int d1); -/** dumps the problem to stdout */ -void prb_dump(const prb_t &p); +/** dumps the problem to a string */ +std::string prb_dump(const prb_t &p); struct call_param_t { const void *in = nullptr; diff --git a/src/cpu/aarch64/jit_uni_reorder_utils.cpp b/src/cpu/aarch64/jit_uni_reorder_utils.cpp index 5000f904f0d..90e78f3877b 100644 --- a/src/cpu/aarch64/jit_uni_reorder_utils.cpp +++ b/src/cpu/aarch64/jit_uni_reorder_utils.cpp @@ -205,9 +205,8 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc() && !im_d.has_runtime_dims_or_strides() && !im_d.has_zero_dim() && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim() - && attr->has_default_values( - primitive_attr_t::skip_mask_t::scales_runtime - | primitive_attr_t::skip_mask_t::zero_points_runtime + && attr->has_default_values(primitive_attr_t::skip_mask_t::scales + | primitive_attr_t::skip_mask_t::zero_points | primitive_attr_t::skip_mask_t::post_ops) && check_post_ops(attr); if (!ok) return unimplemented; @@ -276,24 +275,21 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, p.src_scale_type = scale_type_t::NONE; int src_mask = 0; - bool is_src_set = false; - CHECK(attr->scales_.get(DNNL_ARG_SRC, &src_mask, &is_src_set)); - if (is_src_set) { + if (!attr->scales_.has_default_values(DNNL_ARG_SRC)) { + src_mask = attr->scales_.get_mask(DNNL_ARG_SRC); p.src_scale_type = src_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; } p.dst_scale_type = scale_type_t::NONE; int dst_mask = 0; - bool is_dst_set = false; - CHECK(attr->scales_.get(DNNL_ARG_DST, &dst_mask, &is_dst_set)); - if (is_dst_set) { + if (!attr->scales_.has_default_values(DNNL_ARG_DST)) { + dst_mask = attr->scales_.get_mask(DNNL_ARG_DST); p.dst_scale_type = dst_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; } - if (is_src_set && is_dst_set && src_mask != dst_mask) - return status::unimplemented; + if (src_mask != dst_mask) return status::unimplemented; p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust) ? om_d.extra().scale_adjust @@ -431,14 +427,14 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; DEBUG({ - printf("init : "); - prb_dump(p); + verbose_printf( + verbose_t::debuginfo, "init : %s\n", prb_dump(p).c_str()); }); // Sort the prb array in increasing sizes of the output stride prb_normalize(p); DEBUG({ - printf("norm : "); - prb_dump(p); + verbose_printf( + verbose_t::debuginfo, "norm : %s\n", prb_dump(p).c_str()); }); // compensation strides require prb_normalized @@ -448,8 +444,8 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, * sides of the reorder */ prb_simplify(p); DEBUG({ - printf("smpl : "); - prb_dump(p); + verbose_printf( + verbose_t::debuginfo, "smpl : %s\n", prb_dump(p).c_str()); }); return success; @@ -605,16 +601,20 @@ void prb_node_move(prb_t &p, int d0, int d1) { p.nodes[d1] = node; } -void prb_dump(const prb_t &p) { - printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype), - dnnl_dt2str(p.otype), p.ndims); - for (int d = 0; d < p.ndims; ++d) - printf("[%zu:%zu:%d:%d:%s:%td:%td:%td:%td]", p.nodes[d].n, - p.nodes[d].tail_size, p.nodes[d].dim_id, - p.nodes[d].parent_node_id, - p.nodes[d].is_zero_pad_needed ? "true" : "false", p.nodes[d].is, - p.nodes[d].os, p.nodes[d].ss, p.nodes[d].cs); - printf(" off:%zu:%zu\n", p.ioff, p.ooff); +std::string prb_dump(const prb_t &p) { + std::stringstream ss; + ss << "@@@ type:" << dnnl_dt2str(p.itype) << ':' << dnnl_dt2str(p.otype) + << " ndims:" << p.ndims; + for (int d = 0; d < p.ndims; ++d) { + if (d != 0) ss << 'x'; + const auto &node = p.nodes[d]; + ss << '[' << node.n << ':' << node.tail_size << ':' << node.dim_id + << ':' << node.parent_node_id << ':' + << (node.is_zero_pad_needed ? "true" : "false") << ':' << node.is + << ':' << node.os << ':' << node.ss << ':' << node.cs << ']'; + } + ss << " off:" << p.ioff << ':' << p.ooff; + return ss.str(); } } // namespace tr diff --git a/src/cpu/aarch64/jit_uni_softmax.cpp b/src/cpu/aarch64/jit_uni_softmax.cpp index ecd91200f3f..1450f7788f9 100644 --- a/src/cpu/aarch64/jit_uni_softmax.cpp +++ b/src/cpu/aarch64/jit_uni_softmax.cpp @@ -16,6 +16,7 @@ *******************************************************************************/ #include +#include #include "common/c_types_map.hpp" #include "common/dnnl_thread.hpp" @@ -668,12 +669,10 @@ struct jit_softmax_t : public jit_softmax_base_t { template jit_uni_softmax_fwd_t::jit_uni_softmax_fwd_t(const pd_t *apd) : primitive_t(apd) - , softmax_driver_(new softmax_impl::driver_t(pd())) {} + , softmax_driver_(utils::make_unique>(pd())) {} template -jit_uni_softmax_fwd_t::~jit_uni_softmax_fwd_t() { - delete softmax_driver_; -} +jit_uni_softmax_fwd_t::~jit_uni_softmax_fwd_t() = default; template status_t jit_uni_softmax_fwd_t::init(engine_t *engine) { @@ -725,12 +724,10 @@ status_t jit_uni_softmax_fwd_t::execute(const exec_ctx_t &ctx) const { template jit_uni_softmax_bwd_t::jit_uni_softmax_bwd_t(const pd_t *apd) : primitive_t(apd) - , softmax_driver_(new softmax_impl::driver_t(pd())) {} + , softmax_driver_(utils::make_unique>(pd())) {} template -jit_uni_softmax_bwd_t::~jit_uni_softmax_bwd_t() { - delete softmax_driver_; -} +jit_uni_softmax_bwd_t::~jit_uni_softmax_bwd_t() = default; template status_t jit_uni_softmax_bwd_t::init(engine_t *engine) { diff --git a/src/cpu/aarch64/jit_uni_softmax.hpp b/src/cpu/aarch64/jit_uni_softmax.hpp index b8933442b1e..090d4300b56 100644 --- a/src/cpu/aarch64/jit_uni_softmax.hpp +++ b/src/cpu/aarch64/jit_uni_softmax.hpp @@ -19,6 +19,7 @@ #define CPU_AARCH64_JIT_UNI_SOFTMAX_HPP #include +#include #include "common/c_types_map.hpp" #include "common/memory_tracking.hpp" @@ -80,7 +81,7 @@ struct jit_uni_softmax_fwd_t : public primitive_t { utils::one_of(bf16, src_dt, dst_dt), mayiuse_bf16()) && (mayiuse(sve_512) || mayiuse(sve_256) || mayiuse(sve_128)) - && attr()->has_default_values(skip_mask_t::scales_runtime) + && attr()->has_default_values(skip_mask_t::scales) && attr_scales_ok() && set_default_formats() == status::success; if (!ok) return status::unimplemented; @@ -119,7 +120,9 @@ struct jit_uni_softmax_fwd_t : public primitive_t { private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - softmax_impl::driver_t *softmax_driver_; + std::unique_ptr> softmax_driver_; + + DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_softmax_fwd_t); }; template @@ -191,7 +194,9 @@ struct jit_uni_softmax_bwd_t : public primitive_t { private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - softmax_impl::driver_t *softmax_driver_; + std::unique_ptr> softmax_driver_; + + DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_softmax_bwd_t); }; } // namespace aarch64 diff --git a/src/cpu/aarch64/matmul/brgemm_matmul.cpp b/src/cpu/aarch64/matmul/brgemm_matmul.cpp index bebdae12041..1f7bd0088d6 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -71,9 +72,9 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { const std::vector supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; bool ok = attr_scales_ok(supported_args); - if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() - && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + if (!attr()->scales_.has_default_values(DNNL_ARG_SRC) + && !attr()->scales_.has_default_values(DNNL_ARG_WEIGHTS) + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } @@ -83,8 +84,18 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { return ok; }; - auto check_attr_zero_points - = [&]() -> bool { return attr()->zero_points_.common(); }; + auto check_attr_zero_points = [&]() -> bool { + const auto &zp = attr()->zero_points_; + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; + for (int arg : supported_args) { + if (!zp.has_default_values(arg)) { + const int mask = zp.get_mask(arg); + if (mask > 0) return false; + } + } + return true; + }; // The current version supports runtime value for M dimension in the case // of 2d problems only and do not support any runtime strides for B and C @@ -101,9 +112,8 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { VDISPATCH_MATMUL( no_dynamic_strides_for_B_and_C, VERBOSE_RUNTIMEDIM_UNSUPPORTED); VDISPATCH_MATMUL( - attr()->has_default_values( - primitive_attr_t::skip_mask_t::scales_runtime - | primitive_attr_t::skip_mask_t::zero_points_runtime + attr()->has_default_values(primitive_attr_t::skip_mask_t::scales + | primitive_attr_t::skip_mask_t::zero_points | primitive_attr_t::skip_mask_t::post_ops | primitive_attr_t::skip_mask_t::sum_dt, dst_dt), @@ -158,6 +168,9 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { = bgmmc_.post_ops_applicable && bgmmc_.nthr_k > 1; CHECK(brgemm_desc_set_attr(&brg, brgattr)); + + CHECK(brgemm_desc_finalize(&brg)); + bgmmc_.wsp_tile_per_thr_bytes = nstl::max( brg.get_wsp_buffer_size(), bgmmc_.wsp_tile_per_thr_bytes); } @@ -642,7 +655,6 @@ void brgemm_matmul_t::copy_b_chunk_in_buffer( = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); ctx.current_K_start = k; ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K); - assert(isa == sve_512); (*copy_B_kernel_)(&ctx); } @@ -654,7 +666,6 @@ void brgemm_matmul_t::copy_b_chunk_in_buffer( = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); ctx.current_K_start = k; ctx.current_K_iters = bgmmc.K % bgmmc.K_blk; - assert(isa == sve_512); (*copy_B_kernel_)(&ctx); } } diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp b/src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp index 0aa98fe9c2b..bfb917afbbd 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp @@ -85,7 +85,8 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init( matmul_conf_for_reorder_.K = dims[ndims - 2]; matmul_conf_for_reorder_.N = dims[ndims - 1]; matmul_conf_for_reorder_.wei_n_blk = matmul_conf_for_reorder_.N_blk - = matmul_conf_for_reorder_.LDB = matmul::get_default_n_block(otag); + = matmul_conf_for_reorder_.LDB + = matmul::get_default_n_block(otag, matmul_conf_for_reorder_); matmul_conf_for_reorder_.N_tail = matmul_conf_for_reorder_.N % matmul_conf_for_reorder_.N_blk; matmul_conf_for_reorder_.K_blk = 16 * vnni_granularity; diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp b/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp index bd9bc023eaf..0610147c752 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp @@ -1,5 +1,7 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation +* Copyright 2023-2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +49,8 @@ using namespace dnnl::impl::utils; using namespace data_type; using namespace format_tag; -int get_default_n_block(format_tag_t matrix_b_tag) { +int get_default_n_block( + format_tag_t matrix_b_tag, brgemm_matmul_conf_t &bgmmc) { // Note: consider using weights mem_descriptor 'inner_blks' to // return B's inner block for non-default cases. switch (matrix_b_tag) { @@ -75,7 +78,23 @@ int get_default_n_block(format_tag_t matrix_b_tag) { case BA16a16b: case BA16a16b2a: case BA16a16b4a: return 16; - default: return 64; + default: { + if (bgmmc.N == 16 || bgmmc.N == 32 || bgmmc.N == 64) return bgmmc.N; + if (!mayiuse(sve_512)) { + if (bgmmc.N <= 16) + return 16; + else { + // It is observed that for M,K>512, N block of 64 works better provided that thread distribution is not hindered. + if (bgmmc.N / 64 >= bgmmc.nthr && bgmmc.K > 512 + && bgmmc.M > 512) + return 64; + else + return 32; + } + + } else + return 64; + } } } @@ -128,9 +147,8 @@ bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr, true /*sum_requires_same_params*/, bcast_set)); } -status_t check_isa_with_datatype( - const cpu_isa_t isa, const brgemm_matmul_conf_utils_t &bm_conf_utils) { - if (bm_conf_utils.is_f32() && !bm_conf_utils.is_int8() +status_t check_datatype(const brgemm_matmul_conf_utils_t &bm_conf_utils) { + if (bm_conf_utils.is_f32() && !bm_conf_utils.is_bf32() && !bm_conf_utils.is_bf16() && !bm_conf_utils.is_f16() && !bm_conf_utils.is_int8()) return status::success; @@ -178,7 +196,7 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag( if (B_any_layout) { const int default_n_block = init_n_tag - ? get_default_n_block(format_tag::undef) + ? get_default_n_block(format_tag::undef, bgmmc) : bgmmc.N_blk; bgmmc.wei_tag = blocked_B_layouts_allowed ? this->pick_blocked_B_layout(default_n_block) @@ -320,10 +338,6 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout( default: return format_tag::undef; } - assert(!this->is_bf16()); - assert(!this->is_f16()); - assert(!this->is_bf32()); - // Note: bf32 assumes f32 blocking if (this->is_f32() || this->is_bf32() || this->is_f16()) switch (n_blk) { case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b; @@ -580,14 +594,17 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc, const int nthr = bgmmc.nthr; const int max_m_blk = nstl::min(/*64*/ 256, matmul.M); - int min_m_blk = nstl::min(32, matmul.M); // max_m_blk + // It is found that for 2d shapes min_m_blk = 128 works better than 32 for most of the shapes. + int min_m = (matmul.batch > 1) ? 32 : 128; + int min_m_blk = nstl::min(min_m, matmul.M); // max_m_blk int n_blk = bgmmc.N_blk; const int n_chunks = div_up(matmul.N, n_blk); const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1; const int n_chunks_start = nstl::min(max_n_chunks, n_chunks); - int default_k_blk = 1024; + //It is found that for M<512 k_blk of 128 works better than 1024 for most of the shapes. + int default_k_blk = (matmul.M >= 512) ? 1024 : 128; int k_blk = nstl::min(matmul.K, default_k_blk); int start_nthr_k = 1; @@ -597,7 +614,22 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc, const bool low_parallel_work = static_cast(nthr) > max_parallel; if (low_parallel_work) { - min_m_blk = nstl::min(matmul.M, 16); + int best_m_blk = 0; + float scr = 0, best_scr = 16 * nthr; + for (int i = 16; i >= 4; i--) { + scr = 0.7 * (matmul.M % i) + + 0.3 * std::abs(nthr - ((float)matmul.M / (float)i)); + if (scr < best_scr) { + best_scr = scr; + best_m_blk = i; + } + } + min_m_blk = nstl::min(matmul.M, best_m_blk); + // Here min_m_blk is set based on M value and no.of threads. Decreasing m_blk size will + // increase no.of m blocks which might make better utilisation of threads. But it is found + // that m_blk being a factor of M is more important than max thread utilisation.Therefore + // in scoring that has been given more weightage(0.7). This was experimentally verified to + // be the best hueristics with multiple shapes. bool low_spatial_work = matmul.M <= 40; if (low_spatial_work) { @@ -732,8 +764,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, dst_d.format_kind() == format_kind::any, bias_md.format_kind == format_kind::any); - VCHECK_BG(check_isa_with_datatype(isa, bm_conf_utils), - VERBOSE_ISA_DT_MISMATCH); + VCHECK_BG(check_datatype(bm_conf_utils), VERBOSE_UNSUPPORTED_DT); bgmmc.a_dt_sz = bgmmc.tr_a_dt_sz = types::data_type_size(bgmmc.src_dt); bgmmc.b_dt_sz = bgmmc.tr_b_dt_sz = types::data_type_size(bgmmc.wei_dt); @@ -752,21 +783,22 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); - bgmmc.with_scales = !src_scales.has_default_values() - || !wei_scales.has_default_values(); - if (bgmmc.with_scales) { - bgmmc.is_oscale_per_n = wei_scales.mask_ == 1 << (bgmmc.ndims - 1); + const bool has_wei_scales = !wei_scales.has_default_values(); + bgmmc.with_scales = !src_scales.has_default_values() || has_wei_scales; + if (has_wei_scales) { + bgmmc.is_oscale_per_n + = wei_scales.get_mask() == (1 << (bgmmc.ndims - 1)); // only common and per-oc-channel scales are supported - VCONDCHECK_BG(wei_scales.mask_ == 0 || bgmmc.is_oscale_per_n, + VCONDCHECK_BG(wei_scales.get_mask() == 0 || bgmmc.is_oscale_per_n, VERBOSE_UNSUPPORTED_SCALES_CFG); } const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); bgmmc.with_dst_scales = !dst_scales.has_default_values(); // only common scales are supported - if (bgmmc.with_dst_scales && dst_scales.mask_ != 0) - return status::unimplemented; + VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.get_mask() > 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); const auto &p = attr.post_ops_; bgmmc.with_sum = p.find(primitive_kind::sum) != -1; @@ -834,7 +866,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG); - bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag); + bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag, bgmmc); bgmmc.blocked_B = bm_conf_utils.get_blocked_B(); bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b(); @@ -1107,4 +1139,4 @@ void init_scratchpad(memory_tracking::registrar_t &scratchpad, } // namespace aarch64 } // namespace cpu } // namespace impl -} // namespace dnnl \ No newline at end of file +} // namespace dnnl diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp b/src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp index fb5d88b14f0..ec4e1b75a27 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation +* Copyright 2023-2024 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -312,7 +313,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, void init_scratchpad(memory_tracking::registrar_t &scratchpad, const brgemm_matmul_conf_t &bgmmc); -int get_default_n_block(format_tag_t matrix_b_tag); +int get_default_n_block(format_tag_t, brgemm_matmul_conf_t &bgmmc); } // namespace matmul } // namespace aarch64 diff --git a/src/cpu/aarch64/matmul/jit_int8_kernel_types.hpp b/src/cpu/aarch64/matmul/jit_int8_kernel_types.hpp new file mode 100644 index 00000000000..27d55f19381 --- /dev/null +++ b/src/cpu/aarch64/matmul/jit_int8_kernel_types.hpp @@ -0,0 +1,94 @@ +/******************************************************************************* +* Copyright 2025 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_JIT_INT8_KERNEL_TYPES_HPP +#define CPU_AARCH64_JIT_INT8_KERNEL_TYPES_HPP + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace matmul { + +typedef enum { + none = 0, + per_tensor = 1, + per_m = 2, + per_n = 3, + per_k = 4, +} jit_int8_broadcast_t; + +struct dyn_vals_t { + int f = 0; + dim_t M = 0; + dim_t K = 0; + dim_t N = 0; + dim_t B = 0; + int is_s8 = 0, is_u8 = 0; + int mtail, ktail, ntail, m_blk, k_blk, n_blk; + int get_min_max = 0, reorder_a = 0, reorder_b = 0, cal_src = 0; + int is_mtail = 0, is_ktail = 0; +}; + +struct dyn_params_t { + const float *dyn_src; + const int8_t *src; + int8_t *dst; + float *max, *min; + int *nk, *nm, *nn; + int *tl, *mtl, *ntl; +}; + +struct brg_int8_t { + int M, K, N; + const int m_blk = 8, n_blk = 4, k_blk = 8; + const int ld_block = 6, rd_block = 4, bd_block = 8; + int na, nb; + int m_tail, n_tail, k_tail; + int is_m_tail, is_k_tail, is_n_tail, is_zp_cal; + int dst_dt_sz; + bool is_s8; + bool is_bias; + bool with_scales; + bool with_dst_scales; + bool is_oc_scales; + jit_int8_broadcast_t zp_type_a = jit_int8_broadcast_t::none; + jit_int8_broadcast_t zp_type_b = jit_int8_broadcast_t::none; + jit_int8_broadcast_t zp_type_c = jit_int8_broadcast_t::none; + bool is_zp_b_int8 = false; + bool b_reo = true; + data_type_t zp_b_dt; + dim_t B; +}; + +struct call_params_t { + const uint8_t *src, *wei; + float *dst; + const float *bias, *scales, *dst_scales; + dim_t M, K, N; + char *buf_B_ptr_; + int *na, *nb; + int32_t *src_zero_point, *wei_zero_point, *dst_zero_point; + const int8_t *wei_zero_point_buf; + float *zp_a_ptr, *zp_b_ptr; +}; + +} // namespace matmul +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl +#endif diff --git a/src/cpu/aarch64/matmul/jit_int8_matmul.cpp b/src/cpu/aarch64/matmul/jit_int8_matmul.cpp new file mode 100644 index 00000000000..97e3a17b45a --- /dev/null +++ b/src/cpu/aarch64/matmul/jit_int8_matmul.cpp @@ -0,0 +1,1478 @@ +/******************************************************************************* +* Copyright 2025 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/math_utils.hpp" +#include "common/memory_tracking.hpp" +#include "common/nstl.hpp" +#include "common/tag_traits.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" +#include "cpu/aarch64/jit_generator.hpp" +#include "cpu/cpu_primitive.hpp" +#include "cpu/matmul/matmul_utils.hpp" +#include "cpu/scale_utils.hpp" + +#include "cpu/platform.hpp" +#include "cpu/primitive_attr_postops.hpp" + +#include "cpu/aarch64/matmul/jit_int8_kernel_types.hpp" +#include "cpu/aarch64/matmul/jit_int8_matmul.hpp" +#include "cpu/aarch64/matmul/jit_int8_matmul_utils.hpp" + +#define GET_OFF(field) (uint32_t) offsetof(call_params_t, field) + +#define LDR_IMM(reg, addr, off) \ + { \ + const uint64_t IMM12_MASK = ~uint64_t(0xfff); \ + if ((off & IMM12_MASK) == 0) { \ + ldr(reg, ptr(addr, off)); \ + } else { \ + add_imm(X_DEFAULT_ADDR, addr, off, X_TMP_0); \ + ldr(reg, ptr(X_DEFAULT_ADDR)); \ + } \ + } + +#define VCHECK_BG(f, msg, ...) \ + VCHECK(primitive, create, dispatch, brgemm_matmul, f, msg, ##__VA_ARGS__); + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace matmul { + +using namespace Xbyak_aarch64; +using namespace dnnl::impl::cpu::matmul; +using namespace dnnl::impl::format_tag; +using namespace dnnl::impl::memory_tracking::names; +using namespace dnnl::impl::utils; + +using namespace nstl; + +using namespace data_type; + +struct jit_int8_matmul_kernel_t : public jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_int8_matmul_kernel_t) + + XReg reg_param = abi_param1; + XReg reg_a = x3; + XReg reg_b = x4; + XReg reg_c = x5; + XReg reg_aux_a = x6; + XReg reg_aux_b = x7; + XReg reg_aux_c = x8; + XReg reg_aux_a1 = x9; + XReg reg_zp_aux_b_buf = x10; + XReg reg_aux_c1 = x11; + XReg reg_ld_loop = x12; + XReg reg_rd_loop = x13; + XReg reg_bd_loop = x14; + XReg reg_tmp = x15; + XReg reg_tmp_1 = x16; + XReg reg_bias = x17; + XReg reg_zp_a = x18; + + XReg reg_scales = x20; + XReg reg_aux_scales = x24; //used X_TMP_1 + XReg reg_na = x25; //used X_TMP_2 + XReg reg_zp_b = x26; //used X_TMP_3 + XReg reg_zp_aux_b = x27; //used X_TMP_4 + PReg prd_ld = p1; + PReg prd_st = p2; + PReg prd_b = p3; + PReg prd_8 = p4; + PReg prd_zp_b_tl = p5; + XReg reg_zp_val_c = x2; + + XReg reg_zp_val_a = reg_scales; + XReg reg_zp_val_b = reg_aux_scales; + + call_params_t inp; + + void operator()(const call_params_t *p) { + return jit_generator::operator()(p); + } + + ZReg loadb(int ld) { return ZReg(ld + 1); } + ZReg acc(int bd, int ld) { + return ZReg(bd * brg_.ld_block + ld + brg_.ld_block + 1); + } + void zero_regs() { + for (int a = 0; a < brg_.bd_block / 2; a++) + for (int b = 0; b < brg_.ld_block; b++) + eor(acc(a, b).d, acc(a, b).d, acc(a, b).d); + } + void store_regs(int bdb, int ldb, int tail) { + for (int a = 0; a < bdb; a++) { + for (int b = 0; b < ldb; b++) { + if (brg_.is_s8) + scvtf(acc(a, b).s, P_ALL_ONE, acc(a, b).s); + else + ucvtf(acc(a, b).s, P_ALL_ONE, acc(a, b).s); + } + } + + for (int a = 0; a < bdb; a++) { + for (int b = 0; b < ldb; b += 2) { + if (b + 1 < ldb) { + uzp1(z31.d, acc(a, b).d, acc(a, b + 1).d); + uzp2(acc(a, b + 1).d, acc(a, b).d, acc(a, b + 1).d); + mov(acc(a, b).d, z31.d); + } else { + uzp1(z31.d, acc(a, b).d, acc(a, b).d); + uzp2(acc(a, b + 1).d, acc(a, b).d, acc(a, b).d); + mov(acc(a, b).d, z31.d); + } + } + } + + if (brg_.zp_type_a != jit_int8_broadcast_t::none) { + for (int b = 0; b < ldb; b += 2) { + PReg p = (brg_.is_n_tail && b >= ldb - 2) ? prd_b : P_ALL_ONE; + ld1w(z31.s, p, ptr(reg_zp_a, b / 2, MUL_VL)); + for (int a = 0; a < bdb; a++) { + fsub(acc(a, b).s, acc(a, b).s, z31.s); + fsub(acc(a, b + 1).s, acc(a, b + 1).s, z31.s); + } + } + } + + if (brg_.zp_type_b != jit_int8_broadcast_t::none) { + int ao = 0; + if (brg_.is_zp_b_int8) { + mov(reg_tmp_1, reg_zp_aux_b_buf); + int ilp = (brg_.is_n_tail) ? n_blks : 3; + for (int i = 0; i < ilp; i++) { + PReg p = (brg_.is_n_tail && i == ilp - 1) ? prd_zp_b_tl + : prd_8; + ld1b(ZRegB(i + 1), p, ptr(reg_tmp_1)); + if (brg_.zp_b_dt == u8) { + uunpklo(ZRegH(i + 1), ZRegB(i + 1)); + uunpklo(ZRegS(i + 1), ZRegH(i + 1)); + ucvtf(ZRegS(i + 1), P_ALL_ONE, ZRegS(i + 1)); + } else { + sunpklo(ZRegH(i + 1), ZRegB(i + 1)); + sunpklo(ZRegS(i + 1), ZRegH(i + 1)); + scvtf(ZRegS(i + 1), P_ALL_ONE, ZRegS(i + 1)); + } + add_imm(reg_tmp_1, reg_tmp_1, 8, X_TMP_0); + } + } + for (int a = 0; a < bdb; a++) { + ld1rw(z31.s, P_ALL_ONE, ptr(reg_zp_aux_b, ao * 4)); + ld1rw(z0.s, P_ALL_ONE, ptr(reg_zp_aux_b, (ao + 1) * 4)); + for (int b = 0; b < ldb; b += 2) { + if (brg_.is_zp_b_int8) { + fmul(z4.s, z31.s, ZRegS(b / 2 + 1)); + fmul(z5.s, z0.s, ZRegS(b / 2 + 1)); + fsub(acc(a, b).s, acc(a, b).s, z4.s); + fsub(acc(a, b + 1).s, acc(a, b + 1).s, z5.s); + } else { + fsub(acc(a, b).s, acc(a, b).s, z31.s); + fsub(acc(a, b + 1).s, acc(a, b + 1).s, z0.s); + } + } + ao += 2; + } + } + + if (brg_.with_scales) { + for (int b = 0; b < ldb; b += 2) { + PReg p = (brg_.is_n_tail && b >= ldb - 2) ? prd_b : P_ALL_ONE; + if (brg_.is_oc_scales) { + ld1w(z31.s, p, ptr(reg_scales, b / 2, MUL_VL)); + } else { + ld1w(z31.s, p, ptr(reg_scales)); + } + + for (int a = 0; a < bdb; a++) { + fmul(acc(a, b).s, acc(a, b).s, z31.s); + fmul(acc(a, b + 1).s, acc(a, b + 1).s, z31.s); + } + } + } + + if (brg_.is_bias) { + for (int b = 0; b < ldb; b += 2) { + PReg p = (brg_.is_n_tail && b >= ldb - 2) ? prd_b : P_ALL_ONE; + ld1w(z31.s, p, ptr(reg_bias, b / 2, MUL_VL)); + for (int a = 0; a < bdb; a++) { + fadd(acc(a, b).s, acc(a, b).s, z31.s); + fadd(acc(a, b + 1).s, acc(a, b + 1).s, z31.s); + } + } + } + + if (brg_.with_dst_scales) { + ld1rw(z31.s, P_ALL_ONE, ptr(reg_aux_scales)); + for (int b = 0; b < ldb; b += 2) { + for (int a = 0; a < bdb; a++) { + fmul(acc(a, b).s, acc(a, b).s, z31.s); + fmul(acc(a, b + 1).s, acc(a, b + 1).s, z31.s); + } + } + } + + if (brg_.zp_type_c != jit_int8_broadcast_t::none) { + LDR_IMM(reg_zp_val_c, reg_param, GET_OFF(dst_zero_point)); + ldr(W_TMP_0, ptr(reg_zp_val_c)); + dup(z0.s, W_TMP_0); + scvtf(z0.s, P_ALL_ONE, z0.s); + for (int b = 0; b < ldb; b += 2) { + for (int a = 0; a < bdb; a++) { + fadd(acc(a, b).s, acc(a, b).s, z0.s); + fadd(acc(a, b + 1).s, acc(a, b + 1).s, z0.s); + } + } + } + + mov(reg_tmp, reg_aux_c); + add_imm(reg_tmp_1, reg_aux_c, brg_.N * brg_.dst_dt_sz, X_TMP_0); + for (int a = 0; a < bdb; a++) { + for (int b = 0; b < ldb; b += 2) { + PReg p = (brg_.is_n_tail && b >= ldb - 2) ? prd_st : P_ALL_ONE; + int vl = b / 2; + st1w(acc(a, b).s, p, ptr(reg_tmp, vl, MUL_VL)); + if (a >= bdb - 1 && brg_.is_m_tail) { + if (brg_.m_tail % 2 == 0) + st1w(acc(a, b + 1).s, p, ptr(reg_tmp_1, vl, MUL_VL)); + } else { + st1w(acc(a, b + 1).s, p, ptr(reg_tmp_1, vl, MUL_VL)); + } + } + add_imm(reg_tmp, reg_tmp, 2 * brg_.N * brg_.dst_dt_sz, X_TMP_0); + add_imm(reg_tmp_1, reg_tmp_1, 2 * brg_.N * brg_.dst_dt_sz, X_TMP_0); + } + } + + void microkernel(int rdb, int bdb, int ldb, int tail) { + int a_off = 0, rd, ld, bd; + mov(reg_tmp, reg_aux_b); + for (rd = 0; rd < rdb; rd++) { + int ao = 0; + + for (ld = 0; ld < ldb; ld++) { + PReg p = (brg_.is_n_tail && ld == ldb - 1) ? prd_ld : P_ALL_ONE; + ld1b(loadb(ld).b, p, ptr(reg_tmp, ld, MUL_VL)); + } + for (bd = 0; bd < bdb; bd++) { + add_imm(X_DEFAULT_ADDR, reg_aux_a, a_off + ao, X_TMP_0); + ld1rqb(z0.b, P_ALL_ONE, ptr(X_DEFAULT_ADDR)); + ao += brg_.m_blk * 2; + + for (ld = 0; ld < ldb; ld++) { + if (brg_.is_s8) + smmla(acc(bd, ld).s, z0.b, loadb(ld).b); + else + ummla(acc(bd, ld).s, z0.b, loadb(ld).b); + } + } + a_off += brg_.m_blk * brg_.k_blk; + add_imm(reg_tmp, reg_tmp, brg_.k_blk * brg_.n_blk * brg_.ld_block, + X_TMP_0); + } + } + + void loop_k(int bdb, int ldb, int tail) { + zero_regs(); + mov(reg_aux_a, reg_aux_a1); + mov(reg_aux_b, reg_b); + if (k_full_blks > 0) { + mov(reg_rd_loop, k_full_blks); + Label l0; + L(l0); + microkernel(brg_.rd_block, bdb, ldb, tail); + add_imm(reg_aux_a, reg_aux_a, + brg_.m_blk * brg_.k_blk * brg_.rd_block, X_TMP_0); + add_imm(reg_aux_b, reg_aux_b, + brg_.k_blk * brg_.n_blk * brg_.ld_block * brg_.rd_block, + X_TMP_0); + sub(reg_rd_loop, reg_rd_loop, 1); + cmp(reg_rd_loop, 0); + b(GT, l0); + } + if (k_tail_blk > 0) { + microkernel(k_tail_blk, bdb, ldb, tail); + add_imm(reg_aux_a, reg_aux_a, brg_.m_blk * brg_.k_blk * k_tail_blk, + X_TMP_0); + add_imm(reg_aux_b, reg_aux_b, + brg_.k_blk * brg_.n_blk * brg_.ld_block * k_tail_blk, + X_TMP_0); + } + if (k_residual_blk > 0) { microkernel(1, bdb, ldb, tail); } + store_regs(bdb, ldb, tail); + } + + void loop_k_zp(int bdb, int ldb, int is_a, int is_b) { + eor(z3.d, z3.d, z3.d); + eor(z4.d, z4.d, z4.d); + for (int i = 0; i < 6; i++) + eor(acc(2, i).d, acc(2, i).d, acc(2, i).d); + mov(reg_aux_a, reg_aux_a1); + mov(reg_aux_b, reg_b); + if (k_full_blks > 0) { + mov(reg_rd_loop, k_full_blks); + Label l0; + L(l0); + zp_comp(brg_.rd_block, bdb, ldb, is_a, is_b); + add_imm(reg_aux_a, reg_aux_a, + brg_.m_blk * brg_.k_blk * brg_.rd_block, X_TMP_0); + add_imm(reg_aux_b, reg_aux_b, + brg_.k_blk * brg_.n_blk * brg_.ld_block * brg_.rd_block, + X_TMP_0); + sub(reg_rd_loop, reg_rd_loop, 1); + cmp(reg_rd_loop, 0); + b(GT, l0); + } + if (k_tail_blk > 0) { + zp_comp(k_tail_blk, bdb, ldb, is_a, is_b); + add_imm(reg_aux_a, reg_aux_a, brg_.m_blk * brg_.k_blk * k_tail_blk, + X_TMP_0); + add_imm(reg_aux_b, reg_aux_b, + brg_.k_blk * brg_.n_blk * brg_.ld_block * k_tail_blk, + X_TMP_0); + } + if (k_residual_blk > 0) { zp_comp(1, bdb, ldb, is_a, is_b); } + + if (brg_.zp_type_b != jit_int8_broadcast_t::none && is_b == 1) { + uzp1(z3.d, z3.d, z4.d); + scvtf(z3.s, P_ALL_ONE, z3.s); + if (!brg_.is_zp_b_int8) { + ldr(W_TMP_0, ptr(reg_zp_val_b)); + dup(z0.s, W_TMP_0); + scvtf(z0.s, P_ALL_ONE, z0.s); + fmul(z3.s, P_ALL_ONE, z0.s); + } else { + if (brg_.zp_type_a != jit_int8_broadcast_t::none) { + ldr(W_TMP_0, ptr(reg_zp_val_a)); + dup(z0.s, W_TMP_0); + mov_imm(W_TMP_0, brg_.K); + dup(z1.s, W_TMP_0); + scvtf(z0.s, P_ALL_ONE, z0.s); + scvtf(z1.s, P_ALL_ONE, z1.s); + fmul(z0.s, z1.s, z0.s); + fsub(z3.s, z3.s, z0.s); + } + } + st1w(z3.s, P_ALL_ONE, ptr(reg_zp_b)); + } + + if ((brg_.zp_type_a != jit_int8_broadcast_t::none) && is_a == 1) { + ldr(W_TMP_0, ptr(reg_zp_val_a)); + dup(z2.s, W_TMP_0); + scvtf(z2.s, P_ALL_ONE, z2.s); + uzp1(acc(2, 0).d, acc(2, 0).d, acc(2, 1).d); + uzp1(acc(2, 2).d, acc(2, 2).d, acc(2, 3).d); + uzp1(acc(2, 4).d, acc(2, 4).d, acc(2, 5).d); + + scvtf(acc(2, 0).s, P_ALL_ONE, acc(2, 0).s); + scvtf(acc(2, 2).s, P_ALL_ONE, acc(2, 2).s); + scvtf(acc(2, 4).s, P_ALL_ONE, acc(2, 4).s); + if (brg_.zp_type_b != jit_int8_broadcast_t::none + && !brg_.is_zp_b_int8) { + ldr(W_TMP_0, ptr(reg_zp_val_b)); + dup(z0.s, W_TMP_0); + mov_imm(W_TMP_0, brg_.K); + dup(z1.s, W_TMP_0); + scvtf(z0.s, P_ALL_ONE, z0.s); + scvtf(z1.s, P_ALL_ONE, z1.s); + fmul(z0.s, z1.s, z0.s); + fsub(acc(2, 0).s, acc(2, 0).s, z0.s); + fsub(acc(2, 2).s, acc(2, 2).s, z0.s); + fsub(acc(2, 4).s, acc(2, 4).s, z0.s); + } + fmul(acc(2, 0).s, P_ALL_ONE, z2.s); + fmul(acc(2, 2).s, P_ALL_ONE, z2.s); + fmul(acc(2, 4).s, P_ALL_ONE, z2.s); + + st1w(acc(2, 0).s, P_ALL_ONE, ptr(reg_zp_a)); + st1w(acc(2, 2).s, P_ALL_ONE, ptr(reg_zp_a, 1, MUL_VL)); + st1w(acc(2, 4).s, P_ALL_ONE, ptr(reg_zp_a, 2, MUL_VL)); + } + } + + void han_blk() { + Label ld_loop, bd_loop; + LDR_IMM(reg_tmp, reg_param, GET_OFF(nb)); + LDR_IMM(reg_na, reg_param, GET_OFF(na)); + ldr(WReg(reg_ld_loop.getIdx()), ptr(reg_tmp)); + mov(reg_aux_a1, reg_a); + // mov(reg_b,reg_b); + mov(reg_aux_c1, reg_c); + mov(reg_aux_c, reg_aux_c1); + mov(reg_zp_aux_b, reg_zp_b); + L(ld_loop); + ldr(WReg(reg_bd_loop.getIdx()), ptr(reg_na)); + L(bd_loop); + loop_k(bdb, ldb, 0); + add_imm(reg_aux_a1, reg_aux_a1, + div_up(brg_.K, brg_.k_blk) * brg_.k_blk * brg_.bd_block, + X_TMP_0); + add_imm(reg_aux_c, reg_aux_c, brg_.N * brg_.bd_block * brg_.dst_dt_sz, + X_TMP_0); + add_imm(reg_zp_aux_b, reg_zp_aux_b, brg_.m_blk * brg_.dst_dt_sz, + X_TMP_0); + sub(reg_bd_loop, reg_bd_loop, 1); + cmp(reg_bd_loop, 0); + b(GT, bd_loop); + mov(reg_aux_a1, reg_a); + mov(reg_zp_aux_b, reg_zp_b); + add_imm(reg_b, reg_b, + (brg_.n_blk * brg_.ld_block) * div_up(brg_.K, brg_.k_blk) + * brg_.k_blk, + X_TMP_0); + add_imm(reg_aux_c1, reg_aux_c1, + brg_.dst_dt_sz * (brg_.n_blk * brg_.ld_block), X_TMP_0); + add_imm(reg_zp_a, reg_zp_a, brg_.n_blk * brg_.ld_block * brg_.dst_dt_sz, + X_TMP_0); + if (brg_.is_oc_scales) + add_imm(reg_scales, reg_scales, + brg_.dst_dt_sz * (brg_.n_blk * brg_.ld_block), X_TMP_0); + add_imm(reg_bias, reg_bias, + brg_.dst_dt_sz * (brg_.n_blk * brg_.ld_block), X_TMP_0); + mov(reg_aux_c, reg_aux_c1); + sub(reg_ld_loop, reg_ld_loop, 1); + cmp(reg_ld_loop, 0); + b(GT, ld_loop); + } + + void han_blk_zp() { + Label ld_loop, bd_loop, skip_ld_loop, skip_bd_loop; + LDR_IMM(reg_tmp, reg_param, GET_OFF(nb)); + LDR_IMM(reg_na, reg_param, GET_OFF(na)); + ldr(WReg(reg_ld_loop.getIdx()), ptr(reg_tmp)); + ldr(WReg(reg_bd_loop.getIdx()), ptr(reg_na)); + mov(reg_aux_a1, reg_a); + // mov(reg_b,reg_b); + if (brg_.zp_type_b != jit_int8_broadcast_t::none) { + cmp(reg_bd_loop, 0); + b(EQ, skip_bd_loop); + L(bd_loop); + loop_k_zp(bdb, ldb, 0, 1); + add_imm(reg_aux_a1, reg_aux_a1, + div_up(brg_.K, brg_.k_blk) * brg_.k_blk * brg_.bd_block, + X_TMP_0); + add_imm(reg_zp_b, reg_zp_b, brg_.m_blk * brg_.dst_dt_sz, X_TMP_0); + sub(reg_bd_loop, reg_bd_loop, 1); + cmp(reg_bd_loop, 0); + b(GT, bd_loop); + L(skip_bd_loop); + } + if (brg_.zp_type_a != jit_int8_broadcast_t::none) { + cmp(reg_ld_loop, 0); + b(EQ, skip_ld_loop); + L(ld_loop); + loop_k_zp(bdb, ldb, 1, 0); + add_imm(reg_zp_a, reg_zp_a, + brg_.n_blk * brg_.ld_block * brg_.dst_dt_sz, X_TMP_0); + add_imm(reg_b, reg_b, + (brg_.n_blk * brg_.ld_block) * div_up(brg_.K, brg_.k_blk) + * brg_.k_blk, + X_TMP_0); + sub(reg_ld_loop, reg_ld_loop, 1); + cmp(reg_ld_loop, 0); + b(GT, ld_loop); + L(skip_ld_loop); + } + } + + void zp_comp(int rdb, int bdb, int ldb, int is_a, int is_b) { + + dup(z0.b, 1); + int rd, ld; + if (brg_.zp_type_b != jit_int8_broadcast_t::none && is_b == 1) { + mov(reg_tmp, reg_aux_a); + for (rd = 0; rd < rdb; rd++) { + ld1b(z1.b, P_ALL_ONE / T_z, ptr(reg_tmp)); + ld1b(z2.b, P_ALL_ONE / T_z, ptr(reg_tmp, 1, MUL_VL)); + add_imm(reg_tmp, reg_tmp, brg_.k_blk * brg_.m_blk, X_TMP_0); + if (brg_.is_s8) { + smmla(z3.s, z0.b, z1.b); + smmla(z4.s, z0.b, z2.b); + } else { + ummla(z3.s, z0.b, z1.b); + ummla(z4.s, z0.b, z2.b); + } + } + } + if ((brg_.zp_type_a != jit_int8_broadcast_t::none) && is_a == 1) { + mov(reg_tmp, reg_aux_b); + + for (rd = 0; rd < rdb; rd++) { + for (ld = 0; ld < ldb; ld++) { + PReg p = (brg_.is_n_tail && ld == ldb - 1) ? prd_ld + : P_ALL_ONE; + ld1b(acc(1, ld).b, p, ptr(reg_tmp, ld, MUL_VL)); + } + add_imm(reg_tmp, reg_tmp, + brg_.k_blk * brg_.n_blk * brg_.ld_block, X_TMP_0); + for (ld = 0; ld < ldb; ld++) { + if (brg_.is_s8) { + smmla(acc(2, ld).s, z0.b, acc(1, ld).b); + } else { + ummla(acc(2, ld).s, z0.b, acc(1, ld).b); + } + } + } + } + } + + void config() { + int m, pred_st = 0, pred_ld = 0, sv_len = 8, pred_b = 8; + n_blks = div_up(brg_.n_tail, 8); + k_full_blks = brg_.K / (brg_.k_blk * brg_.rd_block); + m = brg_.K % (brg_.k_blk * brg_.rd_block); + k_tail_blk = m / brg_.k_blk; + k_residual_blk = m % brg_.k_blk; + ldb = (brg_.is_n_tail) ? div_up(brg_.n_tail, 4) : brg_.ld_block; + bdb = (brg_.is_m_tail) ? div_up(brg_.m_tail, 2) : brg_.bd_block / 2; + rdb = (brg_.is_k_tail) ? div_up(brg_.k_tail, brg_.k_blk) : 4; + + int pred_zp_b_tl = (brg_.n_tail % 8 == 0) ? 8 : brg_.n_tail % 8; + set_preg(prd_8.b, 8, X_TMP_0, X_TMP_1); + set_preg(prd_zp_b_tl.b, pred_zp_b_tl, X_TMP_0, X_TMP_1); + + if (brg_.is_n_tail) { + pred_b = (brg_.n_tail % 8 == 0) ? sv_len : (brg_.n_tail % 8); + if (brg_.n_tail % brg_.n_blk == 0) { + pred_st = (brg_.n_tail % (brg_.n_blk * 2) == 0) ? sv_len + : sv_len / 2; + pred_ld = sv_len * brg_.dst_dt_sz; + } else { + pred_ld = (brg_.n_tail % brg_.n_blk) * brg_.k_blk; + pred_st = (ldb % 2 == 0) + ? (sv_len / 2) + (brg_.n_tail % brg_.n_blk) + : (brg_.n_tail % brg_.n_blk); + } + } + set_preg(prd_ld.b, pred_ld, X_TMP_0, X_TMP_1); + set_preg(prd_st.s, pred_st, X_TMP_0, X_TMP_1); + set_preg(prd_b.s, pred_b, X_TMP_0, X_TMP_1); + } + + void generate() override { + preamble(); + config(); + + LDR_IMM(reg_a, reg_param, GET_OFF(src)); + LDR_IMM(reg_b, reg_param, GET_OFF(wei)); + LDR_IMM(reg_c, reg_param, GET_OFF(dst)); + LDR_IMM(reg_zp_b, reg_param, GET_OFF(zp_b_ptr)); + LDR_IMM(reg_zp_a, reg_param, GET_OFF(zp_a_ptr)); + if (brg_.is_zp_cal) { + LDR_IMM(reg_zp_val_b, reg_param, GET_OFF(wei_zero_point)); + LDR_IMM(reg_zp_val_a, reg_param, GET_OFF(src_zero_point)); + han_blk_zp(); + } else { + + LDR_IMM(reg_bias, reg_param, GET_OFF(bias)); + LDR_IMM(reg_scales, reg_param, GET_OFF(scales)); + LDR_IMM(reg_aux_scales, reg_param, GET_OFF(dst_scales)); + LDR_IMM(reg_zp_aux_b_buf, reg_param, GET_OFF(wei_zero_point_buf)); + han_blk(); + } + + postamble(); + } + + jit_int8_matmul_kernel_t(const brg_int8_t &k) : brg_(k) {} + ~jit_int8_matmul_kernel_t() override = default; + +private: + brg_int8_t brg_; + int ldb; + int bdb; + int rdb; + int k_full_blks; + int k_tail_blk; + int k_residual_blk; + int n_blks; +}; + +status_t jit_int8_matmul_t::pd_t::init(engine_t *engine) { + + const auto src_type = src_md(0)->data_type; + const auto wei_type = weights_md(0)->data_type; + const auto dst_type = dst_md(0)->data_type; + + const memory_desc_wrapper src_d(src_md_); + const memory_desc_wrapper weights_d(weights_md_); + const memory_desc_wrapper dst_d(dst_md_); + const memory_desc_wrapper bias_d(bias_md_); + + const bool no_runtime_dims_or_strides + = !(src_d.has_runtime_dims_or_strides() + || weights_d.has_runtime_dims_or_strides()); + + VDISPATCH_MATMUL( + no_runtime_dims_or_strides, VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + bool is_s8_wei = utils::everyone_is(s8, wei_type); + bool is_u8 = utils::everyone_is(u8, src_type, wei_type); + bool is_s8 = utils::everyone_is(s8, src_type, wei_type); + + int dims = src_d.ndims(); + + auto check_attr_scales = [&]() -> bool { + const std::vector supported_args + = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; + bool ok = attr_scales_ok(supported_args); + auto is_src_scl + = !attr()->scales_.get(DNNL_ARG_SRC).has_default_values(); + auto is_wei_scl + = !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); + auto dst_scl_msk = attr()->scales_.get(DNNL_ARG_DST).get_mask(); + auto wei_scl_msk = attr()->scales_.get(DNNL_ARG_WEIGHTS).get_mask(); + auto src_scl_msk = attr()->scales_.get(DNNL_ARG_SRC).get_mask(); + + if (src_scl_msk > 0 + || (wei_scl_msk > 0 && wei_scl_msk != 1 << (dims - 1)) + || dst_scl_msk > 0) + return false; + + if (is_src_scl && is_wei_scl && wei_scl_msk > 0) { + // This case requires scratchpad. + if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; + } + return ok; + }; + + auto check_bias = [&]() -> bool { + if (bias_d.format_any()) { + if (bias_d.has_runtime_dims_or_strides()) return false; + status_t status = memory_desc_init_by_strides(bias_md_, nullptr); + if (status != status::success) return false; + } + + const auto bia_dt = weights_md(1)->data_type; + return IMPLICATION(with_bias(), bia_dt == f32 && is_bias_1xN()); + }; + + auto init_zp_type = [&](brg_int8_t *brg_) -> bool { + auto zero_points = attr()->zero_points_; + + auto wt_int8 = zero_points.get_data_type(DNNL_ARG_WEIGHTS) == u8 + || zero_points.get_data_type(DNNL_ARG_WEIGHTS) == s8; + if (!zero_points.has_default_data_type(DNNL_ARG_SRC) + || !zero_points.has_default_data_type(DNNL_ARG_DST) + || (!zero_points.has_default_data_type(DNNL_ARG_WEIGHTS) + && !wt_int8)) + return false; + + if (!zero_points.has_default_data_type(DNNL_ARG_WEIGHTS)) { + switch (zero_points.get_data_type(DNNL_ARG_WEIGHTS)) { + case u8: { + brg_->zp_b_dt = u8; + brg_->is_zp_b_int8 = true; + break; + } + case s8: { + brg_->zp_b_dt = s8; + brg_->is_zp_b_int8 = true; + break; + } + case s32: { + brg_->is_zp_b_int8 = false; + break; + } + default: return false; + } + } + + if (zero_points.get_mask(DNNL_ARG_SRC) > 0 + || zero_points.get_mask(DNNL_ARG_DST) > 0 + || (zero_points.get_mask(DNNL_ARG_WEIGHTS) > 0 + && (zero_points.get_mask(DNNL_ARG_WEIGHTS)) + != (3 << (dims - 2)))) + return false; + + brg_->zp_type_a = zero_points.has_default_values(DNNL_ARG_SRC) + ? jit_int8_broadcast_t::none + : jit_int8_broadcast_t::per_tensor; + + brg_->zp_type_b = zero_points.has_default_values(DNNL_ARG_WEIGHTS) + ? jit_int8_broadcast_t::none + : jit_int8_broadcast_t::per_tensor; + + brg_->zp_type_c = zero_points.has_default_values(DNNL_ARG_DST) + ? jit_int8_broadcast_t::none + : jit_int8_broadcast_t::per_tensor; + + return true; + }; + + VDISPATCH_MATMUL(init_zp_type(&brg_), VERBOSE_UNSUPPORTED_ZP_CFG); + + VDISPATCH_MATMUL(check_bias(), VERBOSE_UNSUPPORTED_BIAS_CFG); + + VDISPATCH_MATMUL(check_attr_scales(), VERBOSE_UNSUPPORTED_SCALES_CFG); + + bool no_post_ops = attr()->post_ops_.has_default_values(); + const bool problem_dt_correct + = (is_s8 || is_u8) && utils::everyone_is(f32, dst_type); + + VDISPATCH_MATMUL(problem_dt_correct, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL(no_post_ops, VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_MATMUL(formats_ok(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_MATMUL(get_sve_length() == 32, VERBOSE_UNSUPPORTED_ISA); + + auto is_src_any = src_d.format_kind() == format_kind::any; + auto is_dst_any = dst_d.format_kind() == format_kind::any; + + switch (dims) { + case 2: { + if (is_src_any) + VCHECK_BG(memory_desc_init_by_tag(src_md_, format_tag::ab), + VERBOSE_UNSUPPORTED_TAG); + if (is_dst_any) + VCHECK_BG(memory_desc_init_by_tag(dst_md_, format_tag::ab), + VERBOSE_UNSUPPORTED_TAG); + if (!weights_d.matches_tag(format_tag::ab)) { + brg_.b_reo = false; + VCHECK_BG(memory_desc_init_by_tag( + weights_md_, format_tag::BA24b8a), + VERBOSE_UNSUPPORTED_TAG); + } else { + VCHECK_BG(memory_desc_init_by_tag(weights_md_, format_tag::ab), + VERBOSE_UNSUPPORTED_TAG); + } + break; + } + case 3: { + if (is_src_any) + VCHECK_BG(memory_desc_init_by_tag(src_md_, format_tag::abc), + VERBOSE_UNSUPPORTED_TAG); + if (is_dst_any) + VCHECK_BG(memory_desc_init_by_tag(dst_md_, format_tag::abc), + VERBOSE_UNSUPPORTED_TAG); + if (!weights_d.matches_tag(format_tag::abc)) { + brg_.b_reo = false; + VCHECK_BG(memory_desc_init_by_tag( + weights_md_, format_tag::aCB24c8b), + VERBOSE_UNSUPPORTED_TAG); + } else { + VCHECK_BG(memory_desc_init_by_tag(weights_md_, format_tag::abc), + VERBOSE_UNSUPPORTED_TAG); + } + if (src_d.dims()[0] != weights_d.dims()[0]) + return status::unimplemented; + break; + } + case 4: { + if (is_src_any) + VCHECK_BG(memory_desc_init_by_tag(src_md_, format_tag::abcd), + VERBOSE_UNSUPPORTED_TAG); + if (is_dst_any) + VCHECK_BG(memory_desc_init_by_tag(dst_md_, format_tag::abcd), + VERBOSE_UNSUPPORTED_TAG); + if (!weights_d.matches_tag(format_tag::abcd)) { + brg_.b_reo = false; + VCHECK_BG(memory_desc_init_by_tag( + weights_md_, format_tag::abDC24d8c), + VERBOSE_UNSUPPORTED_TAG); + } else { + VCHECK_BG( + memory_desc_init_by_tag(weights_md_, format_tag::abcd), + VERBOSE_UNSUPPORTED_TAG); + } + if (src_d.dims()[0] != weights_d.dims()[0] + || src_d.dims()[1] != weights_d.dims()[1]) + return status::unimplemented; + break; + } + default: return status::unimplemented; + } + + bool is_scales = !attr()->scales_.get(DNNL_ARG_SRC).has_default_values() + || !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); + + bool is_dst_scales + = !attr()->scales_.get(DNNL_ARG_DST).has_default_values(); + + const auto &wei_scales = attr()->scales_.get(DNNL_ARG_WEIGHTS); + + matmul_helper_t helper(src_d, weights_d, dst_d); + brg_.K = helper.K(); + brg_.M = helper.M(); + brg_.N = helper.N(); + brg_.dst_dt_sz = 4; + brg_.na = 1; + brg_.nb = 1; + brg_.m_tail = brg_.M % brg_.m_blk; + brg_.k_tail = brg_.K % (brg_.k_blk * brg_.rd_block); + brg_.n_tail = brg_.N % (brg_.n_blk * brg_.ld_block); + brg_.is_s8 = is_s8_wei; + brg_.is_bias = with_bias(); + brg_.B = batch(); + brg_.with_scales = is_scales; + brg_.with_dst_scales = is_dst_scales; + brg_.is_oc_scales = wei_scales.get_mask() > 0; + dyn_.K = brg_.K; + dyn_.N = brg_.N; + dyn_.M = brg_.M; + dyn_.B = brg_.B; + dyn_.mtail = brg_.m_tail; + dyn_.m_blk = brg_.m_blk; + dyn_.k_blk = brg_.k_blk; + dyn_.n_blk = brg_.n_blk * brg_.ld_block; + dyn_.ntail = brg_.n_tail; + dyn_.ktail = dyn_.K % brg_.k_blk; + + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_brgemm_primitive_zp_comp_a, + div_up(brg_.N, (brg_.n_blk * brg_.ld_block)) + * (brg_.n_blk * brg_.ld_block) * brg_.dst_dt_sz * brg_.B, + sizeof(char)); + scratchpad.book(key_brgemm_primitive_zp_comp_b, + div_up(brg_.M, brg_.m_blk) * brg_.m_blk * brg_.dst_dt_sz * brg_.B, + sizeof(char)); + scratchpad.book(key_brgemm_primitive_buffer_a, + brg_.B * div_up(brg_.M, brg_.m_blk) * div_up(brg_.K, brg_.k_blk) + * brg_.m_blk * brg_.k_blk, + sizeof(char)); + scratchpad.book(key_brgemm_primitive_buffer_b, brg_.B * brg_.M * brg_.K, + sizeof(char)); + if (brg_.b_reo) + scratchpad.book(key_gemm_blocked_b, + brg_.B * div_up(brg_.N, (brg_.n_blk * brg_.ld_block)) + * (brg_.n_blk * brg_.ld_block) + * div_up(brg_.K, brg_.k_blk) * brg_.k_blk, + sizeof(char)); + book_precomputed_scales(scratchpad, attr()->scales_, N()); + + return status::success; +} + +status_t jit_int8_matmul_t::init(engine_t *engine) { + + const auto &b1 = pd()->get_b(); + const auto &d1 = pd()->get_d(); + + dyn_vals_t d; + d.K = d1.K; + d.M = d1.M; + d.B = d1.B; + d.N = d1.N; + d.mtail = d1.mtail; + d.ktail = d1.ktail; + d.ntail = d1.ntail; + d.k_blk = d1.k_blk; + d.m_blk = d1.m_blk; + d.n_blk = d1.n_blk; + + brg_int8_t b; + b.M = b1.M; + b.K = b1.K; + b.N = b1.N; + b.na = b1.na; + b.nb = b1.nb; + b.m_tail = b1.m_tail; + b.n_tail = b1.n_tail; + b.k_tail = b1.k_tail; + b.dst_dt_sz = b1.dst_dt_sz; + b.is_s8 = b1.is_s8; + b.B = b1.B; + b.is_bias = b1.is_bias; + b.zp_type_a = b1.zp_type_a; + b.zp_type_b = b1.zp_type_b; + b.zp_type_c = b1.zp_type_c; + b.is_zp_b_int8 = b1.is_zp_b_int8; + b.zp_b_dt = b1.zp_b_dt; + b.with_scales = b1.with_scales; + b.with_dst_scales = b1.with_dst_scales; + b.is_oc_scales = b1.is_oc_scales; + b.b_reo = b1.b_reo; + + for (int z = 0; z < 2; z++) + for (int m = 0; m < 2; m++) + for (int n = 0; n < 2; n++) + for (int k = 0; k < 2; k++) { + int idx = pd()->get_idx(z, m, k, n, b1); + if (idx == -1 || idx > 15) continue; + b.is_m_tail = m; + b.is_k_tail = k; + b.is_n_tail = n; + b.is_zp_cal = z; + int8_kernels_[idx] + = std::unique_ptr { + new jit_int8_matmul_kernel_t(b)}; + if (!int8_kernels_[idx]) return status::runtime_error; + CHECK(int8_kernels_[idx]->create_kernel()); + } + + d.reorder_a = 1; + d.reorder_b = 0; + reo_ker_a_ = std::unique_ptr { + new jit_int8_matmul_utils_kernel_t(d)}; + CHECK(reo_ker_a_->create_kernel()); + + d.reorder_b = 1; + d.reorder_a = 0; + reo_ker_b_ = std::unique_ptr { + new jit_int8_matmul_utils_kernel_t(d)}; + CHECK(reo_ker_b_->create_kernel()); + + return status::success; +} + +jit_int8_matmul_t::jit_int8_matmul_t(const pd_t *apd) : primitive_t(apd) {} +jit_int8_matmul_t::~jit_int8_matmul_t() = default; + +status_t jit_int8_matmul_t::execute(const exec_ctx_t &ctx) const { + const auto *weights_b = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS); + const auto *src_b = CTX_IN_MEM(const float *, DNNL_ARG_SRC); + auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST); + const auto *bias = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); + + DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); + DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS); + DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST); + DEFINE_ZERO_POINTS_BUFFER(wei_zero_point_buf, DNNL_ARG_WEIGHTS); + DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); + DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); + DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); + + const auto &b = pd()->get_b(); + const auto &d = pd()->get_d(); + + auto &scratchpad = ctx.get_scratchpad_grantor(); + + int num_threads = dnnl_get_current_num_threads(); + char *src = scratchpad.template get(key_brgemm_primitive_buffer_a); + char *weights = (b.b_reo) + ? scratchpad.template get(key_gemm_blocked_b) + : (char *)weights_b; + char *zp_ptr_a + = scratchpad.template get(key_brgemm_primitive_zp_comp_a); + char *zp_ptr_b + = scratchpad.template get(key_brgemm_primitive_zp_comp_b); + const float *oscales = precompute_scales( + scratchpad, src_scales, wei_scales, pd()->N(), pd()->attr()); + + const dim_t B = b.B; + const dim_t M = b.M; + const dim_t N = b.N; + const dim_t K = b.K; + + auto reorder_a = [&]() { + int m_blks = div_up(M, b.m_blk); + int k_blks = div_up(K, b.k_blk); + int n_blks = div_up(N, (b.n_blk * b.ld_block)); + int parallel_work = B * m_blks * k_blks; + int parallel_work_mn = B * m_blks * n_blks; + int blk_per_bt = m_blks * k_blks; + int nt = std::min(num_threads, parallel_work); + nt = std::min(parallel_work_mn, nt); + auto tmp_src = src_b; + + parallel(nt, [&](const int ithr, const int nthr) { + int start {0}, end {0}; + balance211(parallel_work, nt, ithr, start, end); + + int bt = start / blk_per_bt; + int bs = start % blk_per_bt; + int nobl = end - start; + int nobt = 1; + int noblf = end - start, nobll; + + if (bs + nobl > blk_per_bt) { + nobt += div_up(nobl - (blk_per_bt - bs), blk_per_bt); + noblf = blk_per_bt - bs; + nobll = (nobl - (blk_per_bt - bs)) % blk_per_bt; + if (nobll == 0) nobll = blk_per_bt; + } + int nob; + for (int i = 0; i < nobt; i++) { + nob = (i == 0) ? noblf : ((i == nobt - 1) ? nobll : blk_per_bt); + bs = start % blk_per_bt; + int m_blk_src = bs / k_blks; + int k_blk_src = bs % k_blks; + int m_blk_dst = bs / k_blks; + int k_blk_dst = bs % k_blks; + + int k1 = std::min(k_blks - k_blk_src, nob); + int k_tmp = nob - k1; + int m1 = (k_tmp > 0) ? k_tmp / k_blks : 0; + int k2 = (k_tmp > 0) ? k_tmp % k_blks : 0; + int src_ad = (bt * M * K) + (m_blk_src * b.m_blk * K) + + (k_blk_src * b.k_blk); + int dst_ad = (bt * m_blks * k_blks * b.m_blk * b.k_blk) + + (m_blk_dst * k_blks * b.m_blk * b.k_blk) + + (k_blk_dst * b.m_blk * b.k_blk); + int src_new = src_ad, dst_new = dst_ad; + + dyn_params_t k; + + if (k1 > 0) { + int a = 1; + int mtl = (d.mtail > 0) ? 1 : 0; + int tl = (d.ktail > 0) ? 1 : 0; + if (k1 + k_blk_src < k_blks) tl = 0; + if (1 + m_blk_src < m_blks) mtl = 0; + k.src = (int8_t *)tmp_src + src_ad; + k.dst = (int8_t *)src + dst_ad; + k.nm = &a; + k.nk = &k1; + k.tl = &tl; + k.mtl = &mtl; + (*reo_ker_a_)(&k); + } + + if (m1 > 0) { + int mtl = (d.mtail > 0) ? 1 : 0; + int tl = (d.ktail > 0) ? 1 : 0; + if (1 + m1 + m_blk_src < m_blks) mtl = 0; + if (k1 != k_blks) { + src_new = src_ad - b.k_blk * (k_blks - k1) + + b.m_blk * K; + } else { + src_new = src_ad + b.m_blk * K; + } + dst_new = dst_ad + b.m_blk * b.k_blk * k1; + k.src = (int8_t *)tmp_src + src_new; + k.dst = (int8_t *)src + dst_new; + k.nm = &m1; + k.nk = &k_blks; + k.tl = &tl; + k.mtl = &mtl; + (*reo_ker_a_)(&k); + } + if (k2 > 0) { + int a = 1, tl = 0; + int mtl = (d.mtail > 0) ? 1 : 0; + if (1 + 1 + m1 + m_blk_src < m_blks) mtl = 0; + if (m1 < 1) { + src_new = src_ad - b.k_blk * (k_blks - k1) + + (b.m_blk * K); + dst_new = dst_ad + b.m_blk * b.k_blk * k1; + } else { + src_new += K * m1 * b.m_blk; + dst_new += b.m_blk * b.k_blk * k_blks * m1; + } + k.src = (int8_t *)tmp_src + src_new; + k.dst = (int8_t *)src + dst_new; + k.nm = &a; + k.nk = &k2; + k.tl = &tl; + k.mtl = &mtl; + (*reo_ker_a_)(&k); + } + bt++; + start += nob; + } + }); + }; + + auto reorder_b = [&]() { + int k_blks = div_up(K, d.k_blk); + int n_blks = div_up(N, d.n_blk); + int parallel_work = B * n_blks * k_blks; + int blk_per_bt = n_blks * k_blks; + int nt = std::min(num_threads, parallel_work); + + parallel(nt, [&](const int ithr, const int nthr) { + int start {0}, end {0}; + balance211(parallel_work, nt, ithr, start, end); + + int bt = start / blk_per_bt; + int bs = start % blk_per_bt; + int nobl = end - start; + int nobt = 1; + int noblf = end - start, nobll; + + if (bs + nobl > blk_per_bt) { + nobt += div_up(nobl - (blk_per_bt - bs), blk_per_bt); + noblf = blk_per_bt - bs; + nobll = (nobl - (blk_per_bt - bs)) % blk_per_bt; + if (nobll == 0) nobll = blk_per_bt; + } + int nob; + for (int i = 0; i < nobt; i++) { + nob = (i == 0) ? noblf : ((i == nobt - 1) ? nobll : blk_per_bt); + bs = start % blk_per_bt; + int n_blk_src = bs / k_blks; + int k_blk_src = bs % k_blks; + int n_blk_dst = bs / k_blks; + int k_blk_dst = bs % k_blks; + + int k1 = std::min(k_blks - k_blk_src, nob); + int k_tmp = nob - k1; + int n1 = (k_tmp > 0) ? k_tmp / k_blks : 0; + int k2 = (k_tmp > 0) ? k_tmp % k_blks : 0; + int src_ad = (bt * N * K) + (n_blk_src * d.n_blk) + + (k_blk_src * d.k_blk * N); + int dst_ad = (bt * n_blks * k_blks * d.k_blk * d.n_blk) + + (n_blk_dst * k_blks * d.k_blk * d.n_blk) + + (k_blk_dst * d.k_blk * d.n_blk); + int src_new = src_ad, dst_new = dst_ad; + + dyn_params_t k; + + if (k1 > 0) { + int a = 1; + int ntl = (d.ntail > 0) ? 1 : 0; + int tl = (d.ktail > 0) ? 1 : 0; + + if (k1 + k_blk_src < k_blks) tl = 0; + if (1 + n_blk_src < n_blks) ntl = 0; + k.src = (int8_t *)weights_b + src_ad; + k.dst = (int8_t *)weights + dst_ad; + k.nn = &a; + k.nk = &k1; + k.tl = &tl; + k.ntl = &ntl; + (*reo_ker_b_)(&k); + } + + if (n1 > 0) { + int ntl = (d.ntail > 0) ? 1 : 0; + int tl = (d.ktail > 0) ? 1 : 0; + if (1 + n1 + n_blk_src < n_blks) ntl = 0; + + if (k1 != k_blks) { + src_new = src_ad - d.k_blk * N * (k_blks - k1) + + d.n_blk; + } else { + src_new = src_ad + d.n_blk; + } + dst_new = dst_ad + d.k_blk * d.n_blk * k1; + k.src = (int8_t *)weights_b + src_new; + k.dst = (int8_t *)weights + dst_new; + k.nn = &n1; + k.nk = &k_blks; + k.tl = &tl; + k.ntl = &ntl; + (*reo_ker_b_)(&k); + } + if (k2 > 0) { + int a = 1, tl = 0; + int ntl = (d.ntail > 0) ? 1 : 0; + if (1 + 1 + n1 + n_blk_src < n_blks) ntl = 0; + if (n1 < 1) { + src_new = src_ad - d.k_blk * N * (k_blks - k1) + + d.n_blk; + dst_new = dst_ad + d.k_blk * d.n_blk * k1; + } else { + src_new += n1 * d.n_blk; + dst_new += d.k_blk * d.n_blk * k_blks * n1; + } + k.src = (int8_t *)weights_b + src_new; + k.dst = (int8_t *)weights + dst_new; + k.nn = &a; + k.nk = &k2; + k.tl = &tl; + k.ntl = &ntl; + (*reo_ker_b_)(&k); + } + bt++; + start += nob; + } + }); + }; + + auto kernel_execute = [&](int idx, int na, int nb, int m_blk_adr, + int n_blk_adr, int dst_adr, int bias_addr, + int scl_addr, int zp_ptr_a_adr, + int zp_ptr_b_adr, int zp_b_buf) { + call_params_t p; + p.na = &na; + p.nb = &nb; + p.src = (uint8_t *)src + m_blk_adr; + p.wei = (uint8_t *)weights + n_blk_adr; + p.dst = dst + dst_adr; + p.bias = (float *)bias + bias_addr; + p.scales = oscales + scl_addr; + p.dst_scales = dst_scales; + p.src_zero_point = &src_zero_point; + if (b.is_zp_b_int8) + p.wei_zero_point_buf = (int8_t *)wei_zero_point_buf + zp_b_buf; + else + p.wei_zero_point = &wei_zero_point; + p.dst_zero_point = &dst_zero_point; + p.M = M; + p.N = N; + p.K = K; + p.zp_a_ptr = (float *)zp_ptr_a + zp_ptr_a_adr; + p.zp_b_ptr = (float *)zp_ptr_b + zp_ptr_b_adr; + (*int8_kernels_[idx])(&p); + }; + + auto kernel_execute_zp = [&]() { + int num_a_blocks = div_up(M, b.m_blk); + int num_b_blocks = div_up(N, (b.n_blk * b.ld_block)); + int ktail = (b.k_tail == 0) ? 0 : 1; + int parallel_work = B * num_a_blocks; + int nt = std::min(num_threads, parallel_work); + if (b.zp_type_b != jit_int8_broadcast_t::none) + parallel(nt, [&](const int ithr, const int nthr) { + int start {0}, end {0}; + balance211(parallel_work, nt, ithr, start, end); + int batch = start / num_a_blocks; + int m_st = start % num_a_blocks; + int m_ed = end - start + m_st; + int mtail + = (m_ed == num_a_blocks) ? ((b.m_tail > 0) ? 1 : 0) : 0; + int m_blk_adr = (batch + * (num_a_blocks * b.m_blk + * div_up(K, b.k_blk) * b.k_blk)) + + m_st * b.m_blk * div_up(K, b.k_blk) * b.k_blk; + int zp_ptr_b_adr + = (batch * (num_a_blocks * b.m_blk)) + m_st * b.m_blk; + + int idx = pd()->get_idx(1, 0, ktail, 0, b); + if (idx < 0) { + assert(!"Requested int8 matmul kernel was not created."); + return; + } + int n_a = m_ed - m_st; + if (mtail) n_a -= 1; + kernel_execute( + idx, n_a, 0, m_blk_adr, 0, 0, 0, 0, 0, zp_ptr_b_adr, 0); + + if (mtail) { + idx = pd()->get_idx(1, mtail, ktail, 0, b); + if (idx < 0) { + assert(!"Requested int8 matmul kernel was not " + "created."); + return; + } + m_blk_adr += n_a * b.m_blk * div_up(K, b.k_blk) * b.k_blk; + zp_ptr_b_adr += n_a * b.m_blk; + kernel_execute(idx, 1, 0, m_blk_adr, 0, 0, 0, 0, 0, + zp_ptr_b_adr, 0); + } + start++; + }); + + parallel_work = B * num_b_blocks; + nt = std::min(num_threads, parallel_work); + if (b.zp_type_a != jit_int8_broadcast_t::none) + parallel(nt, [&](const int ithr, const int nthr) { + int start {0}, end {0}; + balance211(parallel_work, nt, ithr, start, end); + int batch = start / num_b_blocks; + int n_st = start % num_b_blocks; + int n_ed = n_st + end - start; + int ntail + = (n_ed == num_b_blocks) ? ((b.n_tail > 0) ? 1 : 0) : 0; + int n_blk_adr = (batch + * (num_b_blocks * (b.n_blk * b.ld_block) + * div_up(K, b.k_blk) * b.k_blk)) + + n_st * (b.n_blk * b.ld_block) * div_up(K, b.k_blk) + * b.k_blk; + int zp_ptr_a_adr + = (batch * num_b_blocks * (b.n_blk * b.ld_block)) + + n_st * (b.n_blk * b.ld_block); + + int idx = pd()->get_idx(1, 0, ktail, 0, b); + if (idx < 0) { + assert(!"Requested int8 matmul kernel was not created."); + return; + } + int n_b = n_ed - n_st; + if (ntail == 1) n_b -= 1; + + kernel_execute( + idx, 0, n_b, 0, n_blk_adr, 0, 0, 0, zp_ptr_a_adr, 0, 0); + + if (ntail) { + idx = pd()->get_idx(1, 0, ktail, 1, b); + if (idx < 0) { + assert(!"Requested int8 matmul kernel was not " + "created."); + return; + } + n_blk_adr += n_b * (b.n_blk * b.ld_block) + * div_up(K, b.k_blk) * b.k_blk; + zp_ptr_a_adr += n_b * (b.n_blk * b.ld_block); + kernel_execute(idx, 0, 1, 0, n_blk_adr, 0, 0, 0, + zp_ptr_a_adr, 0, 0); + } + + start++; + }); + }; + + if (b.b_reo) reorder_b(); + + reorder_a(); + + if (b.zp_type_a != jit_int8_broadcast_t::none + || b.zp_type_b != jit_int8_broadcast_t::none) + kernel_execute_zp(); + + int m_block_sz = 32; + int n_block_sz = 24; + int m_block1 = div_up(m_block_sz, b.m_blk); + int n_block1 = div_up(n_block_sz, (b.n_blk * b.ld_block)); + int m_block1_rs = div_up(M % m_block_sz, b.m_blk); + int n_block1_rs = div_up(N % n_block_sz, (b.n_blk * b.ld_block)); + + int num_a_blocks_act = div_up(M, b.m_blk); + int num_b_blocks_act = div_up(N, (b.n_blk * b.ld_block)); + int num_a_blocks = div_up(M, m_block_sz); + int num_b_blocks = div_up(N, n_block_sz); + int ktail = (b.k_tail == 0) ? 0 : 1; + int parallel_work = B * num_a_blocks * num_b_blocks; + int nt = std::min(num_threads, parallel_work); + + parallel(nt, [&](const int ithr, const int nthr) { + int start {0}, end {0}; + balance211(parallel_work, nt, ithr, start, end); + while (start < end) { + int batch = start / (num_a_blocks * num_b_blocks); + int batch_start = start % (num_a_blocks * num_b_blocks); + int m_block = batch_start % num_a_blocks; + int n_block = batch_start / num_a_blocks; + int mtail + = (m_block1_rs != 0 && m_block == num_a_blocks - 1) ? 1 : 0; + int ntail + = (n_block1_rs != 0 && n_block == num_b_blocks - 1) ? 1 : 0; + int dst_adr = (batch * M * N) + m_block * b.m_blk * m_block1 * N + + n_block * (b.n_blk * b.ld_block) * n_block1; + int m_blk_adr = (batch + * (num_a_blocks_act * b.m_blk + * div_up(K, b.k_blk) * b.k_blk)) + + m_block * b.m_blk * m_block1 * div_up(K, b.k_blk) + * b.k_blk; + int n_blk_adr = (batch + * (num_b_blocks_act * (b.n_blk * b.ld_block) + * div_up(K, b.k_blk) * b.k_blk)) + + n_block * (b.n_blk * b.ld_block) * n_block1 + * div_up(K, b.k_blk) * b.k_blk; + int zp_ptr_a_adr + = (batch * (num_b_blocks_act * (b.n_blk * b.ld_block))) + + n_block * (b.n_blk * b.ld_block) * n_block1; + int zp_ptr_b_adr = (batch * (num_a_blocks_act * b.m_blk)) + + m_block * b.m_blk * m_block1; + int bias_addr = n_block * (b.n_blk * b.ld_block) * n_block1; + int zp_b_buf = n_block * (b.n_blk * b.ld_block) * n_block1; + int scl_addr = (b.is_oc_scales) + ? (n_block * (b.n_blk * b.ld_block) * n_block1) + : 0; + int idx = pd()->get_idx(0, 0, ktail, 0, b); + if (idx < 0) { + assert(!"Requested int8 matmul kernel was not created."); + return; + } + int n_a = m_block1, n_b = n_block1; + n_a = (mtail) ? ((b.m_tail) ? m_block1_rs - 1 : m_block1_rs) + : m_block1; + n_b = (ntail) ? ((b.n_tail) ? n_block1_rs - 1 : n_block1_rs) + : n_block1; + + if (n_a > 0 && n_b > 0) { + + kernel_execute(idx, n_a, n_b, m_blk_adr, n_blk_adr, dst_adr, + bias_addr, scl_addr, zp_ptr_a_adr, zp_ptr_b_adr, + zp_b_buf); + } + + if (mtail && b.m_tail > 0 && n_b > 0) { + int new_dst_adr = dst_adr + b.m_blk * n_a * N; + int new_m_blk_adr = m_blk_adr + + b.m_blk * n_a * div_up(K, b.k_blk) * b.k_blk; + int new_zp_ptr_b_adr = zp_ptr_b_adr + b.m_blk * n_a; + idx = pd()->get_idx(0, 1, ktail, 0, b); + if (idx < 0) { + assert(!"Requested int8 matmul kernel was not created."); + return; + } + int na = 1; + kernel_execute(idx, na, n_b, new_m_blk_adr, n_blk_adr, + new_dst_adr, bias_addr, scl_addr, zp_ptr_a_adr, + new_zp_ptr_b_adr, zp_b_buf); + } + + if (ntail && b.n_tail > 0 && n_a > 0) { + int new_dst_adr = dst_adr + (b.n_blk * b.ld_block) * n_b; + int new_n_blk_adr = n_blk_adr + + (b.n_blk * b.ld_block) * n_b * div_up(K, b.k_blk) + * b.k_blk; + int new_zp_b_buf = zp_b_buf + (b.n_blk * b.ld_block) * n_b; + int new_zp_ptr_a_adr + = zp_ptr_a_adr + (b.n_blk * b.ld_block) * n_b; + int new_bias_addr = bias_addr + (b.n_blk * b.ld_block) * n_b; + int new_scl_addr = scl_addr + + ((b.is_oc_scales) ? ((b.n_blk * b.ld_block) * n_b) + : 0); + idx = pd()->get_idx(0, 0, ktail, 1, b); + if (idx < 0) { + assert(!"Requested int8 matmul kernel was not created."); + return; + } + int nb = 1; + + kernel_execute(idx, n_a, nb, m_blk_adr, new_n_blk_adr, + new_dst_adr, new_bias_addr, new_scl_addr, + new_zp_ptr_a_adr, zp_ptr_b_adr, new_zp_b_buf); + } + + if (mtail && b.m_tail > 0 && ntail && b.n_tail > 0) { + int new_dst_adr = dst_adr + (b.n_blk * b.ld_block) * n_b + + b.m_blk * n_a * N; + int new_m_blk_adr = m_blk_adr + + b.m_blk * n_a * div_up(K, b.k_blk) * b.k_blk; + int new_n_blk_adr = n_blk_adr + + (b.n_blk * b.ld_block) * n_b * div_up(K, b.k_blk) + * b.k_blk; + int new_zp_b_buf = zp_b_buf + (b.n_blk * b.ld_block) * n_b; + int new_zp_ptr_a_adr + = zp_ptr_a_adr + (b.n_blk * b.ld_block) * n_b; + int new_zp_ptr_b_adr = zp_ptr_b_adr + b.m_blk * n_a; + int new_bias_addr = bias_addr + (b.n_blk * b.ld_block) * n_b; + int new_scl_addr = scl_addr + + ((b.is_oc_scales) ? ((b.n_blk * b.ld_block) * n_b) + : 0); + idx = pd()->get_idx(0, 1, ktail, 1, b); + if (idx < 0) { + assert(!"Requested int8 matmul kernel was not created."); + return; + } + int nb = 1, na = 1; + kernel_execute(idx, na, nb, new_m_blk_adr, new_n_blk_adr, + new_dst_adr, new_bias_addr, new_scl_addr, + new_zp_ptr_a_adr, new_zp_ptr_b_adr, new_zp_b_buf); + } + start++; + } + }); + + return status::success; +} + +} // namespace matmul +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/matmul/jit_int8_matmul.hpp b/src/cpu/aarch64/matmul/jit_int8_matmul.hpp new file mode 100644 index 00000000000..6cc32633a2a --- /dev/null +++ b/src/cpu/aarch64/matmul/jit_int8_matmul.hpp @@ -0,0 +1,111 @@ +/******************************************************************************* +* Copyright 2025 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_JIT_INT8_MATMUL_HPP +#define CPU_AARCH64_JIT_INT8_MATMUL_HPP + +#include "common/c_types_map.hpp" +#include "common/primitive.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/platform.hpp" +#include "cpu/primitive_attr_postops.hpp" + +#include "cpu/aarch64/matmul/jit_int8_kernel_types.hpp" +#include "cpu/matmul/cpu_matmul_pd.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace matmul { + +struct jit_int8_matmul_kernel_t; +struct jit_int8_matmul_utils_kernel_t; + +struct jit_int8_matmul_t : public primitive_t { + struct pd_t : public dnnl::impl::cpu::matmul::cpu_matmul_pd_t { + using ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t::cpu_matmul_pd_t; + + DECLARE_COMMON_PD_T("jit:int8", jit_int8_matmul_t); + + status_t init(engine_t *engine); + + bool formats_ok() const { + + const memory_desc_wrapper src_d(src_md_); + const memory_desc_wrapper weights_d(weights_md_); + const memory_desc_wrapper dst_d(dst_md_); + const bool is_dst = dst_d.matches_one_of_tag(format_tag::ab, + format_tag::abc, format_tag::abcd) + != format_tag::undef + || dst_d.format_kind() == format_kind::any; + const bool is_wei + = weights_d.matches_one_of_tag(format_tag::ab, + format_tag::abc, format_tag::abcd, + format_tag::BA24b8a, format_tag::aCB24c8b, + format_tag::abDC24d8c) + != format_tag::undef + || weights_d.format_kind() == format_kind::any; + const bool is_src = src_d.matches_one_of_tag(format_tag::ab, + format_tag::abc, format_tag::abcd) + != format_tag::undef + || src_d.format_kind() == format_kind::any; + return is_dst && is_wei && is_src; + } + const brg_int8_t &get_b() const { return brg_; } + + const dyn_vals_t &get_d() const { return dyn_; } + + int get_idx(int z, int m, int k, int n, const brg_int8_t b) const { + + if (b.zp_type_a == jit_int8_broadcast_t::none + && b.zp_type_b == jit_int8_broadcast_t::none && z == 1) + return -1; + int mt = b.M % b.m_blk; + int nt = b.N % (b.n_blk * b.ld_block); + int kt = b.K % (b.k_blk * 4); + if ((m == 1 && mt == 0) || (k == 1 && kt == 0) + || (n == 1 && nt == 0) || (k == 0 && kt == 1)) + return -1; + return k + n * 2 + m * 2 * 2 + z * 2 * 2 * 2; + } + + private: + brg_int8_t brg_; + dyn_vals_t dyn_; + }; + + jit_int8_matmul_t(const pd_t *apd); + ~jit_int8_matmul_t() override; + int get_idx(int z, int m, int k, int n, int M, int K, int N); + status_t init(engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::unique_ptr int8_kernels_[16]; + std::unique_ptr reo_ker_a_; + std::unique_ptr reo_ker_b_; +}; + +} // namespace matmul +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl +#endif diff --git a/src/cpu/aarch64/matmul/jit_int8_matmul_utils.cpp b/src/cpu/aarch64/matmul/jit_int8_matmul_utils.cpp new file mode 100644 index 00000000000..3908598677f --- /dev/null +++ b/src/cpu/aarch64/matmul/jit_int8_matmul_utils.cpp @@ -0,0 +1,294 @@ +/******************************************************************************* +* Copyright 2025 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "common/math_utils.hpp" +#include "cpu/aarch64/jit_generator.hpp" +#include "cpu/aarch64/matmul/jit_int8_matmul_utils.hpp" + +#define GET_OFF(field) (uint32_t) offsetof(dyn_params_t, field) + +#define LDR_IMM(reg, addr, off) \ + { \ + const uint64_t IMM12_MASK = ~uint64_t(0xfff); \ + if ((off & IMM12_MASK) == 0) { \ + ldr(reg, ptr(addr, off)); \ + } else { \ + add_imm(X_DEFAULT_ADDR, addr, off, X_TMP_0); \ + ldr(reg, ptr(X_DEFAULT_ADDR)); \ + } \ + } + +#define VCHECK_BG(f, msg, ...) \ + VCHECK(primitive, create, dispatch, brgemm_matmul, f, msg, ##__VA_ARGS__); + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace matmul { + +using namespace Xbyak_aarch64; +using namespace dnnl::impl::format_tag; +using namespace dnnl::impl::utils; + +using namespace nstl; + +using namespace data_type; + +void jit_int8_matmul_utils_kernel_t::reo_A_8x8(int lp, int kt) { + mov(reg_tmp_1, reg_tmp); + if (kt > 0) { + for (int i = 0; i < lp; i++) { + ld1b(ZRegB(i), prd_ld, ptr(reg_tmp_1)); + add_imm(reg_tmp_1, reg_tmp_1, dyn_.K, X_TMP_0); + st1b(ZRegB(i), prd_st, ptr(reg_dst)); + add_imm(reg_dst, reg_dst, dyn_.k_blk, X_TMP_0); + } + for (int i = 0; i < dyn_.m_blk - lp; i++) { + mov(ZRegB(i), 0); + st1b(ZRegB(i), prd_st, ptr(reg_dst)); + add_imm(reg_dst, reg_dst, dyn_.k_blk, X_TMP_0); + } + } else { + for (int i = 0; i < lp; i++) { + ldr(DReg(i), ptr(reg_tmp_1)); + add_imm(reg_tmp_1, reg_tmp_1, dyn_.K, X_TMP_0); + str(DReg(i), ptr(reg_dst)); + add_imm(reg_dst, reg_dst, dyn_.k_blk, X_TMP_0); + } + for (int i = 0; i < dyn_.m_blk - lp; i++) { + mov(ZRegB(i), 0); + st1b(ZRegB(i), prd_st, ptr(reg_dst)); + add_imm(reg_dst, reg_dst, dyn_.k_blk, X_TMP_0); + } + } +} + +void jit_int8_matmul_utils_kernel_t::reo_B_8x24(int lp, int nt) { + auto p = (nt > 0) ? prd_p3 : prd_ld; + mov(reg_tmp, reg_aux_a); + for (int i = 0; i < lp; i++) { + ld1b(ZRegB(i), p, ptr(reg_tmp)); + add_imm(reg_tmp, reg_tmp, dyn_.N, X_TMP_4); + } + for (int i = lp; i < dyn_.k_blk; i++) { + mov(ZRegB(i), 0); + } + + zip2(ZRegB(8), ZRegB(0), ZRegB(1)); + zip1(ZRegB(0), ZRegB(0), ZRegB(1)); + zip2(ZRegB(10), ZRegB(2), ZRegB(3)); + zip1(ZRegB(2), ZRegB(2), ZRegB(3)); + zip2(ZRegB(12), ZRegB(4), ZRegB(5)); + zip1(ZRegB(4), ZRegB(4), ZRegB(5)); + zip2(ZRegB(14), ZRegB(6), ZRegB(7)); + zip1(ZRegB(6), ZRegB(6), ZRegB(7)); + + zip2(ZRegH(1), ZRegH(0), ZRegH(2)); + zip1(ZRegH(0), ZRegH(0), ZRegH(2)); + zip2(ZRegH(5), ZRegH(4), ZRegH(6)); + zip1(ZRegH(4), ZRegH(4), ZRegH(6)); + zip1(ZRegH(8), ZRegH(8), ZRegH(10)); + zip1(ZRegH(12), ZRegH(12), ZRegH(14)); + + zip2(ZRegS(2), ZRegS(0), ZRegS(4)); + zip1(ZRegS(0), ZRegS(0), ZRegS(4)); + zip2(ZRegS(6), ZRegS(1), ZRegS(5)); + zip1(ZRegS(1), ZRegS(1), ZRegS(5)); + zip2(ZRegS(10), ZRegS(8), ZRegS(12)); + zip1(ZRegS(8), ZRegS(8), ZRegS(12)); + + str(ZReg(0), ptr(reg_aux_b, 0, MUL_VL)); + str(ZReg(2), ptr(reg_aux_b, 1, MUL_VL)); + str(ZReg(1), ptr(reg_aux_b, 2, MUL_VL)); + str(ZReg(6), ptr(reg_aux_b, 3, MUL_VL)); + str(ZReg(8), ptr(reg_aux_b, 4, MUL_VL)); + str(ZReg(10), ptr(reg_aux_b, 5, MUL_VL)); + + add_imm(reg_aux_b, reg_aux_b, dyn_.n_blk * dyn_.k_blk, X_TMP_4); +} + +void jit_int8_matmul_utils_kernel_t::gen_reo_a() { + + int ktl = (dyn_.ktail) ? dyn_.ktail : dyn_.k_blk; + + set_preg(prd_ld.b, ktl, X_TMP_0, X_TMP_1); + set_preg(prd_st.b, dyn_.k_blk, X_TMP_0, X_TMP_1); + + int lp = (dyn_.mtail) ? dyn_.mtail : dyn_.m_blk; + + Label m_loop, last_m, m_end, k_loop, last_k, k_end, k_loop_1, last_k_1, + k_end_1; + + LDR_IMM(reg_max, reg_param, GET_OFF(nk)); + LDR_IMM(reg_min, reg_param, GET_OFF(nm)); + + LDR_IMM(reg_tmp_2, reg_param, GET_OFF(tl)); + ldr(WReg(reg_tail.getIdx()), ptr(reg_tmp_2)); + + LDR_IMM(reg_tmp_2, reg_param, GET_OFF(mtl)); + ldr(WReg(reg_m_tail.getIdx()), ptr(reg_tmp_2)); + + ldr(WReg(reg_m_loop.getIdx()), ptr(reg_min)); + + cmp(reg_m_loop, 1); + b(EQ, last_m); + L(m_loop); + ldr(WReg(reg_k_loop.getIdx()), ptr(reg_max)); + mov(reg_tmp, reg_src); + cmp(reg_k_loop, 1); + b(EQ, last_k); + L(k_loop); + reo_A_8x8(dyn_.m_blk, 0); + add_imm(reg_tmp, reg_tmp, dyn_.k_blk, X_TMP_0); + sub(reg_k_loop, reg_k_loop, 1); + cmp(reg_k_loop, 1); + b(GT, k_loop); + b(LT, k_end); + L(last_k); + sub(reg_k_loop, reg_k_loop, 1); + cmp(reg_tail, 0); + b(EQ, k_loop); + reo_A_8x8(dyn_.m_blk, 1); + L(k_end); + add_imm(reg_src, reg_src, dyn_.K * dyn_.m_blk, X_TMP_0); + sub(reg_m_loop, reg_m_loop, 1); + cmp(reg_m_loop, 1); + b(GT, m_loop); + b(LT, m_end); + + L(last_m); + sub(reg_m_loop, reg_m_loop, 1); + cmp(reg_m_tail, 0); + b(EQ, m_loop); + ldr(WReg(reg_k_loop.getIdx()), ptr(reg_max)); + mov(reg_tmp, reg_src); + cmp(reg_k_loop, 1); + b(EQ, last_k_1); + L(k_loop_1); + reo_A_8x8(lp, 0); + add_imm(reg_tmp, reg_tmp, dyn_.k_blk, X_TMP_0); + sub(reg_k_loop, reg_k_loop, 1); + cmp(reg_k_loop, 1); + b(GT, k_loop_1); + b(LT, k_end_1); + L(last_k_1); + sub(reg_k_loop, reg_k_loop, 1); + cmp(reg_tail, 0); + b(EQ, k_loop_1); + reo_A_8x8(lp, 1); + L(k_end_1); + L(m_end); +} + +void jit_int8_matmul_utils_kernel_t::gen_reo_b() { + + int lp = (dyn_.ktail > 0) ? dyn_.ktail : dyn_.k_blk; + + set_preg(prd_ld.b, dyn_.n_blk, X_TMP_4, X_TMP_1); + set_preg(prd_p3.b, dyn_.ntail, X_TMP_4, X_TMP_1); + + LDR_IMM(reg_max, reg_param, GET_OFF(nn)); + LDR_IMM(reg_min, reg_param, GET_OFF(nk)); + + LDR_IMM(reg_tmp_2, reg_param, GET_OFF(tl)); + ldr(WReg(reg_tail.getIdx()), ptr(reg_tmp_2)); + + LDR_IMM(reg_tmp_2, reg_param, GET_OFF(ntl)); + ldr(WReg(reg_n_tail.getIdx()), ptr(reg_tmp_2)); + + ldr(WReg(reg_n_loop.getIdx()), ptr(reg_max)); + ldr(WReg(reg_k_loop.getIdx()), ptr(reg_min)); + + mov(reg_aux_a, reg_src); + mov(reg_aux_b, reg_dst); + + Label n_loop, last_n, n_end, k_loop, last_k, k_end, k_loop_1, last_k_1, + k_end_1; + + cmp(reg_n_loop, 1); + b(EQ, last_n); + L(n_loop); + ldr(WReg(reg_k_loop.getIdx()), ptr(reg_min)); + mov(reg_aux_a, reg_src); + cmp(reg_k_loop, 1); + b(EQ, last_k); + L(k_loop); + reo_B_8x24(dyn_.k_blk, 0); + add_imm(reg_aux_a, reg_aux_a, dyn_.k_blk * dyn_.N, X_TMP_4); + sub(reg_k_loop, reg_k_loop, 1); + cmp(reg_k_loop, 1); + b(GT, k_loop); + b(LT, k_end); + L(last_k); + sub(reg_k_loop, reg_k_loop, 1); + cmp(reg_tail, 0); + b(EQ, k_loop); + reo_B_8x24(lp, 0); + L(k_end); + add_imm(reg_src, reg_src, dyn_.n_blk, X_TMP_4); + sub(reg_n_loop, reg_n_loop, 1); + cmp(reg_n_loop, 1); + b(GT, n_loop); + b(LT, n_end); + + L(last_n); + sub(reg_n_loop, reg_n_loop, 1); + cmp(reg_n_tail, 0); + b(EQ, n_loop); + ldr(WReg(reg_k_loop.getIdx()), ptr(reg_min)); + mov(reg_aux_a, reg_src); + cmp(reg_k_loop, 1); + b(EQ, last_k_1); + L(k_loop_1); + reo_B_8x24(dyn_.k_blk, dyn_.ntail); + add_imm(reg_aux_a, reg_aux_a, dyn_.k_blk * dyn_.N, X_TMP_4); + sub(reg_k_loop, reg_k_loop, 1); + cmp(reg_k_loop, 1); + b(GT, k_loop_1); + b(LT, k_end_1); + L(last_k_1); + sub(reg_k_loop, reg_k_loop, 1); + cmp(reg_tail, 0); + b(EQ, k_loop_1); + reo_B_8x24(lp, dyn_.ntail); + L(k_end_1); + L(n_end); +} + +void jit_int8_matmul_utils_kernel_t::generate() { + + preamble(); + + if (dyn_.reorder_a == 1) { + LDR_IMM(reg_src, reg_param, GET_OFF(src)); + LDR_IMM(reg_dst, reg_param, GET_OFF(dst)); + gen_reo_a(); + } else if (dyn_.reorder_b == 1) { + LDR_IMM(reg_src, reg_param, GET_OFF(src)); + LDR_IMM(reg_dst, reg_param, GET_OFF(dst)); + gen_reo_b(); + } + + postamble(); +} +} // namespace matmul +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/matmul/jit_int8_matmul_utils.hpp b/src/cpu/aarch64/matmul/jit_int8_matmul_utils.hpp new file mode 100644 index 00000000000..d7905f84896 --- /dev/null +++ b/src/cpu/aarch64/matmul/jit_int8_matmul_utils.hpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2025 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_JIT_INT8_MATMUL_UTILS_HPP +#define CPU_AARCH64_JIT_INT8_MATMUL_UTILS_HPP + +// #include "common/primitive.hpp" +#include "cpu/aarch64/jit_generator.hpp" +#include "cpu/aarch64/matmul/jit_int8_kernel_types.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace matmul { + +using namespace Xbyak_aarch64; +struct jit_int8_matmul_utils_kernel_t : public jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_int8_matmul_utils_kernel_t); + + XReg reg_param = abi_param1; + XReg reg_src = x3; + XReg reg_dst = x4; + XReg reg_scl = x5; + XReg reg_zp = x6; + XReg reg_tmp = x7; + XReg reg_tmp_2 = x8; + XReg reg_max = x9; + XReg reg_min = x10; + XReg reg_tmp_1 = x11; + XReg reg_k_loop = x12; + XReg reg_m_loop = x13; + XReg reg_loop = x14; + XReg reg_tail = x15; + XReg reg_m_tail = x16; + XReg reg_aux_b = x17; + XReg reg_aux_a = x18; + + PReg prd_ld = p1; + PReg prd_st = p2; + PReg prd_p1 = p3; + PReg prd_p2 = p4; + PReg prd_p3 = p5; + + XReg reg_n_loop = reg_m_loop; + XReg reg_n_tail = reg_m_tail; + + int f32_dt_sz = 4; + + void operator()(const dyn_params_t *p) { + return jit_generator::operator()(p); + } + + jit_int8_matmul_utils_kernel_t(const dyn_vals_t &k) : dyn_(k) {} + ~jit_int8_matmul_utils_kernel_t() override = default; + +private: + void gen_reo_a(); + void gen_reo_b(); + void reo_A_8x8(int, int); + void reo_B_8x24(int, int); + void generate() override; + + dyn_vals_t dyn_; +}; + +} // namespace matmul +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl +#endif diff --git a/src/cpu/aarch64/xbyak_aarch64/_clang-format b/src/cpu/aarch64/xbyak_aarch64/_clang-format deleted file mode 100644 index af422e6188f..00000000000 --- a/src/cpu/aarch64/xbyak_aarch64/_clang-format +++ /dev/null @@ -1,127 +0,0 @@ ---- -Language: Cpp -# BasedOnStyle: LLVM -AccessModifierOffset: -2 -AlignAfterOpenBracket: Align -AlignConsecutiveMacros: false -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignEscapedNewlines: Right -AlignOperands: true -AlignTrailingComments: true -AllowAllArgumentsOnNextLine: true -AllowAllConstructorInitializersOnNextLine: true -AllowAllParametersOfDeclarationOnNextLine: true -AllowShortBlocksOnASingleLine: false -AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: All -AllowShortLambdasOnASingleLine: All -AllowShortIfStatementsOnASingleLine: false -AllowShortLoopsOnASingleLine: false -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: false -AlwaysBreakTemplateDeclarations: MultiLine -BinPackArguments: true -BinPackParameters: true -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false - SplitEmptyFunction: true - SplitEmptyRecord: true - SplitEmptyNamespace: true -BreakBeforeBinaryOperators: None -BreakBeforeBraces: Attach -BreakBeforeInheritanceComma: false -BreakInheritanceList: BeforeColon -BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakConstructorInitializers: BeforeColon -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: true -ColumnLimit: 300 -CommentPragmas: '^ IWYU pragma:' -CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: false -ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DerivePointerAlignment: false -DisableFormat: false -ExperimentalAutoDetectBinPacking: false -FixNamespaceComments: true -ForEachMacros: - - foreach - - Q_FOREACH - - BOOST_FOREACH -IncludeBlocks: Preserve -IncludeCategories: - - Regex: '^"(llvm|llvm-c|clang|clang-c)/' - Priority: 2 - - Regex: '^(<|"(gtest|gmock|isl|json)/)' - Priority: 3 - - Regex: '.*' - Priority: 1 -IncludeIsMainRegex: '(Test)?$' -IndentCaseLabels: false -IndentPPDirectives: None -IndentWidth: 2 -IndentWrappedFunctionNames: false -JavaScriptQuotes: Leave -JavaScriptWrapImports: true -KeepEmptyLinesAtTheStartOfBlocks: true -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBinPackProtocolList: Auto -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PenaltyBreakAssignment: 2 -PenaltyBreakBeforeFirstCallParameter: 19 -PenaltyBreakComment: 120 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 60 -PointerAlignment: Right -ReflowComments: true -SortIncludes: true -SortUsingDeclarations: true -SpaceAfterCStyleCast: false -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 1 -SpacesInAngles: false -SpacesInContainerLiterals: true -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -Standard: Cpp11 -StatementMacros: - - Q_UNUSED - - QT_REQUIRE_VERSION -TabWidth: 8 -UseTab: Never -... - diff --git a/src/cpu/acl/CMakeLists.txt b/src/cpu/acl/CMakeLists.txt new file mode 100644 index 00000000000..abe0a5c49eb --- /dev/null +++ b/src/cpu/acl/CMakeLists.txt @@ -0,0 +1,33 @@ +#******************************************************************************* +# Copyright 2020-2022 Arm Ltd. and affiliates +# Copyright 2020-2021 FUJITSU LIMITED +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#******************************************************************************* +file(GLOB_RECURSE SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/*.[ch] + ${CMAKE_CURRENT_SOURCE_DIR}/*.[ch]pp + ) +# If the runtime is not THREADPOOL remove threadpool_scheduler sources. +if(NOT DNNL_CPU_RUNTIME STREQUAL "THREADPOOL") + list(APPEND ACL_THREADPOOL_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/acl_threadpool_scheduler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/acl_threadpool_scheduler.hpp + ) + list(REMOVE_ITEM SOURCES ${ACL_THREADPOOL_FILES}) +endif() +set(OBJ_LIB ${DNNL_LIBRARY_NAME}_cpu_acl) +add_library(${OBJ_LIB} OBJECT ${SOURCES}) +set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS + $) +enable_conditional_compilation4(${OBJ_LIB}) \ No newline at end of file diff --git a/src/cpu/aarch64/acl_batch_normalization.cpp b/src/cpu/acl/acl_batch_normalization.cpp similarity index 96% rename from src/cpu/aarch64/acl_batch_normalization.cpp rename to src/cpu/acl/acl_batch_normalization.cpp index 77a723207fc..83f4c5061a0 100644 --- a/src/cpu/aarch64/acl_batch_normalization.cpp +++ b/src/cpu/acl/acl_batch_normalization.cpp @@ -14,12 +14,12 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_batch_normalization.hpp" +#include "cpu/acl/acl_batch_normalization.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_batch_normalization_fwd_t::execute_forward( const exec_ctx_t &ctx) const { @@ -72,7 +72,7 @@ status_t acl_batch_normalization_fwd_t::execute_forward( return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_batch_normalization.hpp b/src/cpu/acl/acl_batch_normalization.hpp similarity index 95% rename from src/cpu/aarch64/acl_batch_normalization.hpp rename to src/cpu/acl/acl_batch_normalization.hpp index 9e91e8b7279..ef7e4c22cbd 100644 --- a/src/cpu/aarch64/acl_batch_normalization.hpp +++ b/src/cpu/acl/acl_batch_normalization.hpp @@ -14,18 +14,18 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_BATCH_NORMALIZATION_HPP -#define CPU_AARCH64_ACL_BATCH_NORMALIZATION_HPP +#ifndef CPU_ACL_BATCH_NORMALIZATION_HPP +#define CPU_ACL_BATCH_NORMALIZATION_HPP #include "cpu/cpu_batch_normalization_pd.hpp" -#include "cpu/aarch64/acl_post_ops.hpp" -#include "cpu/aarch64/acl_utils.hpp" +#include "cpu/acl/acl_post_ops.hpp" +#include "cpu/acl/acl_utils.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_batch_normalization_obj_t { arm_compute::NEBatchNormalizationLayer bnorm; @@ -92,12 +92,6 @@ struct acl_batch_normalization_fwd_t : public primitive_t { using cpu_batch_normalization_fwd_pd_t:: cpu_batch_normalization_fwd_pd_t; - pd_t(const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) - , abp() {} - DECLARE_COMMON_PD_T("acl", acl_batch_normalization_fwd_t); status_t init(engine_t *engine) { @@ -240,7 +234,7 @@ struct acl_batch_normalization_fwd_t : public primitive_t { return status::success; } - acl_batch_normalization_conf_t abp; + acl_batch_normalization_conf_t abp = utils::zero(); acl_post_ops_t post_ops; }; // pd_t @@ -272,7 +266,7 @@ struct acl_batch_normalization_fwd_t : public primitive_t { const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } }; // acl_batch_normalization_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_benchmark_scheduler.cpp b/src/cpu/acl/acl_benchmark_scheduler.cpp similarity index 96% rename from src/cpu/aarch64/acl_benchmark_scheduler.cpp rename to src/cpu/acl/acl_benchmark_scheduler.cpp index e8658dd6e6a..b2ceb96339e 100644 --- a/src/cpu/aarch64/acl_benchmark_scheduler.cpp +++ b/src/cpu/acl/acl_benchmark_scheduler.cpp @@ -14,13 +14,13 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_benchmark_scheduler.hpp" +#include "cpu/acl/acl_benchmark_scheduler.hpp" #include "common/verbose.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { using namespace arm_compute; BenchmarkScheduler::BenchmarkScheduler(IScheduler &real_scheduler) @@ -72,7 +72,7 @@ void BenchmarkScheduler::run_workloads(std::vector &workloads) { ARM_COMPUTE_ERROR("Can't be reached"); } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl \ No newline at end of file diff --git a/src/cpu/aarch64/acl_benchmark_scheduler.hpp b/src/cpu/acl/acl_benchmark_scheduler.hpp similarity index 92% rename from src/cpu/aarch64/acl_benchmark_scheduler.hpp rename to src/cpu/acl/acl_benchmark_scheduler.hpp index 8fddf7ea298..a23a903c385 100644 --- a/src/cpu/aarch64/acl_benchmark_scheduler.hpp +++ b/src/cpu/acl/acl_benchmark_scheduler.hpp @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_BENCHMARK_SCHEDULER_HPP -#define CPU_AARCH64_ACL_BENCHMARK_SCHEDULER_HPP +#ifndef CPU_ACL_BENCHMARK_SCHEDULER_HPP +#define CPU_ACL_BENCHMARK_SCHEDULER_HPP #include "arm_compute/core/CPP/ICPPKernel.h" #include "arm_compute/runtime/IScheduler.h" @@ -22,7 +22,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { // BenchmarkScheduler implement's ACL IScheduler interface and acts as an interceptor scheduler // when DNNL_VERBOSE=profile,profile_externals. It intercepts calls made by the actual scheduler used by ACL and adds // timers to benchmark execution time of ACL kernels and store kernel information. @@ -52,9 +52,9 @@ class BenchmarkScheduler final : public arm_compute::IScheduler { IScheduler &_real_scheduler; }; -#endif // CPU_AARCH64_ACL_BENCHMARK_SCHEDULER_HPP +#endif // CPU_ACL_BENCHMARK_SCHEDULER_HPP -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_binary.cpp b/src/cpu/acl/acl_binary.cpp similarity index 99% rename from src/cpu/aarch64/acl_binary.cpp rename to src/cpu/acl/acl_binary.cpp index b1b70c80636..04418e65e70 100644 --- a/src/cpu/aarch64/acl_binary.cpp +++ b/src/cpu/acl/acl_binary.cpp @@ -27,7 +27,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_binary_t::pd_t::init(engine_t *engine) { using namespace acl_utils; @@ -229,7 +229,7 @@ const acl_binary_t::pd_t *acl_binary_t::pd() const { return static_cast(primitive_t::pd().get()); } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_binary.hpp b/src/cpu/acl/acl_binary.hpp similarity index 95% rename from src/cpu/aarch64/acl_binary.hpp rename to src/cpu/acl/acl_binary.hpp index 41ecdded523..7040fe8aa42 100644 --- a/src/cpu/aarch64/acl_binary.hpp +++ b/src/cpu/acl/acl_binary.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_BINARY_HPP -#define CPU_AARCH64_ACL_BINARY_HPP +#ifndef CPU_ACL_BINARY_HPP +#define CPU_ACL_BINARY_HPP #include "acl_utils.hpp" #include "cpu/cpu_binary_pd.hpp" @@ -28,7 +28,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_binary_conf_t { arm_compute::TensorInfo src0_info; @@ -73,7 +73,7 @@ struct acl_binary_t : public primitive_t { }; // acl_binary_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/acl/acl_convolution_utils.cpp similarity index 84% rename from src/cpu/aarch64/acl_convolution_utils.cpp rename to src/cpu/acl/acl_convolution_utils.cpp index 15437746069..fcb3c36e394 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/acl/acl_convolution_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. *******************************************************************************/ -#include "acl_convolution_utils.hpp" +#include "cpu/acl/acl_convolution_utils.hpp" #include "common/convolution_pd.hpp" #include "common/utils.hpp" #include "oneapi/dnnl/dnnl.h" @@ -22,7 +22,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace acl_convolution_utils { @@ -283,9 +283,63 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, return status::success; } + +status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { + + // Under these conditions, fallback to faster GEMM-based convolution + // unless the user explicitly specifies Winograd algorithm + // clang-format off + + // Heuristic only for servers + if (dnnl_get_max_threads() > 28 && cd.alg_kind == alg_kind::convolution_auto) { + return status::unimplemented; + } + // Heuristic for other devices + if (one_of(true, src_md.dims[1] < 64, // ic + dst_md.dims[1] < 64) // oc + && cd.alg_kind == alg_kind::convolution_auto) { + return status::unimplemented; + } + + // clang-format on + + // General Compute Library checks, memory tags are also set there + acp.alg_winograd = true; + CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); + + const bool shape_ok + // only unit strides allowed + = (acp.padstride_info.stride() == std::pair {1, 1}) + // Note: Compute Library supports arbitrary padding for wino kernels + // but we only allow small padding to be consistent with oneDNN + && (acp.padstride_info.pad().first <= 1) // padding left/right + && (acp.padstride_info.pad().second <= 1) // padding top/bottom + // only non-dilated convolutions allowed + && (acp.dilation_info == arm_compute::Size2D(1, 1)); + + ACL_CHECK_SUPPORT(!shape_ok, "shape not supported by winograd kernels"); + + // clang-format off + // Validate convolution manually to check for return status + ACL_CHECK_VALID(arm_compute::NEWinogradConvolutionLayer::validate( + &acp.src_tensor_info, + &acp.wei_tensor_info, + acp.with_bias ? &acp.bia_tensor_info : nullptr, + &acp.dst_tensor_info, + acp.padstride_info, + acp.act_info, + true)); // enable_fast_math flag in ACL Winograd + // clang-format on + + return status::success; +} + } // namespace acl_convolution_utils -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/acl/acl_convolution_utils.hpp b/src/cpu/acl/acl_convolution_utils.hpp new file mode 100644 index 00000000000..fb616e71a7c --- /dev/null +++ b/src/cpu/acl/acl_convolution_utils.hpp @@ -0,0 +1,239 @@ +/******************************************************************************* +* Copyright 2020-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ACL_CONVOLUTION_UTILS_HPP +#define CPU_ACL_CONVOLUTION_UTILS_HPP + +#include +#include "acl_post_ops.hpp" +#include "acl_utils.hpp" +#include "arm_compute/runtime/experimental/operators/CpuDepthwiseConv2d.h" +#include "cpu/cpu_convolution_pd.hpp" +#include +namespace dnnl { +namespace impl { +namespace cpu { +namespace acl { + +template +struct acl_obj_t { + arm_compute::Tensor src_tensor; + arm_compute::Tensor wei_tensor; + arm_compute::Tensor bia_tensor; + arm_compute::Tensor dst_tensor; + ConvOp conv; + arm_compute::experimental::MemoryRequirements aux_mem_req; +}; + +struct acl_conv_conf_t { + bool with_bias; + bool fast_math; + // If this is true, the result of the convolution goes into a temporarily + // allocated ACL tensor to be accumulated into the oneDNN dst during postops + bool use_dst_acc_for_sum; + // Tells that the selected algorithm is Winograd. This is needed because the + // algorithm can be set to algorithm::convolution_auto and later on we need to + // skip fixed-format protocol as ACL Winograd does not support it. + bool alg_winograd; + arm_compute::TensorInfo src_tensor_info; + arm_compute::TensorInfo wei_tensor_info; + arm_compute::TensorInfo bia_tensor_info; + arm_compute::TensorInfo dst_tensor_info; + + arm_compute::PadStrideInfo padstride_info; + arm_compute::Size2D dilation_info; + // Additional information about the weights not included in wei_tensor_info + arm_compute::WeightsInfo weights_info; + // Note: this will default to not enabled, and will do nothing + arm_compute::ActivationLayerInfo act_info; +}; + +namespace acl_convolution_utils { + +status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr); + +status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr); + +} // namespace acl_convolution_utils + +// Keys are anonymous with local linkage. So deduce the type automagically. +using conv_key_t = decltype(memory_tracking::names::key_gemm_tmp_buffer); + +template +status_t init_scratchpad(op_t &conv, memory_tracking::registrar_t &scratchpad, + const std::map &conv_keys, engine_t *engine, + post_ops_t &post_ops, dnnl::impl::post_ops_t &attr_post_ops, + arm_compute::ActivationLayerInfo &act_info, bool &use_dst_acc_for_sum, + const dnnl::impl::memory_desc_t &dst_md) { + + // Book temp mem. + const auto aux_mem_req = conv.workspace(); + for (const auto &key : conv_keys) { + const auto id = key.first; + if (aux_mem_req[id].size > 0) { + scratchpad.book(key.second, aux_mem_req[id].size, 1, + aux_mem_req[id].alignment, aux_mem_req[id].alignment); + } + } + + CHECK(post_ops.init(engine, attr_post_ops, dst_md, act_info)); + use_dst_acc_for_sum = post_ops.has_sum(); + + if (use_dst_acc_for_sum) { + const memory_desc_wrapper dst_d(&dst_md); + scratchpad.book(memory_tracking::names::key_generic_acc, dst_d.nelems(), + dst_d.data_type_size()); + } + + return status::success; +} + +template +status_t execute_forward_conv_acl(const exec_ctx_t &ctx, + conv_obj_t *acl_conv_obj, const conv_pd_t *pd, + const std::map &conv_keys) { + + auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); + auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); + + // import_memory() and free() methods do not allocate/free any additional + // memory, only acquire/release pointers. + arm_compute::Tensor src_tensor; + arm_compute::Tensor wei_tensor; + arm_compute::Tensor bia_tensor = nullptr; + arm_compute::Tensor dst_tensor; + + auto const acp = pd->acp_; + src_tensor.allocator()->init(acp.src_tensor_info); + wei_tensor.allocator()->init(acp.wei_tensor_info); + dst_tensor.allocator()->init(acp.dst_tensor_info); + + src_tensor.allocator()->import_memory(const_cast(src_base)); + wei_tensor.allocator()->import_memory(const_cast(wei_base)); + + const auto scratchpad = ctx.get_scratchpad_grantor(); + + // If we have an unfused sum post op, put the result in a scratchpad tensor. + // Result will be summed to the dst during acl_post_ops.execute + auto dst_base = acp.use_dst_acc_for_sum + ? scratchpad.get(memory_tracking::names::key_generic_acc) + : CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); + dst_tensor.allocator()->import_memory(dst_base); + + if (acp.with_bias) { + auto bia_base = CTX_IN_MEM(const bia_data_t *, DNNL_ARG_BIAS); + bia_tensor.allocator()->init(acp.bia_tensor_info); + bia_tensor.allocator()->import_memory( + const_cast(bia_base)); + } + + // Constness of the weight tensor matters for depthwise conv in ACL. + // Otherwise, it will package the weights more often than needed, as + // it will expect the weights to change within the duration of the run + // func. + arm_compute::ITensorPack pack; + pack.add_tensor(arm_compute::TensorType::ACL_SRC_0, &src_tensor); + pack.add_const_tensor(arm_compute::TensorType::ACL_SRC_1, &wei_tensor); + pack.add_const_tensor(arm_compute::TensorType::ACL_SRC_2, &bia_tensor); + pack.add_tensor(arm_compute::TensorType::ACL_DST, &dst_tensor); + + // Get temp workspaces. + const auto aux_mem = acl_conv_obj->aux_mem_req; + + // Hold onto tmp tensors while we need pack. + std::vector tmp_tensors(aux_mem.size()); + for (const auto &key : conv_keys) { + const auto id = key.first; + if (aux_mem[id].size > 0) { + const auto info = arm_compute::TensorInfo( + arm_compute::TensorShape(aux_mem[id].size), 1, + arm_compute::DataType::U8); + auto buffer = scratchpad.get(key.second); + tmp_tensors[id].allocator()->init(info, aux_mem[id].alignment); + tmp_tensors[id].allocator()->import_memory(buffer); + pack.add_tensor(aux_mem[id].slot, &tmp_tensors[id]); + } + } + + acl_conv_obj->conv.run(pack); + + void *dst = dst_tensor.buffer(); + pd->post_ops.execute(ctx, dst); + + return status::success; +} + +template +status_t execute_forward_conv_acl( + const exec_ctx_t &ctx, conv_obj_t &acl_conv_obj, const conv_pd_t *pd) { + bool with_bias = pd->acp_.with_bias; + bool use_dst_acc_for_sum = pd->acp_.use_dst_acc_for_sum; + + auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); + auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); + + // import_memory() and free() methods do not allocate/free any additional + // memory, only acquire/release pointers. + acl_conv_obj.src_tensor.allocator()->import_memory( + const_cast(src_base)); + acl_conv_obj.wei_tensor.allocator()->import_memory( + const_cast(wei_base)); + + const auto scratchpad = ctx.get_scratchpad_grantor(); + + // If we have an unfused sum post op, put the result in a scratchpad tensor. + // Result will be summed to the dst during acl_post_ops.execute + auto dst_base = use_dst_acc_for_sum + ? scratchpad.get(memory_tracking::names::key_generic_acc) + : CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); + acl_conv_obj.dst_tensor.allocator()->import_memory(dst_base); + + if (with_bias) { + auto bia_base = CTX_IN_MEM(const bia_data_t *, DNNL_ARG_BIAS); + acl_conv_obj.bia_tensor.allocator()->import_memory( + const_cast(bia_base)); + } + + acl_conv_obj.conv.run(); + + acl_conv_obj.src_tensor.allocator()->free(); + acl_conv_obj.wei_tensor.allocator()->free(); + if (with_bias) { acl_conv_obj.bia_tensor.allocator()->free(); } + + void *dst = acl_conv_obj.dst_tensor.buffer(); + pd->post_ops.execute(ctx, dst); + + acl_conv_obj.dst_tensor.allocator()->free(); + + return status::success; +} + +} // namespace acl +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_ACL_CONVOLUTION_UTILS_HPP diff --git a/src/cpu/aarch64/acl_deconvolution.cpp b/src/cpu/acl/acl_deconvolution.cpp similarity index 96% rename from src/cpu/aarch64/acl_deconvolution.cpp rename to src/cpu/acl/acl_deconvolution.cpp index cdeca9cb8bb..0eef20dbabc 100644 --- a/src/cpu/aarch64/acl_deconvolution.cpp +++ b/src/cpu/acl/acl_deconvolution.cpp @@ -14,12 +14,12 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_deconvolution.hpp" +#include "cpu/acl/acl_deconvolution.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_deconvolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { @@ -64,7 +64,7 @@ status_t acl_deconvolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_deconvolution.hpp b/src/cpu/acl/acl_deconvolution.hpp similarity index 92% rename from src/cpu/aarch64/acl_deconvolution.hpp rename to src/cpu/acl/acl_deconvolution.hpp index 97413c7ba65..18c8c1f1a67 100644 --- a/src/cpu/aarch64/acl_deconvolution.hpp +++ b/src/cpu/acl/acl_deconvolution.hpp @@ -14,16 +14,16 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_DECONVOLUTION_HPP -#define CPU_AARCH64_ACL_DECONVOLUTION_HPP +#ifndef CPU_ACL_DECONVOLUTION_HPP +#define CPU_ACL_DECONVOLUTION_HPP -#include "cpu/aarch64/acl_post_ops.hpp" +#include "cpu/acl/acl_post_ops.hpp" #include "cpu/cpu_deconvolution_pd.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_deconv_obj_t { arm_compute::NEDeconvolutionLayer deconv; @@ -82,10 +82,6 @@ struct acl_deconv_resource_t : public resource_t { struct acl_deconvolution_fwd_t : public primitive_t { struct pd_t : public cpu_deconvolution_fwd_pd_t { using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; - pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) - , acl_pd_conf() {} DECLARE_COMMON_PD_T( "acl", acl_deconvolution_fwd_t, USE_GLOBAL_SCRATCHPAD); @@ -193,8 +189,9 @@ struct acl_deconvolution_fwd_t : public primitive_t { } // Data layout - const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC - : arm_compute::DataLayout::NCHW; + const arm_compute::DataLayout acl_layout = is_nspc + ? arm_compute::DataLayout::NHWC + : arm_compute::DataLayout::NCHW; acl_pd_conf.src_info = arm_compute::TensorInfo(is_nspc ? arm_compute::TensorShape(ic, iw, ih, mb) @@ -243,18 +240,15 @@ struct acl_deconvolution_fwd_t : public primitive_t { // padding is set for convolution. Otherwise, describe deconvolution as convolution of // upsampling input with stride = 1 and pad = 0. arm_compute::ConvolutionMethod conv_method; - arm_compute::TensorInfo *conv_src_info; + arm_compute::TensorInfo conv_src_info( + acl_pd_conf.src_info.clone()->set_is_resizable(true)); unsigned int pad_left = 0; unsigned int pad_right = 0; unsigned int pad_top = 0; unsigned int pad_bottom = 0; if (sh != 1 || sw != 1) { - arm_compute::TensorInfo scale_out_info( - acl_pd_conf.src_info.clone() - ->set_is_resizable(true) - .reset_padding() - .set_tensor_shape(scale_out_shape)); - conv_src_info = &scale_out_info; + conv_src_info.reset_padding(); + conv_src_info.set_tensor_shape(scale_out_shape); } else { // compute correct padding here pad_left = pr > pl ? pr - pl : 0; @@ -269,15 +263,13 @@ struct acl_deconvolution_fwd_t : public primitive_t { pad_right += deconv_pad_x / 2; pad_top += deconv_pad_y / 2; pad_bottom += deconv_pad_y / 2; - - conv_src_info = &acl_pd_conf.src_info; } const arm_compute::PadStrideInfo conv_info(1, 1, pad_left, pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::CEIL); conv_method = arm_compute::NEConvolutionLayer::get_convolution_method( - conv_src_info, &acl_pd_conf.wei_info, + &conv_src_info, &acl_pd_conf.wei_info, &acl_pd_conf.dst_info, conv_info, arm_compute::WeightsInfo(), arm_compute::Size2D(1U, 1U), @@ -302,7 +294,7 @@ struct acl_deconvolution_fwd_t : public primitive_t { return status::success; } - acl_deconv_conf_t acl_pd_conf; + acl_deconv_conf_t acl_pd_conf = utils::zero(); acl_post_ops_t post_ops; private: @@ -338,7 +330,7 @@ struct acl_deconvolution_fwd_t : public primitive_t { const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } }; // acl_deconvolution_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_depthwise_convolution.cpp b/src/cpu/acl/acl_depthwise_convolution.cpp similarity index 96% rename from src/cpu/aarch64/acl_depthwise_convolution.cpp rename to src/cpu/acl/acl_depthwise_convolution.cpp index 4752cfd5852..15edd205f76 100644 --- a/src/cpu/aarch64/acl_depthwise_convolution.cpp +++ b/src/cpu/acl/acl_depthwise_convolution.cpp @@ -14,15 +14,15 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_depthwise_convolution.hpp" +#include "cpu/acl/acl_depthwise_convolution.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace { -using data_t = prec_traits::type; +using data_t = prec_traits_t::type; // Keys are anonymous. So deduce the type automagically. using conv_key_t = decltype(memory_tracking::names::key_gemm_tmp_buffer); @@ -87,7 +87,7 @@ status_t acl_depthwise_convolution_fwd_t::init(engine_t *engine) { acl_obj_->aux_mem_req = acl_obj_->conv.workspace(); return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_depthwise_convolution.hpp b/src/cpu/acl/acl_depthwise_convolution.hpp similarity index 81% rename from src/cpu/aarch64/acl_depthwise_convolution.hpp rename to src/cpu/acl/acl_depthwise_convolution.hpp index 3e3f0e1ccbc..61c39332a67 100644 --- a/src/cpu/aarch64/acl_depthwise_convolution.hpp +++ b/src/cpu/acl/acl_depthwise_convolution.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP -#define CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP +#ifndef CPU_ACL_DEPTHWISE_CONVOLUTION_HPP +#define CPU_ACL_DEPTHWISE_CONVOLUTION_HPP #include "acl_convolution_utils.hpp" #include "arm_compute/runtime/experimental/operators/CpuDepthwiseConv2d.h" @@ -24,23 +24,21 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_depthwise_convolution_fwd_t : public primitive_t { using Op = arm_compute::experimental::op::CpuDepthwiseConv2d; struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), acp_() {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T("depthwise_convolution:acl", acl_depthwise_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); status_t init(engine_t *engine); - acl_conv_conf_t acp_; + acl_conv_conf_t acp_ = utils::zero(); acl_post_ops_t post_ops; }; @@ -59,9 +57,9 @@ struct acl_depthwise_convolution_fwd_t : public primitive_t { std::unique_ptr> acl_obj_; }; -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP +#endif // CPU_ACL_DEPTHWISE_CONVOLUTION_HPP diff --git a/src/cpu/aarch64/acl_eltwise.cpp b/src/cpu/acl/acl_eltwise.cpp similarity index 98% rename from src/cpu/aarch64/acl_eltwise.cpp rename to src/cpu/acl/acl_eltwise.cpp index e7789825f42..98f539dd8f9 100644 --- a/src/cpu/aarch64/acl_eltwise.cpp +++ b/src/cpu/acl/acl_eltwise.cpp @@ -19,7 +19,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_eltwise_fwd_t::execute(const exec_ctx_t &ctx) const { return execute_forward(ctx); @@ -108,7 +108,7 @@ status_t acl_eltwise_fwd_t::pd_t::init(engine_t *engine) { return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_eltwise.hpp b/src/cpu/acl/acl_eltwise.hpp similarity index 93% rename from src/cpu/aarch64/acl_eltwise.hpp rename to src/cpu/acl/acl_eltwise.hpp index bd64eac1936..45869414bec 100644 --- a/src/cpu/aarch64/acl_eltwise.hpp +++ b/src/cpu/acl/acl_eltwise.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_ELTWISE_HPP -#define CPU_AARCH64_ACL_ELTWISE_HPP +#ifndef CPU_ACL_ELTWISE_HPP +#define CPU_ACL_ELTWISE_HPP #include #include "cpu/cpu_eltwise_pd.hpp" @@ -27,7 +27,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_eltwise_conf_t { arm_compute::ActivationLayerInfo act_info; @@ -71,9 +71,9 @@ struct acl_eltwise_fwd_t : public primitive_t { friend struct acl_post_ops_t; }; // acl_eltwise_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_ELTWISE_HPP +#endif // CPU_ACL_ELTWISE_HPP diff --git a/src/cpu/aarch64/acl_gemm_convolution.cpp b/src/cpu/acl/acl_gemm_convolution.cpp similarity index 98% rename from src/cpu/aarch64/acl_gemm_convolution.cpp rename to src/cpu/acl/acl_gemm_convolution.cpp index d3a663d8c63..922ea42e396 100644 --- a/src/cpu/aarch64/acl_gemm_convolution.cpp +++ b/src/cpu/acl/acl_gemm_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace { // Keys are anonymous. So deduce the type automagically. @@ -112,7 +112,7 @@ template struct acl_gemm_convolution_fwd_t; template struct acl_gemm_convolution_fwd_t; template struct acl_gemm_convolution_fwd_t; -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_gemm_convolution.hpp b/src/cpu/acl/acl_gemm_convolution.hpp similarity index 75% rename from src/cpu/aarch64/acl_gemm_convolution.hpp rename to src/cpu/acl/acl_gemm_convolution.hpp index 23fe03f2d85..14d0050c7ab 100644 --- a/src/cpu/aarch64/acl_gemm_convolution.hpp +++ b/src/cpu/acl/acl_gemm_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_GEMM_CONVOLUTION_HPP -#define CPU_AARCH64_ACL_GEMM_CONVOLUTION_HPP +#ifndef CPU_ACL_GEMM_CONVOLUTION_HPP +#define CPU_ACL_GEMM_CONVOLUTION_HPP #include "common/memory_tracking.hpp" #include "cpu/cpu_convolution_pd.hpp" @@ -27,7 +27,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { template @@ -36,16 +36,14 @@ struct acl_gemm_convolution_fwd_t : public primitive_t { using Op = arm_compute::experimental::op::CpuGemmConv2d; struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), acp_() {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T( "gemm:acl", acl_gemm_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); status_t init(engine_t *engine); - acl_conv_conf_t acp_; + acl_conv_conf_t acp_ = utils::zero(); acl_post_ops_t post_ops; }; @@ -54,10 +52,10 @@ struct acl_gemm_convolution_fwd_t : public primitive_t { status_t init(engine_t *engine) override; - using src_data_t = typename prec_traits::type; - using wei_data_t = typename prec_traits::type; - using dst_data_t = typename prec_traits::type; - using bia_data_t = typename prec_traits::type; + using src_data_t = typename prec_traits_t::type; + using wei_data_t = typename prec_traits_t::type; + using dst_data_t = typename prec_traits_t::type; + using bia_data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); @@ -69,7 +67,7 @@ struct acl_gemm_convolution_fwd_t : public primitive_t { std::unique_ptr> acl_obj_; }; // acl_gemm_convolution_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_indirect_gemm_convolution.cpp b/src/cpu/acl/acl_indirect_gemm_convolution.cpp similarity index 94% rename from src/cpu/aarch64/acl_indirect_gemm_convolution.cpp rename to src/cpu/acl/acl_indirect_gemm_convolution.cpp index 19ee68062da..d543b2b6b19 100644 --- a/src/cpu/aarch64/acl_indirect_gemm_convolution.cpp +++ b/src/cpu/acl/acl_indirect_gemm_convolution.cpp @@ -22,10 +22,10 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace { -using data_t = typename prec_traits::type; +using data_t = typename prec_traits_t::type; // Keys are anonymous. So deduce the type automagically. using conv_key_t = decltype(memory_tracking::names::key_gemm_tmp_buffer); @@ -96,11 +96,13 @@ status_t acl_indirect_gemm_convolution_fwd_t::pd_t::init(engine_t *engine) { const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) && attr()->has_default_values(smask_t::post_ops, f16); + const bool is_bf16_ok = expect_data_types(bf16, bf16, bf16, bf16, undef) + && attr_.post_ops_.len() == 0; const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) && attr()->has_default_values( smask_t::post_ops | smask_t::fpmath_mode, f32); bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) - && utils::one_of(true, is_fp16_ok, is_fp32_ok) + && utils::one_of(true, is_fp16_ok, is_bf16_ok, is_fp32_ok) && !has_zero_dim_memory(); if (!ok) return status::unimplemented; @@ -120,7 +122,7 @@ status_t acl_indirect_gemm_convolution_fwd_t::pd_t::init(engine_t *engine) { dst_md_); } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp b/src/cpu/acl/acl_indirect_gemm_convolution.hpp similarity index 81% rename from src/cpu/aarch64/acl_indirect_gemm_convolution.hpp rename to src/cpu/acl/acl_indirect_gemm_convolution.hpp index d5b914e5fd7..7286cc3ced6 100644 --- a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp +++ b/src/cpu/acl/acl_indirect_gemm_convolution.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_INDIRECT_GEMM_CONVOLUTION_HPP -#define CPU_AARCH64_ACL_INDIRECT_GEMM_CONVOLUTION_HPP +#ifndef CPU_ACL_INDIRECT_GEMM_CONVOLUTION_HPP +#define CPU_ACL_INDIRECT_GEMM_CONVOLUTION_HPP #include "cpu/cpu_convolution_pd.hpp" @@ -25,24 +25,21 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_indirect_gemm_convolution_fwd_t : public primitive_t { using Op = arm_compute::experimental::op::CpuGemmDirectConv2d; struct pd_t : public cpu_convolution_fwd_pd_t { - - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), acp_() {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T("indirect_gemm:acl", acl_indirect_gemm_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); status_t init(engine_t *engine); - acl_conv_conf_t acp_; + acl_conv_conf_t acp_ = utils::zero(); acl_post_ops_t post_ops; private: @@ -64,9 +61,9 @@ struct acl_indirect_gemm_convolution_fwd_t : public primitive_t { std::unique_ptr> acl_obj_; }; -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_INDIRECT_GEMM_CONVOLUTION_HPP +#endif // CPU_ACL_INDIRECT_GEMM_CONVOLUTION_HPP diff --git a/src/cpu/aarch64/acl_inner_product.cpp b/src/cpu/acl/acl_inner_product.cpp similarity index 96% rename from src/cpu/aarch64/acl_inner_product.cpp rename to src/cpu/acl/acl_inner_product.cpp index 34de43ae638..9dcceaa2d30 100644 --- a/src/cpu/aarch64/acl_inner_product.cpp +++ b/src/cpu/acl/acl_inner_product.cpp @@ -14,12 +14,12 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_inner_product.hpp" +#include "cpu/acl/acl_inner_product.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_inner_product_fwd_t::execute_forward(const exec_ctx_t &ctx) const { @@ -70,7 +70,7 @@ status_t acl_inner_product_fwd_t::execute_forward(const exec_ctx_t &ctx) const { return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/acl/acl_inner_product.hpp similarity index 93% rename from src/cpu/aarch64/acl_inner_product.hpp rename to src/cpu/acl/acl_inner_product.hpp index 336168ba626..4dd84e4fa8b 100644 --- a/src/cpu/aarch64/acl_inner_product.hpp +++ b/src/cpu/acl/acl_inner_product.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Arm Ltd. and affiliates +* Copyright 2021-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,18 +14,18 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_INNER_PRODUCT_HPP -#define CPU_AARCH64_ACL_INNER_PRODUCT_HPP +#ifndef CPU_ACL_INNER_PRODUCT_HPP +#define CPU_ACL_INNER_PRODUCT_HPP -#include "cpu/aarch64/acl_utils.hpp" +#include "cpu/acl/acl_utils.hpp" #include "cpu/cpu_inner_product_pd.hpp" -#include "cpu/aarch64/acl_post_ops.hpp" +#include "cpu/acl/acl_post_ops.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_ip_obj_t { arm_compute::NEFullyConnectedLayer fc; @@ -85,10 +85,6 @@ struct acl_inner_product_fwd_t : public primitive_t { struct pd_t : public cpu_inner_product_fwd_pd_t { using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; - pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_inner_product_fwd_pd_t(adesc, attr, hint_fwd_pd), aip() {} - DECLARE_COMMON_PD_T("acl", acl_inner_product_fwd_t); status_t init(engine_t *engine) { @@ -101,6 +97,9 @@ struct acl_inner_product_fwd_t : public primitive_t { const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) && attr()->has_default_values( smask_t::post_ops | smask_t::fpmath_mode, f32); + const bool is_bf16_ok + = expect_data_types(bf16, bf16, bf16, bf16, undef) + && attr()->has_default_values(smask_t::post_ops, bf16); const bool is_fp32_bf16_ok = expect_data_types(f32, bf16, f32, f32, undef) && attr()->has_default_values( @@ -109,8 +108,8 @@ struct acl_inner_product_fwd_t : public primitive_t { = utils::one_of(weights_format_kind_received, format_kind::any, format_kind::blocked); const bool ok = is_fwd() && !has_zero_dim_memory() - && utils::one_of( - true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok) + && utils::one_of(true, is_fp16_ok, is_fp32_ok, + is_fp32_bf16_ok, is_bf16_ok) && is_weights_md_format_ok && set_default_params(true) == status::success; @@ -128,7 +127,7 @@ struct acl_inner_product_fwd_t : public primitive_t { return status::success; } - acl_ip_conf_t aip; + acl_ip_conf_t aip = utils::zero(); acl_post_ops_t post_ops; @@ -257,8 +256,11 @@ struct acl_inner_product_fwd_t : public primitive_t { // Fallback int block_by = arm_compute::block_by(expected_weight_format); + bool is_bf16 = src_md()->data_type == data_type::bf16 + && weights_md()->data_type == data_type::bf16 + && dst_md()->data_type == data_type::bf16; if (is_4d && weights_md_.dims[inner_dim] % block_by != 0 - && aip.fc_info.enable_fast_math) { + && (aip.fc_info.enable_fast_math || is_bf16)) { aip.fc_info.enable_fast_math = false; aip.weights_info.set_weight_format( arm_compute::WeightFormat::ANY); @@ -331,9 +333,9 @@ struct acl_inner_product_fwd_t : public primitive_t { const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } }; // acl_inner_product_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_INNER_PRODUCT_HPP +#endif // CPU_ACL_INNER_PRODUCT_HPP diff --git a/src/cpu/aarch64/acl_layer_normalization.cpp b/src/cpu/acl/acl_layer_normalization.cpp similarity index 94% rename from src/cpu/aarch64/acl_layer_normalization.cpp rename to src/cpu/acl/acl_layer_normalization.cpp index 05bcb1766f1..11c4796d7d5 100644 --- a/src/cpu/aarch64/acl_layer_normalization.cpp +++ b/src/cpu/acl/acl_layer_normalization.cpp @@ -14,12 +14,12 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_layer_normalization.hpp" +#include "cpu/acl/acl_layer_normalization.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_layer_normalization_fwd_t::execute_forward( const exec_ctx_t &ctx) const { @@ -48,7 +48,7 @@ status_t acl_layer_normalization_fwd_t::execute_forward( return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_layer_normalization.hpp b/src/cpu/acl/acl_layer_normalization.hpp similarity index 92% rename from src/cpu/aarch64/acl_layer_normalization.hpp rename to src/cpu/acl/acl_layer_normalization.hpp index 80dd681a84b..9363511a521 100644 --- a/src/cpu/aarch64/acl_layer_normalization.hpp +++ b/src/cpu/acl/acl_layer_normalization.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Arm Ltd. and affiliates +* Copyright 2023-2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,16 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_LAYER_NORMALIZATION_HPP -#define CPU_AARCH64_ACL_LAYER_NORMALIZATION_HPP +#ifndef CPU_ACL_LAYER_NORMALIZATION_HPP +#define CPU_ACL_LAYER_NORMALIZATION_HPP -#include "cpu/aarch64/acl_utils.hpp" +#include "cpu/acl/acl_utils.hpp" #include "cpu/cpu_layer_normalization_pd.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_msdnorm_obj_t { arm_compute::NEMeanStdDevNormalizationLayer msdNorm; @@ -68,11 +68,6 @@ struct acl_layer_normalization_fwd_t : public primitive_t { struct pd_t : public cpu_layer_normalization_fwd_pd_t { using cpu_layer_normalization_fwd_pd_t:: cpu_layer_normalization_fwd_pd_t; - pd_t(const layer_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_layer_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) - , anp() {} DECLARE_COMMON_PD_T("acl", acl_layer_normalization_fwd_t); @@ -81,9 +76,10 @@ struct acl_layer_normalization_fwd_t : public primitive_t { // dir and flags ACL_CHECK_SUPPORT( !is_fwd(), "ACL lnorm supports forward propagation only"); - ACL_CHECK_SUPPORT(is_training() && !use_global_stats(), - "ACL only supports forward training with lnorm if stats " - "are provided (use global stats)"); + ACL_CHECK_SUPPORT( + is_training(), "ACL supports inference only for lnorm"); + ACL_CHECK_SUPPORT(use_global_stats(), + "ACL does not support global stats with lnorm"); ACL_CHECK_SUPPORT(use_scale() || use_shift(), "ACL does not support lnorm scale and shift"); @@ -219,7 +215,7 @@ struct acl_layer_normalization_fwd_t : public primitive_t { || X * C > acl_better_XC_per_thread * threads); } - acl_msdnorm_conf_t anp; + acl_msdnorm_conf_t anp = utils::zero(); }; // pd_t @@ -250,9 +246,9 @@ struct acl_layer_normalization_fwd_t : public primitive_t { const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } }; // acl_layer_normalization_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_LAYER_NORMALIZATION_HPP +#endif // CPU_ACL_LAYER_NORMALIZATION_HPP diff --git a/src/cpu/aarch64/acl_pooling.cpp b/src/cpu/acl/acl_pooling.cpp similarity index 96% rename from src/cpu/aarch64/acl_pooling.cpp rename to src/cpu/acl/acl_pooling.cpp index 1aac8c53a34..e3ecd290638 100644 --- a/src/cpu/aarch64/acl_pooling.cpp +++ b/src/cpu/acl/acl_pooling.cpp @@ -14,12 +14,12 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_pooling.hpp" +#include "cpu/acl/acl_pooling.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_pooling_fwd_t::execute_forward(const exec_ctx_t &ctx) const { // Lock here is needed because resource_mapper does not support @@ -52,7 +52,7 @@ status_t acl_pooling_fwd_t::execute_forward(const exec_ctx_t &ctx) const { return status; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_pooling.hpp b/src/cpu/acl/acl_pooling.hpp similarity index 96% rename from src/cpu/aarch64/acl_pooling.hpp rename to src/cpu/acl/acl_pooling.hpp index a397d69aa7d..a696dac8e69 100644 --- a/src/cpu/aarch64/acl_pooling.hpp +++ b/src/cpu/acl/acl_pooling.hpp @@ -14,16 +14,16 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_POOLING_HPP -#define CPU_AARCH64_ACL_POOLING_HPP +#ifndef CPU_ACL_POOLING_HPP +#define CPU_ACL_POOLING_HPP -#include "cpu/aarch64/acl_utils.hpp" +#include "cpu/acl/acl_utils.hpp" #include "cpu/cpu_pooling_pd.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_pooling_obj_t { arm_compute::NEPoolingLayer pool; @@ -77,9 +77,6 @@ struct acl_pooling_resource_t : public resource_t { struct acl_pooling_fwd_t : public primitive_t { struct pd_t : public cpu_pooling_fwd_pd_t { using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; - pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr, - const pooling_fwd_pd_t *hint_fwd_pd) - : cpu_pooling_fwd_pd_t(adesc, attr, hint_fwd_pd), app() {} DECLARE_COMMON_PD_T("acl", acl_pooling_fwd_t); @@ -265,7 +262,7 @@ struct acl_pooling_fwd_t : public primitive_t { return problem_size > cutoff * thread_count; } - acl_pooling_conf_t app; + acl_pooling_conf_t app = utils::zero(); }; acl_pooling_fwd_t(const pd_t *apd) : primitive_t(apd) {} @@ -295,9 +292,9 @@ struct acl_pooling_fwd_t : public primitive_t { const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } }; // acl_pooling_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_POOLING_HPP +#endif // CPU_ACL_POOLING_HPP diff --git a/src/cpu/aarch64/acl_post_ops.cpp b/src/cpu/acl/acl_post_ops.cpp similarity index 97% rename from src/cpu/aarch64/acl_post_ops.cpp rename to src/cpu/acl/acl_post_ops.cpp index dbb1bf2d53c..816d195a920 100644 --- a/src/cpu/aarch64/acl_post_ops.cpp +++ b/src/cpu/acl/acl_post_ops.cpp @@ -15,12 +15,12 @@ *******************************************************************************/ #include "common/float16.hpp" -#include "cpu/aarch64/acl_gemm_convolution.hpp" +#include "cpu/acl/acl_gemm_convolution.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_post_ops_t::execute( const exec_ctx_t &ctx, void *src, void *dst) const { @@ -97,7 +97,7 @@ status_t acl_post_ops_t::execute( return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_post_ops.hpp b/src/cpu/acl/acl_post_ops.hpp similarity index 94% rename from src/cpu/aarch64/acl_post_ops.hpp rename to src/cpu/acl/acl_post_ops.hpp index 5c80f413463..d5e470e4578 100644 --- a/src/cpu/aarch64/acl_post_ops.hpp +++ b/src/cpu/acl/acl_post_ops.hpp @@ -14,16 +14,16 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_POST_OPS_HPP -#define CPU_AARCH64_ACL_POST_OPS_HPP +#ifndef CPU_ACL_POST_OPS_HPP +#define CPU_ACL_POST_OPS_HPP -#include "cpu/aarch64/acl_binary.hpp" -#include "cpu/aarch64/acl_eltwise.hpp" +#include "cpu/acl/acl_binary.hpp" +#include "cpu/acl/acl_eltwise.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_post_ops_t { @@ -142,10 +142,8 @@ struct acl_post_ops_t { CHECK(base_post_ops.set_default_formats(&dst_md)); dst_data_type = dst_md.data_type; - // If the first entry is eltwise, we fuse it, except when the datatype - // is fp16 because in this case we want to execute the eltwise in fp32. - if (base_post_ops.len() >= 1 && base_post_ops.entry_[0].is_eltwise() - && dst_data_type != data_type::f16) { + // If the first entry is eltwise, we fuse it + if (base_post_ops.len() >= 1 && base_post_ops.entry_[0].is_eltwise()) { const auto &first_po = base_post_ops.entry_[0].eltwise; ACL_CHECK_SUPPORT(first_po.scale != 1.0f, @@ -178,7 +176,7 @@ struct acl_post_ops_t { std::vector> post_op_primitives; }; -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_prelu.cpp b/src/cpu/acl/acl_prelu.cpp similarity index 96% rename from src/cpu/aarch64/acl_prelu.cpp rename to src/cpu/acl/acl_prelu.cpp index e2aae9392c0..b118fe20811 100644 --- a/src/cpu/aarch64/acl_prelu.cpp +++ b/src/cpu/acl/acl_prelu.cpp @@ -14,12 +14,12 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_prelu.hpp" +#include "cpu/acl/acl_prelu.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { status_t acl_prelu_fwd_t::execute_forward(const exec_ctx_t &ctx) const { @@ -51,7 +51,7 @@ status_t acl_prelu_fwd_t::execute_forward(const exec_ctx_t &ctx) const { return status::success; } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_prelu.hpp b/src/cpu/acl/acl_prelu.hpp similarity index 96% rename from src/cpu/aarch64/acl_prelu.hpp rename to src/cpu/acl/acl_prelu.hpp index 8517d1bb3ee..a7b70402687 100644 --- a/src/cpu/aarch64/acl_prelu.hpp +++ b/src/cpu/acl/acl_prelu.hpp @@ -13,16 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_PRELU_HPP -#define CPU_AARCH64_ACL_PRELU_HPP +#ifndef CPU_ACL_PRELU_HPP +#define CPU_ACL_PRELU_HPP -#include "cpu/aarch64/acl_utils.hpp" +#include "cpu/acl/acl_utils.hpp" #include "cpu/cpu_prelu_pd.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { struct acl_prelu_obj_t { arm_compute::NEPReluLayer prelu; @@ -151,9 +151,9 @@ struct acl_prelu_fwd_t : public primitive_t { const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } }; // acl_prelu_fwd_t -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_PRELU_HPP +#endif // CPU_ACL_PRELU_HPP diff --git a/src/cpu/acl/acl_softmax.cpp b/src/cpu/acl/acl_softmax.cpp new file mode 100644 index 00000000000..9b8fea25759 --- /dev/null +++ b/src/cpu/acl/acl_softmax.cpp @@ -0,0 +1,173 @@ +/******************************************************************************* +* Copyright 2021-2024 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/acl/acl_softmax.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace acl { + +const acl_softmax_fwd_t::pd_t *acl_softmax_fwd_t::pd() const { + return static_cast(primitive_t::pd().get()); +} + +status_t acl_softmax_fwd_t::pd_t::init(engine_t *engine) { + + bool ok = is_fwd() + && set_default_formats() == status::success + // ACL only supports matching src/dst (this must come after + // set_default_formats() to handle format_kind::any) + && *src_md() == *dst_md() + && utils::one_of( + src_md()->data_type, data_type::f32, data_type::f16) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + // Get memory desc to find sizes and dims + const memory_desc_wrapper src_d(src_md()); + const data_type_t data_type = src_d.data_type(); + + // ACL only supports plain tensors, can be permuted but not blocked + if (!src_d.is_plain()) return status::unimplemented; + + // Guards against a 0-sized dimension + if (src_d.has_zero_dim()) return status::unimplemented; + + // No scaling + asp_.beta = 1; + + asp_.is_logsoftmax = is_logsoftmax(); + + // The strides give us the in memory inner size + dim_t inner_size_ = src_d.blocking_desc().strides[axis()]; + + dim_t axis_size_ = axis_size(); + + // The outer size is any left-over dimensions not inner or on the axis + dim_t outer_size_ = src_d.nelems() / (inner_size_ * axis_size_); + + // In this context, NHWC tells ACL that the logical and physical + // dimensions are the same + arm_compute::DataLayout acl_layout = arm_compute::DataLayout::NHWC; + + const arm_compute::DataType acl_data_t + = acl_utils::get_acl_data_t(data_type); + + const int threads = dnnl_get_max_threads(); + + // A rough empirical heuristic created by fitting a polynomial + // of the tensor sizes and thread count to the run time of the + // ref and ACL softmax. This variable is greater than zero when + // ref is faster, and less than zero when ACL is faster. We can + // interpret the constant term as the constant overhead + // associated with calling the external library and the negative + // coefficient on total_size as ACL being faster at processing + // each element + auto calculate_performance_diff = [&](double axis_coeff) { + double acl_ref_performance_diff = 1 + 0.005 * outer_size_ + + axis_coeff * axis_size_ + * std::ceil(double(outer_size_) / threads); + + if (threads > 1 || outer_size_ > 1) { + acl_ref_performance_diff + += 17; // Adds constant overhead for using threads within ACL + } + return acl_ref_performance_diff; + }; + + if (inner_size_ == 1) { + double acl_ref_performance_diff = calculate_performance_diff(-0.0027); + if (acl_ref_performance_diff > 0) return status::unimplemented; + + // If the inner size is 1, we can get rid of the dimension. + // This stops ACL doing a unnecessary permute + arm_compute::TensorShape acl_tensor_shape + = arm_compute::TensorShape(axis_size_, outer_size_); + asp_.axis = 0; + + asp_.src_info = arm_compute::TensorInfo( + acl_tensor_shape, 1, acl_data_t, acl_layout); + asp_.dst_info = arm_compute::TensorInfo( + acl_tensor_shape, 1, acl_data_t, acl_layout); + } else { + // A rough empirical heuristic, see comment above + // The only difference here is that ACL does a reorder, and so + // is considerably better + double acl_ref_performance_diff + = calculate_performance_diff(-0.01 * inner_size_); + if (acl_ref_performance_diff > 0) return status::unimplemented; + + // Irrespective of the input dimensions, we construct a tensor + // with dimensions such that softmax can be applied over the + // middle axis (1), with the correct stride and vector length. + arm_compute::TensorShape acl_tensor_shape = arm_compute::TensorShape( + inner_size_, axis_size_, outer_size_); + asp_.axis = 1; + + asp_.src_info = arm_compute::TensorInfo( + acl_tensor_shape, 1, acl_data_t, acl_layout); + asp_.dst_info = arm_compute::TensorInfo( + acl_tensor_shape, 1, acl_data_t, acl_layout); + } + + // Validate manually to check for return status + ACL_CHECK_VALID(arm_compute::experimental::op::CpuSoftmax::validate( + &asp_.src_info, &asp_.dst_info, asp_.beta, asp_.axis)); + + return status::success; +} + +status_t acl_softmax_fwd_t::init(engine_t *engine) { + auto asp = pd()->asp_; + + auto op = std::make_unique(); + + softmax_op_ = std::move(op); + // Configure softmax operation, mem allocation happens. + softmax_op_->configure(&asp.src_info, &asp.dst_info, asp.beta, asp.axis, + asp.is_logsoftmax); + + return status::success; +} + +status_t acl_softmax_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); + auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); + + auto asp = pd()->asp_; + + arm_compute::Tensor src_tensor; + arm_compute::Tensor dst_tensor; + + src_tensor.allocator()->init(asp.src_info); + src_tensor.allocator()->import_memory(const_cast(src)); + dst_tensor.allocator()->init(asp.dst_info); + dst_tensor.allocator()->import_memory(dst); + + arm_compute::ITensorPack run_pack { + {arm_compute::TensorType::ACL_SRC_0, &src_tensor}, + {arm_compute::TensorType::ACL_DST, &dst_tensor}}; + + softmax_op_->run(run_pack); + + return status::success; +} + +} // namespace acl +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/acl/acl_softmax.hpp b/src/cpu/acl/acl_softmax.hpp new file mode 100644 index 00000000000..470eea9a1a3 --- /dev/null +++ b/src/cpu/acl/acl_softmax.hpp @@ -0,0 +1,71 @@ +/******************************************************************************* +* Copyright 2021-2024 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ACL_SOFTMAX_HPP +#define CPU_ACL_SOFTMAX_HPP + +#include "cpu/cpu_softmax_pd.hpp" + +#include "cpu/acl/acl_utils.hpp" + +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/IOperator.h" +#include "arm_compute/runtime/experimental/operators/CpuSoftmax.h" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace acl { + +struct acl_softmax_conf_t { + arm_compute::TensorInfo src_info; + arm_compute::TensorInfo dst_info; + float beta; + int32_t axis; + bool is_logsoftmax; +}; + +struct acl_softmax_fwd_t : public primitive_t { + struct pd_t : public cpu_softmax_fwd_pd_t { + using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; + + DECLARE_COMMON_PD_T("acl", acl_softmax_fwd_t); + status_t init(engine_t *engine); + + acl_softmax_conf_t asp_; + }; // pd_t + + // constructor + acl_softmax_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + status_t execute(const exec_ctx_t &ctx) const override { + return execute_forward(ctx); + } + +private: + const pd_t *pd() const; + + status_t init(engine_t *engine) override; + status_t execute_forward(const exec_ctx_t &ctx) const; + std::unique_ptr softmax_op_; +}; // acl_softmax_fwd_t + +} // namespace acl +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/acl/acl_thread.cpp b/src/cpu/acl/acl_thread.cpp new file mode 100644 index 00000000000..5ab2e428605 --- /dev/null +++ b/src/cpu/acl/acl_thread.cpp @@ -0,0 +1,125 @@ +/******************************************************************************* +* Copyright 2022-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/acl/acl_thread.hpp" +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL +#include "cpu/acl/acl_threadpool_scheduler.hpp" +#endif +#include "cpu/acl/acl_benchmark_scheduler.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace acl { + +namespace acl_thread_utils { + +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP +void acl_thread_bind() { + static std::once_flag flag_once; + // The threads in Compute Library are bound for the cores 0..max_threads-1 + // dnnl_get_max_threads() returns OMP_NUM_THREADS + const int max_threads = dnnl_get_max_threads(); + // arm_compute::Scheduler does not support concurrent access thus a + // workaround here restricts it to only one call + std::call_once(flag_once, [&]() { + arm_compute::Scheduler::get().set_num_threads(max_threads); + }); +} +// Swap BenchmarkScheduler for default ACL scheduler builds (i.e. CPPScheduler, OMPScheduler) +void acl_set_benchmark_scheduler_default() { + static std::once_flag flag_once; + arm_compute::IScheduler *_real_scheduler = &arm_compute::Scheduler::get(); + std::shared_ptr benchmark_scheduler + = std::make_unique(*_real_scheduler); + // set Benchmark scheduler in ACL + std::call_once(flag_once, [&]() { + arm_compute::Scheduler::set( + std::static_pointer_cast( + benchmark_scheduler)); + }); +} +#endif + +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL + +void acl_set_tp_scheduler() { + static thread_local std::once_flag flag_once; + // Create threadpool scheduler + std::call_once(flag_once, [&]() { + // Create threadpool scheduler + std::shared_ptr threadpool_scheduler + = std::make_unique(); + arm_compute::Scheduler::set(threadpool_scheduler); + }); +} + +void acl_set_threadpool_num_threads() { + using namespace dnnl::impl::threadpool_utils; + static thread_local std::once_flag flag_once; + threadpool_interop::threadpool_iface *tp = get_active_threadpool(); + // Check active threadpool + bool is_main = get_active_threadpool() == tp; + if (is_main) { + // Set num threads based on threadpool size + const int num_threads = (tp) ? dnnl_get_max_threads() : 1; + std::call_once(flag_once, [&]() { + arm_compute::Scheduler::get().set_num_threads(num_threads); + }); + } +} +// Swap BenchmarkScheduler for custom scheduler builds (i.e. ThreadPoolScheduler) +void acl_set_tp_benchmark_scheduler() { + static thread_local std::once_flag flag_once; + std::call_once(flag_once, [&]() { + // Create threadpool scheduler + std::unique_ptr threadpool_scheduler + = std::make_unique(); + arm_compute::IScheduler *_real_scheduler = nullptr; + _real_scheduler = threadpool_scheduler.release(); + + // Create benchmark scheduler and set TP as real scheduler + std::shared_ptr benchmark_scheduler + = std::make_unique(*_real_scheduler); + + arm_compute::Scheduler::set(benchmark_scheduler); + }); +} +#endif + +void set_acl_threading() { +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP + acl_thread_bind(); + if (get_verbose(verbose_t::profile_externals)) { + acl_set_benchmark_scheduler_default(); + } +#endif +#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL + if (get_verbose(verbose_t::profile_externals)) { + acl_set_tp_benchmark_scheduler(); + } else { + acl_set_tp_scheduler(); + } + +#endif +} + +} // namespace acl_thread_utils + +} // namespace acl +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/acl_thread.hpp b/src/cpu/acl/acl_thread.hpp similarity index 92% rename from src/cpu/aarch64/acl_thread.hpp rename to src/cpu/acl/acl_thread.hpp index f073376e63a..26b65564d79 100644 --- a/src/cpu/aarch64/acl_thread.hpp +++ b/src/cpu/acl/acl_thread.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_THREAD_HPP -#define CPU_AARCH64_ACL_THREAD_HPP +#ifndef CPU_ACL_THREAD_HPP +#define CPU_ACL_THREAD_HPP #include "common/dnnl_thread.hpp" #include "common/verbose.hpp" @@ -25,7 +25,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace acl_thread_utils { @@ -49,9 +49,9 @@ void acl_set_tp_benchmark_scheduler(); void set_acl_threading(); } // namespace acl_thread_utils -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_THREAD_HPP +#endif // CPU_ACL_THREAD_HPP diff --git a/src/cpu/aarch64/acl_threadpool_scheduler.cpp b/src/cpu/acl/acl_threadpool_scheduler.cpp similarity index 84% rename from src/cpu/aarch64/acl_threadpool_scheduler.cpp rename to src/cpu/acl/acl_threadpool_scheduler.cpp index 30910398d9c..ae559c5ead9 100644 --- a/src/cpu/aarch64/acl_threadpool_scheduler.cpp +++ b/src/cpu/acl/acl_threadpool_scheduler.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2024 Arm Ltd. and affiliates +* Copyright 2022-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,33 +14,26 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_threadpool_scheduler.hpp" +#include "cpu/acl/acl_threadpool_scheduler.hpp" #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL -#include "cpu/aarch64/acl_thread.hpp" - #include "common/counting_barrier.hpp" #include "common/dnnl_thread.hpp" +#include "cpu/acl/acl_thread.hpp" #include "arm_compute/core/CPP/ICPPKernel.h" #include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/Utils.h" #include "arm_compute/runtime/IScheduler.h" -// BARRIER #include #include -#include #include -#include -#include namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { using namespace arm_compute; @@ -51,7 +44,7 @@ class ThreadFeeder { /// Function to check the next element in the range if there is one. bool get_next(unsigned int &next) { - next = atomic_fetch_add_explicit( + next = std::atomic_fetch_add_explicit( &_atomic_counter, 1u, std::memory_order_relaxed); return next < _end; } @@ -70,11 +63,8 @@ void process_workloads(std::vector &workloads, } while (feeder.get_next(workload_index)); } -ThreadpoolScheduler::ThreadpoolScheduler() { - using namespace dnnl::impl::threadpool_utils; - // Set number of threads to one when threadpool is not available. - _num_threads = get_active_threadpool() == nullptr ? 1 : num_threads_hint(); -} +ThreadpoolScheduler::ThreadpoolScheduler() + : _num_threads(dnnl_get_max_threads()) {} ThreadpoolScheduler::~ThreadpoolScheduler() = default; @@ -83,8 +73,8 @@ unsigned int ThreadpoolScheduler::num_threads() const { } void ThreadpoolScheduler::set_num_threads(unsigned int num_threads) { - arm_compute::lock_guard lock(this->_run_workloads_mutex); - _num_threads = num_threads == 0 ? num_threads_hint() : num_threads; + std::lock_guard lock(this->_mtx); + _num_threads = num_threads == 0 ? dnnl_get_max_threads() : num_threads; } void ThreadpoolScheduler::schedule(ICPPKernel *kernel, const Hints &hints) { @@ -104,7 +94,7 @@ void ThreadpoolScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, void ThreadpoolScheduler::run_workloads( std::vector &workloads) { - arm_compute::lock_guard lock(this->_run_workloads_mutex); + std::lock_guard lock(this->_mtx); const unsigned int num_threads = std::min(static_cast(_num_threads), @@ -145,7 +135,7 @@ void ThreadpoolScheduler::run_workloads( if (is_async) b.wait(); } -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_threadpool_scheduler.hpp b/src/cpu/acl/acl_threadpool_scheduler.hpp similarity index 84% rename from src/cpu/aarch64/acl_threadpool_scheduler.hpp rename to src/cpu/acl/acl_threadpool_scheduler.hpp index e9ba21c8032..6370141010e 100644 --- a/src/cpu/aarch64/acl_threadpool_scheduler.hpp +++ b/src/cpu/acl/acl_threadpool_scheduler.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Arm Ltd. and affiliates +* Copyright 2022, 2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,25 +14,26 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_THREADPOOL_SCHEDULER_HPP -#define CPU_AARCH64_ACL_THREADPOOL_SCHEDULER_HPP +#ifndef CPU_ACL_THREADPOOL_SCHEDULER_HPP +#define CPU_ACL_THREADPOOL_SCHEDULER_HPP #include "oneapi/dnnl/dnnl_config.h" #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL #include "arm_compute/runtime/IScheduler.h" -#include "support/Mutex.h" + +#include namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { class ThreadpoolScheduler final : public arm_compute::IScheduler { public: ThreadpoolScheduler(); - ~ThreadpoolScheduler(); + ~ThreadpoolScheduler() override; /// Sets the number of threads the scheduler will use to run the kernels. void set_num_threads(unsigned int num_threads) override; @@ -54,15 +55,15 @@ class ThreadpoolScheduler final : public arm_compute::IScheduler { void run_workloads(std::vector &workloads) override; private: - uint _num_threads {}; - arm_compute::Mutex _run_workloads_mutex {}; + unsigned int _num_threads {}; + std::mutex _mtx; }; -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_THREADPOOL_SCHEDULER_HPP +#endif // CPU_ACL_THREADPOOL_SCHEDULER_HPP #endif // DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL diff --git a/src/cpu/aarch64/acl_utils.cpp b/src/cpu/acl/acl_utils.cpp similarity index 99% rename from src/cpu/aarch64/acl_utils.cpp rename to src/cpu/acl/acl_utils.cpp index eaf415df01f..a5d7b8a6048 100644 --- a/src/cpu/aarch64/acl_utils.cpp +++ b/src/cpu/acl/acl_utils.cpp @@ -14,12 +14,12 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/acl_utils.hpp" +#include "cpu/acl/acl_utils.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace acl_utils { @@ -345,7 +345,7 @@ void reorder_to_weight_format(arm_compute::TensorInfo &info, memory_desc_t &md, } // namespace acl_utils -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/acl_utils.hpp b/src/cpu/acl/acl_utils.hpp similarity index 97% rename from src/cpu/aarch64/acl_utils.hpp rename to src/cpu/acl/acl_utils.hpp index f76a78b9ff1..939d0001f4d 100644 --- a/src/cpu/aarch64/acl_utils.hpp +++ b/src/cpu/acl/acl_utils.hpp @@ -14,8 +14,8 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_UTILS_HPP -#define CPU_AARCH64_ACL_UTILS_HPP +#ifndef CPU_ACL_UTILS_HPP +#define CPU_ACL_UTILS_HPP #include @@ -33,7 +33,7 @@ namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace acl_utils { @@ -124,9 +124,9 @@ void reorder_to_weight_format(arm_compute::TensorInfo &info, memory_desc_t &md, } // namespace acl_utils -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_UTILS_HPP +#endif // CPU_ACL_UTILS_HPP diff --git a/src/cpu/acl/acl_winograd_convolution.cpp b/src/cpu/acl/acl_winograd_convolution.cpp new file mode 100644 index 00000000000..eb2e0bd9883 --- /dev/null +++ b/src/cpu/acl/acl_winograd_convolution.cpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2020-2023, 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/acl/acl_winograd_convolution.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace acl { +using data_t = prec_traits_t::type; + +status_t acl_wino_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + // Lock here is needed because resource_mapper does not support + // concurrent multithreaded access. + std::lock_guard _lock {this->mtx}; + // Retrieve primitive resource and configured Compute Library objects + auto *acl_resource + = ctx.get_resource_mapper()->get(this); + acl_obj_t &acl_wino_obj + = acl_resource->get_acl_obj(); + + return execute_forward_conv_acl< + acl_obj_t, pd_t, data_t>( + ctx, acl_wino_obj, pd()); +} + +} // namespace acl +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/acl/acl_winograd_convolution.hpp b/src/cpu/acl/acl_winograd_convolution.hpp new file mode 100644 index 00000000000..9c29ea376a3 --- /dev/null +++ b/src/cpu/acl/acl_winograd_convolution.hpp @@ -0,0 +1,152 @@ +/******************************************************************************* +* Copyright 2020-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ACL_WINOGRAD_CONVOLUTION_HPP +#define CPU_ACL_WINOGRAD_CONVOLUTION_HPP + +#include "cpu/cpu_convolution_pd.hpp" + +#include "cpu/acl/acl_convolution_utils.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace acl { + +struct acl_wino_resource_t : public resource_t { + acl_wino_resource_t() + : acl_wino_obj_(utils::make_unique< + acl_obj_t>()) {} + + status_t configure(const acl_conv_conf_t &acp) { + if (!acl_wino_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor + acl_wino_obj_->src_tensor.allocator()->init(acp.src_tensor_info); + acl_wino_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); + acl_wino_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); + acl_wino_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); + + // clang-format off + acl_wino_obj_->conv.configure( + &acl_wino_obj_->src_tensor, + &acl_wino_obj_->wei_tensor, + acp.with_bias ? &acl_wino_obj_->bia_tensor : nullptr, + &acl_wino_obj_->dst_tensor, + acp.padstride_info, + acp.act_info, + true); // to support 5x5, 7x7 filter shapes in addition to 3x3 + // clang-format on + + return status::success; + } + + acl_obj_t &get_acl_obj() const { + return *acl_wino_obj_; + } + + DNNL_DISALLOW_COPY_AND_ASSIGN(acl_wino_resource_t); + +private: + std::unique_ptr> + acl_wino_obj_; +}; // acl_wino_resource_t + +struct acl_wino_convolution_fwd_t : public primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; + + DECLARE_COMMON_PD_T( + "wino:acl", acl_wino_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); + + status_t init(engine_t *engine) { + using namespace data_type; + const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, f16); + const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, f32); + bool ok = is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && utils::one_of(true, is_fp16_ok, is_fp32_ok) + && !has_zero_dim_memory(); + + ok = ok && DNNL_CPU_THREADING_RUNTIME != DNNL_RUNTIME_THREADPOOL; + if (!ok) return status::unimplemented; + + CHECK(acl_convolution_utils::init_conf_wino(acp_, src_md_, + weights_md_, dst_md_, bias_md_, *desc(), *attr())); + + set_default_alg_kind(alg_kind::convolution_winograd); + + CHECK(post_ops.init( + engine, attr_.post_ops_, dst_md_, acp_.act_info)); + acp_.use_dst_acc_for_sum = post_ops.has_sum(); + + if (acp_.use_dst_acc_for_sum) { + const memory_desc_wrapper dst_d(&dst_md_); + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_generic_acc, + dst_d.nelems(), dst_d.data_type_size()); + } + + return status::success; + } + + acl_conv_conf_t acp_ = utils::zero(); + acl_post_ops_t post_ops; + }; + + acl_wino_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + status_t create_resource( + engine_t *engine, resource_mapper_t &mapper) const override { + if (mapper.has_resource(this)) return status::success; + + auto r = utils::make_unique(); + if (!r) return status::out_of_memory; + + // Configure the resource based on information from primitive descriptor + CHECK(r->configure(pd()->acp_)); + mapper.add(this, std::move(r)); + + return status::success; + } + + ~acl_wino_convolution_fwd_t() override = default; + + using data_t = typename prec_traits_t::type; + + status_t execute(const exec_ctx_t &ctx) const override { + return execute_forward(ctx); + } + +private: + // To guard the const execute_forward(), the mutex must be 'mutable' + mutable std::mutex mtx; + status_t execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } +}; // acl_wino_convolution_fwd_t + +} // namespace acl +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_ACL_WINOGRAD_CONVOLUTION_HPP diff --git a/src/cpu/aarch64/matmul/acl_lowp_matmul.cpp b/src/cpu/acl/matmul/acl_lowp_matmul.cpp similarity index 75% rename from src/cpu/aarch64/matmul/acl_lowp_matmul.cpp rename to src/cpu/acl/matmul/acl_lowp_matmul.cpp index 076d5fd321a..925431cea0c 100644 --- a/src/cpu/aarch64/matmul/acl_lowp_matmul.cpp +++ b/src/cpu/acl/matmul/acl_lowp_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Arm Ltd. and affiliates +* Copyright 2024-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,14 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/matmul/acl_lowp_matmul.hpp" +#include "cpu/acl/matmul/acl_lowp_matmul.hpp" + +#include "src/cpu/CpuTypes.h" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace matmul { status_t acl_lowp_matmul_resource_t::configure( @@ -61,18 +63,25 @@ status_t acl_lowp_matmul_resource_t::configure( status_t acl_lowp_matmul_t::pd_t::init(engine_t *engine) { VDISPATCH_MATMUL(set_default_formats(), "failed to set default formats"); using smask_t = primitive_attr_t::skip_mask_t; - VDISPATCH_MATMUL( - attr()->has_default_values(smask_t::scales_runtime - | smask_t::zero_points_runtime | smask_t::post_ops), + VDISPATCH_MATMUL(attr()->has_default_values(smask_t::scales + | smask_t::zero_points | smask_t::post_ops), "only scale, zero point and post-ops attrs supported"); - VDISPATCH_MATMUL(attr()->scales_.get(DNNL_ARG_SRC).mask_ == 0 - && attr()->zero_points_.get(DNNL_ARG_SRC) == 0 - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 - && attr()->zero_points_.get(DNNL_ARG_WEIGHTS) == 0 - && attr()->scales_.get(DNNL_ARG_DST).mask_ == 0 - && attr()->zero_points_.get(DNNL_ARG_DST) == 0, - "common scales and zero points only"); + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; + for (int arg : supported_args) { + if (attr()->scales_.has_default_values(arg)) continue; + + VDISPATCH_MATMUL(attr()->scales_.get_mask(arg) == 0, + VERBOSE_UNSUPPORTED_SCALES_CFG); + } + + for (int arg : supported_args) { + if (attr()->zero_points_.has_default_values(arg)) continue; + + VDISPATCH_MATMUL(attr()->zero_points_.get_mask(arg) == 0, + VERBOSE_UNSUPPORTED_SCALES_CFG); + } VDISPATCH_MATMUL( !has_runtime_dims_or_strides(), VERBOSE_RUNTIMEDIM_UNSUPPORTED); @@ -82,6 +91,14 @@ status_t acl_lowp_matmul_t::pd_t::init(engine_t *engine) { const memory_desc_wrapper bia_d(bias_md_); const memory_desc_wrapper dst_d(dst_md_); + cpu::matmul::matmul_helper_t helper(src_d, wei_d, dst_d); + const dim_t M = helper.M(); + const dim_t N = helper.N(); + const dim_t K = helper.K(); + const dim_t dst_batch = helper.batch(); + const dim_t src_batch = helper.src_batch(); + const dim_t wei_batch = helper.wei_batch(); + using namespace data_type; // Note that has_default_values checks the argument for default zero @@ -100,39 +117,66 @@ status_t acl_lowp_matmul_t::pd_t::init(engine_t *engine) { VERBOSE_UNSUPPORTED_DT_CFG); almc_.dst_is_s8 = dst_d.data_type() == s8; - VDISPATCH_MATMUL(src_d.matches_tag(format_tag::ab) - && wei_d.matches_tag(format_tag::ab) - && dst_d.matches_tag(format_tag::ab), - VERBOSE_UNSUPPORTED_TAG); + // reject in case the op is running on a cpu that have i8mm instruction set. + // this is a temporary fix until the issue is resolved. + VDISPATCH_MATMUL( + arm_compute::CPUInfo::get().has_i8mm() || dst_d.data_type() != s8, + "Op not supported on CPUs without i8mm instructions when dest " + "datatype is s8"); + + using namespace format_tag; + auto src_tag = memory_desc_matches_one_of_tag(src_md_, abcd, abc, ab); + auto wei_tag = memory_desc_matches_one_of_tag(weights_md_, abcd, abc, ab); + auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, abcd, abc, ab); - VDISPATCH_MATMUL_SC( - memory_desc_init_by_tag(bias_md_, bias_md_.ndims, bias_md_.dims, - bias_md_.data_type, format_tag::ab), + ACL_CHECK_SUPPORT( + utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag), + "Format tag is undefined"); + + VDISPATCH_MATMUL_SC(memory_desc_init_by_tag(bias_md_, bias_md_.ndims, + bias_md_.dims, bias_md_.data_type, dst_tag), VERBOSE_UNSUPPORTED_BIAS_CFG); // We set the QuantizationInfo to be dynamic because it is re-set in run() - almc_.src_tensor_info - = arm_compute::TensorInfo(arm_compute::TensorShape(K(), M()), 1, - arm_compute::DataType::QASYMM8_SIGNED, - arm_compute::QuantizationInfo(1.0, 0, true)); + almc_.src_tensor_info = arm_compute::TensorInfo( + arm_compute::TensorShape(K, M, 1, src_batch), 1, + arm_compute::DataType::QASYMM8_SIGNED, + arm_compute::QuantizationInfo(1.0, 0, true)); almc_.src_tensor_info.set_are_values_constant(false); almc_.wei_tensor_info - = arm_compute::TensorInfo(arm_compute::TensorShape(N(), K()), 1, - arm_compute::DataType::QASYMM8_SIGNED, + = arm_compute::TensorInfo(arm_compute::TensorShape(N, K, wei_batch), + 1, arm_compute::DataType::QASYMM8_SIGNED, arm_compute::QuantizationInfo(1.0, 0, true)); almc_.wei_tensor_info.set_are_values_constant(false); almc_.bia_tensor_info = arm_compute::TensorInfo( arm_compute::TensorShape(), 1, arm_compute::DataType::F32); almc_.with_bias = bia_d.format_kind() != format_kind::undef; + if (almc_.with_bias) { - // This is not currently guarded in ACL - VDISPATCH_MATMUL(bia_d.ndims() == 2 && bia_d.dims()[0] == 1 - && bia_d.dims()[1] == N(), - "Only 1xN bias is supported"); - almc_.bia_tensor_info.set_tensor_shape( - arm_compute::TensorShape(bia_d.dims()[1], bia_d.dims()[0])); + switch (bia_d.ndims()) { + case 2: + VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == N, + "Only 1xN bias is supported for 2D input"); + almc_.bia_tensor_info.set_tensor_shape( + arm_compute::TensorShape(bia_d.dims()[1], 1)); + break; + case 3: + VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == 1 + && bia_d.dims()[2] == N, + "Only 1x1xN bias is supported for 3D input"); + almc_.bia_tensor_info.set_tensor_shape( + arm_compute::TensorShape(bia_d.dims()[2], 1, 1)); + break; + case 4: + VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == 1 + && bia_d.dims()[2] == 1 && bia_d.dims()[3] == N, + "Only 1x1x1xN bias is supported for 4D input"); + almc_.bia_tensor_info.set_tensor_shape( + arm_compute::TensorShape(bia_d.dims()[3], 1, 1, 1)); + break; + } } // We can fuse sum if it is the first post op @@ -166,14 +210,15 @@ status_t acl_lowp_matmul_t::pd_t::init(engine_t *engine) { almc_.gemm_info.accumulate() ? 1 : 0)); almc_.dst_tensor_info = arm_compute::TensorInfo( - arm_compute::TensorShape(N(), M()), arm_compute::Format::F32); + arm_compute::TensorShape(N, M, 1, dst_batch), + arm_compute::Format::F32); almc_.dst_cast_tensor_info = almc_.dst_tensor_info; - almc_.dst_s8_tensor_info - = arm_compute::TensorInfo(arm_compute::TensorShape(N(), M()), 1, - arm_compute::DataType::QASYMM8_SIGNED, - arm_compute::QuantizationInfo(1.0, 0, true)); + almc_.dst_s8_tensor_info = arm_compute::TensorInfo( + arm_compute::TensorShape(N, M, 1, dst_batch), 1, + arm_compute::DataType::QASYMM8_SIGNED, + arm_compute::QuantizationInfo(1.0, 0, true)); ACL_CHECK_VALID(arm_compute::NEGEMMLowpMatrixMultiplyCore::validate( &almc_.src_tensor_info, &almc_.wei_tensor_info, @@ -203,11 +248,11 @@ status_t acl_lowp_matmul_t::pd_t::init_scratchpad( const memory_desc_wrapper dst_d(&dst_md_); if (almc_.use_dst_acc) { scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt, - dst_d.nelems(), sizeof(float32_t)); + dst_d.nelems(), sizeof(arm_compute::float32_t)); } if (almc_.use_cast_acc) { scratchpad.book(memory_tracking::names::key_matmul_dst_cast_acc, - dst_d.nelems(), sizeof(float32_t)); + dst_d.nelems(), sizeof(arm_compute::float32_t)); } return status::success; } @@ -326,7 +371,7 @@ status_t acl_lowp_matmul_t::execute(const exec_ctx_t &ctx) const { }; } // namespace matmul -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/matmul/acl_lowp_matmul.hpp b/src/cpu/acl/matmul/acl_lowp_matmul.hpp similarity index 89% rename from src/cpu/aarch64/matmul/acl_lowp_matmul.hpp rename to src/cpu/acl/matmul/acl_lowp_matmul.hpp index 30502aea1cc..3aaee9a70df 100644 --- a/src/cpu/aarch64/matmul/acl_lowp_matmul.hpp +++ b/src/cpu/acl/matmul/acl_lowp_matmul.hpp @@ -21,16 +21,18 @@ #include "cpu/matmul/cpu_matmul_pd.hpp" #include "cpu/matmul/matmul_utils.hpp" +#include "arm_compute/core/CPP/CPPTypes.h" #include "arm_compute/runtime/NEON/functions/NEDequantizationLayer.h" #include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h" #include "arm_compute/runtime/NEON/functions/NEQuantizationLayer.h" -#include "cpu/aarch64/acl_post_ops.hpp" -#include "cpu/aarch64/acl_utils.hpp" + +#include "cpu/acl/acl_post_ops.hpp" +#include "cpu/acl/acl_utils.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace matmul { struct acl_lowp_matmul_obj_t { @@ -76,11 +78,6 @@ struct acl_lowp_matmul_resource_t : public resource_t { struct acl_lowp_matmul_t : public primitive_t { struct pd_t : public dnnl::impl::cpu::matmul::cpu_matmul_pd_t { - - pd_t(const matmul_desc_t *adesc, const primitive_attr_t *attr, - const cpu_matmul_pd_t *hint_fwd_pd) - : cpu_matmul_pd_t(adesc, attr, hint_fwd_pd), almc_() {} - using cpu_matmul_pd_t::cpu_matmul_pd_t; DECLARE_COMMON_PD_T( @@ -90,7 +87,7 @@ struct acl_lowp_matmul_t : public primitive_t { status_t init_scratchpad(memory_tracking::registrar_t &scratchpad); - acl_lowp_matmul_conf_t almc_; + acl_lowp_matmul_conf_t almc_ = utils::zero(); acl_post_ops_t acl_post_ops; }; @@ -106,9 +103,9 @@ struct acl_lowp_matmul_t : public primitive_t { }; } // namespace matmul -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_LOWP_MATMUL_HPP \ No newline at end of file +#endif // ACL_LOWP_MATMUL_HPP diff --git a/src/cpu/acl/matmul/acl_lowp_matmul_sq.cpp b/src/cpu/acl/matmul/acl_lowp_matmul_sq.cpp new file mode 100644 index 00000000000..e4ac7dbb485 --- /dev/null +++ b/src/cpu/acl/matmul/acl_lowp_matmul_sq.cpp @@ -0,0 +1,273 @@ +/******************************************************************************* +* Copyright 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/acl/matmul/acl_lowp_matmul_sq.hpp" + +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h" +#include "arm_compute/runtime/NEON/functions/NEQuantizationLayer.h" + +#include "cpu/acl/acl_utils.hpp" +#include "src/cpu/CpuTypes.h" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace acl { +namespace matmul { +status_t acl_lowp_matmul_sq_resource_t::configure( + const acl_lowp_matmul_sq_conf_t &almc) { + if (!acl_obj_) return status::out_of_memory; + acl_obj_->src_tensor.allocator()->init(almc.src_tensor_info); + acl_obj_->wei_tensor.allocator()->init(almc.wei_tensor_info); + if (almc.with_bias) { + acl_obj_->bia_tensor.allocator()->init(almc.bia_tensor_info); + } + acl_obj_->dst_tensor.allocator()->init(almc.dst_tensor_info); + arm_compute::QuantizationInfo qi {1.0, 0, true}; + acl_obj_->src_tensor.info()->set_quantization_info(qi); + acl_obj_->wei_tensor.info()->set_quantization_info(qi); + acl_obj_->dst_tensor.info()->set_quantization_info(qi); + acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, + almc.with_bias ? &acl_obj_->bia_tensor : nullptr, + &acl_obj_->dst_tensor, almc.gemm_info); + return status::success; +} +status_t acl_lowp_matmul_sq_t::pd_t::init(engine_t *engine) { + VDISPATCH_MATMUL(set_default_formats(), "failed to set default formats"); + using smask_t = primitive_attr_t::skip_mask_t; + VDISPATCH_MATMUL(attr()->has_default_values(smask_t::scales + | smask_t::zero_points | smask_t::post_ops), + "only scale, zero point and post-ops attrs supported"); + + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; + for (int arg : supported_args) { + if (attr()->scales_.has_default_values(arg)) continue; + + VDISPATCH_MATMUL(attr()->scales_.get_mask(arg) == 0, + VERBOSE_UNSUPPORTED_SCALES_CFG); + } + + for (int arg : supported_args) { + if (attr()->zero_points_.has_default_values(arg)) continue; + + VDISPATCH_MATMUL(attr()->zero_points_.get_mask(arg) == 0, + VERBOSE_UNSUPPORTED_SCALES_CFG); + } + + VDISPATCH_MATMUL( + !has_runtime_dims_or_strides(), VERBOSE_RUNTIMEDIM_UNSUPPORTED); + const memory_desc_wrapper src_d(src_md_); + const memory_desc_wrapper wei_d(weights_md_); + const memory_desc_wrapper bia_d(bias_md_); + const memory_desc_wrapper dst_d(dst_md_); + + cpu::matmul::matmul_helper_t helper(src_d, wei_d, dst_d); + const dim_t M = helper.M(); + const dim_t N = helper.N(); + const dim_t K = helper.K(); + const dim_t dst_batch = helper.batch(); + const dim_t src_batch = helper.src_batch(); + const dim_t wei_batch = helper.wei_batch(); + + using namespace data_type; + VDISPATCH_MATMUL(utils::one_of(src_d.data_type(), s8, u8) + && wei_d.data_type() == s8 + && (src_d.data_type() == s8 ? dst_d.data_type() == s8 + : dst_d.data_type() == u8), + VERBOSE_UNSUPPORTED_DT_CFG); + VDISPATCH_MATMUL(utils::one_of(bia_d.data_type(), f32, undef), + VERBOSE_UNSUPPORTED_DT_CFG); + + // reject in case the op is running on a cpu that have i8mm instruction set. + // this is a temporary fix until the issue is resolved. + VDISPATCH_MATMUL(arm_compute::CPUInfo::get().has_i8mm(), + "Op not supported on CPUs without i8mm instructions"); + + // ACL batch dimension only support s32 for 3D and 4D + VDISPATCH_MATMUL( + wei_batch == 1, "Batch dimension must be 1 for the weights"); + + using namespace format_tag; + auto src_tag = memory_desc_matches_one_of_tag(src_md_, abcd, abc, ab); + auto wei_tag = memory_desc_matches_one_of_tag(weights_md_, abcd, abc, ab); + auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, abcd, abc, ab); + + ACL_CHECK_SUPPORT( + utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag), + "Format tag is undefined"); + + VDISPATCH_MATMUL_SC(memory_desc_init_by_tag(bias_md_, bias_md_.ndims, + bias_md_.dims, bias_md_.data_type, dst_tag), + VERBOSE_UNSUPPORTED_BIAS_CFG); + + almc_.bia_tensor_info = arm_compute::TensorInfo( + arm_compute::TensorShape(), 1, arm_compute::DataType::S32); + almc_.with_bias = bia_d.format_kind() != format_kind::undef; + + almc_.src_tensor_info = arm_compute::TensorInfo( + arm_compute::TensorShape(K, M, 1, src_batch), 1, + acl_utils::get_acl_data_t(src_d.data_type(), true), + arm_compute::QuantizationInfo(1.0, 0, true)); + almc_.src_tensor_info.set_are_values_constant(false); + + almc_.wei_tensor_info = arm_compute::TensorInfo( + arm_compute::TensorShape(N, K, 1, wei_batch), 1, + acl_utils::get_acl_data_t(wei_d.data_type(), true), + arm_compute::QuantizationInfo(1.0, 0, true)); + almc_.wei_tensor_info.set_are_values_constant(false); + almc_.dst_tensor_info = arm_compute::TensorInfo( + arm_compute::TensorShape(N, M, 1, dst_batch), 1, + acl_utils::get_acl_data_t(dst_d.data_type(), true), + arm_compute::QuantizationInfo(1.0, 0, true)); + + almc_.bia_tensor_info = arm_compute::TensorInfo( + arm_compute::TensorShape(), 1, arm_compute::DataType::S32); + almc_.with_bias = bia_d.format_kind() != format_kind::undef; + + if (almc_.with_bias) { + switch (bia_d.ndims()) { + case 2: + VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == N, + "Only 1xN bias is supported for 2D input"); + almc_.bia_tensor_info.set_tensor_shape(arm_compute::TensorShape( + bia_d.dims()[1], bia_d.dims()[0])); + break; + case 3: + VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == 1 + && bia_d.dims()[2] == N, + "Only 1x1xN bias is supported for 3D input"); + almc_.bia_tensor_info.set_tensor_shape( + arm_compute::TensorShape(bia_d.dims()[2], 1, 1)); + break; + case 4: + VDISPATCH_MATMUL(bia_d.dims()[0] == 1 && bia_d.dims()[1] == 1 + && bia_d.dims()[2] == 1 && bia_d.dims()[3] == N, + "Only 1x1x1xN bias is supported for 4D input"); + almc_.bia_tensor_info.set_tensor_shape( + arm_compute::TensorShape(bia_d.dims()[3], 1, 1, 1)); + break; + } + } + + arm_compute::GEMMLowpOutputStageInfo info; + info.type = arm_compute::GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + info.gemmlowp_multiplier = 1073741824; + info.gemmlowp_shift = -1; + info.gemmlowp_offset = 0; + info.gemmlowp_min_bound = -128; + info.gemmlowp_max_bound = 127; + info.output_data_type = almc_.dst_tensor_info.data_type(); + almc_.gemm_info.set_gemmlowp_output_stage(info); + auto scratchpad = scratchpad_registry().registrar(); + const dnnl::impl::memory_desc_t dst_md_ {desc_.dst_desc}; + arm_compute::ActivationLayerInfo act_info; + + CHECK(init_scratchpad(engine, scratchpad, acl_post_ops, attr_.post_ops_, + act_info, dst_md_)); + almc_.gemm_info.set_activation_info(act_info); + + ACL_CHECK_VALID(arm_compute::NEGEMMLowpMatrixMultiplyCore::validate( + &almc_.src_tensor_info, &almc_.wei_tensor_info, + almc_.with_bias ? &almc_.bia_tensor_info : nullptr, + &almc_.dst_tensor_info, almc_.gemm_info)); + return status::success; +} + +status_t acl_lowp_matmul_sq_t::pd_t::init_scratchpad(engine_t *engine, + memory_tracking::registrar_t &scratchpad, acl_post_ops_t &post_ops, + dnnl::impl::post_ops_t &attr_post_ops, + arm_compute::ActivationLayerInfo &act_info, + const dnnl::impl::memory_desc_t &dst_md) { + CHECK(post_ops.init(engine, attr_post_ops, dst_md, act_info)); + // ACL only accepts s32 bias for quantization and since + // the current bias vector is f32 we need to convert. + if (almc_.with_bias) { + const memory_desc_wrapper bias_d(&bias_md_); + scratchpad.book(memory_tracking::names::key_conv_bias_s32_convert, + bias_d.nelems(), bias_d.data_type_size()); + } + return status::success; +} +status_t acl_lowp_matmul_sq_t::create_resource( + engine_t *engine, resource_mapper_t &mapper) const { + if (mapper.has_resource(this)) return status::success; + auto r = utils::make_unique(); + if (!r) return status::out_of_memory; + CHECK(r->configure(pd()->almc_)); + mapper.add(this, std::move(r)); + return status::success; +} +status_t acl_lowp_matmul_sq_t::execute(const exec_ctx_t &ctx) const { + std::lock_guard _lock {this->mtx_}; + bool with_bias = pd()->almc_.with_bias; + acl_lowp_matmul_sq_obj_t &acl_obj + = ctx.get_resource_mapper() + ->get(this) + ->get_acl_obj(); + auto src = CTX_IN_MEM(const int8_t *, DNNL_ARG_SRC); + auto wei = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); + auto dst = CTX_OUT_MEM(const int8_t *, DNNL_ARG_DST); + acl_obj.src_tensor.allocator()->import_memory(const_cast(src)); + acl_obj.wei_tensor.allocator()->import_memory(const_cast(wei)); + acl_obj.dst_tensor.allocator()->import_memory(const_cast(dst)); + DEFINE_ARG_SCALES_BUFFER(src_scale, DNNL_ARG_SRC); + DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); + DEFINE_ARG_SCALES_BUFFER(wei_scale, DNNL_ARG_WEIGHTS); + DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS); + DEFINE_ARG_SCALES_BUFFER(dst_scale, DNNL_ARG_DST); + DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST); + if (with_bias) { + const auto scratchpad = ctx.get_scratchpad_grantor(); + auto bia_s32_base = scratchpad.get( + memory_tracking::names::key_conv_bias_s32_convert); + auto bia_f32_base = CTX_IN_MEM(const arm_compute::float32_t *, DNNL_ARG_BIAS); + const float bias_scale = 1 / (*src_scale * (*wei_scale)); + const int num_elements + = acl_obj.bia_tensor.info()->total_size() / sizeof(arm_compute::float32_t); + parallel_nd(num_elements, [&](dim_t e) { + const auto b = int32_t(std::round(bia_f32_base[e] * bias_scale)); + bia_s32_base[e] = b; + }); + acl_obj.bia_tensor.allocator()->init(*acl_obj.bia_tensor.info()); + acl_obj.bia_tensor.allocator()->import_memory(bia_s32_base); + } + acl_obj.src_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo(*src_scale, -src_zero_point, true)); + acl_obj.wei_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo(*wei_scale, -wei_zero_point, true)); + // for efficiency reasons, oneDNN saves the inverse of the destination + acl_obj.dst_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo( + 1.0 / (*dst_scale), dst_zero_point, true)); + // The two calls below are stateful and, therefore, not fully thread-safe. + // This issue is being addressed, and the lock will be removed when the + // matmul stateless work is finished. + acl_obj.gemm.update_quantization_parameters(); + acl_obj.gemm.run(); + // free() here tells ACL it can no longer use it, it does not deallocate + acl_obj.src_tensor.allocator()->free(); + acl_obj.wei_tensor.allocator()->free(); + if (with_bias) { acl_obj.bia_tensor.allocator()->free(); } + acl_obj.dst_tensor.allocator()->free(); + return status::success; +}; +} // namespace matmul +} // namespace acl +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/acl/matmul/acl_lowp_matmul_sq.hpp b/src/cpu/acl/matmul/acl_lowp_matmul_sq.hpp new file mode 100644 index 00000000000..d9c6192206f --- /dev/null +++ b/src/cpu/acl/matmul/acl_lowp_matmul_sq.hpp @@ -0,0 +1,105 @@ +/******************************************************************************* +* Copyright 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ACL_LOWP_MATMUL_SQ_HPP +#define ACL_LOWP_MATMUL_SQ_HPP + +#include + +#include "cpu/cpu_primitive.hpp" +#include "cpu/matmul/cpu_matmul_pd.hpp" +#include "cpu/matmul/matmul_utils.hpp" + +#include "cpu/acl/acl_post_ops.hpp" + +#include "arm_compute/core/CPP/CPPTypes.h" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace acl { +namespace matmul { + +struct acl_lowp_matmul_sq_obj_t { + arm_compute::GEMMLowpOutputStageInfo info; + arm_compute::NEGEMMLowpMatrixMultiplyCore gemm; + arm_compute::Tensor src_tensor; + arm_compute::Tensor wei_tensor; + arm_compute::Tensor bia_tensor; + arm_compute::Tensor dst_tensor; +}; + +struct acl_lowp_matmul_sq_conf_t { + bool with_bias; + arm_compute::TensorInfo src_tensor_info; + arm_compute::TensorInfo wei_tensor_info; + arm_compute::TensorInfo bia_tensor_info; + arm_compute::TensorInfo dst_tensor_info; + arm_compute::GEMMInfo gemm_info; +}; + +struct acl_lowp_matmul_sq_resource_t : public resource_t { + acl_lowp_matmul_sq_resource_t() + : acl_obj_(utils::make_unique()) {} + + status_t configure(const acl_lowp_matmul_sq_conf_t &almc); + + acl_lowp_matmul_sq_obj_t &get_acl_obj() const { return *acl_obj_; } + + DNNL_DISALLOW_COPY_AND_ASSIGN(acl_lowp_matmul_sq_resource_t); + +private: + std::unique_ptr acl_obj_; +}; + +struct acl_lowp_matmul_sq_t : public primitive_t { + struct pd_t : public dnnl::impl::cpu::matmul::cpu_matmul_pd_t { + + using cpu_matmul_pd_t::cpu_matmul_pd_t; + + DECLARE_COMMON_PD_T("lowp_gemm_sq:acl", acl_lowp_matmul_sq_t, + USE_GLOBAL_SCRATCHPAD); + + status_t init(engine_t *engine); + + status_t init_scratchpad(engine_t *engine, + memory_tracking::registrar_t &scratchpad, + acl_post_ops_t &post_ops, dnnl::impl::post_ops_t &attr_post_ops, + arm_compute::ActivationLayerInfo &act_info, + const dnnl::impl::memory_desc_t &dst_md); + + acl_lowp_matmul_sq_conf_t almc_; + acl_post_ops_t acl_post_ops; + }; + + acl_lowp_matmul_sq_t(const pd_t *apd) : primitive_t(apd) {} + + status_t create_resource(engine_t *engine, resource_mapper_t &mapper) const; + + status_t execute(const exec_ctx_t &ctx) const; + +private: + mutable std::mutex mtx_; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } +}; + +} // namespace matmul +} // namespace acl +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // ACL_LOWP_MATMUL_HPP \ No newline at end of file diff --git a/src/cpu/aarch64/matmul/acl_matmul.cpp b/src/cpu/acl/matmul/acl_matmul.cpp similarity index 85% rename from src/cpu/aarch64/matmul/acl_matmul.cpp rename to src/cpu/acl/matmul/acl_matmul.cpp index 3d3e95a491d..2da752d7883 100644 --- a/src/cpu/aarch64/matmul/acl_matmul.cpp +++ b/src/cpu/acl/matmul/acl_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Arm Ltd. and affiliates +* Copyright 2021-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,18 +14,20 @@ * limitations under the License. *******************************************************************************/ -#include "cpu/aarch64/matmul/acl_matmul.hpp" +#include "cpu/acl/matmul/acl_matmul.hpp" + +#include namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace matmul { using namespace data_type; namespace { -using data_t = prec_traits::type; +using data_t = prec_traits_t::type; } // namespace status_t acl_matmul_t::init(engine_t *engine) { @@ -76,19 +78,24 @@ status_t acl_matmul_t::pd_t::init(engine_t *engine) { = utils::everyone_is(data_type::bf16, src_md()->data_type, weights_md()->data_type, dst_md()->data_type) && platform::has_data_type_support(data_type::bf16); + const bool is_bf16f32_ok + = utils::everyone_is(data_type::bf16, src_md()->data_type, + weights_md()->data_type) + && utils::everyone_is(data_type::f32, dst_md()->data_type) + && platform::has_data_type_support(data_type::bf16); // we need to save this state as it can change inside set_default_formats() weights_format_kind_ = weights_md_.format_kind; VDISPATCH_MATMUL(is_dense_format_kind(), VERBOSE_UNSUPPORTED_SPARSE_CFG); - VDISPATCH_MATMUL(utils::one_of(true, is_fp32_ok, is_fp16_ok, is_bf16_ok), + VDISPATCH_MATMUL(utils::one_of(true, is_fp32_ok, is_fp16_ok, is_bf16_ok, + is_bf16f32_ok), VERBOSE_UNSUPPORTED_DT_CFG); VDISPATCH_MATMUL(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); - VDISPATCH_MATMUL(attr()->has_default_values(smask_t::oscale - | smask_t::post_ops | smask_t::fpmath_mode), + VDISPATCH_MATMUL(attr()->has_default_values( + smask_t::post_ops | smask_t::fpmath_mode), VERBOSE_UNSUPPORTED_ATTR); - VDISPATCH_MATMUL(attr_oscale_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_MATMUL( !has_runtime_dims_or_strides(), VERBOSE_RUNTIMEDIM_UNSUPPORTED); @@ -145,18 +152,16 @@ status_t acl_matmul_t::pd_t::init(engine_t *engine) { auto scratchpad = scratchpad_registry().registrar(); arm_compute::experimental::MemoryRequirements aux_mem_req; - // Query buffer memory requirement, if not using fixed-format kernel - if (weights_format_kind_ != format_kind::any) { - arm_compute::experimental::op::ll::CpuGemmAssemblyDispatch asm_gemm; - if (amp_.do_transC) { - asm_gemm.configure(&_.wei_tensor_info, &_.src_tensor_info, - nullptr, &_.dst_acc_info, amp_.gemm_info); - } else { - asm_gemm.configure(&_.src_tensor_info, &_.wei_tensor_info, - nullptr, &_.dst_tensor_info, amp_.gemm_info); - } - aux_mem_req = asm_gemm.workspace(); + // Query buffer memory requirement + arm_compute::experimental::op::ll::CpuGemmAssemblyDispatch asm_gemm; + if (amp_.do_transC) { + asm_gemm.configure(&_.wei_tensor_info, &_.src_tensor_info, + nullptr, &_.dst_acc_info, amp_.gemm_info); + } else { + asm_gemm.configure(&_.src_tensor_info, &_.wei_tensor_info, + nullptr, &_.dst_tensor_info, amp_.gemm_info); } + aux_mem_req = asm_gemm.workspace(); CHECK(acl_matmul_utils::init_scratchpad( scratchpad, amp_, src_md_, weights_md_, dst_md_, aux_mem_req)); @@ -165,12 +170,20 @@ status_t acl_matmul_t::pd_t::init(engine_t *engine) { template status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { - status_t status = status::success; auto src_base = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); - auto amp = pd()->amp_; + const auto & = pd()->amp_; + + std::unique_lock locker {mtx_, std::defer_lock}; + + // Some of the underlying kernels used by ACL still require some state and + // are not safe to be called in parallel with different execution contexts. + // Eventually when all kernels are truly stateless, this guard can be + // removed. + if (!acl_obj_->asm_gemm.has_stateless_impl()) { locker.lock(); } + bool is_transA = amp.is_transA; bool is_transB = amp.is_transB; bool do_transC = amp.do_transC; @@ -287,27 +300,28 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { } // Get pointer to scratchpad memory and create a workspace tensor for - // CpuGemm. Fixed-format kernel does not need this workspace tensor. + // CpuGemmAssemblyDispatch. std::vector tmp_tensors(acl_obj_->aux_mem_req.size()); - if (!IsFixedFormat) { - for (const auto &key : matmul_keys) { - const auto id = key.first; - if (acl_obj_->aux_mem_req[id].size > 0) { - const auto info = arm_compute::TensorInfo( - arm_compute::TensorShape( - acl_obj_->aux_mem_req[id].size), - 1, arm_compute::DataType::U8); - auto buffer = scratchpad.get(key.second); - tmp_tensors[id].allocator()->init( - info, acl_obj_->aux_mem_req[id].alignment); - tmp_tensors[id].allocator()->import_memory(buffer); - matmul_pack.add_tensor( - acl_obj_->aux_mem_req[id].slot, &tmp_tensors[id]); - } + for (const auto &key : matmul_keys) { + const auto id = key.first; + if (acl_obj_->aux_mem_req[id].size > 0) { + auto info = arm_compute::TensorInfo( + arm_compute::TensorShape(acl_obj_->aux_mem_req[id].size), 1, + arm_compute::DataType::U8); + + auto *buffer = scratchpad.get(key.second); + + tmp_tensors[id].allocator()->init( + info, acl_obj_->aux_mem_req[id].alignment); + tmp_tensors[id].allocator()->import_memory(buffer); + + matmul_pack.add_tensor( + acl_obj_->aux_mem_req[id].slot, &tmp_tensors[id]); } } acl_obj_->asm_gemm.run(matmul_pack); + if (do_act) { auto dst_to_use = do_transC ? &dst_acc_tensor : &dst_tensor; arm_compute::ITensorPack act_pack; @@ -337,7 +351,7 @@ template status_t acl_matmul_t::execute_forward( const exec_ctx_t &ctx) const; } // namespace matmul -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/acl/matmul/acl_matmul.hpp similarity index 76% rename from src/cpu/aarch64/matmul/acl_matmul.hpp rename to src/cpu/acl/matmul/acl_matmul.hpp index 30641a746a7..cd587e2cd60 100644 --- a/src/cpu/aarch64/matmul/acl_matmul.hpp +++ b/src/cpu/acl/matmul/acl_matmul.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Arm Ltd. and affiliates +* Copyright 2021-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,38 +17,29 @@ #ifndef ACL_MATMUL_HPP #define ACL_MATMUL_HPP -#include "common/utils.hpp" -#include "cpu/aarch64/acl_post_ops.hpp" -#include "cpu/aarch64/matmul/acl_matmul_utils.hpp" +#include "cpu/acl/acl_post_ops.hpp" +#include "cpu/acl/matmul/acl_matmul_utils.hpp" +#include "cpu/matmul/cpu_matmul_pd.hpp" + +#include namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace matmul { struct acl_matmul_t : public primitive_t { struct pd_t : public dnnl::impl::cpu::matmul::cpu_matmul_pd_t { - - pd_t(const matmul_desc_t *adesc, const primitive_attr_t *attr, - const cpu_matmul_pd_t *hint_fwd_pd) - : cpu_matmul_pd_t(adesc, attr, hint_fwd_pd), amp_() {} - using cpu_matmul_pd_t::cpu_matmul_pd_t; DECLARE_COMMON_PD_T("gemm:acl", acl_matmul_t, USE_GLOBAL_SCRATCHPAD); status_t init(engine_t *engine); - acl_matmul_conf_t amp_; + acl_matmul_conf_t amp_ = utils::zero(); acl_post_ops_t acl_post_ops; dnnl::impl::format_kind_t weights_format_kind_; - - protected: - bool attr_oscale_ok() const { - const auto &oscale = attr()->output_scales_; - return oscale.mask_ == 0; - } }; acl_matmul_t(const pd_t *apd) @@ -71,10 +62,11 @@ struct acl_matmul_t : public primitive_t { const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr acl_obj_; + mutable std::mutex mtx_; }; // acl_matmul_t } // namespace matmul -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/acl/matmul/acl_matmul_utils.cpp similarity index 92% rename from src/cpu/aarch64/matmul/acl_matmul_utils.cpp rename to src/cpu/acl/matmul/acl_matmul_utils.cpp index a921422ac0b..9ea9cfcdde1 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +++ b/src/cpu/acl/matmul/acl_matmul_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Arm Ltd. and affiliates +* Copyright 2021-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. *******************************************************************************/ +#include "cpu/acl/matmul/acl_matmul_utils.hpp" +#include "cpu/acl/acl_utils.hpp" #include "cpu/matmul/matmul_utils.hpp" -#include "cpu/aarch64/matmul/acl_matmul_utils.hpp" - namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace acl_matmul_utils { @@ -47,10 +47,26 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, // for e.g when ab in abcd is 1x1 bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1) && IMPLICATION(wei_batch > 1, src_batch == 1); + ACL_CHECK_SUPPORT(src_d.ndims() == 4 && src_batch != wei_batch && !batch_ok, "matmul broadcast supported only for 3D shapes and 4D shapes when " "ab is 1x1"); + if (src_d.ndims() == 4 && src_batch == wei_batch + && src_d.dims()[0] != wei_d.dims()[0]) { // 4D broadcast occurred + if (src_d.dims()[0] == 1 && wei_d.dims()[0] != 1) { // Broadcast src + ACL_CHECK_SUPPORT( + IMPLICATION(src_d.dims()[1] != 1, wei_d.dims()[1] == 1), + "acl only broadcasts one of src or wei at once"); + } + + if (wei_d.dims()[0] == 1 && src_d.dims()[0] != 1) { // Broadcast wei + ACL_CHECK_SUPPORT( + IMPLICATION(src_d.dims()[1] == 1, wei_d.dims()[1] != 1), + "acl only broadcasts one of src or wei at once"); + } + } + // ACL does not support bias bool with_bias = md.bias_desc.format_kind != format_kind::undef; ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul"); @@ -221,7 +237,7 @@ template status_t init_conf_matmul(acl_matmul_conf_t &, } // namespace acl_matmul_utils -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/acl/matmul/acl_matmul_utils.hpp similarity index 85% rename from src/cpu/aarch64/matmul/acl_matmul_utils.hpp rename to src/cpu/acl/matmul/acl_matmul_utils.hpp index cc8eae44ea7..d55cf71263f 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +++ b/src/cpu/acl/matmul/acl_matmul_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Arm Ltd. and affiliates +* Copyright 2021-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,30 +14,30 @@ * limitations under the License. *******************************************************************************/ -#ifndef CPU_AARCH64_ACL_MATMUL_UTILS_HPP -#define CPU_AARCH64_ACL_MATMUL_UTILS_HPP +#ifndef CPU_ACL_MATMUL_UTILS_HPP +#define CPU_ACL_MATMUL_UTILS_HPP #include "arm_compute/runtime/experimental/low_level/CpuGemmAssemblyDispatch.h" #include "arm_compute/runtime/experimental/operators/CpuActivation.h" #include "arm_compute/runtime/experimental/operators/CpuTranspose.h" -#include "cpu/matmul/cpu_matmul_pd.hpp" - -#include "cpu/aarch64/acl_utils.hpp" +#include "common/memory_tracking.hpp" namespace dnnl { namespace impl { namespace cpu { -namespace aarch64 { +namespace acl { namespace { // Keys are anonymous. So deduce the type automagically. using matmul_key_t = decltype(memory_tracking::names::key_gemm_asm_tmp_buffer); // Map: [slot , key] -const std::map matmul_keys - = {{0, matmul_key_t::key_gemm_asm_tmp_buffer}, - {2, matmul_key_t::key_gemm_pretranspose}}; +const std::map matmul_keys = { + {0, matmul_key_t::key_gemm_asm_tmp_buffer}, + {1, matmul_key_t::key_gemm_pretransposed_rhs}, + {2, matmul_key_t::key_gemm_pretranspose}, +}; } // namespace struct acl_matmul_obj_t { @@ -80,9 +80,9 @@ status_t init_scratchpad(memory_tracking::registrar_t &scratchpad, } // namespace acl_matmul_utils -} // namespace aarch64 +} // namespace acl } // namespace cpu } // namespace impl } // namespace dnnl -#endif // CPU_AARCH64_ACL_MATMUL_UTILS_HPP +#endif // CPU_ACL_MATMUL_UTILS_HPP diff --git a/src/cpu/binary_injector_utils.cpp b/src/cpu/binary_injector_utils.cpp index 339979907ce..b6dc0688cfe 100644 --- a/src/cpu/binary_injector_utils.cpp +++ b/src/cpu/binary_injector_utils.cpp @@ -30,7 +30,7 @@ std::vector prepare_binary_args(const post_ops_t &post_ops, unsigned idx = first_arg_idx_offset; for (const auto &post_op : post_ops.entry_) { - if (post_op.is_binary()) { + if (post_op.is_binary() || post_op.is_depthwise() || post_op.is_quantization()) { post_ops_binary_rhs_arg_vec.emplace_back(CTX_IN_MEM(const void *, DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); } diff --git a/src/cpu/cpu_batch_normalization_list.cpp b/src/cpu/cpu_batch_normalization_list.cpp index ab093a380f0..cf7490ccbaa 100644 --- a/src/cpu/cpu_batch_normalization_list.cpp +++ b/src/cpu/cpu_batch_normalization_list.cpp @@ -32,11 +32,12 @@ using namespace dnnl::impl::cpu::x64; #if DNNL_AARCH64 #include "cpu/aarch64/jit_uni_batch_normalization.hpp" #include "cpu/aarch64/jit_uni_batch_normalization_s8.hpp" -#if DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_batch_normalization.hpp" -#endif using namespace dnnl::impl::cpu::aarch64; #endif +#if DNNL_USE_ACL +#include "cpu/acl/acl_batch_normalization.hpp" +using namespace dnnl::impl::cpu::acl; +#endif namespace dnnl { namespace impl { @@ -51,52 +52,43 @@ const std::map> &impl_list_map() { static const std::map> the_map = REG_BNORM_P({ {{forward}, { /* fp */ - CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t) - CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t) - CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t) - CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t) - CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t) - CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_batch_normalization_fwd_t) - CPU_INSTANCE(ncsp_batch_normalization_fwd_t) - CPU_INSTANCE(ncsp_batch_normalization_fwd_t) - CPU_INSTANCE(ncsp_batch_normalization_fwd_t) - CPU_INSTANCE(nspc_batch_normalization_fwd_t) - CPU_INSTANCE(nspc_batch_normalization_fwd_t) - CPU_INSTANCE(nspc_batch_normalization_fwd_t) - CPU_INSTANCE(ref_batch_normalization_fwd_t) - CPU_INSTANCE(ref_batch_normalization_fwd_t) - CPU_INSTANCE(ref_batch_normalization_fwd_t) + CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t, avx512_core) + CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t, avx2) + CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t, sse41) + CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t, avx512_core) + CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t, avx2) + CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t, sve_512) + CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t, asimd) + CPU_INSTANCE(ncsp_batch_normalization_fwd_t, f32) + CPU_INSTANCE(ncsp_batch_normalization_fwd_t, bf16) + CPU_INSTANCE(nspc_batch_normalization_fwd_t, f32) + CPU_INSTANCE(nspc_batch_normalization_fwd_t, bf16) + CPU_INSTANCE(ref_batch_normalization_fwd_t, f32) + CPU_INSTANCE(ref_batch_normalization_fwd_t, bf16) /* int */ - CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t) - CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t) - CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_s8_fwd_t) - CPU_INSTANCE(ref_batch_normalization_fwd_t) + CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t, avx512_core) + CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t, avx2) + CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_s8_fwd_t, sve_512) + CPU_INSTANCE(ref_batch_normalization_fwd_t, s8) nullptr, }}, {{backward}, REG_BWD_PK({ - CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t) - CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t) - CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t) - CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t) - CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t) - CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t) - CPU_INSTANCE(ncsp_batch_normalization_bwd_t) - CPU_INSTANCE(ncsp_batch_normalization_bwd_t) - CPU_INSTANCE(ncsp_batch_normalization_bwd_t) - CPU_INSTANCE(nspc_batch_normalization_bwd_t) - CPU_INSTANCE(nspc_batch_normalization_bwd_t) - CPU_INSTANCE(nspc_batch_normalization_bwd_t) - CPU_INSTANCE(ref_batch_normalization_bwd_t) - CPU_INSTANCE(ref_batch_normalization_bwd_t) - CPU_INSTANCE(ref_batch_normalization_bwd_t) + CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t, avx512_core) + CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t, avx2) + CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t, sse41) + CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t, avx512_core) + CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t, avx2) + CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t, sve_512) + CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t, asimd) + CPU_INSTANCE(ncsp_batch_normalization_bwd_t, f32) + CPU_INSTANCE(ncsp_batch_normalization_bwd_t, bf16) + CPU_INSTANCE(nspc_batch_normalization_bwd_t, f32) + CPU_INSTANCE(nspc_batch_normalization_bwd_t, bf16) + CPU_INSTANCE(ref_batch_normalization_bwd_t, f32) + CPU_INSTANCE(ref_batch_normalization_bwd_t, bf16) nullptr, })}, }); diff --git a/src/cpu/cpu_binary_list.cpp b/src/cpu/cpu_binary_list.cpp index 49bad158f1c..d37a0d39017 100644 --- a/src/cpu/cpu_binary_list.cpp +++ b/src/cpu/cpu_binary_list.cpp @@ -25,11 +25,12 @@ using namespace dnnl::impl::cpu::x64; #elif DNNL_AARCH64 #include "cpu/aarch64/jit_uni_binary.hpp" -#if DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_binary.hpp" -#endif using namespace dnnl::impl::cpu::aarch64; #endif +#if DNNL_USE_ACL +#include "cpu/acl/acl_binary.hpp" +using namespace dnnl::impl::cpu::acl; +#endif namespace dnnl { namespace impl { @@ -39,10 +40,10 @@ namespace { using namespace dnnl::impl::data_type; // clang-format off -constexpr impl_list_item_t impl_list[] = REG_BINARY_P({ +const impl_list_item_t impl_list[] = REG_BINARY_P({ CPU_INSTANCE_X64(jit_uni_binary_t) CPU_INSTANCE_AARCH64(jit_uni_binary_t) - CPU_INSTANCE_AARCH64_ACL(acl_binary_t) + CPU_INSTANCE_ACL(acl_binary_t) CPU_INSTANCE(ref_binary_t) /* eol */ nullptr, diff --git a/src/cpu/cpu_concat.cpp b/src/cpu/cpu_concat.cpp index 0af6336d709..06567411cdf 100644 --- a/src/cpu/cpu_concat.cpp +++ b/src/cpu/cpu_concat.cpp @@ -26,22 +26,24 @@ namespace cpu { namespace { using namespace dnnl::impl::data_type; -#define INSTANCE(...) \ +#define INSTANCE_IMPL(...) \ impl_list_item_t(impl_list_item_t::concat_type_deduction_helper_t< \ - __VA_ARGS__::pd_t>()), + __VA_ARGS__::pd_t>()) +#define INSTANCE(...) DNNL_PRIMITIVE_IMPL(INSTANCE_IMPL, __VA_ARGS__) // clang-format off -constexpr impl_list_item_t cpu_concat_impl_list[] = REG_CONCAT_P({ - INSTANCE(simple_concat_t) - INSTANCE(simple_concat_t) - INSTANCE(simple_concat_t) - INSTANCE(simple_concat_t) - INSTANCE(simple_concat_t) - INSTANCE(simple_concat_t) +const impl_list_item_t cpu_concat_impl_list[] = REG_CONCAT_P({ + INSTANCE(simple_concat_t, f32) + INSTANCE(simple_concat_t, u8) + INSTANCE(simple_concat_t, s8) + INSTANCE(simple_concat_t, s32) + INSTANCE(simple_concat_t, bf16) + INSTANCE(simple_concat_t, f16) INSTANCE(ref_concat_t) nullptr, }); // clang-format on #undef INSTANCE +#undef INSTANCE_IMPL } // namespace const impl_list_item_t * diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp index daf5cb4915d..bc38bb3f478 100644 --- a/src/cpu/cpu_convolution_list.cpp +++ b/src/cpu/cpu_convolution_list.cpp @@ -1,7 +1,7 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation -* Copyright 2020-2024 Arm Ltd. and affiliates -* Copyright 2020-2024 FUJITSU LIMITED +* Copyright 2019-2025 Intel Corporation +* Copyright 2020-2025 Arm Ltd. and affiliates +* Copyright 2020-2025 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ #include "cpu/x64/gemm_bf16_convolution.hpp" #include "cpu/x64/ip_convolution.hpp" #include "cpu/x64/jit_avx2_1x1_convolution.hpp" +#include "cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.hpp" #include "cpu/x64/jit_avx2_convolution.hpp" #include "cpu/x64/jit_avx512_common_1x1_convolution.hpp" #include "cpu/x64/jit_avx512_common_convolution.hpp" @@ -52,25 +53,30 @@ #include "cpu/x64/jit_sse41_1x1_convolution.hpp" #include "cpu/x64/jit_sse41_convolution.hpp" #include "cpu/x64/jit_uni_dw_convolution.hpp" +#include "cpu/x64/jit_uni_fork_dw_convolution.hpp" +#include "cpu/x64/jit_uni_ncsp_convolution.hpp" #include "cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp" #include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp" +#include "cpu/x64/jit_uni_planar_convolution.hpp" using namespace dnnl::impl::cpu::x64; #elif DNNL_AARCH64 #include "cpu/aarch64/jit_brdgmm_dw_conv.hpp" #include "cpu/aarch64/jit_brgemm_1x1_conv.hpp" #include "cpu/aarch64/jit_brgemm_conv.hpp" -#include "cpu/aarch64/jit_sve_512_1x1_convolution.hpp" +#include "cpu/aarch64/jit_brgemm_conv_bwd.hpp" +#include "cpu/aarch64/jit_sve_1x1_convolution.hpp" #include "cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp" #include "cpu/aarch64/jit_sve_convolution.hpp" #include "cpu/aarch64/jit_uni_dw_convolution.hpp" -#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_depthwise_convolution.hpp" -#include "cpu/aarch64/acl_gemm_convolution.hpp" -#include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" -#include "cpu/aarch64/acl_winograd_convolution.hpp" -#endif using namespace dnnl::impl::cpu::aarch64; #endif +#if DNNL_USE_ACL +#include "cpu/acl/acl_gemm_convolution.hpp" +#include "cpu/acl/acl_indirect_gemm_convolution.hpp" +#include "cpu/acl/acl_depthwise_convolution.hpp" +#include "cpu/acl/acl_winograd_convolution.hpp" +using namespace dnnl::impl::cpu::acl; +#endif namespace dnnl { namespace impl { @@ -84,8 +90,8 @@ using namespace dnnl::impl::prop_kind; { \ {forward, dtsrc, dtwei, dtdst}, { \ CPU_INSTANCE_AMX( \ - brgemm_1x1_convolution_fwd_t) \ - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) \ + brgemm_1x1_convolution_fwd_t, avx10_1_512_amx_fp16) \ + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx10_1_512_amx_fp16) \ CPU_INSTANCE(ref_convolution_fwd_t) nullptr, \ } \ } @@ -121,75 +127,105 @@ const std::map> &impl_list_map() {{forward, f32, f32, f32}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(jit_avx512_common_planar_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_avx512_common_fork_dw_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_fwd_f32_t) - CPU_INSTANCE_AVX512(jit_avx512_common_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_avx512_common_convolution_fwd_t, f32) + CPU_INSTANCE_AVX2(jit_avx2_planar_convolution_fwd_t) CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_avx2_fork_dw_convolution_fwd_t) + CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_with_dw_conv_fwd_t) CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_fwd_t) CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_fwd_t) + CPU_INSTANCE_SSE41(jit_sse41_fork_dw_convolution_fwd_t) CPU_INSTANCE_SSE41(jit_sse41_1x1_convolution_fwd_t) CPU_INSTANCE_AVX2(jit_avx2_convolution_fwd_t) CPU_INSTANCE_SSE41(jit_sse41_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t) - CPU_INSTANCE_AARCH64(brdgmm_dw_convolution_fwd_t) - CPU_INSTANCE_AARCH64(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AARCH64(brgemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) - CPU_INSTANCE_AARCH64(jit_sve_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64(brdgmm_dw_convolution_fwd_t) - CPU_INSTANCE_AARCH64(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AARCH64(brgemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_convolution_fwd_t) + CPU_INSTANCE_ACL(acl_wino_convolution_fwd_t) + CPU_INSTANCE_AARCH64(brdgmm_dw_convolution_fwd_t,sve_512) + CPU_INSTANCE_AARCH64(brgemm_1x1_convolution_fwd_t,sve_512) + CPU_INSTANCE_AARCH64(brgemm_convolution_fwd_t,sve_512) + CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t,sve_512,data_type::f32) + CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_fwd_t,f32,f32,f32,sve_512) + CPU_INSTANCE_AARCH64(jit_sve_convolution_fwd_t,f32,f32,f32,sve_512) + CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t,sve_256,data_type::f32) + CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_fwd_t,f32,f32,f32,sve_256) + CPU_INSTANCE_AARCH64(jit_sve_convolution_fwd_t,f32,f32,f32,sve_256) + CPU_INSTANCE_ACL(acl_depthwise_convolution_fwd_t) + CPU_INSTANCE_ACL(acl_indirect_gemm_convolution_fwd_t) + CPU_INSTANCE_ACL(acl_gemm_convolution_fwd_t,f32) + CPU_INSTANCE_AARCH64(brdgmm_dw_convolution_fwd_t,sve_256) + CPU_INSTANCE_AARCH64(brgemm_1x1_convolution_fwd_t,sve_256) + CPU_INSTANCE_AARCH64(brgemm_convolution_fwd_t,sve_256) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_fwd_t) CPU_INSTANCE(gemm_convolution_fwd_t) CPU_INSTANCE(ref_convolution_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) nullptr, }}, + {{forward, f32, f16, f32}, { + CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2) + CPU_INSTANCE(ref_convolution_fwd_t) + nullptr, + }}, + {{forward, f32, bf16, f32}, { + CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2) + CPU_INSTANCE(ref_convolution_fwd_t) + nullptr, + }}, {{forward, bf16, bf16, f32}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t) - CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t, avx512_core, bf16, f32) + CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_fwd_t, avx512_core, bf16, f32) + CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t, f32) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t) - CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_fwd_t) + CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t,f32) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t,avx2_vnni_2) CPU_INSTANCE(ref_convolution_fwd_t) nullptr, }}, {{forward, bf16, bf16, bf16}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t) - CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t, avx512_core, bf16, bf16) + CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_fwd_t, avx512_core, bf16, bf16) + CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t, bf16) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t) - CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_fwd_t) + CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t,bf16) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t,avx2_vnni_2) + CPU_INSTANCE_ACL(acl_indirect_gemm_convolution_fwd_t) CPU_INSTANCE(ref_convolution_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) nullptr, @@ -197,28 +233,30 @@ const std::map> &impl_list_map() {{forward, f16, f16, f32}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t,avx512_core_fp16, f16, f32) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t,avx512_core_amx_fp16) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t,avx512_core_amx_fp16) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t,avx512_core_fp16) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t,avx512_core_fp16) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t,avx2_vnni_2) CPU_INSTANCE(ref_convolution_fwd_t) nullptr, }}, {{forward, f16, f16, f16}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t,avx512_core_fp16, f16, f16) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t,avx512_core_amx_fp16) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t,avx512_core_amx_fp16) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t,avx512_core_fp16) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t,avx512_core_fp16) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t,avx2_vnni_2) + CPU_INSTANCE_ACL(acl_wino_convolution_fwd_t) + CPU_INSTANCE_ACL(acl_depthwise_convolution_fwd_t) + CPU_INSTANCE_ACL(acl_indirect_gemm_convolution_fwd_t) + CPU_INSTANCE_ACL(acl_gemm_convolution_fwd_t, f16) CPU_INSTANCE(ref_convolution_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) nullptr, @@ -242,80 +280,106 @@ const std::map> &impl_list_map() // BWD_D fp {{backward_data, f32, f32, f32}, REG_BWD_D_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_data_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_bwd_data_t) + CPU_INSTANCE_AVX512(jit_avx512_common_fork_dw_convolution_bwd_data_t) CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_bwd_data_f32_t) - CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_data_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_data_t, f32) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t, avx2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2) CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_bwd_data_t) + CPU_INSTANCE_AVX2(jit_avx2_fork_dw_convolution_bwd_data_t) CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_data_t) - CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_data_t) + CPU_INSTANCE_SSE41(jit_sse41_fork_dw_convolution_bwd_data_t) CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_data_t) - CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t) - CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_data_f32_t) - CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t) - CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t) - CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t) + CPU_INSTANCE_AARCH64(brgemm_convolution_bwd_t,sve_512) + CPU_INSTANCE_AARCH64(brgemm_convolution_bwd_t,sve_256) + CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t,sve_512,data_type::f32) + CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_bwd_data_t,f32,f32,f32,sve_512) + CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t,f32,f32,f32,sve_512) + CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t,sve_256,data_type::f32) + CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_bwd_data_t,f32,f32,f32,sve_256) + CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t,f32,f32,f32,sve_256) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_bwd_data_t) CPU_INSTANCE(gemm_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_bwd_data_t) nullptr, })}, + {{backward_data, f32, bf16, f32}, REG_BWD_D_PK({ + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t, avx2) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx2) + CPU_INSTANCE(ref_convolution_bwd_data_t) + nullptr, + })}, + {{backward_data, f32, f16, f32}, REG_BWD_D_PK({ + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t, avx2) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx2) + CPU_INSTANCE(ref_convolution_bwd_data_t) + nullptr, + })}, {{backward_data, f32, bf16, bf16}, REG_BWD_D_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_data_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t, avx512_core, bf16, f32) + CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_bwd_data_t, avx512_core, bf16, f32) + CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t, f32) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_bwd_data_t) + CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t,f32) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t,avx2_vnni_2) CPU_INSTANCE(ref_convolution_bwd_data_t) nullptr, })}, {{backward_data, bf16, bf16, bf16}, REG_BWD_D_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_data_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t, avx512_core, bf16, bf16) + CPU_INSTANCE_AVX512(jit_uni_fork_dw_convolution_bwd_data_t, avx512_core, bf16, bf16) + CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t, bf16) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_bwd_data_t) + CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t,bf16) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t,avx2_vnni_2) CPU_INSTANCE(ref_convolution_bwd_data_t) nullptr, })}, {{backward_data, f32, f16, f16}, REG_BWD_D_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_data_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_t, avx512_core_amx_fp16) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx_fp16) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t, avx512_core_fp16) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_fp16) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) CPU_INSTANCE(ref_convolution_bwd_data_t) nullptr, })}, {{backward_data, f16, f16, f16}, REG_BWD_D_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_data_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_t, avx512_core_amx_fp16) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx_fp16) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t, avx512_core_fp16) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_fp16) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) CPU_INSTANCE(ref_convolution_bwd_data_t) nullptr, })}, @@ -340,39 +404,43 @@ const std::map> &impl_list_map() CPU_INSTANCE_X64(ip_convolution_bwd_weights_t) CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_bwd_weights_t) CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_bwd_weights_t) - CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_weights_t) + CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_weights_t, f32) CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_bwd_weights_t) CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_weights_t) CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_weights_t) CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_weights_t) - CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t) - CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_weights_t) - CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t) - CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t) - CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t) + CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t,sve_512,data_type::f32) + CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_bwd_weights_t,f32,f32,f32,sve_512) + CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t,f32,f32,f32,sve_512) + CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t,sve_256,data_type::f32) + CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_bwd_weights_t,f32,f32,f32,sve_256) + CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t,f32,f32,f32,sve_256) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_bwd_weights_t) CPU_INSTANCE(gemm_convolution_bwd_weights_t) CPU_INSTANCE(ref_convolution_bwd_weights_t) nullptr, })}, {{backward_weights, bf16, f32, bf16}, REG_BWD_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_weights_t) - CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_weights_t) + CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_weights_t, avx512_core, bf16, f32) CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_weights_t) - CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_weights_t) + CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_weights_t, f32) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_weights_t) - CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_weights_t) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_bwd_weights_t) + CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_weights_t,f32) CPU_INSTANCE(ref_convolution_bwd_weights_t) nullptr, })}, {{backward_weights, bf16, bf16, bf16}, REG_BWD_PK({ CPU_INSTANCE_X64(ip_convolution_bwd_weights_t) - CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_weights_t) + CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_weights_t, avx512_core, bf16, bf16) CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_weights_t) - CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_weights_t) + CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_weights_t, bf16) CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_weights_t) - CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_weights_t) + // CPU_INSTANCE_X64(jit_uni_ncsp_convolution_bwd_weights_t) + CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_weights_t,bf16) CPU_INSTANCE(ref_convolution_bwd_weights_t) nullptr, })}, @@ -408,25 +476,25 @@ const std::map> &impl_list_map() {{forward, s8, s8, f32}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t, s8, f32) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) @@ -436,14 +504,30 @@ const std::map> &impl_list_map() CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE(ref_convolution_int8_fwd_t) + nullptr, + }}, + {{forward, s8, s8, f16}, { + CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) + CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) + CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) nullptr, @@ -451,25 +535,25 @@ const std::map> &impl_list_map() {{forward, s8, s8, s32}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t, s8, s32) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) @@ -478,26 +562,26 @@ const std::map> &impl_list_map() {{forward, s8, s8, s8}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t, s8, s8) + CPU_INSTANCE_ACL(acl_gemm_convolution_fwd_t, s8, s8, s8, s32) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) @@ -506,25 +590,25 @@ const std::map> &impl_list_map() {{forward, s8, s8, u8}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t, s8, u8) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) @@ -534,43 +618,61 @@ const std::map> &impl_list_map() {{forward, u8, s8, f32}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t, u8, f32) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) nullptr, }}, {{forward, u8, s8, bf16}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) + CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) + CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE(ref_convolution_int8_fwd_t) + nullptr, + }}, + {{forward, u8, s8, f16}, { + CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) nullptr, @@ -578,25 +680,25 @@ const std::map> &impl_list_map() {{forward, u8, s8, s32}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t, u8, s32) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) nullptr, @@ -604,25 +706,25 @@ const std::map> &impl_list_map() {{forward, u8, s8, s8}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t, u8, s8) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) @@ -631,25 +733,25 @@ const std::map> &impl_list_map() {{forward, u8, s8, u8}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t, avx512_core_amx) + CPU_INSTANCE_AMX(brgemm_convolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t, avx512_core) + CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t, u8, u8) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) @@ -657,100 +759,100 @@ const std::map> &impl_list_map() }}, // BWD int8 (diff_dst:u8) {{backward_data, f32, s8, u8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, {{backward_data, bf16, s8, u8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, {{backward_data, s32, s8, u8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, {{backward_data, s8, s8, u8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, {{backward_data, u8, s8, u8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, // BWD int8 (diff_dst:s8) {{backward_data, f32, s8, s8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, {{backward_data, bf16, s8, s8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, {{backward_data, s32, s8, s8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, {{backward_data, s8, s8, s8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, })}, {{backward_data, u8, s8, s8}, REG_BWD_D_PK({ - CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) - CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t) + CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_convolution_bwd_strided_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_convolution_bwd_strided_t, avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_int8_bwd_data_t) nullptr, @@ -783,4 +885,4 @@ const impl_list_item_t *get_convolution_impl_list( } // namespace cpu } // namespace impl -} // namespace dnnl \ No newline at end of file +} // namespace dnnl diff --git a/src/cpu/cpu_deconvolution_list.cpp b/src/cpu/cpu_deconvolution_list.cpp index 468f4711452..a904216222b 100644 --- a/src/cpu/cpu_deconvolution_list.cpp +++ b/src/cpu/cpu_deconvolution_list.cpp @@ -1,7 +1,7 @@ /******************************************************************************* * Copyright 2019-2023 Intel Corporation * Copyright 2022 FUJITSU LIMITED -* Copyright 2022 Arm Ltd. and affiliates +* Copyright 2022, 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,9 @@ * limitations under the License. *******************************************************************************/ +#include "common/compiler_workarounds.hpp" #include "cpu/cpu_engine.hpp" - #include "cpu/ref_deconvolution.hpp" - #if DNNL_X64 #include "cpu/x64/jit_avx512_core_amx_deconvolution.hpp" #include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp" @@ -30,11 +29,12 @@ using namespace dnnl::impl::cpu::x64; #elif DNNL_AARCH64 #include "cpu/aarch64/jit_sve_512_core_x8s8s32x_deconvolution.hpp" -#if DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_deconvolution.hpp" -#endif using namespace dnnl::impl::cpu::aarch64; #endif +#if DNNL_USE_ACL +#include "cpu/acl/acl_deconvolution.hpp" +using namespace dnnl::impl::cpu::acl; +#endif namespace dnnl { namespace impl { @@ -48,24 +48,24 @@ using namespace dnnl::impl::prop_kind; const std::map> &impl_list_map() { static const std::map> the_map = REG_DECONV_P({ {{forward}, { - CPU_INSTANCE_AMX(brgemm_deconvolution_fwd_t) - CPU_INSTANCE_AMX(brgemm_deconvolution_fwd_t) + CPU_INSTANCE_AMX(brgemm_deconvolution_fwd_t, avx512_core_amx_fp16) + CPU_INSTANCE_AMX(brgemm_deconvolution_fwd_t, avx512_core_amx) CPU_INSTANCE_AMX(jit_avx512_core_amx_deconvolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_deconvolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_deconvolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_deconvolution_fwd_t) - CPU_INSTANCE_AVX512(brgemm_deconvolution_fwd_t) + CPU_INSTANCE_AVX512(brgemm_deconvolution_fwd_t, avx512_core_fp16) + CPU_INSTANCE_AVX512(brgemm_deconvolution_fwd_t, avx512_core_bf16) + CPU_INSTANCE_AVX512(brgemm_deconvolution_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_deconvolution_fwd_t, avx512_core) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t) CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_deconvolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_deconvolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_deconvolution_fwd_t) - CPU_INSTANCE_AVX2(brgemm_deconvolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_deconvolution_fwd_t) - CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_deconvolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_deconvolution_fwd_t) - CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_deconvolution_fwd_t) + CPU_INSTANCE_AVX2(brgemm_deconvolution_fwd_t, avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_deconvolution_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_deconvolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_deconvolution_fwd_t, avx2) + CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_deconvolution_fwd_t, avx2) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_deconvolution_fwd_t, sse41) + CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_deconvolution_fwd_t, sse41) CPU_INSTANCE_AARCH64(jit_sve_512_core_x8s8s32x_deconvolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_deconvolution_fwd_t) + CPU_INSTANCE_ACL(acl_deconvolution_fwd_t) CPU_INSTANCE(ref_deconvolution_fwd_t) nullptr, }}, diff --git a/src/cpu/cpu_eltwise_list.cpp b/src/cpu/cpu_eltwise_list.cpp index 03d4f107449..704a18f1c69 100644 --- a/src/cpu/cpu_eltwise_list.cpp +++ b/src/cpu/cpu_eltwise_list.cpp @@ -27,11 +27,12 @@ using namespace dnnl::impl::cpu::x64; #elif DNNL_AARCH64 #include "cpu/aarch64/jit_uni_eltwise.hpp" #include "cpu/aarch64/jit_uni_eltwise_int.hpp" -#if DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_eltwise.hpp" -#endif // DNNL_AARCH64_USE_ACL using namespace dnnl::impl::cpu::aarch64; #endif +#if DNNL_USE_ACL +#include "cpu/acl/acl_eltwise.hpp" +using namespace dnnl::impl::cpu::acl; +#endif namespace dnnl { namespace impl { @@ -45,59 +46,59 @@ using namespace dnnl::impl::prop_kind; const std::map> &impl_list_map() { static const std::map> the_map = REG_ELTWISE_P({ {{forward}, { - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_eltwise_fwd_t) - CPU_INSTANCE(ref_eltwise_fwd_t) - CPU_INSTANCE(ref_eltwise_fwd_t) - CPU_INSTANCE(ref_eltwise_fwd_t) - CPU_INSTANCE(ref_eltwise_fwd_t) - CPU_INSTANCE(ref_eltwise_fwd_t) - CPU_INSTANCE(ref_eltwise_fwd_t) - CPU_INSTANCE(ref_eltwise_fwd_t) - CPU_INSTANCE(ref_eltwise_fwd_t) + CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t, avx512_core_amx, f8_e4m3) + CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t, avx512_core_amx, f8_e5m2) + // CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) + CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t, avx512_core, f32) + CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t, avx512_core, bf16) + // CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t) + CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t, avx2_vnni_2, bf16) + CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t, avx2, f32) + CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t, avx, f32) + CPU_INSTANCE_X64(jit_uni_eltwise_fwd_t, sse41, f32) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, avx512_core, s32) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, avx512_core, s8) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, avx512_core, u8) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, avx2, s32) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, avx2, s8) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, avx2, u8) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, sse41, s32) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, sse41, s8) + CPU_INSTANCE_X64(jit_uni_eltwise_int_fwd_t, sse41, u8) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_fwd_t, sve_512, f32) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_fwd_t, sve_256, f32) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_fwd_t, sve_128, f32) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t, sve_512, s32) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t, sve_512, s8) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t, sve_512, u8) + CPU_INSTANCE_ACL(acl_eltwise_fwd_t) + CPU_INSTANCE(ref_eltwise_fwd_t, f32) + CPU_INSTANCE(ref_eltwise_fwd_t, bf16) + // CPU_INSTANCE(ref_eltwise_fwd_t) + CPU_INSTANCE(ref_eltwise_fwd_t, s32) + CPU_INSTANCE(ref_eltwise_fwd_t, s8) + CPU_INSTANCE(ref_eltwise_fwd_t, u8) + CPU_INSTANCE(ref_eltwise_fwd_t, f8_e4m3) + CPU_INSTANCE(ref_eltwise_fwd_t, f8_e5m2) nullptr, }}, {{backward}, REG_BWD_PK({ - CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t) - CPU_INSTANCE(ref_eltwise_bwd_t) - CPU_INSTANCE(ref_eltwise_bwd_t) - CPU_INSTANCE(ref_eltwise_bwd_t) - CPU_INSTANCE(ref_eltwise_bwd_t) - CPU_INSTANCE(ref_eltwise_bwd_t) + CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t, avx512_core_amx, f8_e4m3) + CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t, avx512_core_amx, f8_e5m2) + // CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t) + CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t, avx512_core, f32) + CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t, avx512_core, bf16) + CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t, avx2, f32) + CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t, avx, f32) + CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t, sse41, f32) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t, sve_512, f32) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t, sve_256, f32) + CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t, sve_128, f32) + CPU_INSTANCE(ref_eltwise_bwd_t, f32) + CPU_INSTANCE(ref_eltwise_bwd_t, bf16) + // CPU_INSTANCE(ref_eltwise_bwd_t) + CPU_INSTANCE(ref_eltwise_bwd_t, f8_e4m3) + CPU_INSTANCE(ref_eltwise_bwd_t, f8_e5m2) nullptr, })}, }); diff --git a/src/cpu/cpu_engine.cpp b/src/cpu/cpu_engine.cpp index c9263eda2cd..751038e698b 100644 --- a/src/cpu/cpu_engine.cpp +++ b/src/cpu/cpu_engine.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,9 @@ namespace cpu { status_t cpu_engine_t::create_memory_storage( memory_storage_t **storage, unsigned flags, size_t size, void *handle) { + assert(runtime_kind() != runtime_kind::sycl); + if (runtime_kind() == runtime_kind::sycl) return status::runtime_error; + auto _storage = new cpu_memory_storage_t(this); if (_storage == nullptr) return status::out_of_memory; status_t status = _storage->init(flags, size, handle); diff --git a/src/cpu/cpu_engine.hpp b/src/cpu/cpu_engine.hpp index 494a5dc7f51..d98f79b95bd 100644 --- a/src/cpu/cpu_engine.hpp +++ b/src/cpu/cpu_engine.hpp @@ -29,21 +29,25 @@ #include "cpu/platform.hpp" -#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_thread.hpp" +#if DNNL_USE_ACL +#include "cpu/acl/acl_thread.hpp" #endif -#define CPU_INSTANCE(...) \ +#define CPU_INSTANCE_IMPL(...) \ impl_list_item_t( \ - impl_list_item_t::type_deduction_helper_t<__VA_ARGS__::pd_t>()), -#define CPU_INSTANCE_X64(...) DNNL_X64_ONLY(CPU_INSTANCE(__VA_ARGS__)) + impl_list_item_t::type_deduction_helper_t<__VA_ARGS__::pd_t>()) +#define CPU_INSTANCE(...) DNNL_PRIMITIVE_IMPL(CPU_INSTANCE_IMPL, __VA_ARGS__) +// Expanding DNNL_X64_ONLY in order to fix Conditional Compilation failure on Windows + CPU plugin. +// DNNL_X64_ONLY == CONCAT2(Z_DO_IF_, DNNL_X64) +#define CPU_INSTANCE_X64(...) \ + CONCAT2(Z_DO_IF_, DNNL_X64)(CPU_INSTANCE(__VA_ARGS__)) #define CPU_INSTANCE_SSE41(...) REG_SSE41_ISA(CPU_INSTANCE(__VA_ARGS__)) #define CPU_INSTANCE_AVX2(...) REG_AVX2_ISA(CPU_INSTANCE(__VA_ARGS__)) #define CPU_INSTANCE_AVX512(...) REG_AVX512_ISA(CPU_INSTANCE(__VA_ARGS__)) #define CPU_INSTANCE_AMX(...) REG_AMX_ISA(CPU_INSTANCE(__VA_ARGS__)) #define CPU_INSTANCE_AARCH64(...) DNNL_AARCH64_ONLY(CPU_INSTANCE(__VA_ARGS__)) -#define CPU_INSTANCE_AARCH64_ACL(...) \ - DNNL_AARCH64_ACL_ONLY(CPU_INSTANCE(__VA_ARGS__)) +#define CPU_INSTANCE_ARM(...) DNNL_ARM_ONLY(CPU_INSTANCE(__VA_ARGS__)) +#define CPU_INSTANCE_ACL(...) DNNL_ACL_ONLY(CPU_INSTANCE(__VA_ARGS__)) #define CPU_INSTANCE_RV64GCV(...) DNNL_RV64GCV_ONLY(CPU_INSTANCE(__VA_ARGS__)) namespace dnnl { @@ -88,7 +92,7 @@ class cpu_engine_impl_list_t { #define CASE(kind) \ case primitive_kind::kind: \ return get_##kind##_impl_list((const kind##_desc_t *)desc); - switch ((int) desc->kind) { + switch ((int)desc->primitive_kind) { CASE(batch_normalization); CASE(binary); CASE(convolution); @@ -156,8 +160,8 @@ class cpu_engine_factory_t : public engine_factory_t { *engine = new cpu_engine_t(new impl::engine_impl_t( engine_kind::cpu, get_cpu_native_runtime(), 0)); -#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL - dnnl::impl::cpu::aarch64::acl_thread_utils::set_acl_threading(); +#if DNNL_USE_ACL + dnnl::impl::cpu::acl::acl_thread_utils::set_acl_threading(); #endif return status::success; }; diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp index 1f595473047..51f754c8450 100644 --- a/src/cpu/cpu_inner_product_list.cpp +++ b/src/cpu/cpu_inner_product_list.cpp @@ -1,5 +1,6 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation +* Copyright 2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,12 +25,13 @@ #if DNNL_X64 #include "cpu/x64/gemm_bf16_inner_product.hpp" #include "cpu/x64/jit_brgemm_inner_product.hpp" +#include "cpu/x64/matmul_inner_product.hpp" using namespace dnnl::impl::cpu::x64; #endif -#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_inner_product.hpp" -using namespace dnnl::impl::cpu::aarch64; +#if DNNL_USE_ACL +#include "cpu/acl/acl_inner_product.hpp" +using namespace dnnl::impl::cpu::acl; #endif namespace dnnl { @@ -40,46 +42,166 @@ namespace { using namespace dnnl::impl::data_type; using namespace dnnl::impl::prop_kind; +#define BRGEMM_FP8_FWD_IP(dtsrc, dtwei, dtdst) \ + { \ + {forward, dtsrc, dtwei, dtdst}, { \ + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) \ + CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, \ + } \ + } + // clang-format off const std::map> &impl_list_map() { static const std::map> the_map = REG_IP_P({ {{forward, f32, f32, f32}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) // bf32 - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) - CPU_INSTANCE(gemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) // bf32 + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2) + CPU_INSTANCE_ACL(acl_inner_product_fwd_t) + CPU_INSTANCE(gemm_inner_product_fwd_t, f32) CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, }}, + {{forward, f32, u8, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, + {{forward, f32, s8, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, + {{forward, f32, nf4, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, + {{forward, f32, f4_e2m1, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, + {{forward, f32, s4, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, + {{forward, f32, u4, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2_vnni) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, + {{forward, f32, f16, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, + {{forward, f32, bf16, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, {{forward, bf16, bf16, f32}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_bf16) + CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t,f32) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, }}, {{forward, bf16, bf16, bf16}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_bf16) + CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t,bf16) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_ACL(acl_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, }}, + {{forward, bf16, u8, f32}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, u8, bf16}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, s8, f32}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, s8, bf16}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, nf4, f32}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, nf4, bf16}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, f4_e2m1, f32}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, f4_e2m1, bf16}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, s4, f32}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, s4, bf16}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, u4, f32}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, u4, bf16}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, {{forward, f16, f16, f32}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx_fp16) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_fp16) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, }}, {{forward, f16, f16, f16}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx_fp16) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_fp16) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_ACL(acl_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_fwd_t) nullptr, }}, @@ -89,170 +211,208 @@ const std::map> &impl_list_map() * in fp32 and weights are in bf16 */ {{forward, f32, bf16, f32}, { - CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) + CPU_INSTANCE_ACL(acl_inner_product_fwd_t) nullptr, }}, + + BRGEMM_FP8_FWD_IP(f8_e5m2, f8_e5m2, f16), + BRGEMM_FP8_FWD_IP(f8_e5m2, f8_e5m2, f32), + BRGEMM_FP8_FWD_IP(f8_e5m2, f8_e5m2, f8_e5m2), + BRGEMM_FP8_FWD_IP(f8_e5m2, f8_e5m2, f8_e4m3), + BRGEMM_FP8_FWD_IP(f8_e5m2, f8_e4m3, f16), + BRGEMM_FP8_FWD_IP(f8_e5m2, f8_e4m3, f32), + BRGEMM_FP8_FWD_IP(f8_e5m2, f8_e4m3, f8_e5m2), + BRGEMM_FP8_FWD_IP(f8_e5m2, f8_e4m3, f8_e4m3), + BRGEMM_FP8_FWD_IP(f8_e4m3, f8_e5m2, f16), + BRGEMM_FP8_FWD_IP(f8_e4m3, f8_e5m2, f32), + BRGEMM_FP8_FWD_IP(f8_e4m3, f8_e5m2, f8_e5m2), + BRGEMM_FP8_FWD_IP(f8_e4m3, f8_e5m2, f8_e4m3), + BRGEMM_FP8_FWD_IP(f8_e4m3, f8_e4m3, f16), + BRGEMM_FP8_FWD_IP(f8_e4m3, f8_e4m3, f32), + BRGEMM_FP8_FWD_IP(f8_e4m3, f8_e4m3, f8_e5m2), + BRGEMM_FP8_FWD_IP(f8_e4m3, f8_e4m3, f8_e4m3), + {{backward_data, f32, f32, f32}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t) // bf32 - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_bwd_data_t) - CPU_INSTANCE(gemm_inner_product_bwd_data_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_data_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t,avx512_core_amx) // bf32 + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_bwd_data_t,avx2) + CPU_INSTANCE(gemm_inner_product_bwd_data_t,f32) CPU_INSTANCE(ref_inner_product_bwd_data_t) nullptr, })}, {{backward_data, f32, bf16, bf16}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t) - CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_data_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_data_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t,avx512_core_bf16) + CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_data_t,f32) CPU_INSTANCE(ref_inner_product_bwd_data_t) nullptr, })}, {{backward_data, bf16, bf16, bf16}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t) - CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_data_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_data_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t,avx512_core_bf16) + CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_data_t,bf16) CPU_INSTANCE(ref_inner_product_bwd_data_t) nullptr, })}, {{backward_data, f32, f16, f16}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_data_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t,avx512_core_amx_fp16) + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t,avx512_core_fp16) CPU_INSTANCE(ref_inner_product_bwd_data_t) nullptr, })}, {{backward_data, f16, f16, f16}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_data_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t,avx512_core_amx_fp16) + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t,avx512_core_fp16) CPU_INSTANCE(ref_inner_product_bwd_data_t) nullptr, })}, {{backward_weights, f32, f32, f32}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t) // bf32 - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_bwd_weights_t) - CPU_INSTANCE(gemm_inner_product_bwd_weights_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_weights_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t,avx512_core_amx) // bf32 + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_bwd_weights_t,avx2) + CPU_INSTANCE(gemm_inner_product_bwd_weights_t,f32) CPU_INSTANCE(ref_inner_product_bwd_weights_t) nullptr, })}, {{backward_weights, bf16, f32, bf16}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t) - CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_weights_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_weights_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t,avx512_core_bf16) + CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_weights_t,f32) CPU_INSTANCE(ref_inner_product_bwd_weights_t) nullptr, })}, {{backward_weights, bf16, bf16, bf16}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t) - CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_weights_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_weights_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t,avx512_core_bf16) + CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_weights_t,bf16) CPU_INSTANCE(ref_inner_product_bwd_weights_t) nullptr, })}, {{backward_weights, f16, f32, f16}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_weights_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t,avx512_core_amx_fp16) + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t,avx512_core_fp16) CPU_INSTANCE(ref_inner_product_bwd_weights_t) nullptr, })}, {{backward_weights, f16, f16, f16}, REG_BWD_PK({ - CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t) + CPU_INSTANCE_X64(matmul_inner_product_bwd_weights_t) + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t,avx512_core_amx_fp16) + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t,avx512_core_fp16) CPU_INSTANCE(ref_inner_product_bwd_weights_t) nullptr, })}, {{forward, s8, s8, f32}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, s8, s8, s32}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, s8, s8, s8}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, s8, s8, u8}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, u8, s8, f32}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, u8, s8, s32}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, u8, s8, s8}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, u8, s8, u8}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni) CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, s8, s8, bf16}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, {{forward, u8, s8, bf16}, { - CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t) - CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t) + //CPU_INSTANCE_X64(matmul_inner_product_fwd_t) + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t,avx2_vnni_2) CPU_INSTANCE(ref_inner_product_int8_fwd_t) nullptr, }}, diff --git a/src/cpu/cpu_inner_product_pd.hpp b/src/cpu/cpu_inner_product_pd.hpp index 7554af4a81f..0d6742d1a3b 100644 --- a/src/cpu/cpu_inner_product_pd.hpp +++ b/src/cpu/cpu_inner_product_pd.hpp @@ -193,8 +193,8 @@ struct cpu_inner_product_fwd_pd_t : public inner_product_fwd_pd_t { /* with batch = 1, no transpose to use the faster gemv kernels */ /* otherwise, we transpose the weights to improve efficiency of * no-copy kernels */ - if (MB() > 1 && transpose_leading_dim(OC(), IC_total())) - transpose_md(weights_md_); +// if (MB() > 1 && transpose_leading_dim(OC(), IC_total())) +// transpose_md(weights_md_); return status::success; }; diff --git a/src/cpu/cpu_layer_normalization_list.cpp b/src/cpu/cpu_layer_normalization_list.cpp index 222233bf74f..d3b33b6c27c 100644 --- a/src/cpu/cpu_layer_normalization_list.cpp +++ b/src/cpu/cpu_layer_normalization_list.cpp @@ -25,8 +25,9 @@ using namespace dnnl::impl::cpu::x64; #endif -#if DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_layer_normalization.hpp" +#if DNNL_USE_ACL +#include "cpu/acl/acl_layer_normalization.hpp" +using namespace dnnl::impl::cpu::acl; #endif namespace dnnl { @@ -37,16 +38,12 @@ namespace { using namespace dnnl::impl::data_type; using namespace dnnl::impl::prop_kind; -#if DNNL_AARCH64_USE_ACL -using namespace dnnl::impl::cpu::aarch64; -#endif - // clang-format off const std::map> &impl_list_map() { static const std::map> the_map = REG_LNORM_P({ {{forward}, { CPU_INSTANCE_X64(jit_uni_layer_normalization_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_layer_normalization_fwd_t) + CPU_INSTANCE_ACL(acl_layer_normalization_fwd_t) CPU_INSTANCE(simple_layer_normalization_fwd_t) CPU_INSTANCE(ref_layer_normalization_fwd_t) nullptr, diff --git a/src/cpu/cpu_lrn_list.cpp b/src/cpu/cpu_lrn_list.cpp index 4f369af72b7..74b12dec11c 100644 --- a/src/cpu/cpu_lrn_list.cpp +++ b/src/cpu/cpu_lrn_list.cpp @@ -36,32 +36,25 @@ using namespace dnnl::impl::prop_kind; const std::map> &impl_list_map() { static const std::map> the_map = REG_LRN_P({ {{forward}, { - CPU_INSTANCE_X64(jit_avx512_common_lrn_fwd_t) - CPU_INSTANCE_X64(jit_avx512_common_lrn_fwd_t) - CPU_INSTANCE_X64(jit_avx512_common_lrn_fwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_fwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_fwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_fwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_fwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_fwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_fwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_fwd_t) - CPU_INSTANCE(ref_lrn_fwd_t) - CPU_INSTANCE(ref_lrn_fwd_t) - CPU_INSTANCE(ref_lrn_fwd_t) + CPU_INSTANCE_X64(jit_avx512_common_lrn_fwd_t, f32) + CPU_INSTANCE_X64(jit_avx512_common_lrn_fwd_t, bf16) + CPU_INSTANCE_X64(jit_uni_lrn_fwd_t, avx512_core, f32) + CPU_INSTANCE_X64(jit_uni_lrn_fwd_t, avx512_core, bf16) + CPU_INSTANCE_X64(jit_uni_lrn_fwd_t, avx2_vnni_2, bf16) + CPU_INSTANCE_X64(jit_uni_lrn_fwd_t, avx2, f32) + CPU_INSTANCE_X64(jit_uni_lrn_fwd_t, sse41, f32) + CPU_INSTANCE(ref_lrn_fwd_t, f32) + CPU_INSTANCE(ref_lrn_fwd_t, bf16) nullptr, }}, {{backward}, REG_BWD_PK({ - CPU_INSTANCE_X64(jit_avx512_common_lrn_bwd_t) - CPU_INSTANCE_X64(jit_avx512_common_lrn_bwd_t) - CPU_INSTANCE_X64(jit_avx512_common_lrn_bwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_bwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_bwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_bwd_t) - CPU_INSTANCE_X64(jit_uni_lrn_bwd_t) - CPU_INSTANCE(ref_lrn_bwd_t) - CPU_INSTANCE(ref_lrn_bwd_t) - CPU_INSTANCE(ref_lrn_bwd_t) + CPU_INSTANCE_X64(jit_avx512_common_lrn_bwd_t, f32) + CPU_INSTANCE_X64(jit_avx512_common_lrn_bwd_t, bf16) + CPU_INSTANCE_X64(jit_uni_lrn_bwd_t, avx512_core, f32) + CPU_INSTANCE_X64(jit_uni_lrn_bwd_t, avx512_core, bf16) + CPU_INSTANCE_X64(jit_uni_lrn_bwd_t, avx2, f32) + CPU_INSTANCE(ref_lrn_bwd_t, f32) + CPU_INSTANCE(ref_lrn_bwd_t, bf16) nullptr, })}, }); diff --git a/src/cpu/cpu_pooling_list.cpp b/src/cpu/cpu_pooling_list.cpp index 951395c44bc..20e060c7e3c 100644 --- a/src/cpu/cpu_pooling_list.cpp +++ b/src/cpu/cpu_pooling_list.cpp @@ -30,15 +30,16 @@ using namespace dnnl::impl::cpu::x64; #include "cpu/aarch64/jit_uni_i8i8_pooling.hpp" #include "cpu/aarch64/jit_uni_pooling.hpp" using namespace dnnl::impl::cpu::aarch64; -#if DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_pooling.hpp" -#endif // DNNL_AARCH64_USE_ACL #elif DNNL_RV64 #if DNNL_RISCV_USE_RVV_INTRINSICS #include "cpu/rv64/rvv_nchw_pooling.hpp" using namespace dnnl::impl::cpu::rv64; #endif // DNNL_RISCV_USE_RVV_INTRINSICS #endif +#if DNNL_USE_ACL +#include "cpu/acl/acl_pooling.hpp" +using namespace dnnl::impl::cpu::acl; +#endif namespace dnnl { namespace impl { @@ -53,60 +54,62 @@ const std::map> &impl_list_map() { static const std::map> the_map = REG_POOLING_P({ {{forward}, { /* fp */ - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_pooling_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_pooling_fwd_t) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx512_core_fp16, f16) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx512_core_fp16, f8_e5m2) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx512_core_fp16, f8_e4m3) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx512_core, bf16) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx512_core, f32) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx2_vnni_2, bf16) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx2_vnni_2, f16) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx2, f32) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, avx, f32) + CPU_INSTANCE_X64(jit_uni_pooling_fwd_t, sse41, f32) + CPU_INSTANCE_AARCH64(jit_uni_pooling_fwd_t, sve_512, f32) + CPU_INSTANCE_AARCH64(jit_uni_pooling_fwd_t, sve_256, f32) + CPU_INSTANCE_ACL(acl_pooling_fwd_t) CPU_INSTANCE_RV64GCV(riscv_nchw_pooling_fwd_t) - CPU_INSTANCE(nchw_pooling_fwd_t) - CPU_INSTANCE(nchw_pooling_fwd_t) - CPU_INSTANCE(nchw_pooling_fwd_t) - CPU_INSTANCE(nchw_pooling_fwd_t) - CPU_INSTANCE(nchw_pooling_fwd_t) - CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(nchw_pooling_fwd_t, bf16) + CPU_INSTANCE(nchw_pooling_fwd_t, f32) + CPU_INSTANCE(nchw_pooling_fwd_t, f16) + CPU_INSTANCE(nchw_pooling_fwd_t, f8_e5m2) + CPU_INSTANCE(nchw_pooling_fwd_t, f8_e4m3) + CPU_INSTANCE(nhwc_pooling_fwd_t, bf16) + CPU_INSTANCE(nhwc_pooling_fwd_t, f32) + CPU_INSTANCE(nhwc_pooling_fwd_t, f16) + CPU_INSTANCE(nhwc_pooling_fwd_t, f8_e5m2) + CPU_INSTANCE(nhwc_pooling_fwd_t, f8_e4m3) + CPU_INSTANCE(ref_pooling_fwd_t, f32, f32, f32) + CPU_INSTANCE(ref_pooling_fwd_t, bf16, bf16, f32) + CPU_INSTANCE(ref_pooling_fwd_t, f16, f16, f32) + CPU_INSTANCE(ref_pooling_fwd_t, f8_e5m2, f8_e5m2, f32) + CPU_INSTANCE(ref_pooling_fwd_t, f8_e4m3, f8_e4m3, f32) /* int */ - CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) - CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_i8i8_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t, avx512_core) + CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t, avx2) + CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_uni_i8i8_pooling_fwd_t, sve_512) + CPU_INSTANCE(ref_pooling_fwd_t, s32, s32, s32) + CPU_INSTANCE(ref_pooling_fwd_t, s8, s8, s32) + CPU_INSTANCE(ref_pooling_fwd_t, s8, f32, f32) + CPU_INSTANCE(ref_pooling_fwd_t, u8, u8, s32) + CPU_INSTANCE(ref_pooling_fwd_t, u8, f32, f32) nullptr, }}, {{backward}, REG_BWD_PK({ - CPU_INSTANCE_X64(jit_uni_pooling_bwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_bwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_bwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_bwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_bwd_t) - CPU_INSTANCE_X64(jit_uni_pooling_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_pooling_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_pooling_bwd_t) - CPU_INSTANCE(nchw_pooling_bwd_t) - CPU_INSTANCE(nchw_pooling_bwd_t) - CPU_INSTANCE(nchw_pooling_bwd_t) - CPU_INSTANCE(nhwc_pooling_bwd_t) - CPU_INSTANCE(nhwc_pooling_bwd_t) - CPU_INSTANCE(nhwc_pooling_bwd_t) + CPU_INSTANCE_X64(jit_uni_pooling_bwd_t, avx512_core_fp16, f16) + CPU_INSTANCE_X64(jit_uni_pooling_bwd_t, avx512_core, bf16) + CPU_INSTANCE_X64(jit_uni_pooling_bwd_t, avx512_core, f32) + CPU_INSTANCE_X64(jit_uni_pooling_bwd_t, avx2, f32) + CPU_INSTANCE_X64(jit_uni_pooling_bwd_t, avx, f32) + CPU_INSTANCE_X64(jit_uni_pooling_bwd_t, sse41, f32) + CPU_INSTANCE_AARCH64(jit_uni_pooling_bwd_t, sve_512, f32) + CPU_INSTANCE_AARCH64(jit_uni_pooling_bwd_t, sve_256, f32) + CPU_INSTANCE(nchw_pooling_bwd_t, bf16) + CPU_INSTANCE(nchw_pooling_bwd_t, f32) + CPU_INSTANCE(nchw_pooling_bwd_t, f16) + CPU_INSTANCE(nhwc_pooling_bwd_t, bf16) + CPU_INSTANCE(nhwc_pooling_bwd_t, f32) + CPU_INSTANCE(nhwc_pooling_bwd_t, f16) CPU_INSTANCE(ref_pooling_bwd_t) nullptr, })}, diff --git a/src/cpu/cpu_prelu_list.cpp b/src/cpu/cpu_prelu_list.cpp index 883c356b18e..c7ff78c3424 100644 --- a/src/cpu/cpu_prelu_list.cpp +++ b/src/cpu/cpu_prelu_list.cpp @@ -23,9 +23,9 @@ #include "cpu/x64/prelu/jit_prelu_forward.hpp" using namespace dnnl::impl::cpu::x64; -#elif DNNL_AARCH64 && DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_prelu.hpp" -using namespace dnnl::impl::cpu::aarch64; +#elif DNNL_USE_ACL +#include "cpu/acl/acl_prelu.hpp" +using namespace dnnl::impl::cpu::acl; #endif namespace dnnl { @@ -41,7 +41,7 @@ const std::map> &impl_list_map() { static const std::map> the_map = REG_PRELU_P({ {{forward}, { CPU_INSTANCE_X64(jit_prelu_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_prelu_fwd_t) + CPU_INSTANCE_ACL(acl_prelu_fwd_t) CPU_INSTANCE(ref_prelu_fwd_t) nullptr, }}, diff --git a/src/cpu/cpu_primitive.hpp b/src/cpu/cpu_primitive.hpp index ff315b8bedf..ff531fd2705 100644 --- a/src/cpu/cpu_primitive.hpp +++ b/src/cpu/cpu_primitive.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,59 +29,33 @@ #include "cpu/ref_io_helper.hpp" -#define DEFINE_SCALES_BUFFER_ATTR_ARG(attr, scales, arg) \ - alignas(16) float CONCAT2(scales, _buf16)[16] = {0}; \ - const float *scales {nullptr}; \ - if ((attr)) { \ - if ((attr)->output_scales_.has_default_values()) { \ - utils::array_set(CONCAT2(scales, _buf16), 1.0f, 16); \ - scales = CONCAT2(scales, _buf16); \ - } else { \ - scales = CTX_IN_MEM(const float *, arg); \ - VCHECK_ATTR(scales != nullptr, \ - "Scales buffer for arg %d is missing", arg); \ - const auto scales_d = ctx.memory_mdw(arg); \ - VCHECK_ATTR(scales_d.data_type() == data_type::f32, \ - "Scales data type is not f32"); \ - VCHECK_ATTR(scales_d.ndims() == 1, "Scales ndims is not 1"); \ - if (scales_d.dims()[0] == 1) { \ - utils::array_set(CONCAT2(scales, _buf16), scales[0], 16); \ - scales = CONCAT2(scales, _buf16); \ - } \ - } \ - } \ - MAYBE_UNUSED(scales); - -#define DEFINE_SCALES_BUFFER_ATTR(attr, scales) \ - DEFINE_SCALES_BUFFER_ATTR_ARG(attr, scales, DNNL_ARG_ATTR_OUTPUT_SCALES); - -#define DEFINE_SCALES_BUFFER(scales) \ - DEFINE_SCALES_BUFFER_ATTR(pd()->attr(), scales) - +//NOLINTBEGIN(bugprone-macro-parentheses) +// These macros are actual pieces of code, can't put certain pieces into `()`. +// TODO: consider making them functions. #define DEFINE_ARG_SCALES_BUFFER_ATTR(attr, scales, arg) \ alignas(16) float CONCAT2(scales, _buf16)[16] = {0}; \ const float *scales {nullptr}; \ if ((attr)) { \ - if ((attr)->scales_.get(arg).has_default_values()) { \ + if ((attr)->scales_.has_default_values(arg)) { \ utils::array_set(CONCAT2(scales, _buf16), 1.0f, 16); \ scales = CONCAT2(scales, _buf16); \ } else { \ - scales = CTX_IN_MEM(const float *, DNNL_ARG_ATTR_SCALES | arg); \ + scales = CTX_IN_MEM(const float *, DNNL_ARG_ATTR_SCALES | (arg)); \ VCHECK_ATTR(scales != nullptr, \ - "Scales buffer for arg %d is missing", arg); \ - const auto scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | arg); \ + "Scales buffer for arg %d is missing", (arg)); \ + const auto scales_d \ + = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | (arg)); \ VCHECK_ATTR( \ - utils::one_of(scales_d.data_type(), data_type::f32, \ - data_type::f16, data_type::bf16, data_type::e8m0), \ + utils::one_of(scales_d.data_type(), data_type::f32, data_type::e8m0) \ + && (scales_d.ndims() == 1 || scales_d.ndims() == 2), \ "Unsupported scales data type"); \ - if (scales_d.nelems() == 1) { \ - const float s = cpu::io::load_float_value( \ - scales_d.data_type(), scales, 0); \ - if (utils::one_of(arg, DNNL_ARG_DST, \ + if (scales_d.dims()[0] == 1) { \ + if (utils::one_of((arg), DNNL_ARG_DST, \ DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST)) { \ - utils::array_set(CONCAT2(scales, _buf16), 1.f / s, 16); \ + utils::array_set( \ + CONCAT2(scales, _buf16), 1.f / scales[0], 16); \ } else { \ - utils::array_set(CONCAT2(scales, _buf16), s, 16); \ + utils::array_set(CONCAT2(scales, _buf16), scales[0], 16); \ } \ scales = CONCAT2(scales, _buf16); \ } \ @@ -90,24 +64,83 @@ MAYBE_UNUSED(scales); #define DEFINE_ARG_SCALES_BUFFER(scales, arg) \ - DEFINE_ARG_SCALES_BUFFER_ATTR(pd()->attr(), scales, arg) - -#define DEFINE_ZERO_POINTS_BUFFER(zero_points_ptr, mem_arg) \ - int32_t CONCAT2(default_zero_point_, mem_arg) = 0; \ - const int32_t *zero_points_ptr \ - = pd()->attr()->zero_points_.defined(mem_arg) \ - ? &CONCAT2(default_zero_point_, mem_arg) \ - : CTX_IN_MEM( \ - const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | mem_arg); \ - VCHECK_ATTR(zero_points_ptr != nullptr, \ - "Zero points buffer for arg %d is missing", mem_arg); \ + DEFINE_ARG_SCALES_BUFFER_ATTR(pd()->attr(), scales, (arg)) + +#define DEFINE_ZERO_POINTS_BUFFER_ATTR_U8(attr, zero_points_ptr, arg) \ + uint8_t CONCAT2(default_zero_point_, arg) = 0; \ + const uint8_t *zero_points_ptr {nullptr}; \ + if ((attr)) { \ + if ((attr)->zero_points_.has_default_values(arg)) { \ + zero_points_ptr = &CONCAT2(default_zero_point_, arg); \ + } else { \ + /* CAVEAT: type should be void to force proper loads of zero-points. + * Accessing `zero_points_ptr` by index will lead to a crash for + * datatypes different from s32. */ \ + zero_points_ptr = CTX_IN_MEM( \ + const uint8_t *, DNNL_ARG_ATTR_ZERO_POINTS | (arg)); \ + VCHECK_ATTR(zero_points_ptr != nullptr, \ + "Zero points buffer for arg %d is missing", (arg)); \ + const auto zero_points_d \ + = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | (arg)); \ + VCHECK_ATTR(utils::one_of(zero_points_d.data_type(), \ + data_type::s32, data_type::s8, data_type::u8, \ + data_type::s4, data_type::u4, data_type::f32), \ + VERBOSE_INVALID_DATATYPE, "zero points"); \ + } \ + } \ + MAYBE_UNUSED(zero_points_ptr); + +#define DEFINE_ZERO_POINTS_BUFFER_ATTR(attr, zero_points_ptr, arg) \ + int32_t CONCAT2(default_zero_point_, arg) = 0; \ + const int32_t *zero_points_ptr {nullptr}; \ + if ((attr)) { \ + if ((attr)->zero_points_.has_default_values(arg)) { \ + zero_points_ptr = &CONCAT2(default_zero_point_, arg); \ + } else { \ + /* CAVEAT: type should be void to force proper loads of zero-points. + * Accessing `zero_points_ptr` by index will lead to a crash for + * datatypes different from s32. */ \ + zero_points_ptr = CTX_IN_MEM( \ + const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | (arg)); \ + VCHECK_ATTR(zero_points_ptr != nullptr, \ + "Zero points buffer for arg %d is missing", (arg)); \ + const auto zero_points_d \ + = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | (arg)); \ + VCHECK_ATTR(utils::one_of(zero_points_d.data_type(), \ + data_type::s32, data_type::s8, data_type::u8, \ + data_type::s4, data_type::u4), \ + VERBOSE_INVALID_DATATYPE, "zero points"); \ + } \ + } \ MAYBE_UNUSED(zero_points_ptr); +#define DEFINE_ZERO_POINTS_BUFFER(zero_points_ptr, arg) \ + DEFINE_ZERO_POINTS_BUFFER_ATTR(pd()->attr(), zero_points_ptr, arg) + #define ASSIGN_ARG_SCALE_VALUE(scale, mem_arg) \ alignas(16) float CONCAT2(CONCAT2(scales, _buf16), mem_arg)[16] = {0}; \ - if (pd()->attr()->scales_.get(mem_arg).has_default_values()) { \ + if (pd()->attr()->scales_.has_default_values(mem_arg)) { \ utils::array_set(CONCAT2(CONCAT2(scales, _buf16), mem_arg), 1.0f, 16); \ scale = CONCAT2(CONCAT2(scales, _buf16), mem_arg); \ + } + +#define DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zero_points_ptr, jcp) \ + const uint8_t *input_zero_points_ptr = nullptr; \ + if (jcp.with_input_zp) { \ + input_zero_points_ptr = CTX_IN_MEM(const uint8_t *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); \ + if (input_zero_points_ptr == nullptr) return status::invalid_arguments; \ + } + +#define DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation_ptr, jcp) \ + const int32_t *output_compensation_ptr = nullptr; \ + if (jcp.with_input_zp) { \ + output_compensation_ptr = CTX_IN_MEM(const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); \ + if (output_compensation_ptr == nullptr) return status::invalid_arguments; \ + } + +#define ASSIGN_INPUT_SCALE_VALUE(scale, mem_arg) \ + if (pd()->attr()->scales_.get(mem_arg).defined()) { \ + scale = pd()->attr()->scales_.get(mem_arg).scales_; \ } else { \ const auto scale_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | mem_arg); \ VCHECK_ATTR(scale_d.data_type() == data_type::f32, \ @@ -145,4 +178,6 @@ #define DEFINE_ZERO_POINT_VALUE(zero_point, mem_arg) \ DEFINE_ZERO_POINT_VALUE_ATTR(pd()->attr(), zero_point, mem_arg) +//NOLINTEND(bugprone-macro-parentheses) + #endif // CPU_CPU_PRIMITIVE_HPP diff --git a/src/cpu/cpu_reduction_list.cpp b/src/cpu/cpu_reduction_list.cpp index 6dde9e1d93f..86465bafc1d 100644 --- a/src/cpu/cpu_reduction_list.cpp +++ b/src/cpu/cpu_reduction_list.cpp @@ -31,20 +31,18 @@ namespace { using namespace dnnl::impl::data_type; // clang-format off -constexpr impl_list_item_t impl_list[] = REG_REDUCTION_P({ +const impl_list_item_t impl_list[] = REG_REDUCTION_P({ CPU_INSTANCE_X64(jit_uni_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) - CPU_INSTANCE(ref_reduction_t) + CPU_INSTANCE(ref_reduction_t, f32, f32, f32) + CPU_INSTANCE(ref_reduction_t, bf16, bf16, f32) + CPU_INSTANCE(ref_reduction_t, bf16, f32, f32) + CPU_INSTANCE(ref_reduction_t, s8, s8, s32) + CPU_INSTANCE(ref_reduction_t, s8, s32, s32) + CPU_INSTANCE(ref_reduction_t, s8, f32, s32) + CPU_INSTANCE(ref_reduction_t, u8, u8, s32) + CPU_INSTANCE(ref_reduction_t, u8, s32, s32) + CPU_INSTANCE(ref_reduction_t, u8, f32, s32) /* eol */ nullptr, }); diff --git a/src/cpu/cpu_shuffle_list.cpp b/src/cpu/cpu_shuffle_list.cpp index e81a19e89c5..cb4681415ba 100644 --- a/src/cpu/cpu_shuffle_list.cpp +++ b/src/cpu/cpu_shuffle_list.cpp @@ -36,14 +36,14 @@ namespace { using namespace dnnl::impl::data_type; // clang-format off -constexpr impl_list_item_t impl_list[] = REG_SHUFFLE_P({ - CPU_INSTANCE_X64(jit_uni_shuffle_t) - CPU_INSTANCE_X64(jit_uni_shuffle_t) - CPU_INSTANCE_X64(jit_uni_shuffle_t) - CPU_INSTANCE_AARCH64(jit_uni_shuffle_t) - CPU_INSTANCE_AARCH64(jit_uni_shuffle_t) - CPU_INSTANCE_AARCH64(jit_uni_shuffle_t) - CPU_INSTANCE_AARCH64(jit_uni_shuffle_t) +const impl_list_item_t impl_list[] = REG_SHUFFLE_P({ + CPU_INSTANCE_X64(jit_uni_shuffle_t, avx512_core) + CPU_INSTANCE_X64(jit_uni_shuffle_t, avx) + CPU_INSTANCE_X64(jit_uni_shuffle_t, sse41) + CPU_INSTANCE_AARCH64(jit_uni_shuffle_t, sve_512) + CPU_INSTANCE_AARCH64(jit_uni_shuffle_t, sve_256) + CPU_INSTANCE_AARCH64(jit_uni_shuffle_t, sve_128) + CPU_INSTANCE_AARCH64(jit_uni_shuffle_t, asimd) CPU_INSTANCE(ref_shuffle_t) /* eol */ nullptr, diff --git a/src/cpu/cpu_softmax_list.cpp b/src/cpu/cpu_softmax_list.cpp index 5168f0708a1..20017f388a7 100644 --- a/src/cpu/cpu_softmax_list.cpp +++ b/src/cpu/cpu_softmax_list.cpp @@ -22,14 +22,16 @@ #if DNNL_X64 #include "cpu/x64/jit_uni_softmax.hpp" +#include "cpu/x64/jit_uni_fork_softmax.hpp" using namespace dnnl::impl::cpu::x64; #elif DNNL_AARCH64 #include "cpu/aarch64/jit_uni_softmax.hpp" -#if DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/acl_softmax.hpp" -#endif using namespace dnnl::impl::cpu::aarch64; #endif +#if DNNL_USE_ACL +#include "cpu/acl/acl_softmax.hpp" +using namespace dnnl::impl::cpu::acl; +#endif namespace dnnl { namespace impl { @@ -44,18 +46,21 @@ const std::map> &impl_list_map() { static std::map> the_map = REG_SOFTMAX_P({ {{forward}, { CPU_INSTANCE_X64(jit_uni_softmax_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_softmax_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_softmax_fwd_t) - CPU_INSTANCE_AARCH64(jit_uni_softmax_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_softmax_fwd_t) + CPU_INSTANCE_X64(jit_uni_fork_softmax_fwd_t, avx512_core) + CPU_INSTANCE_X64(jit_uni_fork_softmax_fwd_t, avx2) + CPU_INSTANCE_X64(jit_uni_fork_softmax_fwd_t, sse41) + CPU_INSTANCE_AARCH64(jit_uni_softmax_fwd_t, sve_512) + CPU_INSTANCE_AARCH64(jit_uni_softmax_fwd_t, sve_256) + CPU_INSTANCE_AARCH64(jit_uni_softmax_fwd_t, sve_128) + CPU_INSTANCE_ACL(acl_softmax_fwd_t) CPU_INSTANCE(ref_softmax_fwd_t) nullptr, }}, {{backward}, REG_BWD_PK({ CPU_INSTANCE_X64(jit_uni_softmax_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_softmax_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_softmax_bwd_t) - CPU_INSTANCE_AARCH64(jit_uni_softmax_bwd_t) + CPU_INSTANCE_AARCH64(jit_uni_softmax_bwd_t, sve_512) + CPU_INSTANCE_AARCH64(jit_uni_softmax_bwd_t, sve_256) + CPU_INSTANCE_AARCH64(jit_uni_softmax_bwd_t, sve_128) CPU_INSTANCE(ref_softmax_bwd_t) nullptr, })}, diff --git a/src/cpu/cpu_stream.hpp b/src/cpu/cpu_stream.hpp index 30d5a6e058b..7bf2cac3a44 100644 --- a/src/cpu/cpu_stream.hpp +++ b/src/cpu/cpu_stream.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace cpu { struct cpu_stream_t : public stream_t { cpu_stream_t(engine_t *engine, impl::stream_impl_t *stream_impl) : stream_t(engine, stream_impl) {} - virtual ~cpu_stream_t() = default; + ~cpu_stream_t() override = default; dnnl::impl::status_t wait() override { // CPU execution is synchronous so return immediately diff --git a/src/cpu/cpu_sum.cpp b/src/cpu/cpu_sum.cpp index 3f1ff8911e9..24095785fcc 100644 --- a/src/cpu/cpu_sum.cpp +++ b/src/cpu/cpu_sum.cpp @@ -32,21 +32,18 @@ namespace cpu { namespace { using namespace dnnl::impl::data_type; + #define INSTANCE(...) \ impl_list_item_t(impl_list_item_t::sum_type_deduction_helper_t< \ __VA_ARGS__::pd_t>()), #define SUM_INSTANCE_AVX512(...) REG_AVX512_ISA(INSTANCE(__VA_ARGS__)) #define SUM_INSTANCE_AVX2(...) REG_AVX2_ISA(INSTANCE(__VA_ARGS__)) // clang-format off -constexpr impl_list_item_t cpu_sum_impl_list[] = REG_SUM_P({ +const impl_list_item_t cpu_sum_impl_list[] = REG_SUM_P({ SUM_INSTANCE_AVX512(jit_xf16_sum_t) SUM_INSTANCE_AVX512(jit_xf16_sum_t) SUM_INSTANCE_AVX2(jit_xf16_sum_t) SUM_INSTANCE_AVX2(jit_xf16_sum_t) - SUM_INSTANCE_AVX2(jit_xf16_sum_t) - SUM_INSTANCE_AVX2(jit_xf16_sum_t) - INSTANCE(simple_sum_t) - INSTANCE(simple_sum_t) INSTANCE(simple_sum_t) INSTANCE(simple_sum_t) INSTANCE(simple_sum_t) diff --git a/src/cpu/dw_convolution_utils.hpp b/src/cpu/dw_convolution_utils.hpp index 088e01b9964..f10a13334b1 100644 --- a/src/cpu/dw_convolution_utils.hpp +++ b/src/cpu/dw_convolution_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,27 +39,27 @@ inline status_t get_depthwise_conv_desc(convolution_desc_t &cd_dw, || !attr_1x1.post_ops_.entry_[dw_po_index].is_convolution()) return status::invalid_arguments; + // todo: [AV] remove this check when we use original oneDNN dw conv fusing + if (attr_1x1.post_ops_.entry_[dw_po_index].is_convolution()) + return status::unimplemented; + // Create new attributes with scales from depthwise post-op and copy // post-ops after depthwise post-op. auto &dw_po = attr_1x1.post_ops_.entry_[dw_po_index].depthwise_conv; - // erase 1x1 conv scales - for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - auto &scale = attr_dw.scales_.get(arg); - if (!scale.has_default_values()) attr_dw.scales_.reset(arg); - } - const auto &dw_src_scales = attr_1x1.scales_.get(DNNL_ARG_DST); const auto &dw_wei_scales = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); const auto &dw_dst_scales = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST); + + assert(attr_dw.scales_.has_default_values()); if (!dw_src_scales.has_default_values()) - attr_dw.scales_.set(DNNL_ARG_SRC, dw_src_scales.mask_); + CHECK(attr_dw.scales_.set(DNNL_ARG_SRC, dw_src_scales.get_mask())); if (!dw_wei_scales.has_default_values()) - attr_dw.scales_.set(DNNL_ARG_WEIGHTS, dw_wei_scales.mask_); + CHECK(attr_dw.scales_.set(DNNL_ARG_WEIGHTS, dw_wei_scales.get_mask())); if (!dw_dst_scales.has_default_values()) - attr_dw.scales_.set(DNNL_ARG_DST, dw_dst_scales.mask_); + CHECK(attr_dw.scales_.set(DNNL_ARG_DST, dw_dst_scales.get_mask())); auto dw_po_len = attr_1x1.post_ops_.len() - (dw_po_index + 1); attr_dw.post_ops_.entry_.resize(dw_po_len); diff --git a/src/cpu/gemm/f32/ref_gemm_f32.cpp b/src/cpu/gemm/f32/ref_gemm_f32.cpp index e7d69f01727..944df461e3c 100644 --- a/src/cpu/gemm/f32/ref_gemm_f32.cpp +++ b/src/cpu/gemm/f32/ref_gemm_f32.cpp @@ -38,7 +38,10 @@ template void copy_A( bool isTransA, dim_t K, const data_t *A, const dim_t lda, data_t *ws) { for (dim_t k = 0; k < K; k++) { +#if !defined(_MSC_VER) + // Compilation with '#pragma omp simd' in this place on VS2019 to lead to fatal error C1001 PRAGMA_OMP_SIMD() +#endif for (dim_t i = 0; i < unroll_factor::m; i++) { ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; } diff --git a/src/cpu/gemm_convolution.cpp b/src/cpu/gemm_convolution.cpp index 672997f6171..80edde22afa 100644 --- a/src/cpu/gemm_convolution.cpp +++ b/src/cpu/gemm_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,9 @@ #include "common/type_helpers.hpp" #include "common/utils.hpp" #include "cpu/gemm_convolution.hpp" +#if DNNL_X64 +#include "cpu/x64/injectors/jit_uni_postops_injector.hpp" +#endif namespace dnnl { namespace impl { @@ -51,13 +54,20 @@ status_t gemm_convolution_fwd_t::execute_forward_nspc( auto bia_base = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); auto dst_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); +#if DNNL_X64 + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); +#else + const auto post_ops_binary_rhs_arg_vec = std::vector(); +#endif + auto scratchpad = ctx.get_scratchpad_grantor(); const conv_gemm_conf_t &jcp = pd()->jcp_; std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_forward_thr_nspc(ctx, ithr, nthr, src_base, - wei_base, bia_base, dst_base, scratchpad); + wei_base, bia_base, dst_base, scratchpad, post_ops_binary_rhs_arg_vec); if (st_thr != status::success) st = st_thr; }); @@ -67,7 +77,7 @@ status_t gemm_convolution_fwd_t::execute_forward_nspc( status_t gemm_convolution_fwd_t::execute_forward_thr_nspc(const exec_ctx_t &ctx, const int ithr, const int nthr, const data_t *src_base, const data_t *wei_base, const data_t *bia_base, data_t *dst_base, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, const std::vector& post_ops_binary_rhs_arg_vec) const { const conv_gemm_conf_t &jcp = pd()->jcp_; // Src Format: mb-spatial-groups-input_channels @@ -151,68 +161,16 @@ status_t gemm_convolution_fwd_t::execute_forward_thr_nspc(const exec_ctx_t &ctx, &LDC); if (st != status::success) return st; - if (jcp.with_bias || jcp.with_eltwise || jcp.with_binary) { - parallel(0, [&](int ithr, int nthr) { - dim_t start, end; - balance211(N * jcp.oc, nthr, ithr, start, end); - - const size_t first_oc = start % jcp.oc; - const size_t last_oc = (end - 1) % jcp.oc; - const size_t first_os = start / jcp.oc; - const size_t last_os = (end - 1) / jcp.oc; - - for (size_t os = first_os; os <= last_os; ++os) { - const size_t start_oc = (os == first_os) ? first_oc : 0; - const size_t end_oc - = (os == last_os) ? last_oc : jcp.oc - 1; - - const data_t *__restrict bia_arr - = bia_base ? bia_base + g * jcp.oc : nullptr; - data_t *__restrict dst_arr = dst + os * dst_os_stride; - - if (jcp.with_bias) { - PRAGMA_OMP_SIMD() - for (size_t oc = start_oc; oc <= end_oc; oc++) { - dst_arr[oc] += bia_arr[oc]; - } - } + if (pp_kernel_) { + const size_t first_oc = g * jcp.oc; + const size_t last_oc = jcp.oc; + const size_t first_os = 0; + const size_t last_os = N; - if (jcp.with_eltwise || jcp.with_binary) { - bool fast_relu_done = false; - if (jcp.with_eltwise && jcp.post_ops.len() == 1) { - // fast branch for ReLU case - const auto &eltwise - = jcp.post_ops.entry_.back().eltwise; - - if (eltwise.alg == alg_kind::eltwise_relu) { - const auto alpha = eltwise.alpha; - const auto scale = eltwise.scale; - PRAGMA_OMP_SIMD() - for (size_t oc = start_oc; oc <= end_oc; - oc++) { - if (dst_arr[oc] < 0) - dst_arr[oc] *= alpha; - dst_arr[oc] *= scale; - } - fast_relu_done = true; - } - } - if (!fast_relu_done) { - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.dst_md = pd()->dst_md(); - - for (size_t oc = start_oc; oc <= end_oc; oc++) { - // jcp.od is not part of jcp.os, so multiply - // jcp.od to get spatial offset. - args.l_offset = (g * jcp.oc + oc) - * (jcp.os * jcp.od); - post_ops_->execute(dst_arr[oc], args); - } - } - } - } - }); + for (size_t os = first_os; os < last_os; ++os) { + data_t* dst_local = dst + os * dst_os_stride; + (*pp_kernel_)(dst_base, dst_local, bia_base, 1, first_oc, last_oc, 1, post_ops_binary_rhs_arg_vec); + } } } nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); @@ -226,16 +184,37 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); + auto dst_orig = dst; + +#if DNNL_X64 + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); +#else + const auto post_ops_binary_rhs_arg_vec = std::vector(); +#endif auto col = ctx.get_scratchpad_grantor().get(key_conv_gemm_col); const conv_gemm_conf_t &jcp = this->pd()->jcp_; - const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + // The second arg in template means sub_offset0 = true + // See `blk_off` method definition. + const size_t src_mb_stride = src_d.blk_off(1); + const size_t src_g_stride = src_d.blk_off(0, 1) * jcp.ic; + + const size_t dst_mb_stride = dst_d.blk_off(1); + const size_t dst_g_stride = dst_d.blk_off(0, 1) * jcp.oc; + const size_t weights_oc_size = jcp.ic * jcp.ks; const size_t weights_g_size = weights_oc_size * jcp.oc; const bool is_problem_3d = pd()->ndims() == 5; + src += src_d.off_l(0); + dst += dst_d.off_l(0); + assert(IMPLICATION(is_problem_3d, jcp.os_block == jcp.os && jcp.ic_block == jcp.ic && jcp.os_nb_block == 1)); @@ -254,7 +233,7 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev, im_pos_t &step, const im_pos_t &end) { const data_t *_src - = src + (curr.n * jcp.ngroups + curr.g) * src_step; + = src + curr.n * src_mb_stride + curr.g * src_g_stride; step.oc = nstl::min( jcp.oc_block, nstl::min(jcp.oc, end.oc) - curr.oc); step.sp = nstl::min(jcp.os_block, @@ -275,10 +254,9 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( const data_t one = 1.0; const dim_t M = jcp.os * jcp.od; - const size_t dst_step = jcp.oc * M; const dim_t m = step.sp; const dim_t LDA = jcp.im2col_sz ? m : M; - data_t *_dst = dst + (curr.n * jcp.ngroups + curr.g) * dst_step + data_t *_dst = dst + curr.n * dst_mb_stride + curr.g * dst_g_stride + curr.oc * M + curr.od * jcp.os + curr.sp; const dim_t K = step.ic * jcp.ks; const dim_t LDB = jcp.ic * jcp.ks; @@ -296,61 +274,8 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp( &LDA, _weights, &LDB, &beta, _dst, &M); if (st != status::success) return st; - if (curr.ic == jcp.ic - step.ic) { - // TODO: for "outer threading" we have parallel section within - // outermost "parallel". It is not good. Consider to use - // "parallel" here with number of threads passed as parameter - const int oc_start = curr.g * jcp.oc + curr.oc; - if (jcp.with_eltwise || jcp.with_binary) { - bool fast_relu_done = false; - if (jcp.with_eltwise && jcp.post_ops.len() == 1) { - // fast branch for ReLU case - const auto &eltwise - = jcp.post_ops.entry_.back().eltwise; - if (eltwise.alg == alg_kind::eltwise_relu) { - parallel_nd(step.oc, [&](dim_t oc) { - data_t b = jcp.with_bias ? bias[oc_start + oc] - : 0; - data_t *d_ = _dst + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - if (d_[oS] < 0) d_[oS] *= eltwise.alpha; - d_[oS] *= eltwise.scale; - } - }); - fast_relu_done = true; - } - } - if (!fast_relu_done) { - parallel_nd(step.oc, [&](dim_t oc) { - data_t b = jcp.with_bias ? bias[oc_start + oc] : 0; - data_t *d_ = _dst + oc * M; - - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.dst_md = pd()->dst_md(); - args.l_offset = d_ - dst; - - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - post_ops_->execute(d_[oS], args); - args.l_offset++; - } - }); - } - - } else if (jcp.with_bias) { - parallel_nd(step.oc, [&](dim_t oc) { - data_t b = bias[oc_start + oc]; - data_t *d_ = _dst + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - } - }); - } + if (pp_kernel_ && curr.ic == jcp.ic - step.ic) { + (*pp_kernel_)(dst_orig, _dst, bias, m, curr.g * jcp.oc + curr.oc, step.oc, M, post_ops_binary_rhs_arg_vec); } return status::success; @@ -422,13 +347,20 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_nspc( auto bia_base = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); auto diff_src_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); +#if DNNL_X64 + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); +#else + const auto post_ops_binary_rhs_arg_vec = std::vector(); +#endif + auto scratchpad = ctx.get_scratchpad_grantor(); const conv_gemm_conf_t &jcp = pd()->jcp_; std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_backward_data_thr_nspc(ithr, nthr, - diff_dst_base, wei_base, bia_base, diff_src_base, scratchpad); + diff_dst_base, wei_base, bia_base, diff_src_base, scratchpad, post_ops_binary_rhs_arg_vec); if (st_thr != status::success) st = st_thr; }); @@ -438,7 +370,8 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_nspc( status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc( const int ithr, const int nthr, const data_t *diff_dst_base, const data_t *wei_base, const data_t *bia_base, data_t *diff_src_base, - const memory_tracking::grantor_t &scratchpad) const { + const memory_tracking::grantor_t &scratchpad, + const std::vector& post_ops_binary_rhs_arg_vec) const { const conv_gemm_conf_t &jcp = pd()->jcp_; // Diff_dst Format: mb-spatial-groups-output_channels @@ -458,6 +391,8 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc( // threads share work across mini-batch and groups const dim_t work_amount = jcp.ngroups * jcp.mb; + const auto &p = pd()->attr()->post_ops_; + data_t *__restrict col = scratchpad.get(key_conv_gemm_col) + (ptrdiff_t)ithr * jcp.im2col_sz; const bool acc_needed = jcp.ngroups > 1; @@ -506,6 +441,31 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc( } }); } + if (p.len() > 0) { + std::size_t post_ops_data_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = post_op.depthwise.alg == alg_kind::depthwise_scale_shift + ? depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts] + : nullptr; + + parallel_nd(static_cast(jcp.is) * jcp.id, [&](size_t is) { + data_t *__restrict diff_src_arr + = diff_src + is * diff_src_os_stride; + for (int ic = 0; ic < jcp.ic; ic++) { + diff_src_arr[ic] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(diff_src_arr[ic], + depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic); + } + }); + post_ops_data_idx++; + depthwise_inj_idx++; + } + } + } nd_iterator_step(n, jcp.mb, g, jcp.ngroups); } return status::success; @@ -517,13 +477,28 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp( auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); +#if DNNL_X64 + const auto post_ops_binary_rhs_arg_vec + = x64::binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); +#else + const auto post_ops_binary_rhs_arg_vec = std::vector(); +#endif + auto col = ctx.get_scratchpad_grantor().get(key_conv_gemm_col); const conv_gemm_conf_t &jcp = this->pd()->jcp_; const dim_t M = jcp.os * jcp.od; - const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id; - const size_t dst_step = (size_t)jcp.oc * M; + const size_t src_step_to_clean = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id; + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + // The second arg in template means sub_offset0 = true + // See `blk_off` method definition. + const size_t src_step = diff_src_d.blk_off(1) / jcp.ngroups; + const size_t dst_step = diff_dst_d.blk_off(1) / jcp.ngroups; + diff_src += diff_src_d.off_l(0); + diff_dst += diff_dst_d.off_l(0); const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks; const dim_t m = jcp.os_block; @@ -533,6 +508,8 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp( const dim_t work_amount = (size_t)jcp.ngroups * jcp.mb; const bool is_problem_3d = pd()->ndims() == 5; + const auto &p = pd()->attr()->post_ops_; + std::atomic st(status::success); parallel(jcp.nthr, [&](const int ithr, const int nthr) { data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; @@ -547,7 +524,7 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp( if (is_problem_3d && jcp.im2col_sz > 0) { // jit_gemm_convolution_utils::col2im_3d() assumes that the // accumulator is initialized by zeroes - for (size_t i = 0; i < src_step; i++) + for (size_t i = 0; i < src_step_to_clean; i++) _diff_src[i] = (data_t)0; } @@ -580,6 +557,31 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp( } } } + if (p.len() > 0) { + std::size_t post_ops_data_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = post_op.depthwise.alg == alg_kind::depthwise_scale_shift + ? depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts] + : nullptr; + parallel_nd(jcp.ic, [&](const int ic) { + for (int id = 0; id < jcp.id; ++id) { + data_t *d_ = _diff_src + ic * jcp.id * jcp.is + id * jcp.is; + for (int iS = 0; iS < jcp.is; ++iS) { + d_[iS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[iS], + depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic); + } + } + }); + post_ops_data_idx++; + depthwise_inj_idx++; + } + } + } nd_iterator_step(g, jcp.ngroups, n, jcp.mb); } }); diff --git a/src/cpu/gemm_convolution.hpp b/src/cpu/gemm_convolution.hpp index c321266ebc4..1000d8fa16f 100644 --- a/src/cpu/gemm_convolution.hpp +++ b/src/cpu/gemm_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,15 +28,19 @@ #include "cpu/gemm_convolution_utils.hpp" #include "cpu/primitive_attr_postops.hpp" +#include "ref_depthwise_injector.hpp" + +#if DNNL_X64 +#include "cpu/x64/cpu_isa_traits.hpp" +#include "cpu/x64/injectors/jit_uni_binary_injector.hpp" +#endif namespace dnnl { namespace impl { namespace cpu { struct gemm_convolution_fwd_t : public primitive_t { struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T( GEMM_IMPL_STR, gemm_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); @@ -55,48 +59,59 @@ struct gemm_convolution_fwd_t : public primitive_t { primitive_attr_t::skip_mask_t::post_ops, f32), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_CONV(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); - auto scratchpad = scratchpad_registry().registrar(); + // TODO: make `init_conf` assign initialized object to `jcp_` + jcp_ = conv_gemm_conf_t(); return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads()); } - conv_gemm_conf_t jcp_; + conv_gemm_conf_t jcp_ = utils::zero(); protected: bool post_ops_ok() const { + using namespace dnnl::impl::primitive_kind; auto const &po = attr()->post_ops_; - auto is_sum_ok = [&](int idx) { - return IMPLICATION(po.entry_[idx].kind == primitive_kind::sum, - idx == 0 && po.entry_[idx].is_sum()); - }; - auto is_binary - = [&](int idx) { return po.entry_[idx].is_binary(); }; - auto is_prelu = [&](int idx) { return po.entry_[idx].is_prelu(); }; - auto is_binary_or_prelu_supported = [&](int idx) { - bool ok = dnnl::impl::get_rhs_arg_broadcasting_strategy( - binary_injector_utils::get_src1_desc( - po.entry_[idx], dst_md_), - dst_md_, - {broadcasting_strategy_t::scalar, - broadcasting_strategy_t::per_oc}) - != broadcasting_strategy_t::unsupported; - return ok; - }; - - if (!ref_post_ops_t::primitive_kind_ok(attr()->post_ops_)) - return false; - for (int idx = 0; idx < po.len(); idx++) { - bool ok = is_sum_ok(idx) - && IMPLICATION(is_binary(idx) || is_prelu(idx), - is_binary_or_prelu_supported(idx)); - if (!ok) return false; - } + auto all_post_ops_supported = [&]() { + for (int i = 0; i < po.len(); i++) { + const auto &post_op = po.entry_[i]; + if (!utils::one_of(post_op.kind, sum, binary, eltwise, + depthwise, quantization)) + return false; + +#if DNNL_X64 + using namespace cpu::x64; + cpu_isa_t isa = isa_undef; + if (po.entry_[i].kind == binary) { + auto dst_md = this->dst_md(); + if (mayiuse(avx512_core)) + isa = avx512_core; + else if (mayiuse(avx2)) + isa = avx2; + else if (mayiuse(sse41)) + isa = sse41; + if ((isa == isa_undef) + || !binary_injector::is_supported(isa, + binary_injector::get_src1_desc( + post_op, *dst_md), + *dst_md, default_strategies())) { + return false; + } + } +#endif + } + return true; + }; + auto contain = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind) != -1; }; + auto position = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind); }; + auto count = [&](dnnl::impl::primitive_kind_t kind) { return po.count(kind); }; - return true; + return all_post_ops_supported() && + count(primitive_kind::sum) <= 1 && + IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0); } }; @@ -104,18 +119,20 @@ struct gemm_convolution_fwd_t : public primitive_t { : primitive_t(apd), post_ops_(nullptr) {} status_t init(engine_t *engine) override { + const auto &post_ops = pd()->attr()->post_ops_; const data_t one = 1.0, zero = 0.0; const auto &jcp = pd()->jcp_; beta_ = jcp.with_sum ? one : zero; - if (jcp.with_eltwise || jcp.with_binary) { - CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops))); - CHECK(post_ops_->init(pd()->dst_md())); - } - return status::success; + bool has_bias = pd()->with_bias(); + bool has_post_ops = post_ops.len() > 0; + postops_in_ip_ = has_bias || has_post_ops; + + CHECK(safe_ptr_assign(pp_kernel_, pp_kernel_t::create(pd(), pd()->jcp_))); + return (pp_kernel_) ? pp_kernel_->create_kernel() : status::success; } - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { bool is_nspc = pd()->jcp_.is_nspc; @@ -128,9 +145,13 @@ struct gemm_convolution_fwd_t : public primitive_t { status_t execute_forward_thr_nspc(const exec_ctx_t &ctx, const int ithr, const int nthr, const data_t *src_base, const data_t *wei_base, const data_t *bia_base, data_t *dst_base, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, + const std::vector& post_ops_binary_rhs_arg_vec) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + using pp_kernel_t = gemm_convolution_utils::pp_kernel_t; + std::unique_ptr pp_kernel_; + bool postops_in_ip_; data_t beta_; std::unique_ptr post_ops_; @@ -138,9 +159,7 @@ struct gemm_convolution_fwd_t : public primitive_t { struct gemm_convolution_bwd_data_t : public primitive_t { struct pd_t : public cpu_convolution_bwd_data_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t, USE_GLOBAL_SCRATCHPAD); @@ -156,21 +175,56 @@ struct gemm_convolution_bwd_data_t : public primitive_t { VERBOSE_BAD_ALGORITHM); VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); VDISPATCH_CONV( - attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + is_supported_post_ops(), VERBOSE_UNSUPPORTED_ATTR); auto scratchpad = scratchpad_registry().registrar(); + // TODO: make `init_conf` assign initialized object to `jcp_` + jcp_ = conv_gemm_conf_t(); return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, *desc(), diff_src_md_, weights_md_, diff_dst_md_, bias_md_, attr_, dnnl_get_max_threads()); } - conv_gemm_conf_t jcp_; + conv_gemm_conf_t jcp_ = utils::zero(); + + protected: + virtual bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; - gemm_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + gemm_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) { + const auto &post_ops = pd()->attr()->post_ops_; + for (int i = 0; i < post_ops.len(); i++) { + auto &post_op = post_ops.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg)); + } + } + } + + ~gemm_convolution_bwd_data_t() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } + + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { bool is_nspc = pd()->jcp_.is_nspc; @@ -184,17 +238,18 @@ struct gemm_convolution_bwd_data_t : public primitive_t { status_t execute_backward_data_thr_nspc(const int ithr, const int nthr, const data_t *diff_dst_base, const data_t *wei_base, const data_t *bia_base, data_t *diff_src_base, - const memory_tracking::grantor_t &scratchpad) const; + const memory_tracking::grantor_t &scratchpad, + const std::vector& post_ops_binary_rhs_arg_vec) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + nstl::vector depthwise_injectors; }; struct gemm_convolution_bwd_weights_t : public primitive_t { struct pd_t : public cpu_convolution_bwd_weights_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) - , jcp_() {} + using cpu_convolution_bwd_weights_pd_t:: + cpu_convolution_bwd_weights_pd_t; DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t, USE_GLOBAL_SCRATCHPAD); @@ -213,17 +268,19 @@ struct gemm_convolution_bwd_weights_t : public primitive_t { attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); auto scratchpad = scratchpad_registry().registrar(); + // TODO: make `init_conf` assign initialized object to `jcp_` + jcp_ = conv_gemm_conf_t(); return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, *desc(), src_md_, diff_weights_md_, diff_dst_md_, diff_bias_md_, attr_, dnnl_get_max_threads()); } - conv_gemm_conf_t jcp_; + conv_gemm_conf_t jcp_ = utils::zero(); }; gemm_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { const bool is_nspc = pd()->jcp_.is_nspc; diff --git a/src/cpu/gemm_convolution_utils.cpp b/src/cpu/gemm_convolution_utils.cpp index 2de4ddcf39f..5060fa1fd03 100644 --- a/src/cpu/gemm_convolution_utils.cpp +++ b/src/cpu/gemm_convolution_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,10 @@ #include "common/utils.hpp" #include "cpu/gemm_convolution_utils.hpp" #include "cpu/scale_utils.hpp" + +#include "ref_eltwise.hpp" +#include "ref_depthwise_injector.hpp" + #if DNNL_X64 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" #endif @@ -30,6 +34,7 @@ #include "cpu/platform.hpp" #if DNNL_X64 +#include "cpu/x64/jit_gemm_convolution_utils.hpp" #include "cpu/x64/cpu_isa_traits.hpp" #endif @@ -51,13 +56,173 @@ single_gemm_conv_chunk_desc_t::single_gemm_conv_chunk_desc_t(dim_t d_off, , w_off_(w_off) , w_size_(w_size) {} +namespace gemm_convolution_utils { + +struct ref_pp_kernel_t : pp_kernel_t { + ref_pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) + : pp_kernel_t(pd, jcp) { + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + ref_eltwise_injectors_.push_back(new ref_eltwise_scalar_fwd_t(post_op.eltwise)); + } else if (post_op.is_depthwise()) { + ref_depthwise_injectors_.push_back(new ref_depthwise_scalar_fwd_t( + post_op.depthwise.alg)); + } + } + } + ~ref_pp_kernel_t() { + for (auto impl : ref_eltwise_injectors_) + delete impl; + ref_eltwise_injectors_.clear(); + for (auto impl : ref_depthwise_injectors_) + delete impl; + ref_depthwise_injectors_.clear(); + } + + virtual void operator()(float *dst_orig, float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride, + const std::vector& post_ops_binary_rhs_arg_vec) const override; + + static bool post_ops_ok(const convolution_pd_t *pd) { + using namespace dnnl::impl::primitive_kind; + const auto& po = pd->attr()->post_ops_; + for (int i = 0; i < po.len(); i++) { + if (!utils::one_of(po.entry_[i].kind, eltwise, depthwise, quantization)) { + return false; + } + } + return true; + } + +private: + nstl::vector ref_eltwise_injectors_; + nstl::vector ref_depthwise_injectors_; +}; + +void ref_pp_kernel_t::operator()(float *dst_orig, float *dst, const float *bias, const int len,const int oc_start, const int oc_work, const int oc_stride, + const std::vector& post_ops_binary_rhs_arg_vec) const { + // TODO: for "outer threading" we have parallel section within + // outermost "parallel". It is not good. Consider to use + // "parallel" here with number of threads passed as parameter + const auto &p = post_ops_; + bool need_bias = do_bias_; + if (p.len() > 0) { + std::size_t post_ops_data_idx = 0; + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + // todo: sum? + if (post_op.is_eltwise()) { + parallel_nd(oc_work, [&](const int oc) { + float b = need_bias ? bias[oc_start + oc] : 0; + float *d_ = dst + oc * oc_stride; + for (int oS = 0; oS < len; ++oS) { + d_[oS] += b; + d_[oS] = ref_eltwise_injectors_[eltwise_inj_idx]->compute_scalar(d_[oS]); + } + }); + + eltwise_inj_idx++; + need_bias = false; + } else if (post_op.is_depthwise()) { + auto depthwise_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts]; + + parallel_nd(oc_work, [&](const int oc) { + float b = need_bias ? bias[oc_start + oc] : 0; + float *d_ = dst + oc * oc_stride; + for (int oS = 0; oS < len; ++oS) { + d_[oS] += b; + d_[oS] = ref_depthwise_injectors_[depthwise_inj_idx]->compute_scalar(d_[oS], + depthwise_weights + oc_start + oc, + depthwise_bias + oc_start + oc); + } + }); + + post_ops_data_idx++; + depthwise_inj_idx++; + need_bias = false; + } else if (post_op.is_quantization()) { + auto quant = post_op.quantization; + auto quantization_base = reinterpret_cast(post_ops_binary_rhs_arg_vec[post_ops_data_idx]); + auto pcl = quantization_base + post_op.quantization.offset[quant.crop_low]; + auto pch = quantization_base + post_op.quantization.offset[quant.crop_high]; + auto pisc = quantization_base + post_op.quantization.offset[quant.inp_scale]; + auto pish = quantization_base + post_op.quantization.offset[quant.inp_shift]; + auto posc = quantization_base + post_op.quantization.offset[quant.output_scale]; + auto posh = quantization_base + post_op.quantization.offset[quant.output_shift]; + + parallel_nd(oc_work, [&](const int oc) { + float b = need_bias ? bias[oc_start + oc] : 0; + float *d_ = dst + oc * oc_stride; + + int cl_idx = !quant.per_channel[quant.crop_low] ? 0 : oc_start + oc; + int ch_idx = !quant.per_channel[quant.crop_high] ? 0 : oc_start + oc; + int isc_idx = !quant.per_channel[quant.inp_scale] ? 0 : oc_start + oc; + int ish_idx = !quant.per_channel[quant.inp_shift] ? 0 : oc_start + oc; + int osc_idx = !quant.per_channel[quant.output_scale] ? 0 : oc_start + oc; + int osh_idx = !quant.per_channel[quant.output_shift] ? 0 : oc_start + oc; + + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < len; ++oS) { + d_[oS] += b; + + d_[oS] = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], d_[oS])); + d_[oS] = d_[oS] * pisc[isc_idx] + pish[ish_idx]; + d_[oS] = roundf(d_[oS]); + d_[oS] = d_[oS] * posc[osc_idx] + posh[osh_idx]; + } + }); + + post_ops_data_idx++; + need_bias = false; + } + } + } + + if (need_bias) { + parallel_nd(oc_work, [&](const int oc) { + float b = bias[oc_start + oc]; + float *d_ = dst + oc * oc_stride; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < len; ++oS) { + d_[oS] += b; + } + }); + } +} + +// Interface section + +pp_kernel_t::pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) + : do_bias_(pd->with_bias()), post_ops_(pd->attr()->post_ops_) {} + +pp_kernel_t *pp_kernel_t::create( + const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) { +#if DNNL_X64 + auto *res + = x64::gemm_convolution_utils::jit_pp_kernel_create(pd, jcp); + if (res) return res; +#endif + + if (ref_pp_kernel_t::post_ops_ok(pd)) { + return new ref_pp_kernel_t(pd, jcp); + } + + return nullptr; +} +} // namespace gemm_convolution_utils + namespace jit_gemm_convolution_utils { template void im2col_3d(const conv_gemm_conf_t &jcp, const data_type_t *im, data_type_t *col, dim_t od, int spatial_step, int spatial_block) { using data_t = - typename conditional::data_type == bf16, + typename conditional::data_type == bf16, uint16_t, data_type_t>::type; const data_t *__restrict _im = reinterpret_cast(im); @@ -277,13 +442,14 @@ template void transpose_dt(const conv_gemm_conf_t &jcp, /* col[kd][kh][kw][g][ic][od][oh][ow] <-- im2col_dt_3d(im[id][ih][iw][g][ic]) */ template void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, - orig_col_dt *__restrict _col, dim_t od) { + orig_col_dt *__restrict _col, dim_t od, const uint8_t *__restrict input_zp) { // For performance reasons, use uint16_t as a proxy for bfloat16_t - using im_dt = typename utils::conditional::data_type - == bf16, - uint16_t, orig_im_dt>::type; + using im_dt = + typename utils::conditional::data_type + == bf16, + uint16_t, orig_im_dt>::type; using col_dt = - typename utils::conditional::data_type + typename utils::conditional::data_type == bf16, uint16_t, orig_col_dt>::type; const im_dt *__restrict imtr @@ -307,15 +473,18 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, const dim_t IHW = jcp.ih * jcp.iw; const dim_t OHW = jcp.oh * jcp.ow; + bool with_input_zp = input_zp != nullptr; + if (sd == 1 && sh == 1 && sw == 1 && dd == 1 && dh == 1 && dw == 1) - parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, + parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic, [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { col_dt *__restrict col_loc = col + kd * col_kd_s + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; const dim_t id = od - fp + kd; if (id < 0 || id >= jcp.id) { + col_dt izp = with_input_zp ? (col_dt)input_zp[ic] : shift; for (ptrdiff_t i = 0; i < OHW; i++) - col_loc[i] = shift; + col_loc[i] = izp; return; } const im_dt *__restrict imtr_loc @@ -337,14 +506,15 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, } }); else if (sd == 2 && sh == 2 && sw == 2 && dd == 1 && dh == 1 && dw == 1) - parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, + parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic, [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { col_dt *__restrict col_loc = col + kd * col_kd_s + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; const dim_t id = od * 2 - fp + kd; if (id < 0 || id >= jcp.id) { + col_dt izp = with_input_zp ? (col_dt)input_zp[ic] : shift; for (ptrdiff_t i = 0; i < OHW; i++) - col_loc[i] = shift; + col_loc[i] = izp; return; } const im_dt *__restrict imtr_loc @@ -368,14 +538,15 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, } }); else - parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, + parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic, [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { col_dt *__restrict col_loc = col + kd * col_kd_s + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; const dim_t id = od * sd - fp + kd * dd; if (id < 0 || id >= jcp.id) { + col_dt izp = with_input_zp ? (col_dt)input_zp[ic] : shift; for (ptrdiff_t i = 0; i < OHW; i++) - col_loc[i] = shift; + col_loc[i] = izp; return; } const im_dt *__restrict imtr_loc @@ -402,13 +573,13 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, } template void im2col_dt_3d(const conv_gemm_conf_t &jcp, - const void *__restrict im, uint8_t *__restrict col, dim_t od); + const void *__restrict im, uint8_t *__restrict col, dim_t od, const uint8_t *__restrict input_zp); template void im2col_dt_3d(const conv_gemm_conf_t &jcp, - const void *__restrict im, uint8_t *__restrict col, dim_t od); + const void *__restrict im, uint8_t *__restrict col, dim_t od, const uint8_t *__restrict input_zp); template void im2col_dt_3d(const conv_gemm_conf_t &jcp, - const void *__restrict im, float *__restrict col, dim_t od); + const void *__restrict im, float *__restrict col, dim_t od, const uint8_t *__restrict input_zp); template void im2col_dt_3d(const conv_gemm_conf_t &jcp, - const void *__restrict im, bfloat16_t *__restrict col, dim_t od); + const void *__restrict im, bfloat16_t *__restrict col, dim_t od, const uint8_t *__restrict input_zp); /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ template @@ -416,7 +587,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, data_type_t *__restrict col, dim_t ss, dim_t sb, dim_t cs, dim_t cb) { using data_t = - typename utils::conditional::data_type + typename utils::conditional::data_type == bf16, uint16_t, data_type_t>::type; const data_t *__restrict _im @@ -511,7 +682,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, // Generated code is more optimized for stride_w == 1 // because innermost loop is by width if (sw == 1) - parallel_nd(cb, jcp.kh, jcp.kw, oh_range, + parallel_nd_legacy(cb, jcp.kh, jcp.kw, oh_range, [&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) { const dim_t oh = ohr + oh_begin; const dim_t ih = oh * sh - tp + kh * dh; @@ -536,7 +707,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, } }); else - parallel_nd(cb, jcp.kh, jcp.kw, oh_range, + parallel_nd_legacy(cb, jcp.kh, jcp.kw, oh_range, [&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) { const dim_t oh = ohr + oh_begin; const dim_t ih = oh * sh - tp + kh * dh; @@ -575,13 +746,14 @@ template void im2col(const conv_gemm_conf_t &jcp, template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, void *__restrict _imtr, orig_col_dt *__restrict _col, dim_t hs, - dim_t hb, dim_t ws, dim_t wb) { + dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp) { // For performance reasons, use uint16_t as a proxy for bfloat16_t - using im_dt = typename utils::conditional::data_type - == bf16, - uint16_t, orig_im_dt>::type; + using im_dt = + typename utils::conditional::data_type + == bf16, + uint16_t, orig_im_dt>::type; using col_dt = - typename utils::conditional::data_type + typename utils::conditional::data_type == bf16, uint16_t, orig_col_dt>::type; const im_dt *__restrict im = reinterpret_cast(_im); @@ -598,6 +770,8 @@ void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, const dim_t tp = jcp.t_pad; const dim_t lp = jcp.l_pad; + bool with_input_zp = input_zp != nullptr; + if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) { /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ const dim_t hp = hs - tp; @@ -641,61 +815,103 @@ void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, const dim_t ow_start = saturate(dim_t(0), wb, ow_kw); const dim_t ow_end = saturate(dim_t(0), wb, ow_kw + iwb); for (dim_t ic = 0; ic < jcp.ic; ic++) { + uint8_t izp = with_input_zp ? input_zp[ic] : (uint8_t) 0; const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; const dim_t imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; for (dim_t oh = 0; oh < oh_start; oh++) { const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; - for (dim_t ow = 0; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = izp; + } else { + for (dim_t ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } } for (dim_t oh = oh_start; oh < oh_end; oh++) { const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb; - for (dim_t ow = 0; ow < ow_start; ++ow) - col[col_idx_oh + ow] = shift; - for (dim_t ow = ow_start; ow < ow_end; ++ow) - col[col_idx_oh + ow] - = imtr[imtr_idx_oh + ow] + shift; - for (dim_t ow = ow_end; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < ow_start; ++ow) + col[col_idx_oh + ow] = izp; + for (dim_t ow = ow_start; ow < ow_end; ++ow) + col[col_idx_oh + ow] + = imtr[imtr_idx_oh + ow]; + for (dim_t ow = ow_end; ow < wb; ++ow) + col[col_idx_oh + ow] = izp; + } else { + for (dim_t ow = 0; ow < ow_start; ++ow) + col[col_idx_oh + ow] = shift; + for (dim_t ow = ow_start; ow < ow_end; ++ow) + col[col_idx_oh + ow] + = imtr[imtr_idx_oh + ow] + shift; + for (dim_t ow = ow_end; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } } for (dim_t oh = oh_end; oh < hb; oh++) { const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; - for (dim_t ow = 0; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = izp; + } else { + for (dim_t ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } } } } } } else { - parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb, + parallel_nd_legacy(jcp.kh, jcp.kw, jcp.ic, hb, [&](dim_t kh, dim_t kw, dim_t ic, dim_t oh) { const dim_t hp = tp - kh * dh; const dim_t ih = (oh + hs) * sh - hp; const ptrdiff_t col_idx_base = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb; + uint8_t izp = with_input_zp ? input_zp[ic] : (uint8_t) 0; if (ih < 0 || ih >= jcp.ih) - for (dim_t ow = 0; ow < wb; ow++) - col[col_idx_base + ow] = shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < wb; ow++) + col[col_idx_base + ow] = izp; + } else { + for (dim_t ow = 0; ow < wb; ow++) + col[col_idx_base + ow] = shift; + } else { const dim_t wp = lp - kw * dw; const dim_t ow_start = saturate(dim_t(0), wb, div_up(wp, sw) - ws); const dim_t ow_end = saturate( dim_t(0), wb, div_up(jcp.iw + wp, sw) - ws); - for (dim_t ow = 0; ow < ow_start; ow++) - col[col_idx_base + ow] = shift; - const dim_t iw_base = ws * sw - wp; - const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; - for (dim_t ow = ow_start; ow < ow_end; ow++) { - const dim_t iw = iw_base + ow * sw; - const ptrdiff_t im_idx - = im_idx_base + iw * im_iw_stride; - col[col_idx_base + ow] = im[im_idx] + shift; + if (with_input_zp) { + for (dim_t ow = 0; ow < ow_start; ow++) + col[col_idx_base + ow] = izp; + const dim_t iw_base = ws * sw - wp; + const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; + for (dim_t ow = ow_start; ow < ow_end; ow++) { + const dim_t iw = iw_base + ow * sw; + const ptrdiff_t im_idx + = im_idx_base + iw * im_iw_stride; + col[col_idx_base + ow] = im[im_idx]; + } + for (dim_t ow = ow_end; ow < wb; ow++) + col[col_idx_base + ow] = izp; + } else { + for (dim_t ow = 0; ow < ow_start; ow++) + col[col_idx_base + ow] = shift; + const dim_t iw_base = ws * sw - wp; + const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; + for (dim_t ow = ow_start; ow < ow_end; ow++) { + const dim_t iw = iw_base + ow * sw; + const ptrdiff_t im_idx + = im_idx_base + iw * im_iw_stride; + col[col_idx_base + ow] = im[im_idx] + shift; + } + for (dim_t ow = ow_end; ow < wb; ow++) + col[col_idx_base + ow] = shift; } - for (dim_t ow = ow_end; ow < wb; ow++) - col[col_idx_base + ow] = shift; } }); } @@ -703,26 +919,25 @@ void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, - uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); + uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp); template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, - uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); + uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp); template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, float *__restrict col, - dim_t hs, dim_t hb, dim_t ws, dim_t wb); + dim_t hs, dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp); template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, - bfloat16_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); + bfloat16_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb, const uint8_t *__restrict input_zp); /* im[id][ih][iw][ic] <-- col2im_dt_3d(col[od][oh][ow][kd][kh][kw][ic]) */ template void col2im_dt(const conv_gemm_conf_t &jcp, const orig_T *__restrict _col, orig_T *__restrict _im) { // For performance reasons, use uint16_t as a proxy for bfloat16_t - using T = - typename utils::conditional::data_type == bf16, - uint16_t, orig_T>::type; + using T = typename utils::conditional< + data_traits_t::data_type == bf16, uint16_t, orig_T>::type; const T *__restrict col = reinterpret_cast(_col); T *__restrict im = reinterpret_cast(_im); @@ -1080,16 +1295,16 @@ status_t init_conf(conv_gemm_conf_t &jcp, CHECK(memory_desc_init_by_tag(src_md, desired_src_tag)); src_tag = desired_src_tag; } else { - src_tag = memory_desc_matches_one_of_tag( - src_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw); + src_tag = src_d.mb_stride_relaxed_match( + nwc, nhwc, ndhwc, ncw, nchw, ncdhw); } if (dst_d.format_kind() == format_kind::any) { CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag)); dst_tag = desired_dst_tag; } else { - dst_tag = memory_desc_matches_one_of_tag( - dst_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw); + dst_tag = dst_d.mb_stride_relaxed_match( + nwc, nhwc, ndhwc, ncw, nchw, ncdhw); } if (src_tag == format_tag::undef || dst_tag == format_tag::undef) @@ -1134,6 +1349,29 @@ status_t init_conf(conv_gemm_conf_t &jcp, const bool is_bwd_w = jcp.prop_kind == backward_weights; const bool is_fwd = !is_bwd_d && !is_bwd_w; + const auto dst_max_size + = static_cast(jcp.iw) * jcp.ih * jcp.id * jcp.ic * 4; + const auto src_max_size + = static_cast(jcp.ow) * jcp.oh * jcp.od * jcp.oc * 4; + VDISPATCH_CONV_IC(dst_max_size <= INT_MAX && src_max_size <= INT_MAX, + VERBOSE_UNSUPPORTED_FEATURE, + "dst/scr size > INT_MAX is not supported"); + + jcp.with_input_zp = !attr.input_zero_points_.has_default_values(); + if (jcp.with_input_zp) { + if (attr.input_zero_points_.count_ != 1 && attr.input_zero_points_.count_ != jcp.ic * jcp.ngroups) + return status::unimplemented; + + if (attr.output_compensations_.count_ != jcp.oc * jcp.ngroups) + return status::unimplemented; + } + + jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values(); + if (jcp.with_weights_zp) { + if (attr.weights_zero_points_.count_ != 1 && attr.weights_zero_points_.count_ != jcp.oc * jcp.ngroups) + return status::unimplemented; + } + bool is_int8_conv = (is_fwd ? utils::one_of(src_d.data_type(), s8, u8) : utils::one_of(dst_d.data_type(), s8, u8)) && weights_d.data_type() == s8; @@ -1165,7 +1403,7 @@ status_t init_conf(conv_gemm_conf_t &jcp, VDISPATCH_CONV_IC( post_ops_ok(post_ops_ok_args_t(x64::avx512_core, - {binary, eltwise, sum}, attr.post_ops_, &dst_d, + {binary, eltwise, sum, depthwise, prelu}, attr.post_ops_, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)), VERBOSE_UNSUPPORTED_POSTOP); @@ -1181,6 +1419,8 @@ status_t init_conf(conv_gemm_conf_t &jcp, jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); const int sum_ind = jcp.post_ops.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; + const int depthwise_ind = jcp.post_ops.find(primitive_kind::depthwise); + jcp.with_depthwise = depthwise_ind != -1; bool is_bf16_conv = false || (is_fwd @@ -2125,8 +2365,8 @@ status_t init_conf(conv_gemm_conf_t &jcp, jcp.dst_os_stride = dst_d.is_blocking_desc() ? dst_d.blocking_desc().strides[ndims - 1] : 0; - jcp.scale_idx_mult = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0; - jcp.with_dst_scale = !attr.scales_.get(DNNL_ARG_DST).has_default_values(); + jcp.scale_idx_mult = attr.scales_.get_mask(DNNL_ARG_WEIGHTS) > 0; + jcp.with_dst_scale = !attr.scales_.has_default_values(DNNL_ARG_DST); book_precomputed_scales(scratchpad, attr.scales_, jcp.ngroups * jcp.oc); if (jcp.zp.src_exists) { @@ -2134,8 +2374,8 @@ status_t init_conf(conv_gemm_conf_t &jcp, if (size) scratchpad.book(key_conv_gemm_zp_src_comp, size); } - VDISPATCH_CONV_IC( - scratchpad.size() <= scratchpad_limit, VERBOSE_SCRATCHPAD_LIMIT); + // VDISPATCH_CONV_IC( + // scratchpad.size() <= scratchpad_limit, VERBOSE_SCRATCHPAD_LIMIT); return status::success; } diff --git a/src/cpu/gemm_convolution_utils.hpp b/src/cpu/gemm_convolution_utils.hpp index 43e9784bc44..222b6d5f71b 100644 --- a/src/cpu/gemm_convolution_utils.hpp +++ b/src/cpu/gemm_convolution_utils.hpp @@ -43,6 +43,7 @@ struct conv_gemm_conf_t { bool with_bias; bool with_eltwise; bool with_binary; + bool with_depthwise; bool with_sum; post_ops_t post_ops; bool is_nspc; @@ -69,6 +70,9 @@ struct conv_gemm_conf_t { size_t dst_os_stride; size_t scale_idx_mult; bool with_dst_scale; + + bool with_input_zp; + bool with_weights_zp; }; struct single_gemm_conv_chunk_desc_t { @@ -84,6 +88,28 @@ struct single_gemm_conv_chunk_desc_t { dim_t w_size_ = 0; }; +namespace gemm_convolution_utils { + +struct pp_kernel_t { + static pp_kernel_t *create( + const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); + + virtual ~pp_kernel_t() = default; + + virtual void operator()(float *dst_orig, float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride, + const std::vector& post_ops_binary_rhs_arg_vec) const = 0; + + virtual status_t create_kernel() { return status::success; } + +protected: + pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); + + bool do_bias_ = false; + post_ops_t post_ops_; +}; + +} // namespace gemm_convolution_utils + namespace jit_gemm_convolution_utils { template void im2col_3d(const conv_gemm_conf_t &jcp, const data_type_t *im, @@ -95,7 +121,7 @@ void transpose_dt(const conv_gemm_conf_t &jcp, const T *__restrict im, template void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict im, - col_dt *__restrict col, dim_t od); + col_dt *__restrict col, dim_t od, const uint8_t *__restrict input_zp = nullptr); template void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, @@ -104,7 +130,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, template void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im, void *__restrict imtr, col_dt *__restrict col, dim_t hs, dim_t hb, - dim_t ws, dim_t wb); + dim_t ws, dim_t wb, const uint8_t *__restrict input_zp = nullptr); template void col2im_dt( diff --git a/src/cpu/gemm_inner_product.hpp b/src/cpu/gemm_inner_product.hpp index 1b7df0d241e..ce32c913024 100644 --- a/src/cpu/gemm_inner_product.hpp +++ b/src/cpu/gemm_inner_product.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -120,7 +120,7 @@ struct gemm_inner_product_fwd_t : public primitive_t { return pp_kernel_->create_kernel(); } - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); @@ -163,7 +163,7 @@ struct gemm_inner_product_bwd_data_t : public primitive_t { }; gemm_inner_product_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_backward_data(ctx); @@ -208,7 +208,7 @@ struct gemm_inner_product_bwd_weights_t : public primitive_t { }; gemm_inner_product_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_backward_weights(ctx); diff --git a/src/cpu/gemm_inner_product_utils.cpp b/src/cpu/gemm_inner_product_utils.cpp index 815e953898b..2d637d543cf 100644 --- a/src/cpu/gemm_inner_product_utils.cpp +++ b/src/cpu/gemm_inner_product_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -157,17 +157,17 @@ pp_kernel_t::pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, , bias_data_type_(bias_dt) , acc_data_type_(acc_dt) , dst_data_type_(dst_md->data_type) - , do_scale_(!attr->scales_.get(DNNL_ARG_SRC).has_default_values() - || !attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) + , do_scale_(!attr->scales_.has_default_values(DNNL_ARG_SRC) + || !attr->scales_.has_default_values(DNNL_ARG_WEIGHTS)) , ndims_(dst_md->ndims) { - if (do_scale_) { - int wei_mask = attr->scales_.get(DNNL_ARG_WEIGHTS).mask_; + if (!attr->scales_.has_default_values(DNNL_ARG_WEIGHTS)) { + int wei_mask = attr->scales_.get_mask(DNNL_ARG_WEIGHTS); // matmul: per_oc: 1 << (ndims_ - 1) // ip: per_oc: 1 << 0 scale_idx_mult_ = wei_mask == (1 << (ndims_ - 1)) || wei_mask == 1 << 0; } - do_dst_scale_ = !attr->scales_.get(DNNL_ARG_DST).has_default_values(); + do_dst_scale_ = !attr->scales_.has_default_values(DNNL_ARG_DST); post_ops_ = attr->post_ops_; const int eltwise_ind = post_ops_.find(primitive_kind::eltwise); diff --git a/src/cpu/gemm_x8s8s32x_convolution.cpp b/src/cpu/gemm_x8s8s32x_convolution.cpp index 8482ae65eb0..8f464e14eae 100644 --- a/src/cpu/gemm_x8s8s32x_convolution.cpp +++ b/src/cpu/gemm_x8s8s32x_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,6 +121,9 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward( = binary_injector_utils::prepare_binary_args( this->pd()->attr()->post_ops_, ctx); + DEFINE_INPUT_ZERO_POINTS_BUFFER(input_zp_base, jcp); + DEFINE_OUTPUT_COMPENSATION_BUFFER(output_compensation_base, jcp); + auto scratchpad = ctx.get_scratchpad_grantor(); assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); @@ -135,15 +138,15 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, - pd()->IC(), pd()->OC(), false, wei_scale_mask != 0, pd()->attr()); + pd()->IC(), pd()->OC(), false, wei_scale_mask > 0, pd()->attr()); parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base, scales, dst_scales, zp, scratchpad, - post_ops_binary_rhs_arg_vec.data(), ctx); + post_ops_binary_rhs_arg_vec.data(), ctx, + input_zp_base, output_compensation_base); if (st_thr != status::success) st = st_thr; }); @@ -163,7 +166,8 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, const char *bia_base, void *dst_base, const float *scales, const float *dst_scales, const zero_point_call_params_t &zp, const memory_tracking::grantor_t &scratchpad, - const void *post_ops_binary_rhs_arg_vec, const exec_ctx_t &ctx) const { + const void *post_ops_binary_rhs_arg_vec, const exec_ctx_t &ctx, + const uint8_t *input_zp_base, const int32_t *output_compensation_base) const { const conv_gemm_conf_t &jcp = this->pd()->jcp_; @@ -190,18 +194,11 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc; const int32_t *_wei_comp - = jcp.signed_input ? get_wei_comp(wei_base, wei_md) : nullptr; - - const bool should_apply_zp_src_comp_pad = jcp.zp.src_exists - && jit_gemm_convolution_utils::padding_exists(jcp); - const bool should_apply_zp_src_comp_pad_jit_pp - = should_apply_zp_src_comp_pad - && gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel( - dst_md.data_type()); - const bool should_apply_zp_src_comp_outside_pp - = should_apply_zp_src_comp_pad - && !gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel( - dst_md.data_type()); + = jcp.signed_input ? get_wei_comp(wei_base, wei_md) : + jcp.with_input_zp ? output_compensation_base : nullptr; + + const bool should_apply_zp_src_comp_pad_jit_pp = false; + const bool should_apply_zp_src_comp_outside_pp = false; dim_t g {0}, n {0}, ohb {0}, owb {0}; dim_t start = 0, end = 0; @@ -217,7 +214,7 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, balance211(work_amount, nthr, ithr, start, end); nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); const uint8_t shift = jcp.signed_input ? 128 : 0; - parallel_nd(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; }); + parallel_nd_legacy(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; }); status_t st = status::success; @@ -237,6 +234,11 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, for (int od = 0; od < jcp.od; od++) { const auto dst_off = n * dst_mb_stride + g * dst_g_stride + ((od * jcp.oh + oh) * jcp.ow + ow) * jcp.dst_os_stride; + + const uint8_t *__restrict input_zp = nullptr; + if (jcp.with_input_zp) + input_zp = input_zp_base + g * jcp.ic; + char *__restrict dst = (char *)dst_base + types::data_type_size(dst_md.data_type()) * dst_off; if (jcp.im2col_sz) { @@ -244,20 +246,20 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, case data_type::s8: { if (is_problem_3d) jit_gemm_convolution_utils::im2col_dt_3d(jcp, imtr, col, od); + uint8_t>(jcp, imtr, col, od, input_zp); else jit_gemm_convolution_utils::im2col_dt(jcp, src, imtr, col, oh, h_step, - ow, w_step); + ow, w_step, input_zp); } break; case data_type::u8: { if (is_problem_3d) jit_gemm_convolution_utils::im2col_dt_3d(jcp, imtr, col, od); + uint8_t>(jcp, imtr, col, od, input_zp); else jit_gemm_convolution_utils::im2col_dt(jcp, src, imtr, col, oh, h_step, - ow, w_step); + ow, w_step, input_zp); } break; default: assert(!"unsupported data type"); break; } @@ -275,10 +277,10 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr, const float onef = 1.f, zerof = 0.f; const char *__restrict src_od = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic; - st = gemm_s8u8s32("N", BT, jcp.signed_input ? "C" : "F", &M, &N, &K, + st = gemm_s8u8s32("N", BT, (jcp.signed_input || jcp.with_input_zp) ? "C" : "F", &M, &N, &K, &onef, wei, &LDA, &off_a, jcp.im2col_sz ? col : (uint8_t *)src_od, &LDB, &off_b, - &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c); + &zerof, acc, &M, (jcp.signed_input || jcp.with_input_zp) ? wei_comp : &off_c); if (st != status::success) return st; @@ -358,16 +360,15 @@ status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data_thr( const auto diff_src_dt_size = types::data_type_size(diff_src_md.data_type()); - const int scale_idx_mult = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ + const int scale_idx_mult = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == (1 << static_cast(pd()->with_groups())); DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, - pd()->IC(), pd()->OC(), false, wei_scale_mask != 0, pd()->attr()); + pd()->IC(), pd()->OC(), false, wei_scale_mask > 0, pd()->attr()); const dim_t work_amount = jcp.ngroups * jcp.mb; diff --git a/src/cpu/gemm_x8s8s32x_convolution.hpp b/src/cpu/gemm_x8s8s32x_convolution.hpp index cb5cccd11b8..866d5f8927f 100644 --- a/src/cpu/gemm_x8s8s32x_convolution.hpp +++ b/src/cpu/gemm_x8s8s32x_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,9 +39,7 @@ namespace cpu { struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t { struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T(src_md()->data_type == data_type::u8 ? IGEMM_S8U8S32_IMPL_STR @@ -71,34 +69,56 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t { VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); - VDISPATCH_CONV( - attr()->has_default_values(skip_mask_t::scales_runtime - | skip_mask_t::zero_points_runtime - | skip_mask_t::post_ops - | skip_mask_t::sum_dt, - dst_type), + VDISPATCH_CONV(attr()->has_default_values(skip_mask_t::scales + | skip_mask_t::zero_points + | skip_mask_t::post_ops + | skip_mask_t::sum_dt + | primitive_attr_t::skip_mask_t::input_zero_points + | primitive_attr_t::skip_mask_t::output_compensations, + dst_type), VERBOSE_UNSUPPORTED_ATTR); - VDISPATCH_CONV(attr()->post_ops_.check_sum_consistency(dst_type, - /* is_int8 */ true), + // VDISPATCH_CONV(attr()->post_ops_.check_sum_consistency(dst_type, + // /* is_int8 */ true), + // VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); VDISPATCH_CONV(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_CONV(zero_points_valid(attr()), VERBOSE_UNSUPPORTED_ATTR); auto scratchpad = scratchpad_registry().registrar(); + // TODO: make `init_conf` assign initialized object to `jcp_` + jcp_ = conv_gemm_conf_t(); CHECK(jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads())); - VDISPATCH_CONV(gemm_x8s8s32x_convolution_utils::post_ops_ok( - attr()->post_ops_, &dst_md_), - VERBOSE_UNSUPPORTED_POSTOP); + // VDISPATCH_CONV(gemm_x8s8s32x_convolution_utils::post_ops_ok( + // attr()->post_ops_, &dst_md_), + // VERBOSE_UNSUPPORTED_POSTOP); return status::success; } - conv_gemm_conf_t jcp_; + conv_gemm_conf_t jcp_ = utils::zero(); + + protected: + bool post_ops_ok() const { + using namespace dnnl::impl::primitive_kind; + auto const &po = attr()->post_ops_; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < po.len(); i++) { + ok = ok && utils::one_of(po.entry_[i].kind, sum, binary, eltwise, depthwise, quantization); + } + return ok; + }; + + return all_post_ops_supported(); + } }; gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} @@ -121,7 +141,8 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t { const zero_point_call_params_t &zp, const memory_tracking::grantor_t &scratchpad, const void *post_ops_binary_rhs_arg_vec, - const exec_ctx_t &ctx) const; + const exec_ctx_t &ctx, + const uint8_t *input_zp_base, const int32_t *output_compensation_base) const; using pp_ker_t = gemm_x8s8s32x_convolution_utils::pp_ker_t; std::unique_ptr pp_ker_; @@ -129,9 +150,7 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t { struct gemm_x8s8s32x_convolution_bwd_data_t : public primitive_t { struct pd_t : public cpu_convolution_bwd_data_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; DECLARE_COMMON_PD_T(diff_dst_md()->data_type == data_type::u8 ? IGEMM_S8U8S32_IMPL_STR @@ -158,13 +177,15 @@ struct gemm_x8s8s32x_convolution_bwd_data_t : public primitive_t { VERBOSE_BAD_ALGORITHM); VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); - VDISPATCH_CONV( - attr()->has_default_values( - primitive_attr_t::skip_mask_t::scales_runtime), + VDISPATCH_CONV(attr()->has_default_values( + primitive_attr_t::skip_mask_t::scales), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_CONV(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); auto scratchpad = scratchpad_registry().registrar(); + + // TODO: make `init_conf` assign initialized object to `jcp_` + jcp_ = conv_gemm_conf_t(); return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, *desc(), diff_src_md_, weights_md_, diff_dst_md_, bias_md_, attr_, dnnl_get_max_threads()); @@ -172,7 +193,7 @@ struct gemm_x8s8s32x_convolution_bwd_data_t : public primitive_t { bool support_bias() const override { return true; } - conv_gemm_conf_t jcp_; + conv_gemm_conf_t jcp_ = utils::zero(); }; gemm_x8s8s32x_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} diff --git a/src/cpu/gemm_x8s8s32x_convolution_utils.cpp b/src/cpu/gemm_x8s8s32x_convolution_utils.cpp index 4d01a014b52..0df5540fe17 100644 --- a/src/cpu/gemm_x8s8s32x_convolution_utils.cpp +++ b/src/cpu/gemm_x8s8s32x_convolution_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,11 +39,29 @@ namespace gemm_x8s8s32x_convolution_utils { template struct ref_pp_ker_t : pp_ker_t { ref_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) - : pp_ker_t(pd, jcp), dst_md_(pd->dst_md()) {} + : pp_ker_t(pd, jcp) { + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + ref_eltwise_injectors_.push_back(new ref_eltwise_scalar_fwd_t(post_op.eltwise)); + } else if (post_op.is_depthwise()) { + ref_depthwise_injectors_.push_back(new ref_depthwise_scalar_fwd_t( + post_op.depthwise.alg)); + } + } + } + ~ref_pp_ker_t() { + for (auto impl : ref_eltwise_injectors_) + delete impl; + ref_eltwise_injectors_.clear(); + for (auto impl : ref_depthwise_injectors_) + delete impl; + ref_depthwise_injectors_.clear(); + } using acc_data_t = pp_ker_t::acc_data_t; - void operator()(void *dst, const acc_data_t *acc, const char *bias, + void operator()(void *dst, acc_data_t *acc, const char *bias, const float *scales, float dst_scale, float sum_scale, float signed_scale, int g, size_t start, size_t end, const zero_point_call_params_t &zp, @@ -51,88 +69,182 @@ struct ref_pp_ker_t : pp_ker_t { const exec_ctx_t &ctx, const memory_desc_t &dst_md, const single_gemm_conv_chunk_desc_t &chunk_desc) const override; - status_t create_kernel() override { - if (this->jcp_.with_eltwise || this->jcp_.with_binary) { - ref_post_ops_ - = utils::make_unique(this->jcp_.post_ops); - if (!ref_post_ops_) return status::out_of_memory; - return ref_post_ops_->init(dst_md_); - } - return status::success; - } - private: - std::unique_ptr ref_post_ops_; - const memory_desc_t *dst_md_; + nstl::vector ref_eltwise_injectors_; + nstl::vector ref_depthwise_injectors_; }; template -void ref_pp_ker_t::operator()(void *void_dst, const acc_data_t *acc, +void ref_pp_ker_t::operator()(void *void_dst, acc_data_t *acc, const char *bias, const float *scales, float dst_scale, float sum_scale, float signed_scale, int g, size_t start, size_t end, const zero_point_call_params_t &zp, - const void * /* post_ops_binary_rhs_arg_vec */, + const void * post_ops_binary_rhs_arg_vec, const void * /* dst_orig */, const exec_ctx_t &ctx, const memory_desc_t &dst_md, const single_gemm_conv_chunk_desc_t &chunk_desc) const { if (end <= start) return; - assert(data_traits::data_type == jcp_.dst_data_type); + assert(data_traits_t::data_type == dst_data_type_); + dst_data_t *dst = (dst_data_t *)void_dst; - const lldiv_t dv_start = std::div((long long)start, (long long)jcp_.oc); - const lldiv_t dv_end = std::div((long long)(end - 1), (long long)jcp_.oc); - const size_t first_oc = dv_start.rem; - const size_t last_oc = dv_end.rem; - const size_t first_os = dv_start.quot; - const size_t last_os = dv_end.quot; + const size_t first_oc = start % OC_; + const size_t last_oc = (end - 1) % OC_; + const size_t first_os = start / OC_; + const size_t last_os = (end - 1) / OC_; const int32_t zp_dst_val = jcp_.zp.dst_exists ? *(zp.dst) : 0; - ref_post_ops_t::args_t args; - args.ctx = &ctx; - args.dst_md = &dst_md; + if (post_ops_.len() == 0) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * dst_os_stride_ + oc; + + float d = (float) (acc[acc_off]); + if (jcp_.signed_input) d *= signed_scale; - for (size_t os = first_os; os <= last_os; os++) { - const size_t start_oc = (os == first_os) ? first_oc : 0; - const size_t end_oc = (os == last_os) ? last_oc : jcp_.oc - 1; - for (size_t oc = start_oc; oc <= end_oc; oc++) { - const size_t acc_off = os * jcp_.oc + oc; - const size_t dst_off = os * jcp_.dst_os_stride + oc; + if (do_bias_) + d += math::get_bias(bias, g * jcp_.oc + oc, bias_data_type_); - int32_t data_s32 = acc[acc_off]; + d *= scales[(g * jcp_.oc + oc) * jcp_.scale_idx_mult]; - if (jcp_.zp.src_exists) { - const auto oc_offset = g * jcp_.oc + oc; - data_s32 += zp.src_comp[oc_offset]; + // quantize data + if (jcp_.with_dst_scale) d *= dst_scale; + if (jcp_.zp.dst_exists) d += zp_dst_val; + + dst[dst_off] = dnnl::impl::cpu::q10n::qz_a1b0_t()(d); } + } + } else { + float* acc_fp = reinterpret_cast(acc); - float data = static_cast(data_s32); + auto load = [&](int idx, size_t oc, size_t os, size_t acc_off, size_t dst_off) { + float d; + if (idx == 0) { + d = (float) (acc[acc_off]); - if (jcp_.signed_input) data *= signed_scale; + if (jcp_.signed_input) + d *= signed_scale; - // dequantize data - data *= scales[(g * jcp_.oc + oc) * jcp_.scale_idx_mult]; + if (do_bias_) + d += math::get_bias(bias, g * jcp_.oc + oc, + bias_data_type_); - if (jcp_.with_bias) { - const float b = io::load_float_value( - jcp_.bias_data_type, bias, g * jcp_.oc + oc); - data += b; + d *= scales[(g * jcp_.oc + oc) * jcp_.scale_idx_mult]; + } else { + d = acc_fp[acc_off]; } - if (jcp_.with_sum) - data += sum_scale - * io::load_float_value( - jcp_.sum_data_type, void_dst, dst_off); - if (jcp_.with_eltwise || jcp_.with_binary) { - args.l_offset = (g * jcp_.oc + oc) * jcp_.os; - ref_post_ops_->execute(data, args); + return d; + }; + + auto store = [&](int idx, float d, size_t acc_off, size_t dst_off) { + if (idx == post_ops_.len() - 1) + dst[dst_off] = dnnl::impl::cpu::q10n::qz_a1b0_t()(d); + else + acc_fp[acc_off] = d; + }; + + auto post_ops_data_ptrs = reinterpret_cast(post_ops_binary_rhs_arg_vec); + std::size_t post_ops_data_idx = 0; + int eltwise_inj_idx = 0; + int depthwise_inj_idx = 0; + for (int i = 0; i < post_ops_.len(); i++) { + auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * this->dst_os_stride_ + oc; + + float d = load(i, oc, os, acc_off, dst_off); + + d = ref_eltwise_injectors_[eltwise_inj_idx]->compute_scalar(d); + + store(i, d, acc_off, dst_off); + } + } + eltwise_inj_idx++; + } else if (post_op.is_depthwise()) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * this->dst_os_stride_ + oc; + + auto depthwise_base = post_ops_data_ptrs[post_ops_data_idx]; + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts]; + + float d = load(i, oc, os, acc_off, dst_off); + + d = ref_depthwise_injectors_[depthwise_inj_idx]->compute_scalar(d, depthwise_weights + g * jcp_.oc + oc, + depthwise_bias + g * jcp_.oc + oc); + + store(i, d, acc_off, dst_off); + + } + } + post_ops_data_idx++; + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * this->dst_os_stride_ + oc; + + auto quant = post_op.quantization; + auto quantization_base = post_ops_data_ptrs[post_ops_data_idx]; + auto pcl = quantization_base + post_op.quantization.offset[quant.crop_low]; + auto pch = quantization_base + post_op.quantization.offset[quant.crop_high]; + auto pisc = quantization_base + post_op.quantization.offset[quant.inp_scale]; + auto pish = quantization_base + post_op.quantization.offset[quant.inp_shift]; + auto posc = quantization_base + post_op.quantization.offset[quant.output_scale]; + auto posh = quantization_base + post_op.quantization.offset[quant.output_shift]; + + float d = load(i, oc, os, acc_off, dst_off); + + int cl_idx = !quant.per_channel[quant.crop_low] ? 0 : g * jcp_.oc + oc; + int ch_idx = !quant.per_channel[quant.crop_high] ? 0 : g * jcp_.oc + oc; + int isc_idx = !quant.per_channel[quant.inp_scale] ? 0 : g * jcp_.oc + oc; + int ish_idx = !quant.per_channel[quant.inp_shift] ? 0 : g * jcp_.oc + oc; + int osc_idx = !quant.per_channel[quant.output_scale] ? 0 : g * jcp_.oc + oc; + int osh_idx = !quant.per_channel[quant.output_shift] ? 0 : g * jcp_.oc + oc; + + d = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], d)); + d = d * pisc[isc_idx] + pish[ish_idx]; + d = roundf(d); + d = d * posc[osc_idx] + posh[osh_idx]; + + store(i, d, acc_off, dst_off); + + } + } + post_ops_data_idx++; + } else if (post_op.is_sum()) { + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * this->dst_os_stride_ + oc; + + float d = load(i, oc, os, acc_off, dst_off); + + d += post_op.sum.scale * math::get_sum((char *) dst, dst_off, post_op.sum.dt); + + store(i, d, acc_off, dst_off); + } + } } - - // quantize data - if (jcp_.with_dst_scale) data *= dst_scale; - if (jcp_.zp.dst_exists) data += static_cast(zp_dst_val); - - io::store_float_value(jcp_.dst_data_type, data, void_dst, dst_off); } } } @@ -140,7 +252,23 @@ void ref_pp_ker_t::operator()(void *void_dst, const acc_data_t *acc, // Interface section pp_ker_t::pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) - : jcp_(jcp) {} + : jcp_(jcp) + , post_ops_(pd->attr()->post_ops_) + , OC_(jcp_.oc) +{ + const auto dst_md = memory_desc_wrapper(pd->dst_md()); + + dst_os_stride_ = dst_md.blocking_desc().strides[pd->ndims() - 1]; + dst_data_type_ = dst_md.data_type(); + // Use weight scale to do DQ. + do_scale_ = !pd->attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); + + do_bias_ = pd->with_bias(); + if (do_bias_) { + bias_data_type_ = pd->desc()->bias_desc.data_type; + assert(bias_data_type_ != data_type::undef); + } +} pp_ker_t *pp_ker_t::create( const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) { @@ -160,21 +288,6 @@ pp_ker_t *pp_ker_t::create( return nullptr; } -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) { -#if DNNL_X64 - return x64::gemm_x8s8s32x_convolution_utils::post_ops_ok(post_ops, dst_d); -#endif - return std::all_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(), - [](const dnnl_post_ops::entry_t &post_op) { - return post_op.is_eltwise() || post_op.is_sum() - || post_op.is_binary() || post_op.is_prelu(); - }); -} - -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d) { - const auto dst_md = memory_desc_wrapper(dst_d); - return post_ops_ok(post_ops, &dst_md); -} bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept { #if DNNL_X64 diff --git a/src/cpu/gemm_x8s8s32x_convolution_utils.hpp b/src/cpu/gemm_x8s8s32x_convolution_utils.hpp index e133222f963..86e949ea995 100644 --- a/src/cpu/gemm_x8s8s32x_convolution_utils.hpp +++ b/src/cpu/gemm_x8s8s32x_convolution_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,9 +32,9 @@ struct pp_ker_t { const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); virtual ~pp_ker_t() = default; - typedef typename prec_traits::type acc_data_t; + using acc_data_t = typename prec_traits_t::type; - virtual void operator()(void *dst, const acc_data_t *acc, const char *bias, + virtual void operator()(void *dst, acc_data_t *acc, const char *bias, const float *scales, float dst_scale, float sum_scale, float signed_scale, int g, size_t start, size_t end, const zero_point_call_params_t &zp, @@ -42,17 +42,25 @@ struct pp_ker_t { const exec_ctx_t &ctx, const memory_desc_t &dst_md, const single_gemm_conv_chunk_desc_t &chunk_desc) const = 0; + size_t dst_os_stride_; + virtual status_t create_kernel() { return status::success; } protected: pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); const conv_gemm_conf_t &jcp_; -}; + const post_ops_t &post_ops_; + size_t OC_; + + bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept; -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d); -bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d); -bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept; + bool do_bias_ = false; + bool do_scale_ = false; + + data_type_t bias_data_type_ = data_type::undef; + data_type_t dst_data_type_ = data_type::undef; +}; } // namespace gemm_x8s8s32x_convolution_utils } // namespace cpu diff --git a/src/cpu/gemm_x8s8s32x_inner_product.cpp b/src/cpu/gemm_x8s8s32x_inner_product.cpp index 341a584a276..cad125ea7be 100644 --- a/src/cpu/gemm_x8s8s32x_inner_product.cpp +++ b/src/cpu/gemm_x8s8s32x_inner_product.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,10 +64,9 @@ status_t gemm_x8s8s32x_inner_product_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); auto scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, - IC, OC, false, wei_scale_mask != 0, pd()->attr()); + IC, OC, false, wei_scale_mask > 0, pd()->attr()); int32_t *acc = pd()->dst_is_acc_ ? (int32_t *)dst diff --git a/src/cpu/gemm_x8s8s32x_inner_product.hpp b/src/cpu/gemm_x8s8s32x_inner_product.hpp index ea62c604e05..bda7860417b 100644 --- a/src/cpu/gemm_x8s8s32x_inner_product.hpp +++ b/src/cpu/gemm_x8s8s32x_inner_product.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,7 +69,7 @@ struct gemm_x8s8s32x_inner_product_fwd_t : public primitive_t { VERBOSE_UNSUPPORTED_DT); VDISPATCH_INNER_PRODUCT( attr()->has_default_values( - primitive_attr_t::skip_mask_t::scales_runtime + primitive_attr_t::skip_mask_t::scales | primitive_attr_t::skip_mask_t::post_ops, dst_md()->data_type), VERBOSE_UNSUPPORTED_ATTR); diff --git a/src/cpu/jit_utils/jit_utils.cpp b/src/cpu/jit_utils/jit_utils.cpp index 431cd71a4a9..d95484401c0 100644 --- a/src/cpu/jit_utils/jit_utils.cpp +++ b/src/cpu/jit_utils/jit_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * Copyright 2021 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +16,7 @@ *******************************************************************************/ #include +#include #include "common/utils.hpp" #include "common/verbose.hpp" @@ -31,7 +32,7 @@ #endif #if DNNL_ENABLE_JIT_PROFILING -#include "common/ittnotify/jitprofiling.h" +#include "ittnotify/jitprofiling.h" #ifdef __linux__ #include "cpu/jit_utils/linux_perf/linux_perf.hpp" #endif @@ -60,7 +61,7 @@ void dump_jit_code(const void *code, size_t code_size, const char *code_name) { // TODO (Roma): support prefix for code / linux perf dumps snprintf(fname, MAX_FNAME_LEN, DUMP_BASE_FNAME "%s" DUMP_EXT_FNAME, code_name); - + std::cout << "[ oneDNN ] dump_jit_code: " << fname << std::endl; FILE *fp = fopen(fname, "wb+"); // Failure to dump code is not fatal if (fp) { @@ -97,7 +98,7 @@ void register_jit_code_vtune(const void *code, size_t code_size, } #else if (flags & DNNL_JIT_PROFILE_VTUNE) - VERROR(primitive, jit_profiling, + VWARN(primitive, jit_profiling, "VTune Profiler integration is not supported"); #endif #else @@ -137,7 +138,9 @@ void register_jit_code(const void *code, size_t code_size, char unique_code_name[MAX_CODENAME_LEN + 1]; snprintf(unique_code_name, MAX_CODENAME_LEN, "%s.%d", code_name, unique_id++); - + if (code && get_jit_dump()) { + std::cout << "[ oneDNN ] register_jit_code: " << unique_code_name << ", " << code_name << std::endl; + } dump_jit_code(code, code_size, unique_code_name); // VTune Profiler does not need a unique name, because it uses // unique method_id diff --git a/src/cpu/jit_utils/linux_perf/linux_perf.cpp b/src/cpu/jit_utils/linux_perf/linux_perf.cpp index 2a815d77505..0dc561518d9 100644 --- a/src/cpu/jit_utils/linux_perf/linux_perf.cpp +++ b/src/cpu/jit_utils/linux_perf/linux_perf.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * Copyright 2021 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -195,7 +195,7 @@ class linux_perf_jitdump_t { } #else if (use_tsc) { - VERROR(primitive, linux_perf, + VWARN(primitive, linux_perf, "TSC timestamps is not supported. clock_gettime() is used " "instead."); } diff --git a/src/cpu/matmul/cpu_matmul_list.cpp b/src/cpu/matmul/cpu_matmul_list.cpp index 6a53d0920c6..6868128bcf0 100644 --- a/src/cpu/matmul/cpu_matmul_list.cpp +++ b/src/cpu/matmul/cpu_matmul_list.cpp @@ -1,7 +1,7 @@ /******************************************************************************* * Copyright 2019-2024 Intel Corporation -* Copyright 2024 FUJITSU LIMITED -* Copyright 2021-2024 Arm Ltd. and affiliates +* Copyright 2024-2025 FUJITSU LIMITED +* Copyright 2021-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,15 +30,23 @@ #include "cpu/x64/matmul/jit_uni_sparse_matmul.hpp" using namespace dnnl::impl::cpu::x64::matmul; using namespace dnnl::impl::cpu::x64; -#elif DNNL_AARCH64 +#endif + +#if DNNL_AARCH64 #include "cpu/aarch64/matmul/brgemm_matmul.hpp" -#ifdef DNNL_AARCH64_USE_ACL -#include "cpu/aarch64/matmul/acl_lowp_matmul.hpp" -#include "cpu/aarch64/matmul/acl_matmul.hpp" +#include "cpu/aarch64/matmul/jit_int8_matmul.hpp" #endif + +#ifdef DNNL_USE_ACL +#include "cpu/acl/matmul/acl_lowp_matmul.hpp" +#include "cpu/acl/matmul/acl_lowp_matmul_sq.hpp" +#include "cpu/acl/matmul/acl_matmul.hpp" +#if DNNL_AARCH64 using namespace dnnl::impl::cpu::aarch64::matmul; using namespace dnnl::impl::cpu::aarch64; - +#endif +using namespace dnnl::impl::cpu::acl::matmul; +using namespace dnnl::impl::cpu::acl; #endif namespace dnnl { @@ -71,25 +79,27 @@ using namespace dnnl::impl::cpu::matmul; #endif // clang-format off -constexpr impl_list_item_t impl_list[] = REG_MATMUL_P({ - - CPU_INSTANCE_AARCH64(brgemm_matmul_t) - CPU_INSTANCE_AARCH64_ACL(acl_lowp_matmul_t) - CPU_INSTANCE_AARCH64_ACL(acl_matmul_t) - CPU_INSTANCE_AARCH64(brgemm_matmul_t) - CPU_INSTANCE_AMX(brgemm_matmul_t) - CPU_INSTANCE_AMX(brgemm_matmul_t) - CPU_INSTANCE_AVX512(brgemm_matmul_t) - CPU_INSTANCE_AVX512(brgemm_matmul_t) - CPU_INSTANCE_AVX512(brgemm_matmul_t) - CPU_INSTANCE_AVX512(brgemm_matmul_t) - CPU_INSTANCE_AVX2(brgemm_matmul_t) - CPU_INSTANCE_AVX2(brgemm_matmul_t) +const impl_list_item_t impl_list[] = REG_MATMUL_P({ + + CPU_INSTANCE_AARCH64(brgemm_matmul_t) + CPU_INSTANCE_ACL(acl_lowp_matmul_sq_t) + CPU_INSTANCE_ACL(acl_lowp_matmul_t) + CPU_INSTANCE_ACL(acl_matmul_t) + CPU_INSTANCE_AARCH64(brgemm_matmul_t,sve_256) + CPU_INSTANCE_AARCH64(jit_int8_matmul_t) + CPU_INSTANCE_AMX(brgemm_matmul_t,avx512_core_amx_fp16) + CPU_INSTANCE_AMX(brgemm_matmul_t,avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_matmul_t,avx512_core_fp16) + CPU_INSTANCE_AVX512(brgemm_matmul_t,avx512_core_bf16) + CPU_INSTANCE_AVX512(brgemm_matmul_t,avx512_core_vnni) + CPU_INSTANCE_AVX512(brgemm_matmul_t,avx512_core) + CPU_INSTANCE_AVX2(brgemm_matmul_t,avx2_vnni_2) + CPU_INSTANCE_AVX2(brgemm_matmul_t,avx2_vnni) CPU_INSTANCE(gemm_f32_matmul_t) - CPU_INSTANCE(gemm_bf16_matmul_t) - CPU_INSTANCE(gemm_bf16_matmul_t) + CPU_INSTANCE(gemm_bf16_matmul_t, f32) + CPU_INSTANCE(gemm_bf16_matmul_t, bf16) CPU_INSTANCE(gemm_x8s8s32x_matmul_t) - CPU_INSTANCE_AVX2(brgemm_matmul_t) + CPU_INSTANCE_AVX2(brgemm_matmul_t, avx2) CPU_INSTANCE(ref_matmul_t) CPU_INSTANCE(ref_matmul_int8_t) // These implementations are enabled only when DNNL_EXPERIMENTAL_SPARSE @@ -112,4 +122,4 @@ const impl_list_item_t *get_matmul_impl_list(const matmul_desc_t *desc) { } // namespace cpu } // namespace impl -} // namespace dnnl +} // namespace dnnl \ No newline at end of file diff --git a/src/cpu/matmul/gemm_bf16_matmul.cpp b/src/cpu/matmul/gemm_bf16_matmul.cpp index cd415b94743..c61b54be5cf 100644 --- a/src/cpu/matmul/gemm_bf16_matmul.cpp +++ b/src/cpu/matmul/gemm_bf16_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,9 +66,9 @@ status_t gemm_bf16_matmul_t::pd_t::init(engine_t *engine) { VDISPATCH_MATMUL(x64::mayiuse(x64::avx512_core), VERBOSE_UNSUPPORTED_ISA); #endif - VDISPATCH_MATMUL(attr()->has_default_values( - primitive_attr_t::skip_mask_t::scales_runtime - | primitive_attr_t::skip_mask_t::post_ops), + VDISPATCH_MATMUL( + attr()->has_default_values(primitive_attr_t::skip_mask_t::scales + | primitive_attr_t::skip_mask_t::post_ops), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_MATMUL(attr()->post_ops_.check_sum_consistency(dst_type, /* is_int8 */ false), @@ -105,9 +105,9 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes( engine_t *engine) { auto check_attr_scales = [&]() -> bool { bool ok = attr_scales_ok(); - if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() - && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + if (!attr()->scales_.has_default_values(DNNL_ARG_SRC) + && !attr()->scales_.has_default_values(DNNL_ARG_WEIGHTS) + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } @@ -145,11 +145,15 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes( // set state CHECK(params_.pp_attr_.copy_from(*attr())); params_.gemm_applies_output_scales_ - = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 && !with_bias(); + = attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0 && !with_bias(); if (params_.gemm_applies_output_scales_) { - params_.pp_attr_.scales_.reset(DNNL_ARG_SRC); - params_.pp_attr_.scales_.reset(DNNL_ARG_WEIGHTS); + VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set( + DNNL_ARG_SRC, default_quant_entry()), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set( + DNNL_ARG_WEIGHTS, default_quant_entry()), + VERBOSE_UNSUPPORTED_SCALES_CFG); } // check post-ops @@ -203,11 +207,10 @@ status_t gemm_bf16_matmul_t::execute_ref( DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); auto scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, src_d.dims()[ndims - 1], dst_d.dims()[ndims - 1], false, - wei_scale_mask != 0, pd()->attr()); + wei_scale_mask > 0, pd()->attr()); if (src_d.has_zero_dim() || weights_d.has_zero_dim() || dst_d.has_zero_dim()) @@ -254,7 +257,7 @@ status_t gemm_bf16_matmul_t::execute_ref( const float beta = params.gemm_beta_; const dim_t acc_ldc = dst_is_acc ? ldc : N; const int scale_idx_mult - = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ + = this->pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == (1 << (ndims - 1)); std::atomic st(status::success); @@ -271,11 +274,7 @@ status_t gemm_bf16_matmul_t::execute_ref( const dim_t acc_stride = gemm_based::get_scratchpad_block_elements( batch, M, N, use_single_gemm_call, nthr); -#ifdef GCC_WA_LAMBDA_C_CAST - parallel(nthr, [= WA_THIS_COPY_CAPTURE, &st](int ithr, int nthr) { -#else parallel(nthr, [&](int ithr, int nthr) { -#endif size_t t_work_start {0}, t_work_end {0}; balance211(work_amount, nthr, ithr, t_work_start, t_work_end); diff --git a/src/cpu/matmul/gemm_bf16_matmul.hpp b/src/cpu/matmul/gemm_bf16_matmul.hpp index 0df8bdd2317..0556db01a3e 100644 --- a/src/cpu/matmul/gemm_bf16_matmul.hpp +++ b/src/cpu/matmul/gemm_bf16_matmul.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,10 +90,10 @@ struct gemm_bf16_matmul_t : public primitive_t { static constexpr data_type_t weights_type = data_type::bf16; static constexpr data_type_t acc_type = data_type::f32; - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type weights_data_t; - typedef typename prec_traits::type dst_data_t; - typedef typename prec_traits::type acc_data_t; + using src_data_t = typename prec_traits_t::type; + using weights_data_t = typename prec_traits_t::type; + using dst_data_t = typename prec_traits_t::type; + using acc_data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_ref(ctx); diff --git a/src/cpu/matmul/gemm_f32_matmul.cpp b/src/cpu/matmul/gemm_f32_matmul.cpp index de57af38944..76c9d1b7be9 100644 --- a/src/cpu/matmul/gemm_f32_matmul.cpp +++ b/src/cpu/matmul/gemm_f32_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,9 +50,9 @@ status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) { auto check_attr_scales = [&]() -> bool { bool ok = attr_scales_ok(); - if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() - && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + if (!attr()->scales_.has_default_values(DNNL_ARG_SRC) + && !attr()->scales_.has_default_values(DNNL_ARG_WEIGHTS) + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } @@ -92,11 +92,11 @@ status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) { VDISPATCH_MATMUL(is_dense_format_kind(), VERBOSE_UNSUPPORTED_SPARSE_CFG); VDISPATCH_MATMUL(problem_dt_correct, VERBOSE_UNSUPPORTED_DT_CFG); VDISPATCH_MATMUL(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); - VDISPATCH_MATMUL(attr()->has_default_values( - primitive_attr_t::skip_mask_t::scales_runtime - | primitive_attr_t::skip_mask_t::post_ops - | primitive_attr_t::skip_mask_t::sum_dt, - dst_type), + VDISPATCH_MATMUL( + attr()->has_default_values(primitive_attr_t::skip_mask_t::scales + | primitive_attr_t::skip_mask_t::post_ops + | primitive_attr_t::skip_mask_t::sum_dt, + dst_type), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_MATMUL(attr()->post_ops_.check_sum_consistency(dst_type, /* is_int8 */ false), @@ -131,10 +131,14 @@ status_t gemm_f32_matmul_t::pd_t::configure_attributes() { CHECK(params_.pp_attr_.copy_from(*attr())); params_.gemm_applies_output_scales_ - = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 && !with_bias(); + = attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0 && !with_bias(); if (params_.gemm_applies_output_scales_) { - params_.pp_attr_.scales_.reset(DNNL_ARG_SRC); - params_.pp_attr_.scales_.reset(DNNL_ARG_WEIGHTS); + VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set( + DNNL_ARG_SRC, default_quant_entry()), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set( + DNNL_ARG_WEIGHTS, default_quant_entry()), + VERBOSE_UNSUPPORTED_SCALES_CFG); } const auto &po = params_.pp_attr_.post_ops_; @@ -186,11 +190,10 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); auto scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, src_d.dims()[ndims - 1], dst_d.dims()[ndims - 1], false, - wei_scale_mask != 0, pd()->attr()); + wei_scale_mask > 0, pd()->attr()); if (src_d.has_zero_dim() || weights_d.has_zero_dim() || dst_d.has_zero_dim()) @@ -237,7 +240,7 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const dim_t acc_ldc = dst_is_acc ? ldc : N; const int scale_idx_mult - = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ + = this->pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == (1 << (ndims - 1)); std::atomic st(status::success); diff --git a/src/cpu/matmul/gemm_f32_matmul.hpp b/src/cpu/matmul/gemm_f32_matmul.hpp index 447de227565..dac4206c198 100644 --- a/src/cpu/matmul/gemm_f32_matmul.hpp +++ b/src/cpu/matmul/gemm_f32_matmul.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -88,10 +88,10 @@ struct gemm_f32_matmul_t : public primitive_t { static constexpr data_type_t dst_type = data_type::f32; static constexpr data_type_t acc_type = data_type::f32; - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type weights_data_t; - typedef typename prec_traits::type dst_data_t; - typedef typename prec_traits::type acc_data_t; + using src_data_t = typename prec_traits_t::type; + using weights_data_t = typename prec_traits_t::type; + using dst_data_t = typename prec_traits_t::type; + using acc_data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_ref(ctx); diff --git a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp index 5fab321d7af..a9a7e209928 100644 --- a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp +++ b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,17 +61,27 @@ status_t gemm_x8s8s32x_matmul_t::pd_t::init(engine_t *engine) { auto check_attr_scales = [&]() -> bool { bool ok = attr_scales_ok(); - if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() - && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + if (!attr()->scales_.has_default_values(DNNL_ARG_SRC) + && !attr()->scales_.has_default_values(DNNL_ARG_WEIGHTS) + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } return ok; }; - auto check_attr_zero_points - = [&]() -> bool { return attr()->zero_points_.common(); }; + auto check_attr_zero_points = [&]() -> bool { + const auto &zp = attr()->zero_points_; + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; + for (int arg : supported_args) { + if (!zp.has_default_values(arg)) { + const int mask = zp.get_mask(arg); + if (mask > 0) return false; + } + } + return true; + }; auto check_attr_post_ops = [&]() -> bool { using namespace primitive_kind; @@ -117,9 +127,8 @@ status_t gemm_x8s8s32x_matmul_t::pd_t::init(engine_t *engine) { VDISPATCH_MATMUL(check_attr_scales(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_MATMUL(check_attr_zero_points(), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_MATMUL( - attr()->has_default_values( - primitive_attr_t::skip_mask_t::scales_runtime - | primitive_attr_t::skip_mask_t::zero_points_runtime + attr()->has_default_values(primitive_attr_t::skip_mask_t::scales + | primitive_attr_t::skip_mask_t::zero_points | primitive_attr_t::skip_mask_t::post_ops | primitive_attr_t::skip_mask_t::sum_dt, dst_md()->data_type), @@ -203,11 +212,10 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); auto &scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, src_d.dims()[ndims - 1], dst_d.dims()[ndims - 1], false, - wei_scale_mask != 0, pd()->attr()); + wei_scale_mask > 0, pd()->attr()); DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINT_VALUE(weights_zero_point, DNNL_ARG_WEIGHTS); @@ -245,7 +253,9 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const char transB = helper.transB(); const dim_t lda = helper.lda(); const dim_t ldb = helper.ldb(); - const dim_t ldc = helper.ldc(); + const dim_t ldc = dst_d.ndims() == 2 && dst_d.count_non_unit_dims(1) + ? N + : helper.ldc(); const int ldx_dim_idx = pd()->ndims() - 2; const dim_t *src_strides = &src_d.blocking_desc().strides[ldx_dim_idx]; const dim_t *weights_strides @@ -276,7 +286,7 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const float beta = params.gemm_beta_; const dim_t acc_ldc = dst_is_acc ? ldc : N; const int scale_idx_mult - = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ + = this->pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == (1 << (ndims - 1)); std::atomic st(status::success); @@ -297,11 +307,7 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { bool postops_in_matmul = need_post_processing(pd(), dst_zero_point_f32); assert(IMPLICATION(postops_in_matmul, params.has_pp_kernel_)); -#ifdef GCC_WA_LAMBDA_C_CAST - parallel(nthr, [= WA_THIS_COPY_CAPTURE, &st](int ithr, int nthr) { -#else parallel(nthr, [&](int ithr, int nthr) { -#endif size_t t_work_start {0}, t_work_end {0}; balance211(work_amount, nthr, ithr, t_work_start, t_work_end); diff --git a/src/cpu/matmul/matmul_utils.hpp b/src/cpu/matmul/matmul_utils.hpp index 28b3c7310d3..c1e9011bd82 100644 --- a/src/cpu/matmul/matmul_utils.hpp +++ b/src/cpu/matmul/matmul_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * Copyright 2022 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,6 +19,7 @@ #define CPU_MATMUL_UTILS_HPP #include "common/memory_desc_wrapper.hpp" +#include "common/tag_traits.hpp" #include "common/utils.hpp" #include "cpu/binary_injector_utils.hpp" @@ -150,6 +151,50 @@ struct matmul_helper_t { return true; } + // TODO: consolidate these functions with ones in simple_reorder.hpp, as they + // are copy-pasted, and address TODOs from there. + static status_t get_quant_md(memory_desc_t &md, const int ndims, + const dims_t in_dims, const int quant_mask, const dim_t g0, + const dim_t g1, const data_type_t dt) { + if (dt == data_type::undef || quant_mask < 0) { + md = glob_zero_md; + return status::success; + } + + dims_t quant_dims {}; + utils::copy_dims_with_mask(quant_dims, in_dims, ndims, quant_mask, + /* fill_with_ones = */ true); + if (ndims >= 2) { + quant_dims[ndims - 1] /= g1; + quant_dims[ndims - 2] /= g0; + } + + CHECK(memory_desc_init_by_tag( + md, ndims, quant_dims, dt, get_abx_tag(ndims))); + return status::success; + } + + static dim_t get_quant_off(const dims_t &input_idx, const int ndims, + const int quant_mask, const dim_t g0, const dim_t g1, + const memory_desc_t &quant_md) { + if (types::is_zero_md(&quant_md)) return 0; + + dims_t quant_idx {}; + utils::array_copy(quant_idx, input_idx, ndims); + utils::apply_mask_on_dims(quant_idx, ndims, quant_mask); + // Note: an `idx` must divide by a group value as grouped quantization + // applies to consecutive points. + // Using quant dimensions in `l_dims_by_l_offset` will lead to wrapping + // around dimensions instead of applying consecutively. + if (ndims >= 2) { + quant_idx[ndims - 1] /= g1; + quant_idx[ndims - 2] /= g0; + } + + const memory_desc_wrapper q_mdw(quant_md); + return q_mdw.off_v(quant_idx); + } + private: mdw_t src_md_; mdw_t weights_md_; diff --git a/src/cpu/matmul/ref_matmul.cpp b/src/cpu/matmul/ref_matmul.cpp index c1a182e2721..19dc7ea2894 100644 --- a/src/cpu/matmul/ref_matmul.cpp +++ b/src/cpu/matmul/ref_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include #include +#include #include "common/c_types_map.hpp" #include "common/dnnl_thread.hpp" #include "common/math_utils.hpp" @@ -87,17 +88,14 @@ status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const auto &attr_zps = pd()->attr()->zero_points_; const bool with_wei_zero_points = !attr_zps.has_default_values(DNNL_ARG_WEIGHTS); - int wei_zp_mask = 0; - attr_zps.get(DNNL_ARG_WEIGHTS, &wei_zp_mask); - const bool wei_zp_per_n = wei_zp_mask & pd()->wei_qmask_N(); - const bool wei_zp_per_k = wei_zp_mask & pd()->wei_qmask_K(); - const dim_t wei_zp_stride_n = wei_zp_per_n ? 1 : 0; - const dim_t wei_zp_stride_k = wei_zp_per_k ? wei_zp_per_n ? N : 1 : 0; + int wei_zp_mask = attr_zps.get_mask(DNNL_ARG_WEIGHTS); const auto &wei_zp_dt = attr_zps.get_data_type(DNNL_ARG_WEIGHTS); - const auto wei_zp_group_ndims = attr_zps.get_groups_ndims(DNNL_ARG_WEIGHTS); - const auto wei_zp_group_k = wei_zp_group_ndims > 0 - ? attr_zps.get_groups(DNNL_ARG_WEIGHTS)[0] - : 1; + const auto wei_zp_group_k = attr_zps.get_group(DNNL_ARG_WEIGHTS, 0); + const auto wei_zp_group_n = attr_zps.get_group(DNNL_ARG_WEIGHTS, 1); + // Initialize a memory desc for quant entries for easier offset calculation. + memory_desc_t wei_zp_md {}; + CHECK(matmul_helper_t::get_quant_md(wei_zp_md, ndims, weights_d.dims(), + wei_zp_mask, wei_zp_group_k, wei_zp_group_n, wei_zp_dt)); const int src_mask = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims); @@ -108,25 +106,23 @@ status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { // arg scales section const auto &attr_scales = pd()->attr()->scales_; - const bool with_src_scales - = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); + const bool with_src_scales = !attr_scales.has_default_values(DNNL_ARG_SRC); const bool with_wei_scales - = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); - const bool with_dst_scales - = !attr_scales.get(DNNL_ARG_DST).has_default_values(); - const auto wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - const bool wei_scale_per_n = wei_scale_mask & pd()->wei_qmask_N(); - const bool wei_scale_per_k = wei_scale_mask & pd()->wei_qmask_K(); - const dim_t wei_scale_stride_n = wei_scale_per_n ? 1 : 0; - const dim_t wei_scale_stride_k - = wei_scale_per_k ? wei_scale_per_n ? N : 1 : 0; - const auto &wei_scale_dt = attr_scales.get(DNNL_ARG_WEIGHTS).data_type_; - const auto scales_d + = !attr_scales.has_default_values(DNNL_ARG_WEIGHTS); + const bool with_dst_scales = !attr_scales.has_default_values(DNNL_ARG_DST); + const auto wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + const dim_t wei_scale_stride_n + = (wei_scale_mask & pd()->wei_qmask_N()) ? 1 : 0; + const auto &wei_scale_dt = attr_scales.get_data_type(DNNL_ARG_WEIGHTS); + const auto wei_scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); - const auto wei_scale_group_ndim = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; - const auto wei_scale_group_k = wei_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[0] - : 1; + const auto wei_scale_group_k = attr_scales.get_group(DNNL_ARG_WEIGHTS, 0); + const auto wei_scale_group_n = attr_scales.get_group(DNNL_ARG_WEIGHTS, 1); + // Initialize a memory desc for quant entries for easier offset calculation. + memory_desc_t wei_scale_md {}; + CHECK(matmul_helper_t::get_quant_md(wei_scale_md, ndims, weights_d.dims(), + wei_scale_mask, wei_scale_group_k, wei_scale_group_n, + wei_scale_dt)); auto dst_rnd_mode = pd()->attr()->rounding_mode_.get(DNNL_ARG_DST); @@ -152,17 +148,24 @@ status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { weights_d.data_type(), weights, weights_off); // weights decompression should happen before the operation if (with_wei_decompression) { - if (with_wei_zero_points) - w -= io::load_float_value(wei_zp_dt, wei_zero_points, - wei_zp_stride_n * n - + wei_zp_stride_k * (k / wei_zp_group_k)); + if (with_wei_zero_points) { + const dim_t wei_zp_offset = matmul_helper_t::get_quant_off( + weights_dims_idx, ndims, wei_zp_mask, + wei_zp_group_k, wei_zp_group_n, wei_zp_md); + const auto wei_zp = io::load_int_value( + wei_zp_dt, wei_zero_points, wei_zp_offset); + w -= wei_zp; + } if (with_wei_scales) { - float wei_scale = scales_d.nelems() == 1 + const dim_t wei_scale_offset + = matmul_helper_t::get_quant_off(weights_dims_idx, + ndims, wei_scale_mask, wei_scale_group_k, + wei_scale_group_n, wei_scale_md); + // Single scale value was already converted into f32. + const float wei_scale = wei_scales_d.nelems() == 1 ? wei_scales[0] - : io::load_float_value(wei_scale_dt, wei_scales, - wei_scale_stride_n * n - + wei_scale_stride_k - * (k / wei_scale_group_k)); + : io::load_float_value( + wei_scale_dt, wei_scales, wei_scale_offset); w *= wei_scale; } } @@ -182,36 +185,52 @@ status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type()); bool with_dropout = !pd()->attr()->dropout_.has_default_values(); - // computations - parallel_nd(batch, M, N, [&](dim_t mb, dim_t m, dim_t n) { - dims_t dst_dims_idx; - // account for M, N dims for index calculations - const size_t l_offset = mb * M * N + m * N + n; - utils::l_dims_by_l_offset(dst_dims_idx, l_offset, dst_d.dims(), ndims); - float d = ker(dst_dims_idx, m, n); - if (with_src_scales) d *= src_scales[0]; - if (with_wei_scales && !with_wei_decompression) - d *= wei_scales[wei_scale_stride_n * n]; - if (bias) d += ker_bias(dst_dims_idx); - - const auto dst_off = dst_d.off_v(dst_dims_idx); - if (non_default_attrs) { - if (with_dropout) - d = ref_dropout(d, dropout_mask, dst_off, *p, *seed); - ref_post_ops_t::args_t args; - args.dst_val = io::load_float_value(sum_dt, dst, dst_off); - args.ctx = &ctx; - args.l_offset = l_offset; - args.dst_md = pd()->dst_md(); - ref_post_ops->execute(d, args); - } - if (with_dst_scales) d *= dst_scales[0]; - if (dst_rnd_mode == rounding_mode::stochastic) - d = math::stochastic_round_fwd( - d, dst_off, rnd_seed[0], dst_d.data_type()); - io::store_float_value(dst_d.data_type(), d, dst, dst_off); - utils::dim_iterator(dst_d.dims(), dst_dims_idx, batch_ndims); - }); + // computations Note: If dst type is < 8 bits, we cannot split a + // byte during store or we get a race condition. To simplify + // logic, we limit parallelization on M and N by a factor of 2. + parallel_nd(batch, utils::div_up(M, 2), utils::div_up(N, 2), + [&](dim_t mb, dim_t m_, dim_t n_) { + for_(int m = 2 * m_; m < std::min(2 * (m_ + 1), M); m++) + for (int n = 2 * n_; n < std::min(2 * (n_ + 1), N); n++) { + dims_t dst_dims_idx; + // account for M, N dims for index calculations + const size_t l_offset = mb * M * N + m * N + n; + utils::l_dims_by_l_offset( + dst_dims_idx, l_offset, dst_d.dims(), ndims); + float d = ker(dst_dims_idx, m, n); + if (with_src_scales) d *= src_scales[0]; + if (with_wei_scales && !with_wei_decompression) { + // Single scale value was already converted into f32. + const float wei_scale = wei_scales_d.nelems() == 1 + ? wei_scales[0] + : io::load_float_value(wei_scale_dt, wei_scales, + wei_scale_stride_n * n); + d *= wei_scale; + } + if (bias) d += ker_bias(dst_dims_idx); + + const auto dst_off = dst_d.off_v(dst_dims_idx); + if (non_default_attrs) { + if (with_dropout) + d = ref_dropout( + d, dropout_mask, dst_off, *p, *seed); + ref_post_ops_t::args_t args; + args.dst_val + = io::load_float_value(sum_dt, dst, dst_off); + args.ctx = &ctx; + args.l_offset = l_offset; + args.dst_md = pd()->dst_md(); + ref_post_ops->execute(d, args); + } + if (with_dst_scales) d *= dst_scales[0]; + if (dst_rnd_mode == rounding_mode::stochastic) + d = math::stochastic_round_fwd( + d, dst_off, rnd_seed[0], dst_d.data_type()); + io::store_float_value(dst_d.data_type(), d, dst, dst_off); + utils::dim_iterator( + dst_d.dims(), dst_dims_idx, batch_ndims); + } + }); return status::success; } diff --git a/src/cpu/matmul/ref_matmul.hpp b/src/cpu/matmul/ref_matmul.hpp index 1c2d3ea392b..19dc04adc0a 100644 --- a/src/cpu/matmul/ref_matmul.hpp +++ b/src/cpu/matmul/ref_matmul.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,78 +49,110 @@ struct ref_matmul_t : public primitive_t { const auto bia_type = weights_md(1)->data_type; const auto dst_type = dst_md(0)->data_type; - bool ok = is_dense_format_kind() - && utils::one_of(src_type, f32, bf16, f16, f8_e5m2, f8_e4m3) - && utils::one_of(wei_type, f32, bf16, f16, f8_e5m2, f8_e4m3, - u8, s8, u4, s4) - && utils::one_of(dst_type, f32, bf16, f16, f8_e5m2, f8_e4m3) - && (src_type == wei_type - || utils::one_of(wei_type, u8, s8, u4, s4)) - /* int8 weights decompression support */ - && IMPLICATION(utils::one_of(wei_type, u8, s8), - attr_.mayiconvert(wei_type, src_type)) - && IMPLICATION(src_type == f32, dst_type == f32) - && IMPLICATION(src_type == bf16, - utils::one_of(dst_type, f32, bf16)) - && IMPLICATION( - src_type == f16, utils::one_of(dst_type, f32, f16)) - // TODO: any implication on allowed dst data type for fp8? - && IMPLICATION(with_bias(), + VDISPATCH_MATMUL( + is_dense_format_kind(), VERBOSE_UNSUPPORTED_SPARSE_CFG); + VDISPATCH_MATMUL(utils::one_of(src_type, f32, bf16, f16, f8_e5m2, + f8_e4m3, f4_e2m1, f4_e3m0), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL(utils::one_of(wei_type, f32, bf16, f16, f8_e5m2, + f8_e4m3, f4_e2m1, f4_e3m0, u8, s8, u4, s4), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL(utils::one_of(dst_type, f32, bf16, f16, f8_e5m2, + f8_e4m3, f4_e2m1, f4_e3m0), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL((src_type == wei_type + || utils::one_of(wei_type, bf16, f16, u8, + s8, u4, s4, f4_e3m0)), + VERBOSE_UNSUPPORTED_DT); + /* int8 weights decompression support */ + VDISPATCH_MATMUL(IMPLICATION(utils::one_of(wei_type, u8, s8), + attr_.mayiconvert(wei_type, src_type)), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL(IMPLICATION(src_type == f32, dst_type == f32), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL(IMPLICATION(src_type == bf16, + utils::one_of(dst_type, f32, bf16)), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL(IMPLICATION(src_type == f16, + utils::one_of(dst_type, f32, f16)), + VERBOSE_UNSUPPORTED_DT); + // TODO: any implication on allowed dst data type for fp8? + VDISPATCH_MATMUL( + IMPLICATION(with_bias(), utils::one_of( bia_type, f32, bf16, f16, f8_e5m2, f8_e4m3) && IMPLICATION( - src_type == f32, bia_type == f32) - && IMPLICATION(src_type == f16, + wei_type == f32, bia_type == f32) + && IMPLICATION(wei_type == f16, utils::one_of(bia_type, f32, f16)) - && IMPLICATION(src_type == bf16, + && IMPLICATION(wei_type == bf16, utils::one_of(bia_type, f32, bf16)) // TODO: any implication on allowed bias // data type for fp8? - ) - && platform::has_data_type_support(src_type) - && attr()->has_default_values( - smask_t::scales_runtime_data_type - | smask_t::scales_runtime_groups - | smask_t::zero_points_runtime_data_type - | smask_t::zero_points_runtime_groups + ), + VERBOSE_UNSUPPORTED_BIAS_CFG); + VDISPATCH_MATMUL(platform::has_data_type_support(src_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL( + attr()->has_default_values(smask_t::scales_data_type + | smask_t::scales_groups + | smask_t::zero_points_data_type + | smask_t::zero_points_groups | smask_t::post_ops | smask_t::sum_dt | smask_t::fpmath_mode | smask_t::dropout | smask_t::rounding_mode, - dst_type) - && attr_.post_ops_.check_sum_consistency(dst_type, - /* is_int8 */ false) - && ref_post_ops_t::primitive_kind_ok(attr()->post_ops_) - && attr_scales_ok() && set_default_formats() - && zero_points_ok() - && attr_.set_default_formats(dst_md(0)) == status::success - && IMPLICATION(!attr_.dropout_.has_default_values(), + dst_type), + VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_MATMUL(attr_.post_ops_.check_sum_consistency(dst_type, + /* is_int8 */ false), + VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_MATMUL( + ref_post_ops_t::primitive_kind_ok(attr()->post_ops_), + VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_MATMUL(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_MATMUL(zero_points_ok(), VERBOSE_UNSUPPORTED_ZP_CFG); + VDISPATCH_MATMUL( + attr_.set_default_formats(dst_md(0)) == status::success, + VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_MATMUL( + IMPLICATION(!attr_.dropout_.has_default_values(), utils::one_of( attr_.dropout_.dropout_desc_.data_type, u8, - s8)) - && IMPLICATION(!attr_.dropout_.has_default_values(), + s8)), + VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_MATMUL( + IMPLICATION(!attr_.dropout_.has_default_values(), memory_desc_wrapper(dst_md(0)).similar_to( - attr_.dropout_.dropout_desc_, true, false)); - return ok ? status::success : status::unimplemented; + attr_.dropout_.dropout_desc_, true, false)), + VERBOSE_UNSUPPORTED_ATTR); + + return status::success; } private: bool zero_points_ok() const { + const auto &zp = attr()->zero_points_; + if (!zp.has_default_values(DNNL_ARG_SRC)) { return false; } /* weights decompression requires zero points support */ - int mask_wei = 0; - attr()->zero_points_.get(DNNL_ARG_WEIGHTS, &mask_wei); - const auto wei_group_ndims - = attr()->zero_points_.get_groups_ndims(DNNL_ARG_WEIGHTS); - const auto wei_group_dims - = attr()->zero_points_.get_groups(DNNL_ARG_WEIGHTS); - - return attr()->zero_points_.has_default_values(DNNL_ARG_SRC) - && attr()->zero_points_.has_default_values(DNNL_ARG_DST) - && utils::one_of(mask_wei, 0, wei_qmask_N(), - wei_qmask_N() + wei_qmask_K()) - && utils::one_of(wei_group_ndims, 0, 2) - && IMPLICATION(wei_group_ndims == 2, - wei_group_dims[1] == 1 - && K() % wei_group_dims[0] == 0); + if (!zp.has_default_values(DNNL_ARG_WEIGHTS)) { + if (!zp.get(DNNL_ARG_WEIGHTS).has_default_groups()) { + const auto gK = zp.get_group(DNNL_ARG_WEIGHTS, 0); + bool ok = IMPLICATION(gK > 1, K() % gK == 0); + if (!ok) return false; + + const auto gN = zp.get_group(DNNL_ARG_WEIGHTS, 1); + ok = IMPLICATION(gN > 1, N() % gN == 0); + if (!ok) return false; + + // Only one non-unit group is supported. + ok = utils::one_of(1, gK, gN); + if (!ok) return false; + } + } + if (!zp.has_default_values(DNNL_ARG_DST)) { return false; } + + return true; } }; diff --git a/src/cpu/matmul/ref_matmul_int8.cpp b/src/cpu/matmul/ref_matmul_int8.cpp index e336f46b3c0..daea0cfa3d9 100644 --- a/src/cpu/matmul/ref_matmul_int8.cpp +++ b/src/cpu/matmul/ref_matmul_int8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); + DEFINE_ZERO_POINTS_BUFFER(src_zero_points, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(wei_zero_points, DNNL_ARG_WEIGHTS); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); @@ -78,24 +78,26 @@ status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { const auto &attr_zps = pd()->attr()->zero_points_; const bool with_src_zero_points = !attr_zps.has_default_values(DNNL_ARG_SRC); + int src_zp_mask = attr_zps.get_mask(DNNL_ARG_SRC); + const auto &src_zp_dt = attr_zps.get_data_type(DNNL_ARG_SRC); + const auto src_zp_group_k = attr_zps.get_group(DNNL_ARG_SRC, 1); + const auto src_zp_ngroups_k = K / src_zp_group_k; + // Initialize a memory desc for quant entries for easier offset calculation. + memory_desc_t src_zp_md {}; + CHECK(matmul_helper_t::get_quant_md(src_zp_md, ndims, src_d.dims(), + src_zp_mask, 1, src_zp_group_k, src_zp_dt)); + const bool with_wei_zero_points = !attr_zps.has_default_values(DNNL_ARG_WEIGHTS); - int src_zp_mask = 0; - int wei_zp_mask = 0; - attr_zps.get(DNNL_ARG_SRC, &src_zp_mask); - attr_zps.get(DNNL_ARG_WEIGHTS, &wei_zp_mask); - const bool src_zp_per_k = src_zp_mask & pd()->src_qmask_K(); - const bool wei_zp_per_n = wei_zp_mask & pd()->wei_qmask_N(); - const bool wei_zp_per_k = wei_zp_mask & pd()->wei_qmask_K(); + int wei_zp_mask = attr_zps.get_mask(DNNL_ARG_WEIGHTS); const auto &wei_zp_dt = attr_zps.get_data_type(DNNL_ARG_WEIGHTS); - const auto wei_zp_group_ndims = attr_zps.get_groups_ndims(DNNL_ARG_WEIGHTS); - const auto wei_zp_group_k = wei_zp_group_ndims > 0 - ? attr_zps.get_groups(DNNL_ARG_WEIGHTS)[0] - : (wei_zp_per_k ? 1 : K); - const dim_t src_zp_stride_k = src_zp_per_k ? 1 : 0; - const dim_t wei_zp_stride_n = wei_zp_per_n ? 1 : 0; - const dim_t wei_zp_stride_k = wei_zp_group_k < K ? wei_zp_per_n ? N : 1 : 0; + const auto wei_zp_group_k = attr_zps.get_group(DNNL_ARG_WEIGHTS, 0); + const auto wei_zp_group_n = attr_zps.get_group(DNNL_ARG_WEIGHTS, 1); const auto wei_zp_ngroups_k = K / wei_zp_group_k; + // Initialize a memory desc for quant entries for easier offset calculation. + memory_desc_t wei_zp_md {}; + CHECK(matmul_helper_t::get_quant_md(wei_zp_md, ndims, weights_d.dims(), + wei_zp_mask, wei_zp_group_k, wei_zp_group_n, wei_zp_dt)); const int src_mask = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims); @@ -105,46 +107,40 @@ status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { = utils::get_dims_mask(dst_d.dims(), bia_d.dims(), ndims); // zp_idx_mult = 1 for per_dim1 zero points and 0, otherwise - const int dst_zp_idx_mult = !attr_zps.common(DNNL_ARG_DST); + const int dst_zp_idx_mult = !attr_zps.has_default_values(DNNL_ARG_DST) + && attr_zps.get_mask(DNNL_ARG_DST) > 0; // arg scales section const auto &attr_scales = pd()->attr()->scales_; - const bool with_src_scales - = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); const bool with_wei_scales - = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); - const bool with_dst_scales - = !attr_scales.get(DNNL_ARG_DST).has_default_values(); - const int src_scale_mask = attr_scales.get(DNNL_ARG_SRC).mask_; - const int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - const auto &src_scales_dt = attr_scales.get_data_type(DNNL_ARG_SRC); - const auto &wei_scales_dt = attr_scales.get_data_type(DNNL_ARG_WEIGHTS); - const bool src_scale_per_k = src_scale_mask & pd()->src_qmask_K(); - const bool src_scale_per_m = src_scale_mask & pd()->src_qmask_M(); - const bool wei_scale_per_n = wei_scale_mask & pd()->wei_qmask_N(); - const bool wei_scale_per_k = wei_scale_mask & pd()->wei_qmask_K(); - const auto src_scale_group_ndim = attr_scales.get(DNNL_ARG_SRC).ndims_; - const auto wei_scale_group_ndim = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; - const auto src_scale_group_k = src_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_SRC).group_dims_[1] - : (src_scale_per_k ? 1 : K); - const auto wei_scale_group_k = wei_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[0] - : (wei_scale_per_k ? 1 : K); - const auto src_scale_ngroups_k = K / src_scale_group_k; + = !attr_scales.has_default_values(DNNL_ARG_WEIGHTS); + const bool with_dst_scales = !attr_scales.has_default_values(DNNL_ARG_DST); + const int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + const auto &wei_scale_dt = attr_scales.get_data_type(DNNL_ARG_WEIGHTS); + const auto wei_scale_group_k = attr_scales.get_group(DNNL_ARG_WEIGHTS, 0); + const auto wei_scale_group_n = attr_scales.get_group(DNNL_ARG_WEIGHTS, 1); const auto wei_scale_ngroups_k = K / wei_scale_group_k; - const dim_t wei_scale_stride_n = wei_scale_per_n ? 1 : 0; - const dim_t src_scale_stride_k = src_scale_group_k < K ? 1 : 0; - const dim_t wei_scale_stride_k - = wei_scale_group_k < K ? wei_scale_per_n ? N : 1 : 0; - const dim_t src_scale_stride_m = src_scale_per_m - ? src_scale_group_k < K ? src_scale_ngroups_k : 1 - : 0; - const auto scale_ngroups_k - = std::max(src_scale_ngroups_k, wei_scale_ngroups_k); + // Initialize a memory desc for quant entries for easier offset calculation. + memory_desc_t wei_scale_md {}; + CHECK(matmul_helper_t::get_quant_md(wei_scale_md, ndims, weights_d.dims(), + wei_scale_mask, wei_scale_group_k, wei_scale_group_n, + wei_scale_dt)); + + const bool with_src_scales = !attr_scales.has_default_values(DNNL_ARG_SRC); + const int src_scale_mask = attr_scales.get_mask(DNNL_ARG_SRC); + const auto &src_scale_dt = attr_scales.get_data_type(DNNL_ARG_SRC); + const auto src_scale_group_k = attr_scales.get_group(DNNL_ARG_SRC, 1); + const auto src_scale_ngroups_k = K / src_scale_group_k; + // Initialize a memory desc for quant entries for easier offset calculation. + memory_desc_t src_scale_md {}; + CHECK(matmul_helper_t::get_quant_md(src_scale_md, ndims, src_d.dims(), + src_scale_mask, 1, src_scale_group_k, src_scale_dt)); // For compute kernel, the minimal group is picked. - const auto ngroups_k = std::max(wei_zp_ngroups_k, scale_ngroups_k); + const auto zp_ngroups_k = std::max(src_zp_ngroups_k, wei_zp_ngroups_k); + const auto scale_ngroups_k + = std::max(src_scale_ngroups_k, wei_scale_ngroups_k); + const auto ngroups_k = std::max(zp_ngroups_k, scale_ngroups_k); const auto group_k = K / ngroups_k; // mm kernel @@ -161,6 +157,7 @@ status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { utils::copy_dims_with_mask(src_dims_idx, dst_dims_idx, ndims, src_mask); utils::copy_dims_with_mask( weights_dims_idx, dst_dims_idx, ndims, wei_mask); + src_dims_idx[ndims - 2] = m; weights_dims_idx[ndims - 1] = n; auto &src_k_dim = src_dims_idx[ndims - 1]; @@ -176,12 +173,17 @@ status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { int w = io::load_int_value( weights_d.data_type(), weights, weights_off); if (with_src_zero_points) { - s -= io::load_int_value(data_type::s32, src_zero_point, - src_zp_stride_k * k); + const dim_t src_zp_offset = matmul_helper_t::get_quant_off( + src_dims_idx, ndims, src_zp_mask, 1, src_zp_group_k, + src_zp_md); + const auto src_zp = io::load_int_value( + src_zp_dt, src_zero_points, src_zp_offset); + s -= src_zp; } if (with_wei_zero_points) { - const auto wei_zp_offset = wei_zp_stride_n * n - + wei_zp_stride_k * (wei_k_dim / wei_zp_group_k); + const dim_t wei_zp_offset = matmul_helper_t::get_quant_off( + weights_dims_idx, ndims, wei_zp_mask, + wei_zp_group_k, wei_zp_group_n, wei_zp_md); const auto wei_zp = io::load_int_value( wei_zp_dt, wei_zero_points, wei_zp_offset); w -= wei_zp; @@ -192,24 +194,25 @@ status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { // Apply scaling after computing a group. float acc_f = static_cast(acc); if (with_src_scales) { + const dim_t src_scale_offset = matmul_helper_t::get_quant_off( + src_dims_idx, ndims, src_scale_mask, 1, + src_scale_group_k, src_scale_md); // Single scale value was already converted into f32. - const auto src_scale_offset - = src_scale_stride_k * (src_k_dim / src_scale_group_k) - + src_scale_stride_m * m; const float src_scale = src_scales_d.nelems() == 1 ? src_scales[0] : io::load_float_value( - src_scales_dt, src_scales, src_scale_offset); + src_scale_dt, src_scales, src_scale_offset); acc_f *= src_scale; } if (with_wei_scales) { + const dim_t wei_scale_offset = matmul_helper_t::get_quant_off( + weights_dims_idx, ndims, wei_scale_mask, + wei_scale_group_k, wei_scale_group_n, wei_scale_md); // Single scale value was already converted into f32. - const auto wei_scale_offset = wei_scale_stride_n * n - + wei_scale_stride_k * (wei_k_dim / wei_scale_group_k); const float wei_scale = wei_scales_d.nelems() == 1 ? wei_scales[0] : io::load_float_value( - wei_scales_dt, wei_scales, wei_scale_offset); + wei_scale_dt, wei_scales, wei_scale_offset); acc_f *= wei_scale; } d += acc_f; diff --git a/src/cpu/matmul/ref_matmul_int8.hpp b/src/cpu/matmul/ref_matmul_int8.hpp index e0ac608d8c6..d9cee62d993 100644 --- a/src/cpu/matmul/ref_matmul_int8.hpp +++ b/src/cpu/matmul/ref_matmul_int8.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,50 +48,84 @@ struct ref_matmul_int8_t : public primitive_t { const auto bia_type = weights_md(1)->data_type; const auto dst_type = dst_md(0)->data_type; - bool ok = is_dense_format_kind() && utils::one_of(src_type, s8, u8) - && utils::one_of(wei_type, s8, u8, s4, u4) - && IMPLICATION(with_bias(), - utils::one_of( - bia_type, f32, bf16, f16, s32, s8, u8)) - && utils::one_of(dst_type, f32, bf16, f16, s32, s8, u8) - && attr()->has_default_values( - smask_t::scales_runtime_data_type - | smask_t::scales_runtime_groups - | smask_t::zero_points_runtime_data_type - | smask_t::zero_points_runtime_groups + VDISPATCH_MATMUL( + is_dense_format_kind(), VERBOSE_UNSUPPORTED_SPARSE_CFG); + VDISPATCH_MATMUL( + utils::one_of(src_type, s8, u8), VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL(utils::one_of(wei_type, s8, u8, s4, u4), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL(IMPLICATION(with_bias(), + utils::one_of(bia_type, f32, bf16, f16, + s32, s8, u8)), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL( + utils::one_of(dst_type, f32, bf16, f16, s32, s8, u8), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_MATMUL( + attr()->has_default_values(smask_t::scales_data_type + | smask_t::scales_groups + | smask_t::zero_points_data_type + | smask_t::zero_points_groups | smask_t::post_ops | smask_t::sum_dt, - dst_type) - && attr_.post_ops_.check_sum_consistency(dst_type, - /* is_int8 */ true) - && ref_post_ops_t::primitive_kind_ok(attr()->post_ops_) - && attr_scales_ok() && attr_zero_points_ok() - && set_default_formats() - && attr_.set_default_formats(dst_md(0)) == status::success; - return ok ? status::success : status::unimplemented; + dst_type), + VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_MATMUL(attr_.post_ops_.check_sum_consistency(dst_type, + /* is_int8 */ true), + VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_MATMUL( + ref_post_ops_t::primitive_kind_ok(attr()->post_ops_), + VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_MATMUL(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_MATMUL(attr_zero_points_ok(), VERBOSE_UNSUPPORTED_ZP_CFG); + VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_MATMUL( + attr_.set_default_formats(dst_md(0)) == status::success, + VERBOSE_UNSUPPORTED_POSTOP); + + return status::success; } private: bool attr_zero_points_ok() const { - int mask_src = 0, mask_wei = 0, mask_dst = 0; - CHECK_BOOL(attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src)); - CHECK_BOOL(attr()->zero_points_.get(DNNL_ARG_WEIGHTS, &mask_wei)); - CHECK_BOOL(attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst)); - - const auto wei_group_ndims - = attr()->zero_points_.get_groups_ndims(DNNL_ARG_WEIGHTS); - const auto wei_group_dims - = attr()->zero_points_.get_groups(DNNL_ARG_WEIGHTS); - - bool mask_src_ok = utils::one_of(mask_src, 0, wei_qmask_N()); - bool mask_wei_ok = utils::one_of( - mask_wei, 0, wei_qmask_N(), wei_qmask_K() + wei_qmask_N()); - bool mask_dst_ok = utils::one_of(mask_dst, 0, wei_qmask_N()); - - return mask_src_ok && mask_wei_ok && mask_dst_ok - && utils::one_of(wei_group_ndims, 0, 2) - && IMPLICATION(wei_group_ndims == 2, - wei_group_dims[1] == 1 - && K() % wei_group_dims[0] == 0); + const auto &zp = attr()->zero_points_; + if (!zp.has_default_values(DNNL_ARG_SRC)) { + int mask_src = zp.get_mask(DNNL_ARG_SRC); + bool ok = utils::one_of(mask_src, 0, src_qmask_K(), + src_qmask_M() + src_qmask_K()); + if (!ok) return false; + + if (!zp.get(DNNL_ARG_SRC).has_default_groups()) { + const auto gM = zp.get_group(DNNL_ARG_SRC, 0); + ok = gM == 1; + if (!ok) return false; + + const auto gK = zp.get_group(DNNL_ARG_SRC, 1); + ok = IMPLICATION(gK > 1, K() % gK == 0); + if (!ok) return false; + } + } + /* weights decompression requires zero points support */ + if (!zp.has_default_values(DNNL_ARG_WEIGHTS)) { + if (!zp.get(DNNL_ARG_WEIGHTS).has_default_groups()) { + const auto gK = zp.get_group(DNNL_ARG_WEIGHTS, 0); + bool ok = IMPLICATION(gK > 1, K() % gK == 0); + if (!ok) return false; + + const auto gN = zp.get_group(DNNL_ARG_WEIGHTS, 1); + ok = IMPLICATION(gN > 1, N() % gN == 0); + if (!ok) return false; + + // Only one non-unit group is supported. + ok = utils::one_of(1, gK, gN); + if (!ok) return false; + } + } + if (!zp.has_default_values(DNNL_ARG_DST)) { + int mask_dst = zp.get_mask(DNNL_ARG_DST); + bool ok = utils::one_of(mask_dst, 0, wei_qmask_N()); + if (!ok) return false; + } + return true; } }; diff --git a/src/cpu/matmul/ref_sparse_matmul.cpp b/src/cpu/matmul/ref_sparse_matmul.cpp index f95a35c505a..d08b1b09d73 100644 --- a/src/cpu/matmul/ref_sparse_matmul.cpp +++ b/src/cpu/matmul/ref_sparse_matmul.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ #include "common/math_utils.hpp" #include "common/type_helpers.hpp" +#include "cpu/ref_io_helper.hpp" + #include "cpu/matmul/ref_sparse_matmul.hpp" namespace dnnl { @@ -27,7 +29,7 @@ namespace matmul { status_t ref_sparse_matmul_t::execute(const exec_ctx_t &ctx) const { status_t status = status::success; - auto dst = CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DST, status); + auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); CHECK(status); const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); @@ -38,48 +40,161 @@ status_t ref_sparse_matmul_t::execute(const exec_ctx_t &ctx) const { const dim_t N = dst_d.dims()[1]; const dim_t K = src_d.dims()[1]; - parallel_nd(M, N, [&](dim_t i, dim_t j) { dst[i * N + j] = 0.0f; }); + const data_type_t mm_dt = src_d.data_type(); + auto scratchpad = ctx.get_scratchpad_grantor(); + + parallel_nd(M, N, [&](dim_t i, dim_t j) { + const dim_t dst_idx = i * N + j; + io::store_float_value(dst_d.data_type(), 0.0f, dst, dst_idx); + }); if (weights_d.is_sparse_desc()) { - const auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC); - const auto wei_values = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS, 0); - const auto wei_indices - = CTX_IN_MEM(const int32_t *, DNNL_ARG_WEIGHTS, 1); - const auto wei_pointers - = CTX_IN_MEM(const int32_t *, DNNL_ARG_WEIGHTS, 2); + const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); + const auto wei_values = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS, 0); + auto wei_buffer_1 = CTX_IN_MEM(const int32_t *, DNNL_ARG_WEIGHTS, 1); + auto wei_buffer_2 = CTX_IN_MEM(const int32_t *, DNNL_ARG_WEIGHTS, 2); + + // Both COO and CSR encoded data is operated on using CSR kernel for + // matrix multiplication. + // For COO encoding, data preparation includes using a temporary + // buffer to convert the data to the CSR format. + // Matrix multiplication is then carried out using the CSR encoded data. + const int32_t *wei_indices = nullptr; + const int32_t *wei_pointers = nullptr; + + if (weights_d.encoding() == sparse_encoding::csr) { + // For CSR encodings, pointer and indices assignment is + // staightforward as, + // index 1 - index buffer, index 2 - pointer buffer. + wei_indices = wei_buffer_1; + wei_pointers = wei_buffer_2; + } else if (weights_d.encoding() == sparse_encoding::coo) { + // For COO encodings, the two index buffers hold the row and column + // indices respectively. For CSR conversion, the row indices are + // compressed to generate the CSR pointers. + wei_indices = wei_buffer_2; + + int32_t *wei_row_pointers = scratchpad.template get( + memory_tracking::names::key_matmul_sparse_tmp_ptr); + + parallel_nd(K + 1, [&](dim_t k) { + io::store_float_value( + weights_d.metadata_type(0), 0, wei_row_pointers, k); + }); + + cvt_coo_indices_to_csr_pointers( + wei_buffer_1, wei_row_pointers, weights_d.nnz(), K); + + wei_pointers = wei_row_pointers; + } + + run_csr_kernel(src, wei_values, wei_indices, wei_pointers, dst, M, N, K, + mm_dt, src_d.is_sparse_desc()); + + } else if (src_d.is_sparse_desc()) { + const auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); + const auto src_values = CTX_IN_MEM(const void *, DNNL_ARG_SRC, 0); + auto src_buffer_1 = CTX_IN_MEM(const int32_t *, DNNL_ARG_SRC, 1); + auto src_buffer_2 = CTX_IN_MEM(const int32_t *, DNNL_ARG_SRC, 2); + + // Both COO and CSR encoded data is operated on using CSR kernel for + // matrix multiplication. + // For COO encoding, data preparation includes using a temporary + // buffer to convert the data to the CSR format. + // Matrix multiplication is then carried out using the CSR encoded data. + const int32_t *src_indices = nullptr; + const int32_t *src_pointers = nullptr; + + if (src_d.encoding() == sparse_encoding::csr) { + // For CSR encodings, pointer and indices assignment is + // staightforward as + // index 1 - index buffer, index 2 - pointer buffer. + src_indices = src_buffer_1; + src_pointers = src_buffer_2; + } else if (src_d.encoding() == sparse_encoding::coo) { + // For COO encodings, the two index buffers hold the row and column + // indices respectively. For CSR conversion, the row indices are + // compressed to generate the CSR pointers. + src_indices = src_buffer_2; + + int32_t *src_row_pointers = scratchpad.template get( + memory_tracking::names::key_matmul_sparse_tmp_ptr); + + parallel_nd(M + 1, [&](dim_t m) { + io::store_float_value( + src_d.metadata_type(0), 0, src_row_pointers, m); + }); + + cvt_coo_indices_to_csr_pointers( + src_buffer_1, src_row_pointers, src_d.nnz(), M); + src_pointers = src_row_pointers; + } + + run_csr_kernel(weights, src_values, src_indices, src_pointers, dst, M, + N, K, mm_dt, src_d.is_sparse_desc()); + } + return status::success; +} + +void ref_sparse_matmul_t::cvt_coo_indices_to_csr_pointers( + const int32_t *indices, int32_t *pointers, const int nnz, + const int nrows) const { + parallel_nd( + nnz, [&](dim_t i) { fetch_and_add(&pointers[indices[i] + 1], 1); }); + for (int i = 0; i < nrows; ++i) { + pointers[i + 1] += pointers[i]; + } +} + +void ref_sparse_matmul_t::run_csr_kernel(const void *dmat, const void *values, + const int32_t *indices, const int32_t *pointers, void *res, + const dim_t M, const dim_t N, const dim_t K, const data_type_t mm_dt, + bool is_src_sparse) const { + + if (is_src_sparse) { + // With a sparse source tensor, the matrix multiplication is carried out + // for a sparse multiplier with parallelization over the sparse rows + // of the multiplier matrix. parallel_nd(M, [&](dim_t m) { - for (dim_t k = 0; k < K; k++) { - const dim_t row_start = wei_pointers[k]; - const dim_t row_end = wei_pointers[k + 1]; - for (dim_t n = row_start; n < row_end; n++) { - const dim_t src_idx = m * K + k; - const dim_t dst_idx = m * N + wei_indices[n]; - dst[dst_idx] = dst[dst_idx] + src[src_idx] * wei_values[n]; + const dim_t row_start = pointers[m]; + const dim_t row_end = pointers[m + 1]; + + for (dim_t n = 0; n < N; n++) { + const dim_t c_idx = m * N + n; + float c_val = io::load_float_value(mm_dt, res, c_idx); + + for (dim_t k = row_start; k < row_end; k++) { + const dim_t b_idx = indices[k] * N + n; + const float a_val = io::load_float_value(mm_dt, values, k); + const float b_val + = io::load_float_value(mm_dt, dmat, b_idx); + c_val += a_val * b_val; } + io::store_float_value(mm_dt, c_val, res, c_idx); } }); - } else if (src_d.is_sparse_desc()) { - const auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS); - const auto src_values = CTX_IN_MEM(const float *, DNNL_ARG_SRC, 0); - const auto src_indices = CTX_IN_MEM(const int32_t *, DNNL_ARG_SRC, 1); - const auto src_pointers = CTX_IN_MEM(const int32_t *, DNNL_ARG_SRC, 2); - + } else { + // With a sparse weights tensor, the matrix multiplication is carried + // out for a sparse multiplicand with parallelization over the dense + // rows of the multiplier matrix. parallel_nd(M, [&](dim_t m) { - const dim_t row_start = src_pointers[m]; - const dim_t row_end = src_pointers[m + 1]; - for (dim_t k = row_start; k < row_end; k++) { - for (dim_t n = 0; n < N; n++) { - const dim_t dst_idx = m * N + n; - const dim_t wei_idx = src_indices[k] * N + n; - dst[dst_idx] - = dst[dst_idx] + src_values[k] * weights[wei_idx]; + for (dim_t k = 0; k < K; k++) { + const dim_t row_start = pointers[k]; + const dim_t row_end = pointers[k + 1]; + for (dim_t n = row_start; n < row_end; n++) { + const dim_t a_idx = m * K + k; + const dim_t c_idx = m * N + indices[n]; + const float a_val + = io::load_float_value(mm_dt, dmat, a_idx); + const float b_val = io::load_float_value(mm_dt, values, n); + float c_val = io::load_float_value(mm_dt, res, c_idx); + c_val += a_val * b_val; + io::store_float_value(mm_dt, c_val, res, c_idx); } } }); } - - return status::success; } } // namespace matmul diff --git a/src/cpu/matmul/ref_sparse_matmul.hpp b/src/cpu/matmul/ref_sparse_matmul.hpp index 2b7dbae8c08..16d63318deb 100644 --- a/src/cpu/matmul/ref_sparse_matmul.hpp +++ b/src/cpu/matmul/ref_sparse_matmul.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,25 +44,62 @@ struct ref_sparse_matmul_t : public primitive_t { memory_desc_wrapper src_d(src_md()); memory_desc_wrapper wei_d(weights_md(0)); - const bool ok - = utils::everyone_is(f32, src_type, wei_type, dst_type) - && utils::one_of(true, wei_d.is_sparse_desc(), - src_d.is_sparse_desc()) - && IMPLICATION(wei_d.is_sparse_desc(), - wei_d.encoding() == sparse_encoding::csr) - && IMPLICATION(src_d.is_sparse_desc(), - src_d.encoding() == sparse_encoding::csr) - && IMPLICATION( - wei_d.is_sparse_desc(), !src_d.is_sparse_desc()) - && IMPLICATION(src_d.is_sparse_desc(), - utils::everyone_is(s32, src_d.metadata_type(0), - src_d.metadata_type(1))) - && IMPLICATION(wei_d.is_sparse_desc(), - utils::everyone_is(s32, wei_d.metadata_type(0), - wei_d.metadata_type(1))) - && !with_bias() && attr()->has_default_values() - && set_default_formats() && formats_ok(src_d, wei_d); - return ok ? status::success : status::unimplemented; + VDISPATCH_MATMUL(wei_d.is_sparse_desc() || src_d.is_sparse_desc(), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + VDISPATCH_MATMUL(wei_d.is_sparse_desc() ^ src_d.is_sparse_desc(), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + + VDISPATCH_MATMUL(IMPLICATION(src_d.is_sparse_desc(), + utils::one_of(src_d.encoding(), + sparse_encoding::csr, + sparse_encoding::coo)), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + VDISPATCH_MATMUL(IMPLICATION(wei_d.is_sparse_desc(), + utils::one_of(wei_d.encoding(), + sparse_encoding::csr, + sparse_encoding::coo)), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + + VDISPATCH_MATMUL( + utils::everyone_is(f16, src_type, wei_type, dst_type) + || utils::everyone_is( + f32, src_type, wei_type, dst_type), + VERBOSE_UNSUPPORTED_DT_CFG); + + if (src_d.is_sparse_desc()) { + sparse_mem_encoding = src_d.encoding(); + VDISPATCH_MATMUL( + IMPLICATION(sparse_mem_encoding == sparse_encoding::coo, + s32 == src_d.metadata_type(0)), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + VDISPATCH_MATMUL( + IMPLICATION(sparse_mem_encoding == sparse_encoding::csr, + utils::everyone_is(s32, src_d.metadata_type(0), + src_d.metadata_type(1))), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + } + if (wei_d.is_sparse_desc()) { + sparse_mem_encoding = wei_d.encoding(); + VDISPATCH_MATMUL( + IMPLICATION(sparse_mem_encoding == sparse_encoding::coo, + s32 == wei_d.metadata_type(0)), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + + VDISPATCH_MATMUL( + IMPLICATION(sparse_mem_encoding == sparse_encoding::csr, + utils::everyone_is(s32, wei_d.metadata_type(0), + wei_d.metadata_type(1))), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + } + + VDISPATCH_MATMUL(!with_bias(), VERBOSE_UNSUPPORTED_BIAS_CFG); + VDISPATCH_MATMUL( + attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_MATMUL(formats_ok(src_d, wei_d), VERBOSE_UNSUPPORTED_TAG); + + init_scratchpad(); + return status::success; } bool formats_ok(const memory_desc_wrapper &src_d, @@ -76,10 +113,41 @@ struct ref_sparse_matmul_t : public primitive_t { return src_d.matches_one_of_tag(format_tag::ab); return false; } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + const memory_desc_wrapper src_d(src_md()); + const memory_desc_wrapper wei_d(weights_md()); + + if (sparse_mem_encoding == sparse_encoding::coo) { + auto scratchpad = scratchpad_registry().registrar(); + const bool is_wei_sparse = wei_d.is_sparse_desc(); + const auto ptr_size + = src_d.dims()[static_cast(is_wei_sparse)] + 1; + scratchpad.template book( + key_matmul_sparse_tmp_ptr, ptr_size); + } + } + + sparse_encoding_t sparse_mem_encoding = sparse_encoding::undef; }; ref_sparse_matmul_t(const pd_t *apd) : primitive_t(apd) {} + // COO sparse encodings are converted to CSR format by + // compressing the respective row indices into CSR pointers. + void cvt_coo_indices_to_csr_pointers(const int32_t *indices, + int32_t *pointers, const int nnz, const int nrows) const; + + // Executes the matrix mutiplication, C = A x B where one of the input + // matrices is dense. Operation indices are determined depending on + // whether the mulitplier or multiplicand is dense + void run_csr_kernel(const void *dmat, const void *values, + const int32_t *indices, const int32_t *pointers, void *res, + const dim_t M, const dim_t N, const dim_t K, + const data_type_t mm_dt, bool is_src_sparse) const; + status_t execute(const exec_ctx_t &ctx) const override; private: diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp index 6b709e4cd3f..6454e68d22b 100644 --- a/src/cpu/nchw_pooling.cpp +++ b/src/cpu/nchw_pooling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,13 +36,18 @@ template <> status_t nchw_pooling_fwd_t::execute_forward( const exec_ctx_t &ctx) const { const auto alg = pd()->desc()->alg_kind; - const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); + auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); const memory_desc_wrapper ws_d(pd()->workspace_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + src += src_d.off_l(0); + dst += dst_d.off_l(0); + const dim_t MB = pd()->MB(); const dim_t C = pd()->OC(); const dim_t OD = pd()->OD(); @@ -61,7 +66,7 @@ status_t nchw_pooling_fwd_t::execute_forward( const dim_t padT = pd()->padT(); const dim_t padL = pd()->padL(); - const auto apply_offset = [](int index, int offset) { + const auto apply_offset = [](dim_t index, dim_t offset) { return (index > offset) ? index - offset : 0; }; @@ -74,7 +79,7 @@ status_t nchw_pooling_fwd_t::execute_forward( + (size_t)OW * oh + (size_t)ow; if (ws_dt == data_type::u8) { assert(0 <= value - && value <= numeric_limits::type>::max()); ws[ws_offset] = value; } else @@ -87,6 +92,10 @@ status_t nchw_pooling_fwd_t::execute_forward( const auto src_off = IW * IH * ID * C * mb + IW * IH * ID * c; const auto *src_loc = &src[src_off]; + data_t d_val = d[0]; + dim_t kd_max = 0; + dim_t kh_max = 0; + dim_t kw_max = 0; for_(dim_t kd = 0; kd < KD; ++kd) for_(dim_t kh = 0; kh < KH; ++kh) for (dim_t kw = 0; kw < KW; ++kw) { @@ -99,11 +108,18 @@ status_t nchw_pooling_fwd_t::execute_forward( const auto src_off_loc = IW * IH * id + IW * ih + iw; const auto &s = src_loc[src_off_loc]; - if (s > d[0]) { - d[0] = s; - set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw); + if (s > d_val) { + d_val = s; + kd_max = kd; + kh_max = kh; + kw_max = kw; } } + + if (d_val > d[0]) { + d[0] = d_val; + set_ws(mb, c, od, oh, ow, kd_max * KH * KW + kh_max * KW + kw_max); + } }; const auto ker_avg = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh, @@ -254,7 +270,7 @@ status_t nchw_pooling_fwd_t::execute_forward( const size_t blocked_size = src_size / simd_w; const size_t tail_size = src_size % simd_w; - auto apply_offset = [=](int index, int offset) { + auto apply_offset = [=](dim_t index, dim_t offset) { return (index > offset) ? index - offset : 0; }; @@ -267,7 +283,7 @@ status_t nchw_pooling_fwd_t::execute_forward( + (size_t)OW * oh + (size_t)ow; if (ws_dt == data_type::u8) { assert(0 <= value - && value <= numeric_limits::type>::max()); ws[ws_offset] = value; } else @@ -280,6 +296,10 @@ status_t nchw_pooling_fwd_t::execute_forward( const auto src_off = IW * IH * ID * C * mb + IW * IH * ID * c; const auto *src_loc = &cvt_wsp[src_off]; + float d_val = d[0]; + dim_t kd_max = 0; + dim_t kh_max = 0; + dim_t kw_max = 0; for_(dim_t kd = 0; kd < KD; ++kd) for_(dim_t kh = 0; kh < KH; ++kh) for (dim_t kw = 0; kw < KW; ++kw) { @@ -292,11 +312,18 @@ status_t nchw_pooling_fwd_t::execute_forward( const auto src_off_loc = IW * IH * id + IW * ih + iw; const auto &s = src_loc[src_off_loc]; - if (s > d[0]) { - d[0] = s; - set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw); + if (s > d_val) { + d_val = s; + kd_max = kd; + kh_max = kh; + kw_max = kw; } } + + if (d_val > d[0]) { + d[0] = d_val; + set_ws(mb, c, od, oh, ow, kd_max * KH * KW + kh_max * KW + kw_max); + } }; auto ker_avg = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh, @@ -442,7 +469,7 @@ status_t nchw_pooling_bwd_t::execute_backward( const dim_t padT = pd()->padT(); const dim_t padL = pd()->padL(); - auto apply_offset = [=](int index, int offset) { + auto apply_offset = [=](dim_t index, dim_t offset) { return (index > offset) ? index - offset : 0; }; @@ -486,7 +513,7 @@ status_t nchw_pooling_bwd_t::execute_backward( diff_src[diff_src_offset] += d[0]; }; - auto ker_avg = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh, + auto ker_avg = [=](data_t d, dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { dim_t id_start = apply_offset(od * SD, padF); dim_t ih_start = apply_offset(oh * SH, padT); @@ -506,7 +533,7 @@ status_t nchw_pooling_bwd_t::execute_backward( size_t diff_src_offset = (size_t)mb * C * ID * IH * IW + (size_t)c * ID * IH * IW + (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw; - diff_src[diff_src_offset] += d[0] / num_summands; + diff_src[diff_src_offset] += d / num_summands; } }; @@ -544,7 +571,7 @@ status_t nchw_pooling_bwd_t::execute_backward( size_t diff_dst_offset = diff_dst_offset_b + (size_t)od * OH * OW + (size_t)oh * OW; for (dim_t ow = ow_start; ow < ow_end; ++ow) { - const data_t *d = &diff_dst[diff_dst_offset + ow]; + data_t d = diff_dst[diff_dst_offset + ow]; ker_avg(d, mb, c, od, oh, ow); } } @@ -595,7 +622,7 @@ status_t nchw_pooling_bwd_t::execute_backward( const size_t dst_sp_size = pd()->OD() * pd()->OH() * pd()->OW(); const size_t src_sp_size = pd()->ID() * pd()->IH() * pd()->IW(); - auto apply_offset = [=](int index, int offset) { + auto apply_offset = [=](dim_t index, dim_t offset) { return (index > offset) ? index - offset : 0; }; @@ -638,8 +665,8 @@ status_t nchw_pooling_bwd_t::execute_backward( diff_src[diff_src_offset] += d[0]; }; - auto ker_avg = [=](const float *d, float *diff_src, dim_t mb, dim_t c, - dim_t od, dim_t oh, dim_t ow) { + auto ker_avg = [=](float d, float *diff_src, dim_t mb, dim_t c, dim_t od, + dim_t oh, dim_t ow) { auto id_start = apply_offset(od * SD, padF); auto ih_start = apply_offset(oh * SH, padT); auto iw_start = apply_offset(ow * SW, padL); @@ -657,7 +684,7 @@ status_t nchw_pooling_bwd_t::execute_backward( for (dim_t iw = iw_start; iw < iw_end; ++iw) { size_t diff_src_offset = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw; - diff_src[diff_src_offset] += d[0] / num_summands; + diff_src[diff_src_offset] += d / num_summands; } }; @@ -677,6 +704,7 @@ status_t nchw_pooling_bwd_t::execute_backward( if (alg == alg_kind::pooling_max) { parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), [&](int ithr, int, dim_t mb, dim_t cb) { + assert(ithr < pd()->nbuf_); bool is_last_c_block = c_blk_tail > 0 && (cb + 1) * c_blk > C; dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk; @@ -713,6 +741,7 @@ status_t nchw_pooling_bwd_t::execute_backward( } else { parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), [&](int ithr, int, dim_t mb, dim_t cb) { + assert(ithr < pd()->nbuf_); bool is_last_c_block = c_blk_tail > 0 && (cb + 1) * c_blk > C; dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk; @@ -734,8 +763,7 @@ status_t nchw_pooling_bwd_t::execute_backward( size_t diff_dst_offset = (size_t)c * OD * OH * OW + (size_t)od * OH * OW + (size_t)oh * OW; for (dim_t ow = ow_start; ow < ow_end; ++ow) { - const float *d - = &diff_dst_fp32[diff_dst_offset + ow]; + float d = diff_dst_fp32[diff_dst_offset + ow]; ker_avg(d, &diff_src_fp32[c * ID * IH * IW], mb, cb * c_blk + c, od, oh, ow); } diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp index ae3b2fc5367..7ea95f2a66e 100644 --- a/src/cpu/nchw_pooling.hpp +++ b/src/cpu/nchw_pooling.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,6 +51,9 @@ struct nchw_pooling_fwd_t : public primitive_t { alg_kind::pooling_avg_include_padding, alg_kind::pooling_avg_exclude_padding), VERBOSE_BAD_ALGORITHM); + VDISPATCH_POOLING( + memory_desc_wrapper(dst_md()).is_dense(false), + VERBOSE_UNSUPPORTED_SPARSE_CFG); VDISPATCH_POOLING(utils::everyone_is(d_type, src_md()->data_type, dst_md()->data_type), VERBOSE_UNSUPPORTED_DT); @@ -101,7 +104,7 @@ struct nchw_pooling_fwd_t : public primitive_t { nchw_pooling_fwd_t(const pd_t *apd) : primitive_t(apd) {} - using data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; status_t init(engine_t *engine) override { ref_post_ops_ @@ -174,8 +177,9 @@ struct nchw_pooling_bwd_t : public primitive_t { return status::success; } - dim_t channel_block_size_; + dim_t channel_block_size_ {1}; int nthr_; // To not exceed the limit in execute used for set up. + int nbuf_ {0}; private: void init_scratchpad() { @@ -185,31 +189,39 @@ struct nchw_pooling_bwd_t : public primitive_t { size_t src_sz_ = ID() * IH() * IW(); auto scratchpad = scratchpad_registry().registrar(); + // The value of nbuf_ must be in compliance with arguments of + // parallel_nd_ext called from execute_backward for data_type!=f32 + nbuf_ = nstl::min(static_cast(nthr_), + MB() * utils::div_up(IC(), channel_block_size_)); + scratchpad.template book(key_pool_src_bf16cvt, - src_sz_ * nthr_ * channel_block_size_); + src_sz_ * nbuf_ * channel_block_size_); scratchpad.template book(key_pool_dst_bf16cvt, - dst_sz_ * nthr_ * channel_block_size_); + dst_sz_ * nbuf_ * channel_block_size_); } } void calculate_channel_block_size() { - // calculate channels block size at which the data fits into half - // of L1, it allows to improve performance for problems with small - // spatial - dim_t dst_sz_ = OD() * OH() * OW(); - dim_t src_sz_ = ID() * IH() * IW(); - dim_t C_per_thr = nstl::min(MB() * IC() / nthr_, IC()); - const dim_t max_block_size - = platform::get_per_core_cache_size(1) / 2; - dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16 - channel_block_size_ = nstl::max( - nstl::min(C_per_thr, max_block_size / data_size_per_ch), - (dim_t)1); + using namespace memory_tracking::names; + if (diff_dst_md()->data_type != data_type::f32) { + // calculate channels block size at which the data fits into half + // of L1, it allows to improve performance for problems with small + // spatial + dim_t dst_sz_ = OD() * OH() * OW(); + dim_t src_sz_ = ID() * IH() * IW(); + dim_t C_per_thr = nstl::min(MB() * IC() / nthr_, IC()); + const dim_t max_block_size + = platform::get_per_core_cache_size(1) / 2; + dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16 + channel_block_size_ = nstl::max( + nstl::min(C_per_thr, max_block_size / data_size_per_ch), + (dim_t)1); + } } }; nchw_pooling_bwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_backward(ctx); diff --git a/src/cpu/ncsp_batch_normalization.hpp b/src/cpu/ncsp_batch_normalization.hpp index 2cfe4d834b4..0cde9f513b5 100644 --- a/src/cpu/ncsp_batch_normalization.hpp +++ b/src/cpu/ncsp_batch_normalization.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -112,11 +112,11 @@ struct ncsp_batch_normalization_fwd_t : public primitive_t { } }; - typedef typename prec_traits::type data_t; - typedef float acc_data_t; + using data_t = typename prec_traits_t::type; + using acc_data_t = float; ncsp_batch_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {} - ~ncsp_batch_normalization_fwd_t() {} + ~ncsp_batch_normalization_fwd_t() override = default; status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); @@ -209,11 +209,11 @@ struct ncsp_batch_normalization_bwd_t : public primitive_t { } }; - typedef typename prec_traits::type data_t; - typedef float acc_data_t; + using data_t = typename prec_traits_t::type; + using acc_data_t = float; ncsp_batch_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {} - ~ncsp_batch_normalization_bwd_t() {} + ~ncsp_batch_normalization_bwd_t() override = default; status_t execute(const exec_ctx_t &ctx) const override { return execute_backward(ctx); diff --git a/src/cpu/ncsp_group_normalization.hpp b/src/cpu/ncsp_group_normalization.hpp index 85c5f68bb6d..5c8237a0bf3 100644 --- a/src/cpu/ncsp_group_normalization.hpp +++ b/src/cpu/ncsp_group_normalization.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,8 +68,7 @@ struct ncsp_group_normalization_fwd_t : public primitive_t { VDISPATCH_GNORM(memory_desc_matches_one_of_tag( *dst_md(), ncdhw, nchw, ncw, nc), VERBOSE_UNSUPPORTED_TAG_S, "dst"); - VDISPATCH_GNORM( - attr()->has_default_values(skip_mask_t::scales_runtime) + VDISPATCH_GNORM(attr()->has_default_values(skip_mask_t::scales) && attr_scales_ok(), VERBOSE_UNSUPPORTED_ATTR); nthr_ = dnnl_get_max_threads(); diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp index b20ee0a92a0..754b5ab4c40 100644 --- a/src/cpu/nhwc_pooling.cpp +++ b/src/cpu/nhwc_pooling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,17 +55,17 @@ namespace cpu { = MEM_D(name).blocking_desc().strides[ndims - 1]; namespace nhwc_pooling { -size_t strided_offset(const int _n, const size_t _sn, const int _d, - const size_t _sd, const int _h, const size_t _sh, const int _w, +size_t strided_offset(const dim_t _n, const size_t _sn, const dim_t _d, + const size_t _sd, const dim_t _h, const size_t _sh, const dim_t _w, const size_t _sw) { return _n * _sn + _d * _sd + _h * _sh + _w * _sw; } } // namespace nhwc_pooling template -void nhwc_pooling_fwd_t::array_div_by_const(const int n, +void nhwc_pooling_fwd_t::array_div_by_const(const dim_t n, const ker_data_t *src, const size_t num, ker_data_t *dst) const { - for (int i = 0; i < n; ++i) { + for (dim_t i = 0; i < n; ++i) { const float ftmp = ((float)src[i]) / num; dst[i] = q10n::out_round(ftmp); } @@ -73,21 +73,21 @@ void nhwc_pooling_fwd_t::array_div_by_const(const int n, template void nhwc_pooling_fwd_t::array_add( - const int n, const ker_data_t *src, ker_data_t *dst) const { - for (int i = 0; i < n; ++i) { + const dim_t n, const ker_data_t *src, ker_data_t *dst) const { + for (dim_t i = 0; i < n; ++i) { dst[i] += src[i]; } } template -void nhwc_pooling_fwd_t::array_nhwc_max(const int n, ker_data_t *dst, +void nhwc_pooling_fwd_t::array_nhwc_max(const dim_t n, ker_data_t *dst, const ker_data_t *src, unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt, const int index) const { assert(ws); #if SAFE_TO_USE_OMP_SIMD PRAGMA_OMP_SIMD() #endif - for (int oc = 0; oc < n; ++oc) { + for (dim_t oc = 0; oc < n; ++oc) { const auto s = src[oc]; ker_data_t mv = dst[oc]; @@ -130,14 +130,14 @@ void nhwc_pooling_fwd_t::array_nhwc_max(const int n, ker_data_t *dst, } template -void nhwc_pooling_fwd_t::array_nhwc_initialize(const int n, +void nhwc_pooling_fwd_t::array_nhwc_initialize(const dim_t n, ker_data_t *dst, unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt) const { assert(ws && (ws_dt == data_type::u8 || ws_dt == data_type::s32)); #if SAFE_TO_USE_OMP_SIMD PRAGMA_OMP_SIMD() #endif - for (int oc = 0; oc < n; ++oc) { + for (dim_t oc = 0; oc < n; ++oc) { if (ws_dt == data_type::u8) ws[ws_offset + oc] = 0; else @@ -189,7 +189,7 @@ status_t nhwc_pooling_fwd_t::execute_forward( DECLARE_READ_STRIDES(src); DECLARE_READ_STRIDES(dst); - const auto apply_offset = [](int index, int offset) { + const auto apply_offset = [](dim_t index, dim_t offset) { return (index > offset) ? index - offset : 0; }; diff --git a/src/cpu/nhwc_pooling.hpp b/src/cpu/nhwc_pooling.hpp index 44c71049b9e..98fb378ddcd 100644 --- a/src/cpu/nhwc_pooling.hpp +++ b/src/cpu/nhwc_pooling.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,8 +35,8 @@ namespace impl { namespace cpu { namespace nhwc_pooling { -size_t strided_offset(const int _n, const size_t _sn, const int _d, - const size_t _sd, const int _h, const size_t _sh, const int _w, +size_t strided_offset(const dim_t _n, const size_t _sn, const dim_t _d, + const size_t _sd, const dim_t _h, const size_t _sh, const dim_t _w, const size_t _sw); } @@ -113,8 +113,8 @@ struct nhwc_pooling_fwd_t : public primitive_t { nhwc_pooling_fwd_t(const pd_t *apd) : primitive_t(apd) {} - using data_t = typename prec_traits::type; - using ker_data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; + using ker_data_t = typename prec_traits_t::type; status_t init(engine_t *engine) override { ref_post_ops_ @@ -130,14 +130,15 @@ struct nhwc_pooling_fwd_t : public primitive_t { private: status_t execute_forward(const exec_ctx_t &ctx) const; - void array_div_by_const(const int n, const ker_data_t *src, + void array_div_by_const(const dim_t n, const ker_data_t *src, const size_t num, ker_data_t *dst) const; - void array_add(const int n, const ker_data_t *src, ker_data_t *dst) const; - void array_nhwc_max(const int n, ker_data_t *dst, const ker_data_t *src, + void array_add(const dim_t n, const ker_data_t *src, ker_data_t *dst) const; + void array_nhwc_max(const dim_t n, ker_data_t *dst, const ker_data_t *src, unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt, const int index) const; - void array_nhwc_initialize(const int n, ker_data_t *dst, unsigned char *ws, - const size_t ws_offset, const data_type_t ws_dt) const; + void array_nhwc_initialize(const dim_t n, ker_data_t *dst, + unsigned char *ws, const size_t ws_offset, + const data_type_t ws_dt) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr ref_post_ops_; @@ -210,7 +211,7 @@ struct nhwc_pooling_bwd_t : public primitive_t { }; nhwc_pooling_bwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_backward(ctx); diff --git a/src/cpu/nspc_batch_normalization.hpp b/src/cpu/nspc_batch_normalization.hpp index 90a8a2e0029..456f5c6b9f0 100644 --- a/src/cpu/nspc_batch_normalization.hpp +++ b/src/cpu/nspc_batch_normalization.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,10 +36,8 @@ namespace cpu { template struct nspc_batch_normalization_fwd_t : public primitive_t { struct pd_t : public cpu_batch_normalization_fwd_pd_t { - pd_t(const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_batch_normalization_fwd_pd_t:: + cpu_batch_normalization_fwd_pd_t; DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t); @@ -111,11 +109,11 @@ struct nspc_batch_normalization_fwd_t : public primitive_t { } }; - typedef typename prec_traits::type data_t; - typedef float acc_data_t; + using data_t = typename prec_traits_t::type; + using acc_data_t = float; nspc_batch_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {} - ~nspc_batch_normalization_fwd_t() {} + ~nspc_batch_normalization_fwd_t() override = default; status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); @@ -129,10 +127,8 @@ struct nspc_batch_normalization_fwd_t : public primitive_t { template struct nspc_batch_normalization_bwd_t : public primitive_t { struct pd_t : public cpu_batch_normalization_bwd_pd_t { - pd_t(const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_batch_normalization_bwd_pd_t:: + cpu_batch_normalization_bwd_pd_t; DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t); @@ -204,11 +200,11 @@ struct nspc_batch_normalization_bwd_t : public primitive_t { } }; - typedef typename prec_traits::type data_t; - typedef float acc_data_t; + using data_t = typename prec_traits_t::type; + using acc_data_t = float; nspc_batch_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {} - ~nspc_batch_normalization_bwd_t() {} + ~nspc_batch_normalization_bwd_t() override = default; status_t execute(const exec_ctx_t &ctx) const override { return execute_backward(ctx); diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp index f372543ccb7..d4f5a217c0f 100644 --- a/src/cpu/platform.cpp +++ b/src/cpu/platform.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * Copyright 2020-2024 FUJITSU LIMITED * Copyright 2022-2024 Arm Ltd. and affiliates * @@ -34,10 +34,12 @@ #include "cpu/x64/cpu_isa_traits.hpp" #elif DNNL_AARCH64 #include "cpu/aarch64/cpu_isa_traits.hpp" -#if DNNL_AARCH64_USE_ACL +#endif +#if DNNL_USE_ACL // For checking if fp16 isa is supported on the platform #include "arm_compute/core/CPP/CPPTypes.h" -#endif +// For setting the number of threads for ACL +#include "src/common/cpuinfo/CpuInfo.h" #endif // For DNNL_X64 build we compute the timestamp using rdtsc. Use std::chrono for @@ -82,6 +84,8 @@ status_t set_max_cpu_isa(dnnl_cpu_isa_t isa) { status_t set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints) { #if DNNL_X64 return x64::set_cpu_isa_hints(isa_hints); +#elif DNNL_AARCH64 + return status::success; #else return status::unimplemented; #endif @@ -124,7 +128,7 @@ bool has_data_type_support(data_type_t data_type) { #if DNNL_X64 return x64::mayiuse(x64::avx512_core_fp16) || x64::mayiuse(x64::avx2_vnni_2); -#elif DNNL_AARCH64_USE_ACL +#elif DNNL_USE_ACL return arm_compute::CPUInfo::get().has_fp16(); #else return false; @@ -151,7 +155,7 @@ bool has_training_support(data_type_t data_type) { #if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__) return true; #endif -#elif DNNL_AARCH64_USE_ACL +#elif DNNL_USE_ACL return arm_compute::CPUInfo::get().has_bf16(); #else return false; @@ -159,7 +163,7 @@ bool has_training_support(data_type_t data_type) { case data_type::f16: #if DNNL_X64 return x64::mayiuse(x64::avx512_core_fp16); -#elif DNNL_AARCH64_USE_ACL +#elif DNNL_USE_ACL return arm_compute::CPUInfo::get().has_fp16(); #else return false; @@ -205,8 +209,8 @@ unsigned get_per_core_cache_size(int level) { unsigned get_num_cores() { #if DNNL_X64 return x64::cpu().getNumCores(Xbyak::util::CoreLevel); -#elif DNNL_AARCH64_USE_ACL - return aarch64::cpu().getNumCores(Xbyak_aarch64::util::CoreLevel); +#elif DNNL_USE_ACL + return arm_compute::cpuinfo::num_threads_hint(); #else return 1; #endif @@ -256,9 +260,9 @@ unsigned get_max_threads_to_use() { int get_vector_register_size() { #if DNNL_X64 using namespace x64; - if (mayiuse(avx512_core)) return cpu_isa_traits::vlen; - if (mayiuse(avx)) return cpu_isa_traits::vlen; - if (mayiuse(sse41)) return cpu_isa_traits::vlen; + if (mayiuse(avx512_core)) return cpu_isa_traits_t::vlen; + if (mayiuse(avx)) return cpu_isa_traits_t::vlen; + if (mayiuse(sse41)) return cpu_isa_traits_t::vlen; #elif DNNL_AARCH64 using namespace aarch64; if (mayiuse(asimd)) return cpu_isa_traits::vlen; diff --git a/src/cpu/platform.hpp b/src/cpu/platform.hpp index 1de81f578e6..af0d6e944a8 100644 --- a/src/cpu/platform.hpp +++ b/src/cpu/platform.hpp @@ -26,7 +26,9 @@ // Possible architectures: // - DNNL_X64 +// - DNNL_X86 // - DNNL_AARCH64 +// - DNNL_ARM // - DNNL_PPC64 // - DNNL_S390X // - DNNL_RV64 @@ -35,12 +37,19 @@ #if defined(DNNL_X64) + defined(DNNL_AARCH64) + defined(DNNL_PPC64) \ + defined(DNNL_S390X) + defined(DNNL_RV64) \ + + defined(DNNL_ARM) + defined(DNNL_X86) \ + defined(DNNL_ARCH_GENERIC) \ == 0 -#if defined(__x86_64__) || defined(_M_X64) +#if defined(__amd64__) || defined(__amd64) || defined(__x86_64__) || \ + defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) #define DNNL_X64 1 -#elif defined(__aarch64__) +#elif defined(i386) || defined(__i386) || defined(__i386__) || defined(__IA32__) || defined(_M_I86) || \ + defined(_M_IX86) || defined(__X86__) || defined(_X86_) || defined(__I86__) || defined(__386) +#define DNNL_X86 1 +#elif defined(__aarch64__) || defined(_M_ARM64) #define DNNL_AARCH64 1 +#elif defined(__arm__) || defined(_M_ARM) || defined(__ARMEL__) +#define DNNL_ARM 1 #elif defined(__powerpc64__) || defined(__PPC64__) || defined(_ARCH_PPC64) #define DNNL_PPC64 1 #elif defined(__s390x__) @@ -54,6 +63,7 @@ #if defined(DNNL_X64) + defined(DNNL_AARCH64) + defined(DNNL_PPC64) \ + defined(DNNL_S390X) + defined(DNNL_RV64) \ + + defined(DNNL_ARM) + defined(DNNL_X86) \ + defined(DNNL_ARCH_GENERIC) \ != 1 #error One and only one architecture should be defined at a time @@ -62,9 +72,15 @@ #if !defined(DNNL_X64) #define DNNL_X64 0 #endif +#if !defined(DNNL_X86) +#define DNNL_X86 0 +#endif #if !defined(DNNL_AARCH64) #define DNNL_AARCH64 0 #endif +#if !defined(DNNL_ARM) +#define DNNL_ARM 0 +#endif #if !defined(DNNL_PPC64) #define DNNL_PPC64 0 #endif @@ -84,6 +100,7 @@ #define DNNL_PPC64_ONLY(...) Z_CONDITIONAL_DO(DNNL_PPC64_ONLY, __VA_ARGS__) #define DNNL_S390X_ONLY(...) Z_CONDITIONAL_DO(DNNL_S390X_ONLY, __VA_ARGS__) #define DNNL_AARCH64_ONLY(...) Z_CONDITIONAL_DO(DNNL_AARCH64, __VA_ARGS__) +#define DNNL_ARM_ONLY(...) Z_CONDITIONAL_DO(DNNL_ARM, __VA_ARGS__) // Using RISC-V implementations optimized with RVV Intrinsics is optional for RISC-V builds // and can be enabled with DNNL_ARCH_OPT_FLAGS="-march=" option, where @@ -98,11 +115,11 @@ #define DNNL_NON_X64_ONLY(...) Z_CONDITIONAL_DO(Z_NOT(DNNL_X64), __VA_ARGS__) // Using Arm Compute Library kernels is optional for AArch64 builds -// and can be enabled with the DNNL_AARCH64_USE_ACL CMake option -#if defined(DNNL_AARCH64) && defined(DNNL_AARCH64_USE_ACL) -#define DNNL_AARCH64_ACL_ONLY(...) __VA_ARGS__ +// and can be enabled with the DNNL_USE_ACL CMake option +#ifdef DNNL_USE_ACL +#define DNNL_ACL_ONLY(...) __VA_ARGS__ #else -#define DNNL_AARCH64_ACL_ONLY(...) +#define DNNL_ACL_ONLY(...) #endif // Primitive ISA section for configuring knobs. diff --git a/src/cpu/ppc64/ppc64_gemm_s8x8s32.cpp b/src/cpu/ppc64/ppc64_gemm_s8x8s32.cpp index 33f88cb17c4..f7ce2b90ac3 100644 --- a/src/cpu/ppc64/ppc64_gemm_s8x8s32.cpp +++ b/src/cpu/ppc64/ppc64_gemm_s8x8s32.cpp @@ -150,8 +150,9 @@ dnnl_status_t cblas_gemm_s8x8s32_ppc64(int ATflag, int BTflag, } } for (int i = 0; i < m; ++i) { - comparray[i] = out_round(saturate( - ((double)comparray[i]) * alpha * -128.0)); + comparray[i] = cpu::q10n::out_round( + cpu::q10n::saturate( + ((double)comparray[i]) * alpha * -128.0)); } for (int j = 0; j < n; ++j) { int *ca = comparray; diff --git a/src/cpu/primitive_attr_postops.cpp b/src/cpu/primitive_attr_postops.cpp index fa80cb23683..d5ddd73cd77 100644 --- a/src/cpu/primitive_attr_postops.cpp +++ b/src/cpu/primitive_attr_postops.cpp @@ -26,7 +26,7 @@ namespace cpu { using namespace alg_kind; using namespace math; -float compute_binary_scalar(alg_kind_t alg, float x, float y) { +float compute_binary_scalar(alg_kind_t alg, float x, float y, bool c) { switch (alg) { case binary_add: return x + y; case binary_div: return x / y; @@ -40,7 +40,9 @@ float compute_binary_scalar(alg_kind_t alg, float x, float y) { case binary_lt: return x < y; case binary_eq: return x == y; case binary_ne: return x != y; - default: assert(!"not supported operation!"); return NAN; + case binary_select: return c ? x : y; + case binary_prelu: return x >= 0 ? x : x * y; + default: assert(!"unsupported operation!"); return NAN; } } @@ -69,6 +71,9 @@ float compute_eltwise_scalar_fwd( case eltwise_mish: d = mish_fwd(s); break; case eltwise_hardsigmoid: d = hardsigmoid_fwd(s, alpha, beta); break; case eltwise_hardswish: d = hardswish_fwd(s, alpha, beta); break; + case eltwise_hsigmoid: d = hsigmoid_fwd(s); break; + case eltwise_round_half_away_from_zero: d = round_half_away_from_zero_fwd(s); break; + case eltwise_round_half_to_even: d = round_half_to_even_fwd(s); break; case eltwise_relu_use_dst_for_bwd: d = relu_fwd(s, alpha); break; case eltwise_tanh_use_dst_for_bwd: d = tanh_fwd(s); break; case eltwise_elu_use_dst_for_bwd: d = elu_fwd(s, alpha); break; @@ -136,15 +141,16 @@ ref_binary_scalar_t::ref_binary_scalar_t(alg_kind_t alg) : alg_(alg) { alg_kind::binary_min, alg_kind::binary_mul, alg_kind::binary_div, alg_kind::binary_sub, alg_kind::binary_ge, alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq, - alg_kind::binary_ne)); + alg_kind::binary_ne, alg_kind::binary_select, alg_kind::binary_prelu)); } ref_binary_scalar_t::ref_binary_scalar_t( const post_ops_t::entry_t::binary_t &binary) : ref_binary_scalar_t(binary.alg) {} -float ref_binary_scalar_t::compute_scalar(float src0, float src1) const { - return compute_binary_scalar(alg_, src0, src1); +float ref_binary_scalar_t::compute_scalar( + float src0, float src1, bool src2) const { + return compute_binary_scalar(alg_, src0, src1, src2); } ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( @@ -155,11 +161,12 @@ ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( eltwise_soft_relu, eltwise_mish, eltwise_logistic, eltwise_exp, eltwise_gelu_tanh, eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2, eltwise_pow, eltwise_gelu_erf, eltwise_round, - eltwise_hardsigmoid, eltwise_hardswish, - eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd, - eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd, - eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd, - eltwise_clip_v2_use_dst_for_bwd)); + eltwise_hardswish, eltwise_hardsigmoid, + eltwise_hsigmoid, eltwise_round_half_away_from_zero, eltwise_round_half_to_even, + eltwise_relu_use_dst_for_bwd, + eltwise_tanh_use_dst_for_bwd, eltwise_elu_use_dst_for_bwd, + eltwise_sqrt_use_dst_for_bwd, eltwise_logistic_use_dst_for_bwd, + eltwise_exp_use_dst_for_bwd, eltwise_clip_v2_use_dst_for_bwd)); } ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( @@ -179,6 +186,8 @@ ref_post_ops_t::ref_post_ops_t(const post_ops_t &po, bool skip_sum) eltwise_po_.emplace_back(e.eltwise); } else if (po_.contain(primitive_kind::binary, idx)) { binary_po_.emplace_back(e.binary); + } else if (po_.contain(primitive_kind::depthwise, idx)) { + depthwise_po_.emplace_back(e.depthwise.alg); } } } @@ -273,12 +282,13 @@ float ref_dropout( return (m) ? src * inv_q : 0; } -void ref_post_ops_t::execute(float &res, const args_t &args) const { +void ref_post_ops_t::execute(float &res, const args_t &args, const size_t oc) const { if (po_.len() == 0) return; auto it_eltwise_po = eltwise_po_.begin(); auto it_binary_po = binary_po_.begin(); auto it_prelu_md = prelu_md_.begin(); + auto it_depthwise_po = depthwise_po_.begin(); for (auto idx = 0; idx < po_.len(); ++idx) { const auto &e = po_.entry_[idx]; switch (e.kind) { @@ -308,7 +318,7 @@ void ref_post_ops_t::execute(float &res, const args_t &args) const { (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); const float val_po = io::load_float_value( src1_desc.data_type, src1_binary_po, off); - res = it_binary_po->compute_scalar(res, val_po); + res = it_binary_po->compute_scalar(res, val_po, false); ++it_binary_po; } break; case primitive_kind::prelu: { @@ -339,6 +349,46 @@ void ref_post_ops_t::execute(float &res, const args_t &args) const { res = weights_value * res; ++it_prelu_md; } break; + case primitive_kind::depthwise: { + const exec_ctx_t &ctx = *args.ctx; + auto depthwise_base = CTX_IN_MEM(const float *, (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); + auto depthwise_weights = depthwise_base + e.depthwise.offset[e.depthwise.scales]; + auto depthwise_bias = depthwise_base + e.depthwise.offset[e.depthwise.shifts]; + + res = it_depthwise_po->compute_scalar(res, depthwise_weights + oc, depthwise_bias + oc); + + ++it_depthwise_po; + } break; + case primitive_kind::quantization: { + bool do_dequantization = e.quantization.alg == alg_kind::quantization_quantize_dequantize; + bool do_rounding = do_dequantization || args.dst_md->data_type == dnnl_f32 || idx != po_.len() - 1; + + auto quant = e.quantization; + const exec_ctx_t &ctx = *args.ctx; + auto quantization_base = CTX_IN_MEM(const float *, (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); + const auto pcl = quantization_base + quant.offset[quant.crop_low]; + const auto pch = quantization_base + quant.offset[quant.crop_high]; + const auto pisc = quantization_base + quant.offset[quant.inp_scale]; + const auto pish = quantization_base + quant.offset[quant.inp_shift]; + const auto posc = quantization_base + quant.offset[quant.output_scale]; + const auto posh = quantization_base + quant.offset[quant.output_shift]; + + int cl_idx = !quant.per_channel[quant.crop_low] ? 0 : oc; + int ch_idx = !quant.per_channel[quant.crop_high] ? 0 : oc; + int isc_idx = !quant.per_channel[quant.inp_scale] ? 0 : oc; + int ish_idx = !quant.per_channel[quant.inp_shift] ? 0 : oc; + int osc_idx = !quant.per_channel[quant.output_scale] ? 0 : oc; + int osh_idx = !quant.per_channel[quant.output_shift] ? 0 : oc; + + res = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], res)); + res = res * pisc[isc_idx] + pish[ish_idx]; + + if (do_rounding) + res = roundf(res); + + if (do_dequantization) + res = res * posc[osc_idx] + posh[osh_idx]; + } break; default: assert(!"unsupported post op primitive kind!"); } } diff --git a/src/cpu/primitive_attr_postops.hpp b/src/cpu/primitive_attr_postops.hpp index bcb09b2e004..3759b89a8d4 100644 --- a/src/cpu/primitive_attr_postops.hpp +++ b/src/cpu/primitive_attr_postops.hpp @@ -22,11 +22,13 @@ #include "common/primitive.hpp" #include "common/primitive_attr.hpp" +#include "ref_depthwise_injector.hpp" + namespace dnnl { namespace impl { namespace cpu { -float compute_binary_scalar(alg_kind_t alg, float x, float y); +float compute_binary_scalar(alg_kind_t alg, float x, float y, bool c); float compute_eltwise_scalar_fwd( const alg_kind_t alg, float s, float alpha, float beta); float compute_eltwise_scalar_bwd( @@ -36,7 +38,7 @@ struct ref_binary_scalar_t { ref_binary_scalar_t(alg_kind_t alg); ref_binary_scalar_t(const post_ops_t::entry_t::binary_t &binary); - float compute_scalar(float src0, float src1) const; + float compute_scalar(float src0, float src1, bool src2) const; private: const alg_kind_t alg_; @@ -71,7 +73,7 @@ struct ref_post_ops_t { status_t init(const memory_desc_t *dst_md); - void execute(float &res, const args_t &args = args_t()) const; + void execute(float &res, const args_t &args = args_t(), const size_t oc = 0) const; static bool primitive_kind_ok(const post_ops_t &po) { using namespace primitive_kind; @@ -86,6 +88,7 @@ struct ref_post_ops_t { std::vector eltwise_po_; std::vector binary_po_; + std::vector depthwise_po_; std::vector prelu_md_; }; diff --git a/src/cpu/ref_batch_normalization.cpp b/src/cpu/ref_batch_normalization.cpp index 0e4a23e8d7e..6ab3c20742f 100644 --- a/src/cpu/ref_batch_normalization.cpp +++ b/src/cpu/ref_batch_normalization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -158,8 +158,8 @@ status_t ref_batch_normalization_fwd_t::execute_forward( } } if (d_type == s8) - dst[d_off] - = q10n::qz_a1b0()(maybe_post_op(bn_res)); + dst[d_off] = q10n::qz_a1b0_t()( + maybe_post_op(bn_res)); else dst[d_off] = maybe_post_op(bn_res); } diff --git a/src/cpu/ref_batch_normalization.hpp b/src/cpu/ref_batch_normalization.hpp index 2e712945533..2932b2a9c7d 100644 --- a/src/cpu/ref_batch_normalization.hpp +++ b/src/cpu/ref_batch_normalization.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,10 +35,8 @@ namespace cpu { template struct ref_batch_normalization_fwd_t : public primitive_t { struct pd_t : public cpu_batch_normalization_fwd_pd_t { - pd_t(const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_batch_normalization_fwd_pd_t:: + cpu_batch_normalization_fwd_pd_t; DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t); @@ -80,7 +78,7 @@ struct ref_batch_normalization_fwd_t : public primitive_t { ref_batch_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); @@ -94,10 +92,8 @@ struct ref_batch_normalization_fwd_t : public primitive_t { template struct ref_batch_normalization_bwd_t : public primitive_t { struct pd_t : public cpu_batch_normalization_bwd_pd_t { - pd_t(const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_batch_normalization_bwd_pd_t:: + cpu_batch_normalization_bwd_pd_t; DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t); @@ -138,7 +134,7 @@ struct ref_batch_normalization_bwd_t : public primitive_t { }; ref_batch_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_backward(ctx); diff --git a/src/cpu/ref_binary.cpp b/src/cpu/ref_binary.cpp index 8d6788b77df..ad7f28ef592 100644 --- a/src/cpu/ref_binary.cpp +++ b/src/cpu/ref_binary.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2022 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,8 @@ namespace cpu { status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const { const auto src0 = CTX_IN_MEM(const void *, DNNL_ARG_SRC_0); const auto src1 = CTX_IN_MEM(const void *, DNNL_ARG_SRC_1); + const auto src2 = CTX_IN_MEM(const void *, DNNL_ARG_SRC_2); + auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); const float *scales[2]; @@ -45,10 +47,12 @@ status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const { const memory_desc_wrapper src0_d(pd()->src_md(0)); const memory_desc_wrapper src1_d(pd()->src_md(1)); + const memory_desc_wrapper src2_d(pd()->src_md(2)); const memory_desc_wrapper dst_d(pd()->dst_md()); const auto src0_dt = src0_d.data_type(); const auto src1_dt = src1_d.data_type(); + const auto src2_dt = src2_d.data_type(); const auto dst_dt = dst_d.data_type(); const auto alg = pd()->desc()->alg_kind; @@ -85,10 +89,11 @@ status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const { } parallel_nd(nelems, [&](dim_t i) { - dims_t dims_src0, dims_src1; // decomposition for physical offsets + // decomposition for physical offsets + dims_t dims_src0, dims_src1, dims_src2; utils::l_dims_by_l_offset(dims_src0, i, dst_d.dims(), ndims); utils::l_dims_by_l_offset(dims_src1, i, dst_d.dims(), ndims); - auto off_C = dst_d.off_v(dims_src0); + auto off_D = dst_d.off_v(dims_src0); int mask_src0 = utils::get_dims_mask(dst_d.dims(), src0_d.dims(), ndims); @@ -101,12 +106,22 @@ status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const { float x_f = io::load_float_value(src0_dt, src0, off_A); float y_f = io::load_float_value(src1_dt, src1, off_B); - float dst_f = io::load_float_value(dst_dt, dst, off_C); + float dst_f = io::load_float_value(dst_dt, dst, off_D); x_f *= scales[0][0]; y_f *= scales[1][0]; - float acc = compute_binary_scalar(alg, x_f, y_f); + bool c_f = false; + if (pd()->is_ternary_op()) { + utils::l_dims_by_l_offset(dims_src2, i, dst_d.dims(), ndims); + int mask_src2 + = utils::get_dims_mask(dst_d.dims(), src2_d.dims(), ndims); + utils::apply_mask_on_dims(dims_src2, ndims, mask_src2); + const auto off_C = src2_d.off_v(dims_src2); + c_f = static_cast(io::load_int_value(src2_dt, src2, off_C)); + } + + float acc = compute_binary_scalar(alg, x_f, y_f, c_f); if (has_postops) { ref_post_ops_t::args_t args; @@ -117,7 +132,7 @@ status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const { ref_post_ops->execute(acc, args); } - io::store_float_value(dst_dt, acc, dst, off_C); + io::store_float_value(dst_dt, acc, dst, off_D); }); return status::success; diff --git a/src/cpu/ref_binary.hpp b/src/cpu/ref_binary.hpp index 459475a580d..b59c0688fbc 100644 --- a/src/cpu/ref_binary.hpp +++ b/src/cpu/ref_binary.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,13 +49,17 @@ struct ref_binary_t : public primitive_t { VDISPATCH_BINARY( platform::has_data_type_support(src_md(1)->data_type), VERBOSE_UNSUPPORTED_DT); + VDISPATCH_BINARY(IMPLICATION(is_ternary_op(), + platform::has_data_type_support( + src_md(2)->data_type)), + VERBOSE_UNSUPPORTED_DT); VDISPATCH_BINARY( platform::has_data_type_support(dst_md()->data_type), VERBOSE_UNSUPPORTED_DT); VDISPATCH_BINARY(set_default_params() == status::success, VERBOSE_UNSUPPORTED_TAG); - VDISPATCH_BINARY(attr()->has_default_values( - sm::post_ops | sm::scales_runtime), + VDISPATCH_BINARY( + attr()->has_default_values(sm::post_ops | sm::scales), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_BINARY(IMPLICATION(!attr()->scales_.has_default_values(), check_scales_mask()), diff --git a/src/cpu/ref_concat.hpp b/src/cpu/ref_concat.hpp index 81a86883f17..6b87295dfcf 100644 --- a/src/cpu/ref_concat.hpp +++ b/src/cpu/ref_concat.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,19 +31,15 @@ namespace cpu { struct ref_concat_t : public primitive_t { struct pd_t : public cpu_concat_pd_t { - pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, int n, - int concat_dim, const memory_desc_t *const *src_mds) - : cpu_concat_pd_t(attr, dst_md, n, concat_dim, src_mds) - , tent_dst_md_(types::zero_md()) {} - pd_t(const pd_t &rhs) = default; - ~pd_t() = default; + using cpu_concat_pd_t::cpu_concat_pd_t; DECLARE_CONCAT_PD_T("ref:any", ref_concat_t); status_t init(engine_t *engine) { using sm = primitive_attr_t::skip_mask_t; - VDISPATCH_CONCAT(attr()->has_default_values(sm::scales_runtime), + VDISPATCH_CONCAT(attr()->has_default_values(sm::scales), VERBOSE_UNSUPPORTED_ATTR); + tent_dst_md_ = types::zero_md(); status_t status = cpu_concat_pd_t::init(); if (status != status::success) { assert(dst_md_.format_kind != format_kind::undef); @@ -62,11 +58,10 @@ struct ref_concat_t : public primitive_t { reorder_pds_.resize(n_ + use_tent_dst()); for (int i = 0; i < n_; ++i) { primitive_attr_t r_attr; - if (!sc.get(DNNL_ARG_MULTIPLE_SRC + i).has_default_values()) { - int mask = 0; - CHECK(sc.get(DNNL_ARG_MULTIPLE_SRC + i, &mask, nullptr)); - if (mask != 0) return status::unimplemented; - r_attr.scales_.set(DNNL_ARG_SRC, mask); + if (!sc.has_default_values(DNNL_ARG_MULTIPLE_SRC + i)) { + int mask = sc.get_mask(DNNL_ARG_MULTIPLE_SRC + i); + VDISPATCH_CONCAT(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + CHECK(r_attr.scales_.set(DNNL_ARG_SRC, mask)); } CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine, src_md(i), src_image_md(i), &r_attr)); @@ -114,7 +109,7 @@ struct ref_concat_t : public primitive_t { return status::success; } - ~ref_concat_t() = default; + ~ref_concat_t() override = default; status_t execute(const exec_ctx_t &ctx) const override { using namespace memory_tracking::names; @@ -145,8 +140,10 @@ struct ref_concat_t : public primitive_t { = scratchpad.get_memory_storage(key_concat_tent_dst); for (int i = 0; i < n; ++i) { - memory_t tent_dst_i(engine, pd()->src_image_md(i), - tent_dst_storage->clone()); + std::unique_ptr tent_dst_i; + CHECK(safe_ptr_assign(tent_dst_i, + new memory_t(engine, pd()->src_image_md(i), + tent_dst_storage->clone()))); const auto &src_scales_arg = ctx.args().find( DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i)); const memory_arg_t *src_scales = nullptr; @@ -154,18 +151,22 @@ struct ref_concat_t : public primitive_t { src_scales = &src_scales_arg->second; execute_reorder(reorders_[i], ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i), - {&tent_dst_i, false}, src_scales, i); + {tent_dst_i.get(), false}, src_scales, i); } - memory_t tent_dst( - engine, &pd()->tent_dst_md_, tent_dst_storage->clone()); - execute_reorder(reorders_[n], {&tent_dst, true}, + std::unique_ptr tent_dst; + CHECK(safe_ptr_assign(tent_dst, + new memory_t(engine, &pd()->tent_dst_md_, + tent_dst_storage->clone()))); + execute_reorder(reorders_[n], {tent_dst.get(), true}, ctx.args().at(DNNL_ARG_DST), nullptr, n); } else { auto &dst_mem_storage = CTX_OUT_STORAGE(DNNL_ARG_DST); for (int i = 0; i < n; ++i) { - memory_t tent_dst_i( - engine, pd()->src_image_md(i), dst_mem_storage.clone()); + std::unique_ptr tent_dst_i; + CHECK(safe_ptr_assign(tent_dst_i, + new memory_t(engine, pd()->src_image_md(i), + dst_mem_storage.clone()))); const auto &src_scales_arg = ctx.args().find( DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i)); const memory_arg_t *src_scales = nullptr; @@ -173,7 +174,7 @@ struct ref_concat_t : public primitive_t { src_scales = &src_scales_arg->second; execute_reorder(reorders_[i], ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i), - {&tent_dst_i, false}, src_scales, i); + {tent_dst_i.get(), false}, src_scales, i); } } return status::success; diff --git a/src/cpu/ref_convolution.cpp b/src/cpu/ref_convolution.cpp index b97c7b76942..22c48174819 100644 --- a/src/cpu/ref_convolution.cpp +++ b/src/cpu/ref_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2021 Intel Corporation +* Copyright 2016-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,8 @@ status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); + const auto rnd_seed + = CTX_IN_MEM(const uint32_t *, DNNL_ARG_ATTR_ROUNDING_SEED); auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); CHECK(status); @@ -73,6 +75,7 @@ status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { const auto padL = pd()->padL(); const auto ndims = pd()->desc()->src_desc.ndims; + const auto dst_rnd_mode = pd()->attr()->rounding_mode_.get(DNNL_ARG_DST); auto ker = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) { float d = 0; @@ -211,7 +214,11 @@ status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { args.ctx = &ctx; args.l_offset = dst_l_off; args.dst_md = pd()->dst_md(); - ref_post_ops->execute(d, args); + ref_post_ops->execute(d, args, g*OC + oc); + if (dst_rnd_mode == rounding_mode::stochastic) + d = math::stochastic_round_fwd( + d, dst_off, rnd_seed[0], dst_d.data_type()); + io::store_float_value(dst_d.data_type(), d, dst, dst_off); }); @@ -387,6 +394,8 @@ status_t ref_convolution_bwd_data_t::execute_backward_data( return ds; }; + const auto &p = pd()->attr()->post_ops_; + parallel_nd(G, MB, IC, ID, IH, IW, [&](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) { float ds = 0; @@ -396,6 +405,19 @@ status_t ref_convolution_bwd_data_t::execute_backward_data( else ds += ker(g, mb, ic, id, ih, iw); + int depthwise_inj_idx = 0; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_depthwise()) { + auto depthwise_base = CTX_IN_MEM(const float *, (DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1)); + auto depthwise_weights = depthwise_base + post_op.depthwise.offset[post_op.depthwise.scales]; + auto depthwise_bias = depthwise_base + post_op.depthwise.offset[post_op.depthwise.shifts]; + + ds = depthwise_injectors[depthwise_inj_idx]->compute_scalar(ds, depthwise_weights + g * IC + ic, depthwise_bias + g * IC + ic); + depthwise_inj_idx++; + } + } + const auto diff_src_off = ref_conv_utils::get_data_off( diff_src_d, ndims, mb, g * IC + ic, id, ih, iw); io::store_float_value( diff --git a/src/cpu/ref_convolution.hpp b/src/cpu/ref_convolution.hpp index dcbb7c909c0..d5ef320b6ea 100644 --- a/src/cpu/ref_convolution.hpp +++ b/src/cpu/ref_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,8 @@ #include "cpu/cpu_convolution_pd.hpp" #include "cpu/primitive_attr_postops.hpp" +#include "ref_depthwise_injector.hpp" + namespace dnnl { namespace impl { namespace cpu { @@ -45,23 +47,42 @@ struct ref_convolution_fwd_t : public primitive_t { const auto bia_type = weights_md(1)->data_type; const auto dst_type = dst_md(0)->data_type; - bool ok = is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && platform::has_data_type_support(src_type) - && platform::has_data_type_support(bia_type) - && platform::has_data_type_support(dst_type) - && utils::one_of(src_type, f32, bf16, f16, f8_e5m2, f8_e4m3) - && src_type == wei_type - && utils::one_of(dst_type, src_type, f32) - && utils::one_of(bia_type, data_type::undef, src_type, f32) - && set_default_formats() - && attr()->has_default_values( - smask_t::post_ops | smask_t::sum_dt, dst_type) - && attr()->post_ops_.check_sum_consistency( - dst_type, /* is_int8 */ false) - && post_ops_ok() - && attr_.set_default_formats(dst_md(0)) == status::success; - return ok ? status::success : status::unimplemented; + VDISPATCH_CONV(is_fwd(), VERBOSE_BAD_PROPKIND); + VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct), + VERBOSE_BAD_ALGORITHM); + VDISPATCH_CONV(platform::has_data_type_support(src_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(platform::has_data_type_support(bia_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(platform::has_data_type_support(dst_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV( + utils::one_of(src_type, f32, bf16, f16, f8_e5m2, f8_e4m3), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(IMPLICATION(src_type != wei_type, + utils::one_of(wei_type, f16, bf16) + && src_type == f32), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(utils::one_of(dst_type, src_type, f32), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV( + utils::one_of(bia_type, data_type::undef, src_type, f32), + VERBOSE_UNSUPPORTED_BIAS_CFG); + VDISPATCH_CONV(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_CONV( + attr()->has_default_values(smask_t::post_ops + | smask_t::sum_dt | smask_t::rounding_mode, + dst_type), + VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV(attr()->post_ops_.check_sum_consistency( + dst_type, /* is_int8 */ false), + VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV( + attr_.set_default_formats(dst_md(0)) == status::success, + VERBOSE_UNSUPPORTED_POSTOP); + + return status::success; } protected: @@ -111,16 +132,31 @@ struct ref_convolution_bwd_data_t : public primitive_t { const auto wei_type = weights_md(0)->data_type; const auto diff_dst_type = diff_dst_md(0)->data_type; - bool ok = desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && platform::has_data_type_support(diff_src_type) - && platform::has_data_type_support(diff_dst_type) - && utils::one_of(diff_dst_type, f32, bf16, f16) - && wei_type == diff_dst_type - && utils::one_of(diff_src_type, f32, diff_dst_type) - && set_default_formats() && attr()->has_default_values(); - - return ok ? status::success : status::unimplemented; + VDISPATCH_CONV(desc()->prop_kind == prop_kind::backward_data, + VERBOSE_BAD_PROPKIND); + VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct), + VERBOSE_BAD_ALGORITHM); + VDISPATCH_CONV(platform::has_data_type_support(diff_src_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(platform::has_data_type_support(diff_dst_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(utils::one_of(diff_dst_type, f32, bf16, f16), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(IMPLICATION(wei_type != diff_dst_type, + utils::one_of(wei_type, f16, bf16) + && diff_dst_type == f32), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(utils::one_of(diff_src_type, f32, diff_dst_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_CONV( + attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_CONV( + attr()->has_default_values(primitive_attr_t::skip_mask_t::post_ops), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_CONV( + is_supported_post_ops(), VERBOSE_UNSUPPORTED_POSTOP); + + return status::success; } protected: @@ -132,9 +168,41 @@ struct ref_convolution_bwd_data_t : public primitive_t { : utils::pick(ndims() - 3, oiw, oihw, oidhw); return set_default_formats_common(dat_tag, wei_tag, dat_tag); } + + bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + if (p.len() > 1) + return false; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise); + } + return ok; + }; + + return all_post_ops_supported(); + } }; - ref_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} + ref_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) { + const auto &post_ops = pd()->attr()->post_ops_; + + for (int i = 0; i < post_ops.len(); i++) { + auto &post_op = post_ops.entry_[i]; + if (post_op.is_depthwise()) { + depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg)); + } + } + } + + ~ref_convolution_bwd_data_t() { + for (auto inj : depthwise_injectors) + delete inj; + depthwise_injectors.clear(); + } status_t execute(const exec_ctx_t &ctx) const override { return execute_backward_data(ctx); @@ -143,6 +211,8 @@ struct ref_convolution_bwd_data_t : public primitive_t { private: status_t execute_backward_data(const exec_ctx_t &ctx) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + nstl::vector depthwise_injectors; }; struct ref_convolution_bwd_weights_t : public primitive_t { @@ -159,17 +229,27 @@ struct ref_convolution_bwd_weights_t : public primitive_t { const auto diff_bia_type = diff_weights_md(1)->data_type; const auto diff_dst_type = diff_dst_md(0)->data_type; - bool ok = desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && platform::has_data_type_support(src_type) - && platform::has_data_type_support(diff_wei_type) - && utils::one_of(src_type, f32, bf16, f16) - && diff_dst_type == src_type - && utils::one_of(diff_wei_type, f32, src_type) - && utils::one_of( - diff_bia_type, data_type::undef, f32, src_type) - && set_default_formats() && attr()->has_default_values(); - return ok ? status::success : status::unimplemented; + VDISPATCH_CONV(desc()->prop_kind == prop_kind::backward_weights, + VERBOSE_BAD_PROPKIND); + VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct), + VERBOSE_BAD_ALGORITHM); + VDISPATCH_CONV(platform::has_data_type_support(src_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(platform::has_data_type_support(diff_wei_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(utils::one_of(src_type, f32, bf16, f16), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(diff_dst_type == src_type, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(utils::one_of(diff_wei_type, f32, src_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(utils::one_of(diff_bia_type, data_type::undef, f32, + src_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_CONV( + attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + + return status::success; } protected: diff --git a/src/cpu/ref_convolution_int8.cpp b/src/cpu/ref_convolution_int8.cpp index b1c99eb8cda..f2c36332888 100644 --- a/src/cpu/ref_convolution_int8.cpp +++ b/src/cpu/ref_convolution_int8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ namespace { void dequantize(float &d, dim_t g, dim_t C, dim_t c, const float *wei_scales, bool with_groups, int wei_mask, const float *src_scales) { // scale_idx_mult = 1 for per_channel scales and 0, otherwise - const int wei_scale_idx_mult = wei_mask != 0; + const int wei_scale_idx_mult = wei_mask > 0; float scale = 1.0f; if (src_scales) scale *= src_scales[0]; if (wei_scales) scale *= wei_scales[(g * C + c) * wei_scale_idx_mult]; @@ -63,8 +63,7 @@ status_t ref_convolution_int8_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); @@ -107,9 +106,9 @@ status_t ref_convolution_int8_fwd_t::execute_forward( // zp_idx_mult = 1 for per_dim1 zero points and 0, otherwise const int src_zp_idx_mult - = !pd()->attr()->zero_points_.common(DNNL_ARG_SRC); + = pd()->attr()->zero_points_.get_mask(DNNL_ARG_SRC) > 0; const int dst_zp_idx_mult - = !pd()->attr()->zero_points_.common(DNNL_ARG_DST); + = pd()->attr()->zero_points_.get_mask(DNNL_ARG_DST) > 0; auto ker = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) { int d = 0; @@ -290,8 +289,7 @@ status_t ref_convolution_int8_bwd_data_t::execute_backward_data( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(diff_dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); diff --git a/src/cpu/ref_convolution_int8.hpp b/src/cpu/ref_convolution_int8.hpp index 86a2b6a1554..b6b650c7ac5 100644 --- a/src/cpu/ref_convolution_int8.hpp +++ b/src/cpu/ref_convolution_int8.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,22 +45,35 @@ struct ref_convolution_int8_fwd_t : public primitive_t { const auto bia_type = weights_md(1)->data_type; const auto dst_type = dst_md(0)->data_type; - bool ok = is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && utils::one_of(src_type, s8, u8) && wei_type == s8 - && IMPLICATION(with_bias(), - utils::one_of(bia_type, f32, bf16, s32, s8, u8)) - && utils::one_of(dst_type, f32, bf16, s32, s8, u8) - && set_default_formats() - && attr()->has_default_values(smask_t::scales_runtime - | smask_t::zero_points_runtime - | smask_t::post_ops | smask_t::sum_dt, - dst_type) - && attr()->post_ops_.check_sum_consistency(dst_type, - /* is_int8 */ true) - && attr_scales_ok() && zero_points_ok() && post_ops_ok() - && attr_.set_default_formats(dst_md(0)) == status::success; - return ok ? status::success : status::unimplemented; + VDISPATCH_CONV(is_fwd(), VERBOSE_BAD_PROPKIND); + VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct), + VERBOSE_BAD_ALGORITHM); + VDISPATCH_CONV( + utils::one_of(src_type, s8, u8), VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(wei_type == s8, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV( + IMPLICATION(with_bias(), + utils::one_of(bia_type, f32, bf16, s32, s8, u8)), + VERBOSE_UNSUPPORTED_BIAS_CFG); + VDISPATCH_CONV(utils::one_of(dst_type, f32, bf16, s32, s8, u8), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_CONV( + attr()->has_default_values(smask_t::scales + | smask_t::zero_points | smask_t::post_ops + | smask_t::sum_dt, + dst_type), + VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_CONV(attr()->post_ops_.check_sum_consistency(dst_type, + /* is_int8 */ true), + VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_CONV(zero_points_ok(), VERBOSE_UNSUPPORTED_ZP_CFG); + VDISPATCH_CONV(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV( + attr_.set_default_formats(dst_md(0)) == status::success, + VERBOSE_UNSUPPORTED_POSTOP); + return status::success; } protected: @@ -74,13 +87,18 @@ struct ref_convolution_int8_fwd_t : public primitive_t { } bool zero_points_ok() const { - int mask_src = 0, mask_dst = 0; - attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); - attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); - - return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) - && (mask_src == 0 || mask_src == 1 << 1) - && (mask_dst == 0 || mask_dst == 1 << 1); + if (!attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) { + int mask_src = attr()->zero_points_.get_mask(DNNL_ARG_SRC); + const bool ok = mask_src == 0 || mask_src == (1 << 1); + if (!ok) return false; + } + if (!attr()->zero_points_.has_default_values(DNNL_ARG_DST)) { + int mask_dst = attr()->zero_points_.get_mask(DNNL_ARG_DST); + const bool ok = mask_dst == 0 || mask_dst == (1 << 1); + if (!ok) return false; + } + + return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS); } bool post_ops_ok() const { @@ -120,16 +138,22 @@ struct ref_convolution_int8_bwd_data_t : public primitive_t { const auto wei_type = weights_md(0)->data_type; const auto diff_dst_type = diff_dst_md(0)->data_type; - bool ok = desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && utils::one_of(diff_dst_type, s8, u8) && wei_type == s8 - && utils::one_of(diff_src_type, f32, bf16, s32, s8, u8) - && set_default_formats() - && attr()->has_default_values( - primitive_attr_t::skip_mask_t::scales_runtime) - && attr_scales_ok(); - - return ok ? status::success : status::unimplemented; + VDISPATCH_CONV(desc()->prop_kind == prop_kind::backward_data, + VERBOSE_BAD_PROPKIND); + VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct), + VERBOSE_BAD_ALGORITHM); + VDISPATCH_CONV(utils::one_of(diff_dst_type, s8, u8), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(wei_type == s8, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(utils::one_of(diff_src_type, f32, bf16, s32, s8, u8), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CONV(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_CONV(attr()->has_default_values( + primitive_attr_t::skip_mask_t::scales), + VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_CONV(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); + + return status::success; } protected: diff --git a/src/cpu/ref_deconvolution.cpp b/src/cpu/ref_deconvolution.cpp index facacbd2ffd..f14126764e0 100644 --- a/src/cpu/ref_deconvolution.cpp +++ b/src/cpu/ref_deconvolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -174,8 +174,7 @@ status_t ref_deconvolution_fwd_t::compute_oscale( DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const memory_desc_wrapper dst_d(pd()->dst_md()); @@ -190,7 +189,7 @@ status_t ref_deconvolution_fwd_t::compute_oscale( const auto maybe_oscale = [](float &d, dim_t oc, const float *src_scales, const float *wei_scales, int wei_mask) { // scale_idx_mult = 1 for per_oc scales and 0, otherwise - const int wei_scale_idx_mult = wei_mask != 0; + const int wei_scale_idx_mult = wei_mask > 0; d *= src_scales[0] * wei_scales[oc * wei_scale_idx_mult]; }; @@ -216,11 +215,14 @@ status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx, auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int dst_scale_mask = pd()->attr()->scales_.get(DNNL_ARG_DST).mask_; + const bool has_dst_scales + = !pd()->attr()->scales_.has_default_values(DNNL_ARG_DST); + const int dst_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_DST); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); - const bool is_dst_zp_common - = pd()->attr()->zero_points_.common(DNNL_ARG_DST); + const bool has_dst_zp + = !pd()->attr()->zero_points_.has_default_values(DNNL_ARG_DST); + const int dst_zp_mask = pd()->attr()->zero_points_.get_mask(DNNL_ARG_DST); const memory_desc_wrapper dst_d(pd()->dst_md()); @@ -232,20 +234,6 @@ status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx, const auto OCP = dst_d.padded_dims()[1]; const auto ndims = pd()->desc()->src_desc.ndims; - const auto maybe_dst_zero_point = [=](float &result, dim_t oc) { - if (is_dst_zp_common) - result += dst_zero_point[0]; - else - result += dst_zero_point[oc]; - }; - - const auto maybe_scale - = [](float &d, dim_t oc, const float *scales, int mask) { - // scale_idx_mult = 1 for per_oc scales and 0, otherwise - const int scale_idx_mult = mask != 0; - d *= scales[oc * scale_idx_mult]; - }; - const auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type()); parallel_nd(MB, OCP, OD, OH, OW, @@ -269,8 +257,13 @@ status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx, args.l_offset = dst_l_off; args.dst_md = pd()->dst_md(); ref_post_ops->execute(tmp_result, args); - maybe_scale(tmp_result, ocp, dst_scales, dst_scale_mask); - maybe_dst_zero_point(tmp_result, ocp); + if (has_dst_scales) { + // scale_idx_mult = 1 for per_oc scales and 0, otherwise + tmp_result *= dst_scales[ocp * (dst_scale_mask > 0)]; + } + if (has_dst_zp) { + tmp_result += dst_zero_point[ocp * (dst_zp_mask > 0)]; + } } io::store_float_value( dst_d.data_type(), tmp_result, dst, dst_off); @@ -300,7 +293,7 @@ dim_t get_weights_off(const memory_desc_wrapper &wei_d, bool with_groups, template static void compute_src_zp_compensation(const exec_ctx_t &ctx, const int32_t *src_zero_point, const bool is_src_zp_common, - typename prec_traits::type *wei, + typename prec_traits_t::type *wei, const cpu_deconvolution_fwd_pd_t *pd) { using namespace memory_tracking::names; @@ -347,7 +340,8 @@ template static std::function prepare_zp_pad_comp_ker(const dim_t ndims, const int32_t *src_zero_point, - const bool is_src_zp_common, typename prec_traits::type *wei, + const bool is_src_zp_common, + typename prec_traits_t::type *wei, const cpu_deconvolution_fwd_pd_t *deconv_pd) { const auto KH = deconv_pd->KH(); @@ -423,7 +417,7 @@ prepare_zp_pad_comp_ker(const dim_t ndims, const int32_t *src_zero_point, template static status_t apply_src_zero_point(const exec_ctx_t &ctx, const cpu_deconvolution_fwd_pd_t *deconv_pd, float *conv_output) { - using wei_data_t = typename prec_traits::type; + using wei_data_t = typename prec_traits_t::type; using namespace memory_tracking::names; using namespace data_type; @@ -432,7 +426,7 @@ static status_t apply_src_zero_point(const exec_ctx_t &ctx, const auto wei = CTX_OUT_MEM(wei_data_t *, DNNL_ARG_WEIGHTS); DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); const bool is_src_zp_common - = deconv_pd->attr()->zero_points_.common(DNNL_ARG_SRC); + = deconv_pd->attr()->zero_points_.get_mask(DNNL_ARG_SRC) == 0; const auto scratchpad = ctx.get_scratchpad_grantor(); const int32_t *const zp_src_compensation @@ -487,9 +481,11 @@ status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const { // Create intermediate memory for f32 output if needed. auto dst = args.at(DNNL_ARG_DST); - memory_t tmp_memory(dst.mem->engine(), pd()->conv_pd_->diff_src_md(), - scratchpad.get_memory_storage(key_deconv_bias)); - memory_arg_t tmp_conv_output = {&tmp_memory, false}; + std::unique_ptr tmp_memory; + CHECK(safe_ptr_assign(tmp_memory, + new memory_t(dst.mem->engine(), pd()->conv_pd_->diff_src_md(), + scratchpad.get_memory_storage(key_deconv_bias)))); + memory_arg_t tmp_conv_output = {tmp_memory.get(), false}; conv_args[DNNL_ARG_DIFF_SRC] = ref_bias || non_default_attr ? tmp_conv_output : dst; @@ -534,11 +530,10 @@ status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const { float *conv_output = scratchpad.get(key_deconv_bias); - const auto &arg_scales = pd()->attr()->scales_; - const auto &src_scales = arg_scales.get(DNNL_ARG_SRC); - const auto &wei_scales = arg_scales.get(DNNL_ARG_WEIGHTS); + const auto &scales = pd()->attr()->scales_; - if (!src_scales.has_default_values() || !wei_scales.has_default_values()) { + if (!scales.has_default_values(DNNL_ARG_SRC) + || !scales.has_default_values(DNNL_ARG_WEIGHTS)) { compute_oscale(ctx, conv_output); } @@ -599,8 +594,8 @@ void ref_deconvolution_bwd_weights_t::compute_bwd_bias( template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw( - typename prec_traits::type *diff_bias, - const typename prec_traits::type *diff_dst) const { + typename prec_traits_t::type *diff_bias, + const typename prec_traits_t::type *diff_dst) const { const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const auto OC = pd()->OC(); @@ -622,8 +617,8 @@ void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw( template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ndhwc( - typename prec_traits::type *diff_bias, - const typename prec_traits::type *diff_dst) const { + typename prec_traits_t::type *diff_bias, + const typename prec_traits_t::type *diff_dst) const { const auto MB = pd()->MB(); const auto SP = pd()->OW() * pd()->OH() * pd()->OD(); const auto OC = pd()->OC(); @@ -637,14 +632,15 @@ void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ndhwc( db += diff_dst[offset]; } } - diff_bias[oc] = static_cast::type>(db); + diff_bias[oc] + = static_cast::type>(db); }); } template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc( - typename prec_traits::type *diff_bias, - const typename prec_traits::type *diff_dst) const { + typename prec_traits_t::type *diff_bias, + const typename prec_traits_t::type *diff_dst) const { const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const auto OC = pd()->OC(); @@ -677,8 +673,8 @@ void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc( template void ref_deconvolution_bwd_weights_t::compute_bias( const exec_ctx_t &ctx) const { - using dbia_data_t = typename prec_traits::type; - using ddst_data_t = typename prec_traits::type; + using dbia_data_t = typename prec_traits_t::type; + using ddst_data_t = typename prec_traits_t::type; auto diff_bias = CTX_OUT_MEM(dbia_data_t *, DNNL_ARG_DIFF_BIAS); auto diff_dst = CTX_IN_MEM(const ddst_data_t *, DNNL_ARG_DIFF_DST); diff --git a/src/cpu/ref_deconvolution.hpp b/src/cpu/ref_deconvolution.hpp index 05f88e54470..04e81e52399 100644 --- a/src/cpu/ref_deconvolution.hpp +++ b/src/cpu/ref_deconvolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * Copyright 2022 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -93,9 +93,7 @@ static status_t conv_descr_create(const deconvolution_desc_t *dd, struct ref_deconvolution_fwd_t : public primitive_t { struct pd_t : public cpu_deconvolution_fwd_pd_t { - pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; pd_t(const pd_t &other) : cpu_deconvolution_fwd_pd_t(other) @@ -104,8 +102,6 @@ struct ref_deconvolution_fwd_t : public primitive_t { , dst_tag_(other.dst_tag_) , name_(other.name_) {} - ~pd_t() = default; - DECLARE_COMMON_PD_T(name_.c_str(), ref_deconvolution_fwd_t); status_t init_convolution(engine_t *engine) { @@ -167,14 +163,23 @@ struct ref_deconvolution_fwd_t : public primitive_t { using smask_t = primitive_attr_t::skip_mask_t; auto skip_mask = smask_t::post_ops | smask_t::sum_dt; if (utils::one_of(desc()->src_desc.data_type, s8, u8)) - skip_mask |= smask_t::scales_runtime - | smask_t::zero_points_runtime; + skip_mask |= smask_t::scales | smask_t::zero_points; VDISPATCH_DECONVOLUTION(is_fwd(), VERBOSE_BAD_PROPKIND); VDISPATCH_DECONVOLUTION(utils::one_of(desc()->alg_kind, alg_kind::deconvolution_direct, alg_kind::deconvolution_winograd), VERBOSE_BAD_ALGORITHM); + // This implementation will check data types requirements through + // an underlying convolution implementation, however, convolution + // might be called without bias, thus, need to check bias data type + // if it was requested. + if (with_bias()) { + const auto bia_type = invariant_wei_md(1)->data_type; + VDISPATCH_DECONVOLUTION(utils::one_of(bia_type, f32, bf16, f16, + f8_e5m2, f8_e4m3), + VERBOSE_UNSUPPORTED_DT); + } VDISPATCH_DECONVOLUTION(attr()->has_default_values(skip_mask), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_DECONVOLUTION( @@ -256,16 +261,25 @@ struct ref_deconvolution_fwd_t : public primitive_t { } bool zero_points_ok() const { + const auto &zp = attr()->zero_points_; + using namespace data_type; - int mask_src = 0, mask_dst = 0; - attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); - attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); - - return IMPLICATION(!utils::one_of(src_md()->data_type, s8, u8), - attr()->zero_points_.has_default_values()) - && attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) - && (mask_src == 0 || mask_src == 1 << 1) - && (mask_dst == 0 || mask_dst == 1 << 1); + bool ok = IMPLICATION(!utils::one_of(src_md()->data_type, s8, u8), + zp.has_default_values()); + if (!ok) return false; + + if (!zp.has_default_values(DNNL_ARG_SRC)) { + int mask_src = zp.get_mask(DNNL_ARG_SRC); + ok = utils::one_of(mask_src, 0, (1 << 1)); + if (!ok) return false; + } + if (!zp.has_default_values(DNNL_ARG_DST)) { + int mask_dst = zp.get_mask(DNNL_ARG_DST); + ok = utils::one_of(mask_dst, 0, (1 << 1)); + if (!ok) return false; + } + + return zp.has_default_values(DNNL_ARG_WEIGHTS); } }; @@ -312,17 +326,13 @@ struct ref_deconvolution_fwd_t : public primitive_t { struct ref_deconvolution_bwd_data_t : public primitive_t { struct pd_t : public cpu_deconvolution_bwd_data_pd_t { - pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_deconvolution_bwd_data_pd_t::cpu_deconvolution_bwd_data_pd_t; pd_t(const pd_t &other) : cpu_deconvolution_bwd_data_pd_t(other) , conv_pd_(other.conv_pd_->clone()) , name_(other.name_) {} - ~pd_t() = default; - DECLARE_COMMON_PD_T(name_.c_str(), ref_deconvolution_bwd_data_t); status_t init_convolution(engine_t *engine) { @@ -357,7 +367,9 @@ struct ref_deconvolution_bwd_data_t : public primitive_t { VERBOSE_BAD_PROPKIND); VDISPATCH_DECONVOLUTION(utils::one_of(wei_type, f32, bf16, f16), VERBOSE_UNSUPPORTED_DT); - VDISPATCH_DECONVOLUTION(ddst_type == wei_type, + VDISPATCH_DECONVOLUTION(IMPLICATION(ddst_type != wei_type, + utils::one_of(wei_type, bf16, f16) + && ddst_type == f32), VERBOSE_INCONSISTENT_DT, "diff_dst", "weights"); VDISPATCH_DECONVOLUTION(utils::one_of(dsrc_type, wei_type, f32), VERBOSE_UNSUPPORTED_DT); @@ -396,7 +408,7 @@ struct ref_deconvolution_bwd_data_t : public primitive_t { } }; - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; ref_deconvolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} @@ -404,7 +416,7 @@ struct ref_deconvolution_bwd_data_t : public primitive_t { return pd()->conv_pd_->create_primitive(conv_p_, engine); } -#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL +#if DNNL_USE_ACL status_t create_resource( engine_t *engine, resource_mapper_t &mapper) const override { CHECK(conv_p_->create_resource(engine, mapper)); @@ -421,9 +433,8 @@ struct ref_deconvolution_bwd_data_t : public primitive_t { struct ref_deconvolution_bwd_weights_t : public primitive_t { struct pd_t : public cpu_deconvolution_bwd_weights_pd_t { - pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {} + using cpu_deconvolution_bwd_weights_pd_t:: + cpu_deconvolution_bwd_weights_pd_t; pd_t(const pd_t &other) : cpu_deconvolution_bwd_weights_pd_t(other) @@ -431,8 +442,6 @@ struct ref_deconvolution_bwd_weights_t : public primitive_t { , dst_tag_(other.dst_tag_) , name_(other.name_) {} - ~pd_t() = default; - DECLARE_COMMON_PD_T(name_.c_str(), ref_deconvolution_bwd_weights_t); status_t init_convolution(engine_t *engine) { @@ -469,18 +478,31 @@ struct ref_deconvolution_bwd_weights_t : public primitive_t { status_t init(engine_t *engine) { using namespace format_tag; using namespace data_type; - auto src_type = desc()->src_desc.data_type; - auto dwei_type = desc()->diff_weights_desc.data_type; - auto ddst_type = desc()->diff_dst_desc.data_type; + auto src_type = invariant_src_md()->data_type; + auto wei_type = invariant_wei_md(0)->data_type; + auto dst_type = invariant_dst_md()->data_type; VDISPATCH_DECONVOLUTION( desc()->prop_kind == prop_kind::backward_weights, VERBOSE_BAD_PROPKIND); VDISPATCH_DECONVOLUTION(utils::one_of(src_type, f32, bf16, f16), VERBOSE_UNSUPPORTED_DT); - VDISPATCH_DECONVOLUTION(ddst_type == src_type, + VDISPATCH_DECONVOLUTION(dst_type == src_type, VERBOSE_INCONSISTENT_DT, "diff_dst", "src"); - VDISPATCH_DECONVOLUTION(utils::one_of(dwei_type, src_type, f32), + VDISPATCH_DECONVOLUTION(utils::one_of(wei_type, src_type, f32), VERBOSE_UNSUPPORTED_DT); + // This implementation will check data types requirements through + // an underlying convolution implementation, however, convolution + // might be called without bias, thus, need to check bias data type + // if it was requested. + if (with_bias()) { + const auto bia_type = invariant_wei_md(1)->data_type; + VDISPATCH_DECONVOLUTION(utils::one_of(bia_type, f32, bf16, f16) + && (bia_type == dst_type + || (bia_type == f32 + && utils::one_of( + dst_type, bf16, f16))), + VERBOSE_UNSUPPORTED_DT); + } VDISPATCH_DECONVOLUTION(utils::one_of(desc()->alg_kind, alg_kind::deconvolution_direct, alg_kind::deconvolution_winograd), @@ -539,18 +561,18 @@ struct ref_deconvolution_bwd_weights_t : public primitive_t { template void compute_bwd_bias_ncdhw( - typename prec_traits::type *diff_bias, - const typename prec_traits::type *diff_dst) const; + typename prec_traits_t::type *diff_bias, + const typename prec_traits_t::type *diff_dst) const; template void compute_bwd_bias_ndhwc( - typename prec_traits::type *diff_bias, - const typename prec_traits::type *diff_dst) const; + typename prec_traits_t::type *diff_bias, + const typename prec_traits_t::type *diff_dst) const; template void compute_bwd_bias_nCdhwXc( - typename prec_traits::type *diff_bias, - const typename prec_traits::type *diff_dst) const; + typename prec_traits_t::type *diff_bias, + const typename prec_traits_t::type *diff_dst) const; template void compute_bias(const exec_ctx_t &ctx) const; diff --git a/src/cpu/ref_depthwise_injector.cpp b/src/cpu/ref_depthwise_injector.cpp new file mode 100644 index 00000000000..585b661324a --- /dev/null +++ b/src/cpu/ref_depthwise_injector.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "ref_depthwise_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +using namespace alg_kind; + +template inline T scale_shift_fwd(T s_val, T w_val, T b_val) { + return s_val*w_val + b_val; +} + +template inline T prelu_fwd(T s_val, T w_val) { + return s_val >= 0 ? s_val : s_val*w_val; +} + +union float_raw { + float f; + unsigned short i[2]; +}; + +static float bf16tof32(bfloat16_t bf16) { + union float_raw t = { 0 }; + t.i[1] = bf16; + t.i[0] = 0; + return t.f; +} + +static bfloat16_t f32tobf16(float f32) { + union float_raw t = { 0 }; + t.f = f32; + return t.i[1]; +} + +inline bfloat16_t bf16_scale_shift_fwd(bfloat16_t s_val, bfloat16_t w_val, bfloat16_t b_val) { + return f32tobf16(bf16tof32(s_val) * bf16tof32(w_val) + bf16tof32(b_val)); +} + +inline bfloat16_t bf16_prelu_fwd(bfloat16_t s_val, bfloat16_t w_val) { + return s_val >= 0 ? s_val : f32tobf16(bf16tof32(s_val) * bf16tof32(w_val)); +} + +ref_depthwise_scalar_fwd_t::ref_depthwise_scalar_fwd_t(const alg_kind_t alg_) + : alg(alg_) { + using namespace alg_kind; + + assert(utils::one_of(alg, depthwise_scale_shift, depthwise_prelu)); +} + +float ref_depthwise_scalar_fwd_t::compute_scalar(float s, const float* weights, const float* bias) const { + switch (alg) { + case depthwise_scale_shift: return scale_shift_fwd(s, *weights, *bias); + case depthwise_prelu: return prelu_fwd(s, *weights); + default: assert(!"unknown depthwise alg_kind"); + } + + return 0.0f; +} + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/ref_depthwise_injector.hpp b/src/cpu/ref_depthwise_injector.hpp new file mode 100644 index 00000000000..1a56e28cdc2 --- /dev/null +++ b/src/cpu/ref_depthwise_injector.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_DEPTHWISE_INJECTOR_HPP +#define REF_DEPTHWISE_INJECTOR_HPP + +#include "common/primitive.hpp" +#include "common/primitive_attr.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +struct ref_depthwise_scalar_fwd_t { +public: + explicit ref_depthwise_scalar_fwd_t(alg_kind_t alg); + float compute_scalar(float s, const float* weights, const float* bias) const; + +private: + alg_kind_t alg; +}; + +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/ref_eltwise.hpp b/src/cpu/ref_eltwise.hpp index df0724e40b9..2adaa11c32c 100644 --- a/src/cpu/ref_eltwise.hpp +++ b/src/cpu/ref_eltwise.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -96,7 +96,7 @@ struct ref_eltwise_fwd_t : public primitive_t { return status::success; } - using data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { if (pd()->use_dense_) @@ -172,7 +172,7 @@ struct ref_eltwise_bwd_t : public primitive_t { }; ref_eltwise_bwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { if (pd()->use_dense_) diff --git a/src/cpu/ref_fused_convolution.hpp b/src/cpu/ref_fused_convolution.hpp index 5fa764fcf3b..c01e2d1d008 100644 --- a/src/cpu/ref_fused_convolution.hpp +++ b/src/cpu/ref_fused_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * Copyright 2022 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -74,23 +74,17 @@ struct ref_fused_convolution_fwd_t : public primitive_t { }; struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) { - name_ = "ref_fused_convolution:any"; - } - - pd_t(const pd_t &other) = default; + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T(name_.c_str(), ref_fused_convolution_fwd_t); virtual status_t init(engine_t *engine) { using namespace primitive_kind; - bool ok = true && is_fwd() - && attr()->post_ops_.has_default_values( - {binary, eltwise, convolution}); - if (!ok) return status::unimplemented; + VDISPATCH_CONV(is_fwd(), VERBOSE_BAD_PROPKIND); + VDISPATCH_CONV(attr()->post_ops_.has_default_values( + {binary, eltwise, convolution}), + VERBOSE_UNSUPPORTED_ATTR); CHECK(init_ops(engine)); init_name(); @@ -99,21 +93,29 @@ struct ref_fused_convolution_fwd_t : public primitive_t { const memory_desc_t *src_md( int index = 0, bool user_input = false) const override { + if (op_pds_.empty()) + return cpu_convolution_fwd_pd_t::src_md(index, user_input); return op_pds_.front()->src_md(index, user_input); } const memory_desc_t *dst_md( int index = 0, bool user_input = false) const override { + if (op_pds_.empty()) + return cpu_convolution_fwd_pd_t::dst_md(index, user_input); return op_pds_.back()->dst_md(index, user_input); } const memory_desc_t *weights_md( int index = 0, bool user_input = false) const override { + if (op_pds_.empty()) + return cpu_convolution_fwd_pd_t::weights_md(index, user_input); return op_pds_.front()->weights_md(index, user_input); // for now } const memory_desc_t *arg_md( int arg, bool user_input = false) const override { + if (op_pds_.empty()) + return cpu_convolution_fwd_pd_t::arg_md(arg, user_input); // Binary post-op: // format_tag::any should be supported here since output dst_md // may be different from the intermediate one and they should be @@ -157,9 +159,9 @@ struct ref_fused_convolution_fwd_t : public primitive_t { if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) return arg_usage_t::input; - if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) - && attr_post_op_dw_inputs() > 1) - return arg_usage_t::input; + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)) + return attr_post_op_dw_inputs() > 1 ? arg_usage_t::input + : arg_usage_t::unused; if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_SRC)) return arg_usage_t::input; @@ -175,7 +177,7 @@ struct ref_fused_convolution_fwd_t : public primitive_t { std::vector args_; private: - std::string name_; + std::string name_ = "ref_fused_convolution:any"; const unsigned int max_fusions_ = 1; status_t append_op(std::shared_ptr &op_pd, @@ -222,10 +224,10 @@ struct ref_fused_convolution_fwd_t : public primitive_t { primitive_attr_t attr_1x1(*attr()); // erase dw_conv post-op scales for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - auto &scale - = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | arg); - if (!scale.has_default_values()) - attr_1x1.scales_.reset(DNNL_ARG_ATTR_POST_OP_DW | arg); + if (!attr_1x1.scales_.has_default_values( + DNNL_ARG_ATTR_POST_OP_DW | arg)) + CHECK(attr_1x1.scales_.set(DNNL_ARG_ATTR_POST_OP_DW | arg, + default_quant_entry())); } // erase post-ops after fusion as they will be handled separately auto &e = attr_1x1.post_ops_.entry_; @@ -248,7 +250,7 @@ struct ref_fused_convolution_fwd_t : public primitive_t { arg_cache.append_ctx_arg(DNNL_ARG_SRC); arg_cache.append_ctx_arg(DNNL_ARG_WEIGHTS); for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) - if (!attr_1x1.scales_.get(arg).has_default_values()) + if (!attr_1x1.scales_.has_default_values(arg)) arg_cache.append_ctx_arg(DNNL_ARG_ATTR_SCALES | arg); if (desc()->bias_desc.data_type != data_type::undef) arg_cache.append_ctx_arg(DNNL_ARG_BIAS); @@ -314,12 +316,12 @@ struct ref_fused_convolution_fwd_t : public primitive_t { arg_cache.append_ctx_arg(DNNL_ARG_WEIGHTS, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); for (auto arg : {DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) - if (!attr_dw.scales_.get(arg).has_default_values()) + if (!attr_dw.scales_.has_default_values(arg)) arg_cache.append_ctx_arg(DNNL_ARG_ATTR_SCALES | arg, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_SCALES | arg); // dw_conv src_scale = 1x1_conv dst_scale - if (!attr_1x1.scales_.get(DNNL_ARG_DST).has_default_values()) + if (!attr_1x1.scales_.has_default_values(DNNL_ARG_DST)) arg_cache.append_ctx_arg( DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); @@ -387,7 +389,7 @@ struct ref_fused_convolution_fwd_t : public primitive_t { return status::success; } -#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL +#if DNNL_USE_ACL status_t create_resource( engine_t *engine, resource_mapper_t &mapper) const override { for (auto &p : primitives_) { @@ -406,7 +408,7 @@ struct ref_fused_convolution_fwd_t : public primitive_t { const auto &ctx_args = ctx.args(); const auto op_count = primitives_.size(); - std::vector> inout_memory; + std::vector> inout_memory; for (size_t i = 0; i < op_count; ++i) { const auto &op = primitives_[i]; diff --git a/src/cpu/ref_group_normalization.hpp b/src/cpu/ref_group_normalization.hpp index 6a4e0aba676..86397164f98 100644 --- a/src/cpu/ref_group_normalization.hpp +++ b/src/cpu/ref_group_normalization.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,9 +50,8 @@ struct ref_group_normalization_fwd_t : public primitive_t { && platform::has_data_type_support( dst_md()->data_type), VERBOSE_UNSUPPORTED_DT); - VDISPATCH_GNORM( - attr()->has_default_values(skip_mask_t::scales_runtime - | skip_mask_t::post_ops), + VDISPATCH_GNORM(attr()->has_default_values(skip_mask_t::scales + | skip_mask_t::post_ops), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_GNORM(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_GNORM(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); diff --git a/src/cpu/ref_inner_product.hpp b/src/cpu/ref_inner_product.hpp index 98f8df93557..042be1fb7ed 100644 --- a/src/cpu/ref_inner_product.hpp +++ b/src/cpu/ref_inner_product.hpp @@ -57,7 +57,8 @@ struct ref_inner_product_fwd_t : public primitive_t { VERBOSE_UNSUPPORTED_DT); VDISPATCH_INNER_PRODUCT(platform::has_data_type_support(dst_type), VERBOSE_UNSUPPORTED_DT); - VDISPATCH_INNER_PRODUCT(utils::one_of(src_type, f32, bf16, f16), + VDISPATCH_INNER_PRODUCT( + utils::one_of(src_type, f32, bf16, f16, f8_e5m2, f8_e4m3), VERBOSE_UNSUPPORTED_DT); VDISPATCH_INNER_PRODUCT(wei_type == src_type, VERBOSE_INCONSISTENT_DT, "weights", "src"); diff --git a/src/cpu/ref_inner_product_int8.cpp b/src/cpu/ref_inner_product_int8.cpp index 91198c680ab..322f39da638 100644 --- a/src/cpu/ref_inner_product_int8.cpp +++ b/src/cpu/ref_inner_product_int8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,13 +76,12 @@ status_t ref_inner_product_int8_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); const auto &attr_scales = pd()->attr()->scales_; - const bool with_dst_scales - = !attr_scales.get(DNNL_ARG_DST).has_default_values(); + const bool with_dst_scales = !attr_scales.has_default_values(DNNL_ARG_DST); auto maybe_oscale = [&](float &d, dim_t oc) { // scale_idx_mult = 1 for per_oc scales and 0, otherwise const int scale_idx_mult - = attr_scales.get(DNNL_ARG_WEIGHTS).mask_ == (1 << 0); + = attr_scales.get_mask(DNNL_ARG_WEIGHTS) == (1 << 0); d *= src_scales[0] * wei_scales[oc * scale_idx_mult]; }; diff --git a/src/cpu/ref_inner_product_int8.hpp b/src/cpu/ref_inner_product_int8.hpp index f905715803f..4f16d2e368d 100644 --- a/src/cpu/ref_inner_product_int8.hpp +++ b/src/cpu/ref_inner_product_int8.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ struct ref_inner_product_int8_fwd_t : public primitive_t { set_default_params(allow_all_tags) == status::success, VERBOSE_UNSUPPORTED_TAG); VDISPATCH_INNER_PRODUCT( - attr()->has_default_values(smask_t::scales_runtime + attr()->has_default_values(smask_t::scales | smask_t::post_ops | smask_t::sum_dt), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_INNER_PRODUCT( diff --git a/src/cpu/ref_io_helper.hpp b/src/cpu/ref_io_helper.hpp index fc5ddb22998..046425bc044 100644 --- a/src/cpu/ref_io_helper.hpp +++ b/src/cpu/ref_io_helper.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ inline int load_int_value(data_type_t dt, const void *ptr, dim_t idx) { #define CASE(dt) \ case dt: \ return static_cast( \ - reinterpret_cast::type *>( \ + reinterpret_cast::type *>( \ ptr)[idx]); using namespace data_type; @@ -44,17 +44,15 @@ inline int load_int_value(data_type_t dt, const void *ptr, dim_t idx) { CASE(s8); CASE(u8); case s4: { - const auto shift = idx % 2 ? int4_extract_t::high_half - : int4_extract_t::low_half; - auto val = int4_t::extract( - reinterpret_cast(ptr)[idx / 2], shift); + const nibble2_t nibble_pair( + reinterpret_cast(ptr)[idx / 2]); + int4_t val(nibble_pair.get(idx % 2)); return static_cast(val); } case u4: { - const auto shift = idx % 2 ? int4_extract_t::high_half - : int4_extract_t::low_half; - auto val = uint4_t::extract( - reinterpret_cast(ptr)[idx / 2], shift); + const nibble2_t nibble_pair( + reinterpret_cast(ptr)[idx / 2]); + uint4_t val(nibble_pair.get(idx % 2)); return static_cast(val); } default: assert(!"bad data_type"); @@ -64,12 +62,12 @@ inline int load_int_value(data_type_t dt, const void *ptr, dim_t idx) { return INT_MAX; } -inline float load_float_value(data_type_t dt, const void *ptr, dim_t idx) { +FORCE_INLINE float load_float_value(data_type_t dt, const void *ptr, dim_t idx) { assert(ptr); #define CASE(dt) \ case dt: \ return static_cast( \ - reinterpret_cast::type *>( \ + reinterpret_cast::type *>( \ ptr)[idx]); using namespace data_type; @@ -84,17 +82,27 @@ inline float load_float_value(data_type_t dt, const void *ptr, dim_t idx) { CASE(u8); CASE(e8m0); case s4: { - const auto shift = idx % 2 ? int4_extract_t::high_half - : int4_extract_t::low_half; - auto val = int4_t::extract( - reinterpret_cast(ptr)[idx / 2], shift); + const nibble2_t nibble_pair( + static_cast(ptr)[idx / 2]); + int4_t val(nibble_pair.get(idx % 2)); return static_cast(val); } case u4: { - const auto shift = idx % 2 ? int4_extract_t::high_half - : int4_extract_t::low_half; - auto val = uint4_t::extract( - reinterpret_cast(ptr)[idx / 2], shift); + const nibble2_t nibble_pair( + static_cast(ptr)[idx / 2]); + uint4_t val(nibble_pair.get(idx % 2)); + return static_cast(val); + } + case f4_e2m1: { + const nibble2_t nibble_pair + = reinterpret_cast(ptr)[idx / 2]; + float4_e2m1_t val(nibble_pair.get(idx % 2), true); + return static_cast(val); + } + case f4_e3m0: { + const nibble2_t nibble_pair + = reinterpret_cast(ptr)[idx / 2]; + float4_e3m0_t val(nibble_pair.get(idx % 2), true); return static_cast(val); } default: assert(!"bad data_type"); @@ -108,7 +116,7 @@ inline void store_float_value(data_type_t dt, float val, void *ptr, dim_t idx) { assert(ptr); #define CASE(dt) \ case dt: { \ - using type_ = typename prec_traits
::type; \ + using type_ = typename prec_traits_t
::type; \ *(reinterpret_cast(ptr) + idx) \ = cpu::q10n::saturate_and_round(val); \ } break; @@ -123,6 +131,22 @@ inline void store_float_value(data_type_t dt, float val, void *ptr, dim_t idx) { CASE(s32); CASE(s8); CASE(u8); + case f4_e2m1: { + auto dst_ = reinterpret_cast(ptr); + nibble2_t nibble_pair = dst_[idx / 2]; + float4_e2m1_t f4_val(val); + nibble_pair.set(f4_val.raw_bits_, idx % 2); + dst_[idx / 2] = nibble_pair; + break; + } + case f4_e3m0: { + auto dst_ = reinterpret_cast(ptr); + nibble2_t nibble_pair = dst_[idx / 2]; + float4_e3m0_t f4_val(val); + nibble_pair.set(f4_val.raw_bits_, idx % 2); + dst_[idx / 2] = nibble_pair; + break; + } default: assert(!"bad data_type"); } diff --git a/src/cpu/ref_layer_normalization.hpp b/src/cpu/ref_layer_normalization.hpp index e6865cb7546..20cf4eb0fde 100644 --- a/src/cpu/ref_layer_normalization.hpp +++ b/src/cpu/ref_layer_normalization.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,9 +63,8 @@ struct ref_layer_normalization_fwd_t : public primitive_t { VDISPATCH_LNORM(check_scale_shift_data_type({f32, bf16, f16}), VERBOSE_UNSUPPORTED_FEATURE, "unsupported scale or shift data type"); - VDISPATCH_LNORM( - attr()->has_default_values(skip_mask_t::scales_runtime - | skip_mask_t::post_ops), + VDISPATCH_LNORM(attr()->has_default_values(skip_mask_t::scales + | skip_mask_t::post_ops), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_LNORM(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_LNORM(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); diff --git a/src/cpu/ref_lrn.hpp b/src/cpu/ref_lrn.hpp index 85dfa0c9b97..6fe97419eb6 100644 --- a/src/cpu/ref_lrn.hpp +++ b/src/cpu/ref_lrn.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,7 +69,7 @@ struct ref_lrn_fwd_t : public primitive_t { }; ref_lrn_fwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { using namespace format_tag; @@ -127,7 +127,7 @@ struct ref_lrn_bwd_t : public primitive_t { }; ref_lrn_bwd_t(const pd_t *apd) : primitive_t(apd) {} - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { using namespace format_tag; diff --git a/src/cpu/ref_pooling.cpp b/src/cpu/ref_pooling.cpp index 00dcb566860..6c3368f0aad 100644 --- a/src/cpu/ref_pooling.cpp +++ b/src/cpu/ref_pooling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,13 +43,13 @@ static inline dim_t get_offset(const memory_desc_wrapper &mdw, dim_t n, dim_t c, using namespace nstl; -template -status_t ref_pooling_fwd_t::execute_forward( +template +status_t ref_pooling_fwd_t::execute_forward( const exec_ctx_t &ctx) const { status_t status = status::success; - auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); - auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); + auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); + auto dst = CTX_OUT_CLEAN_MEM(dst_data_t *, DNNL_ARG_DST, status); CHECK(status); auto ws = CTX_OUT_CLEAN_MEM(unsigned char *, DNNL_ARG_WORKSPACE, status); CHECK(status); @@ -89,7 +89,7 @@ status_t ref_pooling_fwd_t::execute_forward( const auto off = get_offset(ws_d, mb, oc, od, oh, ow); if (ws_dt == data_type::u8) { assert(0 <= value - && value <= numeric_limits::type>::max()); ws[off] = value; } else @@ -167,12 +167,39 @@ status_t ref_pooling_fwd_t::execute_forward( * (KW - iw_start_excluded - iw_end_excluded); } d /= num_summands; + + const auto &p = pd()->attr()->post_ops_; + for (int i = 0; i < p.len(); i++) { + auto &post_op = p.entry_[i]; + if (post_op.is_quantization()) { + auto quant = post_op.quantization; + auto quantization_base = CTX_IN_MEM(const float *, (DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1)); + const auto crop_low_data = quantization_base + quant.offset[quant.crop_low]; + const auto crop_high_data = quantization_base + quant.offset[quant.crop_high]; + const auto inp_scale_data = quantization_base + quant.offset[quant.inp_scale]; + const auto inp_shift_data = quantization_base + quant.offset[quant.inp_shift]; + const auto output_scale_data = quantization_base + quant.offset[quant.output_scale]; + const auto output_shift_data = quantization_base + quant.offset[quant.output_shift]; + + float cl = crop_low_data[!quant.per_channel[quant.crop_low] ? 0 : oc]; + float ch = crop_high_data[!quant.per_channel[quant.crop_high] ? 0 : oc]; + float isc = inp_scale_data[!quant.per_channel[quant.inp_scale] ? 0 : oc]; + float ish = inp_shift_data[!quant.per_channel[quant.inp_shift] ? 0 : oc]; + float osc = output_scale_data[!quant.per_channel[quant.output_scale] ? 0 : oc]; + float osh = output_shift_data[!quant.per_channel[quant.output_shift] ? 0 : oc]; + + d = nstl::min(ch, nstl::max(cl, d)); + d = d * isc + ish; + d = roundf(d); + d = d * osc + osh; + } + } }; const bool is_max_pool = alg == alg_kind::pooling_max; float base_res - = is_max_pool ? (float)numeric_limits::lowest() : 0.f; + = is_max_pool ? (float)numeric_limits::lowest() : 0.f; using ker_t = std::function; ker_t kernel = is_max_pool ? (ker_t)ker_max : (ker_t)ker_avg; @@ -191,7 +218,7 @@ status_t ref_pooling_fwd_t::execute_forward( args.dst_md = pd()->dst_md(); ref_post_ops->execute(res, args); - dst[data_p_off] = cpu::q10n::saturate_and_round(res); + dst[data_p_off] = cpu::q10n::saturate_and_round(res); }); return status::success; @@ -371,14 +398,16 @@ status_t ref_pooling_bwd_t::execute(const exec_ctx_t &ctx) const { return status::success; } -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; } // namespace cpu } // namespace impl diff --git a/src/cpu/ref_pooling.hpp b/src/cpu/ref_pooling.hpp index d6e89f5b195..c11de0703b5 100644 --- a/src/cpu/ref_pooling.hpp +++ b/src/cpu/ref_pooling.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ namespace dnnl { namespace impl { namespace cpu { -template +template struct ref_pooling_fwd_t : public primitive_t { struct pd_t : public cpu_pooling_fwd_pd_t { using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; @@ -43,23 +43,29 @@ struct ref_pooling_fwd_t : public primitive_t { status_t init(engine_t *engine) { using sm = primitive_attr_t::skip_mask_t; - VDISPATCH_POOLING(platform::has_data_type_support(data_type), + VDISPATCH_POOLING(platform::has_data_type_support(src_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_POOLING(platform::has_data_type_support(dst_type), VERBOSE_UNSUPPORTED_DT); VDISPATCH_POOLING(set_default_params() == status::success, VERBOSE_UNSUPPORTED_TAG); VDISPATCH_POOLING(is_fwd(), VERBOSE_BAD_PROPKIND); - VDISPATCH_POOLING(utils::everyone_is(data_type, src_md()->data_type, - dst_md()->data_type), + VDISPATCH_POOLING(utils::everyone_is(src_type, src_md()->data_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_POOLING(utils::everyone_is(dst_type, dst_md()->data_type), VERBOSE_UNSUPPORTED_DT); VDISPATCH_POOLING(desc()->accum_data_type == acc_type, VERBOSE_UNSUPPORTED_DT); VDISPATCH_POOLING(attr()->has_default_values(sm::post_ops), VERBOSE_UNSUPPORTED_ATTR); + // VDISPATCH_POOLING( + // ref_post_ops_t::primitive_kind_ok(attr()->post_ops_), + // VERBOSE_UNSUPPORTED_POSTOP); VDISPATCH_POOLING( - ref_post_ops_t::primitive_kind_ok(attr()->post_ops_), + attr_.set_default_formats(dst_md(0)) == status::success, VERBOSE_UNSUPPORTED_POSTOP); VDISPATCH_POOLING( - attr_.set_default_formats(dst_md(0)) == status::success, + is_supported_post_ops(), VERBOSE_UNSUPPORTED_POSTOP); bool is_training = desc_.prop_kind == prop_kind::forward_training; @@ -68,6 +74,24 @@ struct ref_pooling_fwd_t : public primitive_t { return status::success; } + + virtual bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::quantization); + } + return ok; + }; + + return all_post_ops_supported() && + IMPLICATION(p.len() > 0, (desc()->alg_kind == dnnl_pooling_avg_include_padding || desc()->alg_kind == dnnl_pooling_avg_exclude_padding) && + src_type != data_type::bf16); + + } }; ref_pooling_fwd_t(const pd_t *apd) : primitive_t(apd) {} @@ -80,8 +104,9 @@ struct ref_pooling_fwd_t : public primitive_t { return status::success; } - using data_t = typename prec_traits::type; - using acc_data_t = typename prec_traits::type; + using src_data_t = typename prec_traits_t::type; + using dst_data_t = typename prec_traits_t::type; + using acc_data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); diff --git a/src/cpu/ref_reduction.hpp b/src/cpu/ref_reduction.hpp index 5017b5c57ef..e59033d3461 100644 --- a/src/cpu/ref_reduction.hpp +++ b/src/cpu/ref_reduction.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,9 +75,9 @@ struct ref_reduction_t : public primitive_t { return status::success; } - using src_t = typename prec_traits::type; - using acc_t = typename prec_traits::type; - using dst_t = typename prec_traits::type; + using src_t = typename prec_traits_t::type; + using acc_t = typename prec_traits_t::type; + using dst_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_ref(ctx); diff --git a/src/cpu/ref_resampling.cpp b/src/cpu/ref_resampling.cpp index 740ae062084..63456c620b1 100644 --- a/src/cpu/ref_resampling.cpp +++ b/src/cpu/ref_resampling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ template load_fn_t create_load() { return [](const byte *base, dim_t offset) -> float { return static_cast( - reinterpret_cast::type *>( + reinterpret_cast::type *>( base)[offset]); }; } @@ -55,7 +55,7 @@ load_fn_t create_load() { } template store_fn_t create_store() { - using dst_t = typename prec_traits::type; + using dst_t = typename prec_traits_t::type; return [](const float val, byte *base, const dim_t offset) { *reinterpret_cast(base + sizeof(dst_t) * offset) = cpu::q10n::saturate_and_round(val); diff --git a/src/cpu/ref_resampling.hpp b/src/cpu/ref_resampling.hpp index bb0c4e63465..cc6941ca58e 100644 --- a/src/cpu/ref_resampling.hpp +++ b/src/cpu/ref_resampling.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,7 +67,8 @@ struct ref_resampling_fwd_t : public primitive_t { }; ref_resampling_fwd_t(const pd_t *apd); - ~ref_resampling_fwd_t(); + + ~ref_resampling_fwd_t() override; status_t init(engine_t *engine) override { ref_post_ops_ @@ -114,7 +115,8 @@ struct ref_resampling_bwd_t : public primitive_t { }; ref_resampling_bwd_t(const pd_t *apd); - ~ref_resampling_bwd_t(); + + ~ref_resampling_bwd_t() override; status_t execute(const exec_ctx_t &ctx) const override { execute_backward(ctx); diff --git a/src/cpu/ref_shuffle.cpp b/src/cpu/ref_shuffle.cpp index 0e7d86eeb82..81f8b58296f 100644 --- a/src/cpu/ref_shuffle.cpp +++ b/src/cpu/ref_shuffle.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ template status_t ref_shuffle_t::execute_(const exec_ctx_t &ctx) const { using namespace prop_kind; using namespace utils; - using data_t = typename typesize_traits::type; + using data_t = typename typesize_traits_t::type; const memory_desc_wrapper src_d( pd()->is_fwd() ? pd()->src_md() : pd()->diff_src_md()); diff --git a/src/cpu/ref_shuffle.hpp b/src/cpu/ref_shuffle.hpp index 5d2adf13407..168c7cd6170 100644 --- a/src/cpu/ref_shuffle.hpp +++ b/src/cpu/ref_shuffle.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -92,7 +92,7 @@ struct ref_shuffle_t : public primitive_t { return dnnl_success; } - ~ref_shuffle_t() { free(rev_transposed_); } + ~ref_shuffle_t() override { free(rev_transposed_); } status_t execute(const exec_ctx_t &ctx) const override { const memory_desc_wrapper src_d( diff --git a/src/cpu/ref_softmax.cpp b/src/cpu/ref_softmax.cpp index 93c93f13a02..b709ddf88ea 100644 --- a/src/cpu/ref_softmax.cpp +++ b/src/cpu/ref_softmax.cpp @@ -53,7 +53,9 @@ status_t ref_softmax_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const { const memory_desc_wrapper src_d(pd()->src_md()); const memory_desc_wrapper dst_d(pd()->dst_md()); - const auto interim_dt = data_type::f32; + const auto interim_dt = pd()->need_intermediate_scratchpad() + ? data_type::f32 + : dst_d.data_type(); const auto is_inplace = (src == dst); const auto has_padding = is_padding(dst_d); const auto zero_padding = has_padding && !is_inplace; @@ -210,7 +212,9 @@ status_t ref_softmax_fwd_t::execute_forward_generic( void *interim_ptr = pd()->need_intermediate_scratchpad() ? interim_scratchpad : dst; - const auto interim_dt = data_type::f32; + const auto interim_dt = pd()->need_intermediate_scratchpad() + ? data_type::f32 + : dst_d.data_type(); const auto is_inplace = (src == dst); const auto has_padding = is_padding(dst_d); if (has_padding && !is_inplace) { diff --git a/src/cpu/ref_softmax.hpp b/src/cpu/ref_softmax.hpp index 8f05eb08b78..de75fee361d 100644 --- a/src/cpu/ref_softmax.hpp +++ b/src/cpu/ref_softmax.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,11 +30,6 @@ #include "cpu/cpu_softmax_pd.hpp" -#define VCHECK_SOFTMAX(cond, msg, ...) \ - VCONDCHECK(primitive, create, dispatch, softmax, (cond), \ - status::unimplemented, "%s," msg, this->info(engine), \ - ##__VA_ARGS__) - namespace dnnl { namespace impl { namespace cpu { @@ -49,26 +44,30 @@ struct ref_softmax_fwd_t : public primitive_t { using namespace data_type; using skip_mask_t = primitive_attr_t::skip_mask_t; - bool ok = is_fwd() - && utils::one_of( - src_md()->data_type, f32, bf16, f16, s8, u8) - && utils::one_of( - dst_md()->data_type, f32, bf16, f16, s8, u8) - && platform::has_data_type_support(src_md()->data_type) - && platform::has_data_type_support(dst_md()->data_type); - if (!ok) return status::unimplemented; - - VCHECK_SOFTMAX( - attr()->has_default_values(skip_mask_t::scales_runtime - | skip_mask_t::post_ops), + VDISPATCH_SOFTMAX(is_fwd(), VERBOSE_BAD_PROPKIND); + VDISPATCH_SOFTMAX( + platform::has_data_type_support(src_md()->data_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_SOFTMAX( + platform::has_data_type_support(dst_md()->data_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_SOFTMAX( + utils::one_of(src_md()->data_type, f32, bf16, f16, s8, u8), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_SOFTMAX( + utils::one_of(dst_md()->data_type, f32, bf16, f16, s8, u8), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_SOFTMAX(attr()->has_default_values(skip_mask_t::scales + | skip_mask_t::post_ops), VERBOSE_UNSUPPORTED_ATTR); - VCHECK_SOFTMAX(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); - VCHECK_SOFTMAX(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); -#undef VCHECK_SOFTMAX + VDISPATCH_SOFTMAX(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_SOFTMAX(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); - ok = set_default_formats() == status::success - && attr_.set_default_formats(dst_md(0)) == status::success; - if (!ok) return status::unimplemented; + VDISPATCH_SOFTMAX(set_default_formats() == status::success, + VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_SOFTMAX( + attr_.set_default_formats(dst_md(0)) == status::success, + VERBOSE_UNSUPPORTED_POSTOP); nthr_ = 0; init_scratchpad(); @@ -79,9 +78,19 @@ struct ref_softmax_fwd_t : public primitive_t { int nthr_; // To not exceed the limit in execute used for set up. bool need_intermediate_scratchpad() const { - return dst_md()->data_type - != types::default_accum_data_type( - src_md()->data_type, dst_md()->data_type); + const auto src_dt = src_md()->data_type; + const auto dst_dt = dst_md()->data_type; + // Relaxed accumulation allows to downconvert intermediate results + // directly from xf16 or xf8 to dst avoiding scratchpad memory. + const bool relaxed_acc = src_dt == dst_dt + && !types::is_integral_dt(dst_dt) + && utils::one_of(attr()->acc_mode_, + accumulation_mode::relaxed, accumulation_mode::any); + const bool need_scratchpad = dst_md()->data_type + != types::default_accum_data_type( + src_md()->data_type, dst_md()->data_type) + && !relaxed_acc; + return need_scratchpad; } private: diff --git a/src/cpu/ref_sum.hpp b/src/cpu/ref_sum.hpp index 917308f6122..e0e2cfb8669 100644 --- a/src/cpu/ref_sum.hpp +++ b/src/cpu/ref_sum.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ struct ref_sum_t : public primitive_t { reorder_pds_.resize(n_ + need_output_reorder()); for (int i = 0; i < n_; ++i) { primitive_attr_t r_attr; - r_attr.scales_.set(DNNL_ARG_SRC, 0); + CHECK(r_attr.scales_.set(DNNL_ARG_SRC, 0)); if (i != 0) r_attr.post_ops_.append_sum(1.0); CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine, src_md(i), dst_acc_md(), &r_attr)); @@ -97,9 +97,10 @@ struct ref_sum_t : public primitive_t { scales_mem_.resize(n); for (size_t i = 0; i < n; ++i) - scales_mem_[i] = std::make_shared(get_service_engine(), - &scales_md, use_runtime_ptr, - const_cast(&(scales[i]))); + CHECK(safe_ptr_assign(scales_mem_[i], + new memory_t(get_service_engine(), &scales_md, + use_runtime_ptr, + const_cast(&(scales[i]))))); return status::success; } @@ -116,14 +117,17 @@ struct ref_sum_t : public primitive_t { key_sum_reduction) : nullptr; auto dst = ctx.args().at(DNNL_ARG_DST); - memory_t acc( - dst.mem->engine(), pd()->dst_acc_md(), std::move(sum_reduce)); - memory_arg_t dst_acc = {&acc, false}; + + std::unique_ptr acc; + CHECK(safe_ptr_assign(acc, + new memory_t(dst.mem->engine(), pd()->dst_acc_md(), + std::move(sum_reduce)))); + memory_arg_t dst_acc = {acc.get(), false}; /* fix: clang MemorySanitizer: use-of-uninitialized-value */ if (pd()->need_output_reorder()) { - const memory_desc_wrapper acc_d(acc.md()); - std::memset(acc.memory_storage()->data_handle(), 0, acc_d.size()); + const memory_desc_wrapper acc_d(acc->md()); + std::memset(acc->memory_storage()->data_handle(), 0, acc_d.size()); } for (int i = 0; i < n; ++i) { @@ -140,7 +144,7 @@ struct ref_sum_t : public primitive_t { } if (pd()->need_output_reorder()) { - dst_acc = {&acc, true}; + dst_acc.is_const = true; r_args[DNNL_ARG_SRC] = dst_acc; r_args[DNNL_ARG_DST] = dst; exec_ctx_t r_ctx(ctx, std::move(r_args)); @@ -155,7 +159,7 @@ struct ref_sum_t : public primitive_t { private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::vector> reorders_; - std::vector> scales_mem_; + std::vector> scales_mem_; }; } // namespace cpu diff --git a/src/cpu/reorder/cpu_reorder.cpp b/src/cpu/reorder/cpu_reorder.cpp index d9d8912d91b..1e8b0961636 100644 --- a/src/cpu/reorder/cpu_reorder.cpp +++ b/src/cpu/reorder/cpu_reorder.cpp @@ -25,6 +25,8 @@ namespace cpu { static const std::map & regular_impl_list_map() { static const std::map the_map = { + {{f32, f4_e2m1, 0}, ®ular_fp4_impl_list_map()}, + {{f32, f4_e3m0, 0}, ®ular_fp4_impl_list_map()}, {{f32, e8m0, 0}, ®ular_f32_fp8_impl_list_map()}, {{f32, f8_e5m2, 0}, ®ular_f32_fp8_impl_list_map()}, {{f32, f8_e4m3, 0}, ®ular_f32_fp8_impl_list_map()}, @@ -34,8 +36,11 @@ regular_impl_list_map() { {{f32, s32, 0}, ®ular_f32_s32_impl_list_map()}, {{f32, s8, 0}, ®ular_f32_s8_impl_list_map()}, {{f32, u8, 0}, ®ular_f32_u8_impl_list_map()}, + {{f4_e3m0, data_type::undef, 0}, ®ular_fp4_impl_list_map()}, {{f8_e5m2, data_type::undef, 0}, ®ular_fp8_impl_list_map()}, {{f8_e4m3, data_type::undef, 0}, ®ular_fp8_impl_list_map()}, + {{e8m0, data_type::undef, 0}, ®ular_fp8_impl_list_map()}, + {{f32, bin, 0}, ®ular_f32_bin_impl_list_map()}, {{bf16, data_type::undef, 0}, ®ular_bf16_impl_list_map()}, {{f16, data_type::undef, 0}, ®ular_f16_impl_list_map()}, {{s32, data_type::undef, 0}, ®ular_s32_impl_list_map()}, @@ -43,8 +48,13 @@ regular_impl_list_map() { {{u8, data_type::undef, 0}, ®ular_u8_impl_list_map()}, {{f32, s4, 0}, ®ular_s4_impl_list_map()}, {{f32, u4, 0}, ®ular_u4_impl_list_map()}, - {{s4, f32, 0}, ®ular_s4_impl_list_map()}, + {{s4, data_type::undef, 0}, ®ular_s4_impl_list_map()}, {{u4, f32, 0}, ®ular_u4_impl_list_map()}, + {{bin, data_type::undef, 0}, ®ular_bin_impl_list_map()}, + {{nf4, data_type::undef, 0}, ®ular_nf4_impl_list_map()}, + {{f4_e2m1, data_type::undef, 0}, ®ular_f4_impl_list_map()}, + {{s4, data_type::undef, 0}, ®ular_s4_impl_list_map()}, + {{u4, data_type::undef, 0}, ®ular_u4_impl_list_map()}, }; return the_map; } diff --git a/src/cpu/reorder/cpu_reorder.hpp b/src/cpu/reorder/cpu_reorder.hpp index dc0105966b1..7816e0397a4 100644 --- a/src/cpu/reorder/cpu_reorder.hpp +++ b/src/cpu/reorder/cpu_reorder.hpp @@ -33,13 +33,14 @@ #if DNNL_X64 #include "cpu/x64/jit_uni_reorder.hpp" +#include "cpu/x64/jit_uni_reorder_direct_copy.hpp" #include "cpu/x64/matmul/brgemm_matmul_reorders.hpp" #elif DNNL_AARCH64 #include "cpu/aarch64/jit_uni_reorder.hpp" #include "cpu/aarch64/matmul/brgemm_matmul_reorders.hpp" #endif -#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL +#if DNNL_AARCH64 && DNNL_USE_ACL #include "cpu/aarch64/acl_reorder.hpp" #endif @@ -74,6 +75,7 @@ using impl_list_map_t = std::map>; /* regular reorders */ +extern const impl_list_map_t ®ular_fp4_impl_list_map(); extern const impl_list_map_t ®ular_f32_fp8_impl_list_map(); extern const impl_list_map_t ®ular_f32_bf16_impl_list_map(); extern const impl_list_map_t ®ular_f32_f16_impl_list_map(); @@ -82,6 +84,7 @@ extern const impl_list_map_t ®ular_f32_s32_impl_list_map(); extern const impl_list_map_t ®ular_f32_s8_impl_list_map(); extern const impl_list_map_t ®ular_f32_u8_impl_list_map(); extern const impl_list_map_t ®ular_fp8_impl_list_map(); +extern const impl_list_map_t ®ular_f32_bin_impl_list_map(); extern const impl_list_map_t ®ular_bf16_impl_list_map(); extern const impl_list_map_t ®ular_f16_impl_list_map(); extern const impl_list_map_t ®ular_s32_impl_list_map(); @@ -89,6 +92,11 @@ extern const impl_list_map_t ®ular_s8_impl_list_map(); extern const impl_list_map_t ®ular_u8_impl_list_map(); extern const impl_list_map_t ®ular_s4_impl_list_map(); extern const impl_list_map_t ®ular_u4_impl_list_map(); +extern const impl_list_map_t ®ular_bin_impl_list_map(); +extern const impl_list_map_t ®ular_nf4_impl_list_map(); +extern const impl_list_map_t ®ular_f4_impl_list_map(); +extern const impl_list_map_t ®ular_s4_impl_list_map(); +extern const impl_list_map_t ®ular_u4_impl_list_map(); /* conv reorders w/ compensation */ extern const impl_list_map_t &comp_f32_s8_impl_list_map(); @@ -96,6 +104,8 @@ extern const impl_list_map_t &comp_bf16_s8_impl_list_map(); extern const impl_list_map_t &comp_s8_s8_impl_list_map(); // clang-format off +#define REG_SPARSE_SR(idt, ifmt, odt, ofmt, ...) \ + CPU_REORDER_INSTANCE(simple_sparse_reorder_t, idt, ifmt, odt, ofmt, __VA_ARGS__) // Some compilers do not allow guarding implementations with macros // in the impl list. @@ -115,17 +125,44 @@ extern const impl_list_map_t &comp_s8_s8_impl_list_map(); #define REG_SPARSE_SR_X64(...) #endif -#define REG_SR(idt, ifmt, odt, ofmt, ...) \ +using spec_reference = spec::reference; +using spec_direct_copy = spec::direct_copy; +using spec_direct_copy_except_dim_0 = spec::direct_copy_except_dim_0; +using spec_conv_req_comp = spec::conv_req_comp; +constexpr bool fmt_order_keep = fmt_order::keep; +constexpr bool fmt_order_reverse = fmt_order::reverse; +constexpr bool fmt_order_any = fmt_order::any; + +#if DNNL_X64 +using x64_jit_blk_reorder_t = x64::jit_blk_reorder_t; +using x64_jit_uni_reorder_t = x64::jit_uni_reorder_t; +using x64_brgemm_matmul_copy_reorder_t = x64::brgemm_matmul_copy_reorder_t; +using x64_jit_uni_reorder_direct_copy_t = x64::jit_uni_reorder_direct_copy_t; +#elif DNNL_AARCH64 +using aarch64_jit_blk_reorder_t = aarch64::jit_blk_reorder_t; +using aarch64_jit_uni_reorder_t = aarch64::jit_uni_reorder_t; +#endif + +#define CPU_REORDER_INSTANCE_IMPL(...) \ impl_list_item_t(impl_list_item_t::reorder_type_deduction_helper_t< \ - simple_reorder_t::pd_t>()), + __VA_ARGS__::pd_t>()) + +#define CPU_REORDER_INSTANCE(...) \ + DNNL_PRIMITIVE_IMPL(CPU_REORDER_INSTANCE_IMPL, __VA_ARGS__) + +#define REG_SR(idt, ifmt, odt, ofmt, ...) \ + CPU_REORDER_INSTANCE(simple_reorder_t, idt, ifmt, odt, ofmt, __VA_ARGS__) + /* impl_list_item_t(impl_list_item_t::reorder_type_deduction_helper_t< \ + simple_reorder_t::pd_t>()), + */ #define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ - REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep) \ - REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) + REG_SR(idt, ifmt, odt, ofmt, fmt_order_keep) \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order_reverse) #define REG_SR_DIRECT_COPY(idt, odt) \ - REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy) \ - REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) + REG_SR(idt, any, odt, any, fmt_order_any, spec_direct_copy) \ + REG_SR(idt, any, odt, any, fmt_order_any, spec_direct_copy_except_dim_0) // clang-format on @@ -147,10 +184,6 @@ extern const impl_list_map_t &comp_s8_s8_impl_list_map(); #define REG_FAST_DIRECT_COPY(sdt, ddt) #endif -#define CPU_REORDER_INSTANCE(...) \ - impl_list_item_t(impl_list_item_t::reorder_type_deduction_helper_t< \ - __VA_ARGS__::pd_t>()), - } // namespace cpu } // namespace impl } // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp index d34f4fe1e48..759706c737b 100644 --- a/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp @@ -26,170 +26,170 @@ const impl_list_map_t &comp_bf16_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // bf16 -> s8 {{bf16, s8, 2}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp)) - REG_SR(bf16, ab, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, ab, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, ab, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, ab, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, ba, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, ba, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, ba, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, ba, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp)) + REG_SR(bf16, ab, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, ab, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, ab, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, ab, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, ba, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, ba, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, ba, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, ba, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) nullptr, }}, // bf16 -> s8 {{bf16, s8, 3}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, wio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - REG_SR(bf16, abc, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, abc, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, abc, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, abc, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, acb, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, acb, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, acb, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(bf16, acb, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, wio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + REG_SR(bf16, abc, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, abc, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, abc, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, abc, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, acb, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, acb, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, acb, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(bf16, acb, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) nullptr, }}, {{bf16, s8, 4}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, hwio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, wigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, hwio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, wigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, {{bf16, s8, 5}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, hwigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, dhwio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, hwigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, dhwio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, {{bf16, s8, 6}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, dhwigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, dhwigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp index a296ac8d2f6..104868ac072 100644 --- a/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp @@ -27,168 +27,168 @@ const impl_list_map_t &comp_f32_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> s8 {{f32, s8, 2}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) - REG_SR(f32, ab, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, ab, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, ab, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, ab, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, ba, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, ba, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, ba, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, ba, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) + REG_SR(f32, ab, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, ab, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, ab, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, ab, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, ba, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, ba, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, ba, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, ba, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) nullptr, }}, // f32 -> s8 {{f32, s8, 3}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, wio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - REG_SR(f32, abc, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, abc, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, abc, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, abc, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, acb, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, acb, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, acb, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(f32, acb, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, wio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + REG_SR(f32, abc, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, abc, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, abc, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, abc, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, acb, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, acb, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, acb, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(f32, acb, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) nullptr, }}, {{f32, s8, 4}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, wigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, hwio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, wigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, {{f32, s8, 5}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, dhwio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, hwigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, dhwio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, {{f32, s8, 6}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, dhwigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, dhwigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp index 0b8c9322e48..4cb92ea0832 100644 --- a/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * Copyright 2023 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,170 +27,175 @@ const impl_list_map_t &comp_s8_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // s8 -> s8 {{s8, s8, 2}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) - REG_SR(s8, ab, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, ab, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, ab, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, ab, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, ba, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, ba, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, ba, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, ba, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) + REG_SR(s8, ab, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, ab, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, ab, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, ab, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, ba, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, ba, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, ba, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, ba, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) nullptr, }}, // s8 -> s8 {{s8, s8, 3}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, wio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - REG_SR(s8, abc, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, abc, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, abc, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, abc, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, acb, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, acb, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, acb, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) - REG_SR(s8, acb, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, wio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + REG_SR(s8, abc, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, abc, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, abc, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, abc, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, acb, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, acb, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, acb, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) + REG_SR(s8, acb, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) nullptr, }}, {{s8, s8, 4}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, wigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, hwio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, wigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, {{s8, s8, 5}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, dhwio, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, hwigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, dhwio, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, {{s8, s8, 6}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, dhwigo, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, dhwigo, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_pd.hpp b/src/cpu/reorder/cpu_reorder_pd.hpp index d1c8499c151..1fbac15ee32 100644 --- a/src/cpu/reorder/cpu_reorder_pd.hpp +++ b/src/cpu/reorder/cpu_reorder_pd.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,9 @@ struct cpu_reorder_pd_t : public reorder_pd_t { post_ops.len() == 1 && post_ops.entry_[0].kind == primitive_kind::sum); VDISPATCH_REORDER(args_ok, VERBOSE_UNSUPPORTED_POSTOP); + auto gpu_zp = memory_extra_flags::compensation_gpu_conv_asymmetric_src; + VDISPATCH_REORDER(!(dst_md()->extra.flags & gpu_zp), + VERBOSE_UNSUPPORTED_MD_FLAG, "extra"); return status::success; } @@ -82,15 +85,15 @@ struct cpu_reorder_pd_t : public reorder_pd_t { const float *dst_scales) const { using namespace dnnl::impl::memory_tracking::names; - int mask = -1; - bool is_set = false; - auto status = attr->scales_.get(DNNL_ARG_DST, &mask, &is_set); - if (status != status::success) return nullptr; + if (attr->scales_.has_default_values(DNNL_ARG_DST)) { + return dst_scales; + } // It's possible that mask > 0 but `count` is still `1`. This case is // covered by `DEFINE_ARG_SCALES_BUFFER` macro and no need to inverse // in such case. - if (is_set && mask > 0 && count > 1) { + int mask = attr->scales_.get_mask(DNNL_ARG_DST); + if (mask > 0 && count > 1) { auto loc_scales = scratchpad.template get( key_reorder_precomputed_dst_scales); if (!loc_scales) return nullptr; diff --git a/src/cpu/reorder/cpu_reorder_regular_bf16.cpp b/src/cpu/reorder/cpu_reorder_regular_bf16.cpp index 192d36137bb..388afbfa39c 100644 --- a/src/cpu/reorder/cpu_reorder_regular_bf16.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_bf16.cpp @@ -26,11 +26,11 @@ const impl_list_map_t ®ular_bf16_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // bf16 -> {{bf16, data_type::undef, 0}, { - CPU_REORDER_INSTANCE(rnn_weights_reorder_t) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) - - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + CPU_REORDER_INSTANCE(rnn_weights_reorder_t, bf16, bf16) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(bf16, any, f32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(bf16, any, f32, nCdhw16c)) @@ -53,14 +53,14 @@ const impl_list_map_t ®ular_bf16_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(bf16, any, u8, OIdhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(bf16, any, u8, OIdhw16i16o)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - REG_SR(bf16, any, bf16, any, fmt_order::any, spec::reference) - REG_SR(bf16, any, f32, any, fmt_order::any, spec::reference) - REG_SR(bf16, any, s8, any, fmt_order::any, spec::reference) - REG_SR(bf16, any, u8, any, fmt_order::any, spec::reference) - REG_SR(bf16, any, f8_e5m2, any, fmt_order::any, spec::reference) - REG_SR(bf16, any, f8_e4m3, any, fmt_order::any, spec::reference) + REG_SR(bf16, any, bf16, any, fmt_order_any, spec_reference) + REG_SR(bf16, any, f32, any, fmt_order_any, spec_reference) + REG_SR(bf16, any, s8, any, fmt_order_any, spec_reference) + REG_SR(bf16, any, u8, any, fmt_order_any, spec_reference) + REG_SR(bf16, any, f8_e5m2, any, fmt_order_any, spec_reference) + REG_SR(bf16, any, f8_e4m3, any, fmt_order_any, spec_reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_bin.cpp b/src/cpu/reorder/cpu_reorder_regular_bin.cpp new file mode 100644 index 00000000000..3078feb1c2b --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_bin.cpp @@ -0,0 +1,47 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t ®ular_bin_impl_list_map() { + static const impl_list_map_t the_map = REG_REORDER_P({ + // bin -> + {{bin, data_type::undef, 4}, { + REG_SR_DIRECT_COPY(bin, bin) + + REG_SR(bin, any, bin, OIhw8o32i, fmt_order_keep) + + REG_SR(bin, any, bin, OIhw16o32i, fmt_order_keep) + + REG_SR_BIDIR(u8, any, u8, nChw8c) + + nullptr, + }}, + }); + return the_map; +} + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_regular_f16.cpp b/src/cpu/reorder/cpu_reorder_regular_f16.cpp index 6d3bd322fef..5d6bb97ac57 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f16.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f16.cpp @@ -28,16 +28,19 @@ const impl_list_map_t ®ular_f16_impl_list_map() { {{f16, data_type::undef, 0}, { DNNL_AARCH64_ONLY(REG_SR_DIRECT_COPY(f16, f16)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - - REG_SR(f16, any, f8_e5m2, any, fmt_order::any, spec::reference) - REG_SR(f16, any, f8_e4m3, any, fmt_order::any, spec::reference) - REG_SR(f16, any, f16, any, fmt_order::any, spec::reference) - REG_SR(f16, any, f32, any, fmt_order::any, spec::reference) - REG_SR(f16, any, s8, any, fmt_order::any, spec::reference) - REG_SR(f16, any, u8, any, fmt_order::any, spec::reference) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + REG_SR(f16, any, f8_e5m2, any, fmt_order_any, spec_reference) + REG_SR(f16, any, f8_e4m3, any, fmt_order_any, spec_reference) + REG_SR(f16, any, f16, any, fmt_order_any, spec_reference) + REG_SR(f16, any, f32, any, fmt_order_any, spec_reference) + REG_SR(f16, any, s8, any, fmt_order_any, spec_reference) + REG_SR(f16, any, u8, any, fmt_order_any, spec_reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp index 213f44723f7..9b6d5cd4f2d 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2024 Intel Corporation * Copyright 2023 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,15 +27,16 @@ const impl_list_map_t ®ular_f32_bf16_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> bf16 {{f32, bf16, 0}, { - CPU_REORDER_INSTANCE(rnn_weights_reorder_t) + CPU_REORDER_INSTANCE(rnn_weights_reorder_t, f32, bf16) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, bf16, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, bf16, nCdhw16c)) - DNNL_AARCH64_ACL_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) + DNNL_AARCH64_ONLY(DNNL_ACL_ONLY(CPU_REORDER_INSTANCE(acl::acl_reorder_fwd_t))) DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR(f32, oihw, bf16, OIhw8i16o2i, fmt_order::keep)) @@ -47,7 +48,7 @@ const impl_list_map_t ®ular_f32_bf16_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR(f32, oihw, bf16, OIhw16i16o, fmt_order::keep)) DNNL_NON_X64_ONLY(REG_SR(f32, goihw, bf16, gOIhw16i16o, fmt_order::keep)) - REG_SR(f32, any, bf16, any, fmt_order::any, spec::reference) + REG_SR(f32, any, bf16, any, fmt_order_any, spec_reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_bin.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_bin.cpp new file mode 100644 index 00000000000..f050b6a648e --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_f32_bin.cpp @@ -0,0 +1,42 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t ®ular_f32_bin_impl_list_map() { + static const impl_list_map_t the_map = REG_REORDER_P({ + // bin -> + {{f32, bin, 4}, { + REG_SR_BIDIR(f32, nchw, bin, nhwc) + REG_SR_BIDIR(f32, nhwc, bin, nhwc) + + nullptr, + }}, + }); + return the_map; +} + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f16.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f16.cpp index a7b1b006549..d4da37cc42d 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_f16.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_f16.cpp @@ -1,5 +1,6 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2024 Intel Corporation +* Copyright 2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,11 +27,13 @@ const impl_list_map_t ®ular_f32_f16_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> f16 {{f32, f16, 0}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - REG_SR(f32, any, f16, any, fmt_order::any, spec::reference) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + REG_SR(f32, any, f16, any, fmt_order::any, spec::reference) nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp index a01fa058785..b2f86ff709d 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2020-2024 Intel Corporation -* Copyright 2022 FUJITSU LIMITED +* Copyright 2022-2024 FUJITSU LIMITED * Copyright 2023 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,30 +28,34 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> f32 {{f32, f32, 0}, { - REG_FAST_DIRECT_COPY_F32_F32 - - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ACL_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) + DNNL_AARCH64_ONLY(DNNL_ACL_ONLY(CPU_REORDER_INSTANCE(acl::acl_reorder_fwd_t))) DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + + REG_FAST_DIRECT_COPY_F32_F32 + + REG_SR(f32, any, f32, any, fmt_order_any, spec::reference) nullptr, }}, {{f32, f32, 3}, { - REG_FAST_DIRECT_COPY_F32_F32 - - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + + REG_FAST_DIRECT_COPY_F32_F32 + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw8c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCw4c)) @@ -66,23 +70,26 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIw16i16o)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, IOw8o8i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, IOw16o16i)) - REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) + REG_SR(f32, any, f32, any, fmt_order_any, spec_reference) nullptr, }}, {{f32, f32, 4}, { - CPU_REORDER_INSTANCE(rnn_weights_reorder_t) + CPU_REORDER_INSTANCE(rnn_weights_reorder_t, f32, f32) - REG_FAST_DIRECT_COPY_F32_F32 + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(DNNL_ACL_ONLY(CPU_REORDER_INSTANCE(acl::acl_reorder_fwd_t))) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_AARCH64_ACL_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + REG_FAST_DIRECT_COPY_F32_F32 DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nChw8c)) @@ -98,6 +105,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIw16i16o)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gIOw8o8i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gIOw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIhw4i4o)) @@ -113,24 +121,27 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, Ohwi16o)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIhw16i16o)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, IOhw8o8i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, IOhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i)) - REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) + REG_SR(f32, any, f32, any, fmt_order_any, spec_reference) nullptr, }}, {{f32, f32, 5}, { - CPU_REORDER_INSTANCE(rnn_weights_reorder_t) + CPU_REORDER_INSTANCE(rnn_weights_reorder_t, f32, f32) - REG_FAST_DIRECT_COPY_F32_F32 + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + REG_FAST_DIRECT_COPY_F32_F32 DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCdhw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, nCdhw8c)) @@ -151,6 +162,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOhwi16o)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIhw16i16o)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gIOhw8o8i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gIOhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIdhw4i4o)) @@ -164,23 +176,24 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, Odhwi16o)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIdhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIdhw16i16o)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, IOdhw8o8i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, IOdhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i)) - REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) + REG_SR(f32, any, f32, any, fmt_order_any, spec_reference) nullptr, }}, {{f32, f32, 6}, { - REG_FAST_DIRECT_COPY_F32_F32 + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + REG_FAST_DIRECT_COPY_F32_F32 DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIdhw4o4i)) @@ -193,9 +206,10 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOdhwi16o)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gIOdhw8o8i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gIOdhw16o16i)) - REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) + REG_SR(f32, any, f32, any, fmt_order_any, spec_reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_fp8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_fp8.cpp index e313c77fb1e..dd642125d53 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_fp8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_fp8.cpp @@ -31,6 +31,7 @@ const impl_list_map_t ®ular_f32_fp8_impl_list_map() { }}, // f32 -> f8_e5m2 {{f32, f8_e5m2, 0}, { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) @@ -40,6 +41,7 @@ const impl_list_map_t ®ular_f32_fp8_impl_list_map() { }}, // f32 -> f8_e4m3 {{f32, f8_e4m3, 0}, { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp index b1881df80e0..7961f8f361b 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2024 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,15 +27,18 @@ const impl_list_map_t ®ular_f32_s32_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> s32 {{f32, s32, 0}, { - REG_FAST_DIRECT_COPY(f32, s32) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + REG_FAST_DIRECT_COPY(f32, s32) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s32, nChw16c)) - REG_SR(f32, any, s32, any, fmt_order::any, spec::reference) + + REG_SR(f32, any, s32, any, fmt_order_any, spec_reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp index 7ce25752c7c..a3878c5d630 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2024 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,23 +27,28 @@ const impl_list_map_t ®ular_f32_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> s8 {{f32, s8, 0}, { - CPU_REORDER_INSTANCE(rnn_data_reorder_t) - CPU_REORDER_INSTANCE(rnn_weights_reorder_s8_t) - CPU_REORDER_INSTANCE(rnn_brgemm_weights_reorder_s8_t) + // TODO: move it down when checks for sparse md are implemented in other implementations. + DNNL_X64_ONLY(REG_SPARSE_SR(f32, oi, s8, OI16i64o4i, sparse_inputs_order::keep, sparse_spec::reference)) + DNNL_X64_ONLY(REG_SPARSE_SR(f32, format_tag::io, s8, OI16i64o4i, sparse_inputs_order::keep, sparse_spec::reference)) - REG_FAST_DIRECT_COPY(f32, s8) + CPU_REORDER_INSTANCE(rnn_data_reorder_t, f32, s8) + CPU_REORDER_INSTANCE(rnn_weights_reorder_s8_t, f32) + CPU_REORDER_INSTANCE(rnn_brgemm_weights_reorder_s8_t, f32, s8) + + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + REG_FAST_DIRECT_COPY(f32, s8) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s8, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i)) - REG_SR(f32, any, s8, any, fmt_order::any, spec::reference) + REG_SR(f32, any, s8, any, fmt_order_any, spec_reference) REG_SPARSE_SR_X64(f32, any, s8, any) diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp index d306c3abeb8..923e74a28ac 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2024 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,17 +27,20 @@ const impl_list_map_t ®ular_f32_u8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> u8 {{f32, u8, 0}, { - CPU_REORDER_INSTANCE(rnn_data_reorder_t) + CPU_REORDER_INSTANCE(rnn_data_reorder_t, f32, u8) - REG_FAST_DIRECT_COPY(f32, u8) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + REG_FAST_DIRECT_COPY(f32, u8) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, u8, nChw16c)) - REG_SR(f32, any, u8, any, fmt_order::any, spec::reference) + + REG_SR(f32, any, u8, any, fmt_order_any, spec_reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f4.cpp b/src/cpu/reorder/cpu_reorder_regular_f4.cpp new file mode 100644 index 00000000000..f42b401726c --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_f4.cpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t ®ular_f4_impl_list_map() { + static const impl_list_map_t the_map = REG_REORDER_P({ + // f4_e2m1 -> + {{f4_e2m1, data_type::undef, 0}, { + REG_SR(f4_e2m1, any, f4_e2m1, OI8i8o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI8i16o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI8i24o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI8i32o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI8i64o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI16i16o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI16i32o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI16i48o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI16i64o2i, fmt_order_keep) + nullptr, + }}, + }); + return the_map; +} + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_regular_fp4.cpp b/src/cpu/reorder/cpu_reorder_regular_fp4.cpp new file mode 100644 index 00000000000..49e3a0ae604 --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_fp4.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t ®ular_fp4_impl_list_map() { + static const impl_list_map_t the_map = REG_REORDER_P({ + {{f32, f4_e2m1, 0}, { + REG_SR(f32, any, f4_e2m1, any, fmt_order::any, spec::reference) + nullptr, + }}, + {{f4_e2m1, data_type::undef, 0}, { + REG_SR(f4_e2m1, any, f32, any, fmt_order::any, spec::reference) + nullptr, + }}, + {{f32, f4_e3m0, 0}, { + REG_SR(f32, any, f4_e3m0, any, fmt_order::any, spec::reference) + nullptr, + }}, + {{f4_e3m0, data_type::undef, 0}, { + REG_SR(f4_e3m0, any, f32, any, fmt_order::any, spec::reference) + nullptr, + }}, + }); + return the_map; +} + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_regular_fp8.cpp b/src/cpu/reorder/cpu_reorder_regular_fp8.cpp index 81ef168d728..bd08fda826d 100644 --- a/src/cpu/reorder/cpu_reorder_regular_fp8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_fp8.cpp @@ -26,6 +26,7 @@ const impl_list_map_t ®ular_fp8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f8_e5m2 -> {{f8_e5m2, data_type::undef, 0}, { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) @@ -38,6 +39,7 @@ const impl_list_map_t ®ular_fp8_impl_list_map() { }}, // f8_e4m3 -> {{f8_e4m3, data_type::undef, 0}, { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) @@ -46,6 +48,12 @@ const impl_list_map_t ®ular_fp8_impl_list_map() { REG_SR(f8_e4m3, any, bf16, any, fmt_order::any, spec::reference) REG_SR(f8_e4m3, any, f32, any, fmt_order::any, spec::reference) + nullptr, + }}, + // f8_e8m0 -> + {{e8m0, data_type::undef, 0}, { + REG_SR(e8m0, any, e8m0, any, fmt_order::any, spec::reference) + nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_regular_nf4.cpp b/src/cpu/reorder/cpu_reorder_regular_nf4.cpp new file mode 100644 index 00000000000..67d7f3bb96c --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_nf4.cpp @@ -0,0 +1,49 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t ®ular_nf4_impl_list_map() { + static const impl_list_map_t the_map = REG_REORDER_P({ + // nf4 -> + {{nf4, data_type::undef, 0}, { + REG_SR(nf4, any, nf4, OI8i8o2i, fmt_order_keep) + REG_SR(nf4, any, nf4, OI8i16o2i, fmt_order_keep) + REG_SR(nf4, any, nf4, OI8i24o2i, fmt_order_keep) + REG_SR(nf4, any, nf4, OI8i32o2i, fmt_order_keep) + REG_SR(nf4, any, nf4, OI8i64o2i, fmt_order_keep) + REG_SR(nf4, any, nf4, OI16i16o2i, fmt_order_keep) + REG_SR(nf4, any, nf4, OI16i32o2i, fmt_order_keep) + REG_SR(nf4, any, nf4, OI16i48o2i, fmt_order_keep) + REG_SR(nf4, any, nf4, OI16i64o2i, fmt_order_keep) + REG_SR(nf4, any, f32, any, fmt_order_keep, spec::reference) + nullptr, + }}, + }); + return the_map; +} + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_regular_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_s32.cpp index a8197402b0a..30cd1392b37 100644 --- a/src/cpu/reorder/cpu_reorder_regular_s32.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_s32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2024 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,26 +27,27 @@ const impl_list_map_t ®ular_s32_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // s32 -> {{s32, data_type::undef, 0}, { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + REG_FAST_DIRECT_COPY(s32, f32) REG_FAST_DIRECT_COPY(s32, s32) REG_FAST_DIRECT_COPY(s32, s8) REG_FAST_DIRECT_COPY(s32, u8) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, f32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, s32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, s8, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, u8, nChw16c)) - REG_SR(s32, any, f32, any, fmt_order::any, spec::reference) - REG_SR(s32, any, s32, any, fmt_order::any, spec::reference) - REG_SR(s32, any, s8, any, fmt_order::any, spec::reference) - REG_SR(s32, any, u8, any, fmt_order::any, spec::reference) + REG_SR(s32, any, f32, any, fmt_order_any, spec_reference) + REG_SR(s32, any, s32, any, fmt_order_any, spec_reference) + REG_SR(s32, any, s8, any, fmt_order_any, spec_reference) + REG_SR(s32, any, u8, any, fmt_order_any, spec_reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_s4.cpp b/src/cpu/reorder/cpu_reorder_regular_s4.cpp index 17bfdba758e..901a683df6c 100644 --- a/src/cpu/reorder/cpu_reorder_regular_s4.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_s4.cpp @@ -28,9 +28,26 @@ const impl_list_map_t ®ular_s4_impl_list_map() { REG_SR(f32, any, s4, any, fmt_order::any, spec::reference) nullptr, }}, + {{s4, data_type::undef, 0}, { + REG_SR(s4, any, s4, OI8i8o2i, fmt_order_keep) + REG_SR(s4, any, s4, OI8i16o2i, fmt_order_keep) + REG_SR(s4, any, s4, OI8i24o2i, fmt_order_keep) + REG_SR(s4, any, s4, OI8i32o2i, fmt_order_keep) + REG_SR(s4, any, s4, OI8i64o2i, fmt_order_keep) + REG_SR(s4, any, s4, OI16i16o2i, fmt_order_keep) + REG_SR(s4, any, s4, OI16i32o2i, fmt_order_keep) + REG_SR(s4, any, s4, OI16i48o2i, fmt_order_keep) + REG_SR(s4, any, s4, OI16i64o2i, fmt_order_keep) + REG_SR(s4, any, u8, any, fmt_order_keep, spec::reference) + REG_SR(s4, any, f32, any, fmt_order_keep, spec::reference) + REG_SR(s4, any, f32, any, fmt_order::any, spec::reference) + REG_SR(s4, any, bf16, any, fmt_order::any, spec::reference) + REG_SR(s4, any, f16, any, fmt_order::any, spec::reference) + nullptr, + }}, {{s4, f32, 0}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) REG_SR(s4, any, f32, any, fmt_order::any, spec::reference) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_regular_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_s8.cpp index c7199e1c41f..a346f1f50b5 100644 --- a/src/cpu/reorder/cpu_reorder_regular_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_s8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2024 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,9 +28,13 @@ const impl_list_map_t ®ular_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // s8 -> {{s8, data_type::undef, 0}, { - CPU_REORDER_INSTANCE(rnn_weights_reorder_s8_t) - CPU_REORDER_INSTANCE(rnn_brgemm_weights_reorder_s8_t) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) + // // TODO: move it down when checks for sparse md are implemented in other implementations. + DNNL_X64_ONLY(REG_SPARSE_SR(s8, oi, s8, OI16i64o4i, sparse_inputs_order::keep, sparse_spec::reference)) + DNNL_X64_ONLY(REG_SPARSE_SR(s8, format_tag::io, s8, OI16i64o4i, sparse_inputs_order::keep, sparse_spec::reference)) + + CPU_REORDER_INSTANCE(rnn_weights_reorder_s8_t,s8) + CPU_REORDER_INSTANCE(rnn_brgemm_weights_reorder_s8_t,s8, s8) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) REG_FAST_DIRECT_COPY(s8, f32) REG_FAST_DIRECT_COPY(s8, s32) @@ -39,12 +43,12 @@ const impl_list_map_t ®ular_s8_impl_list_map() { REG_FAST_DIRECT_COPY(s8, s8) REG_FAST_DIRECT_COPY(s8, u8) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, f32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, s32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, bf16, nChw16c)) @@ -58,12 +62,12 @@ const impl_list_map_t ®ular_s8_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, bf16, gOIhw4i16o4i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i)) - REG_SR(s8, any, f32, any, fmt_order::any, spec::reference) - REG_SR(s8, any, s32, any, fmt_order::any, spec::reference) - REG_SR(s8, any, bf16, any, fmt_order::any, spec::reference) - REG_SR(s8, any, f16, any, fmt_order::any, spec::reference) - REG_SR(s8, any, s8, any, fmt_order::any, spec::reference) - REG_SR(s8, any, u8, any, fmt_order::any, spec::reference) + REG_SR(s8, any, f32, any, fmt_order_any, spec_reference) + REG_SR(s8, any, s32, any, fmt_order_any, spec_reference) + REG_SR(s8, any, bf16, any, fmt_order_any, spec_reference) + REG_SR(s8, any, f16, any, fmt_order_any, spec_reference) + REG_SR(s8, any, s8, any, fmt_order_any, spec_reference) + REG_SR(s8, any, u8, any, fmt_order_any, spec_reference) REG_SPARSE_SR_X64(s8, any, s8, any) diff --git a/src/cpu/reorder/cpu_reorder_regular_u4.cpp b/src/cpu/reorder/cpu_reorder_regular_u4.cpp index 60a85da4a30..3cb62066af7 100644 --- a/src/cpu/reorder/cpu_reorder_regular_u4.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_u4.cpp @@ -28,8 +28,29 @@ const impl_list_map_t ®ular_u4_impl_list_map() { REG_SR(f32, any, u4, any, fmt_order::any, spec::reference) nullptr, }}, + {{u4, data_type::undef, 0}, { + REG_SR(u4, any, u4, OI8i8o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI8i16o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI8i24o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI8i32o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI8i64o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI16i16o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI16i32o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI16i48o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI16i64o2i, fmt_order_keep) + REG_SR(u4, any, u4, OI16i16o4i, fmt_order_keep) + REG_SR(u4, any, u4, OI16i32o4i, fmt_order_keep) + REG_SR(u4, any, u4, OI16i48o4i, fmt_order_keep) + REG_SR(u4, any, u4, OI16i64o4i, fmt_order_keep) + REG_SR(u4, any, u8, any, fmt_order_keep, spec::reference) + REG_SR(u4, any, f32, any, fmt_order_keep, spec::reference) + REG_SR(u4, any, f32, any, fmt_order::any, spec::reference) + REG_SR(u4, any, bf16, any, fmt_order::any, spec::reference) + REG_SR(u4, any, f16, any, fmt_order::any, spec::reference) + nullptr, + }}, {{u4, f32, 0}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) REG_SR(u4, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_u8.cpp index c96343e19b8..97c5c135420 100644 --- a/src/cpu/reorder/cpu_reorder_regular_u8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_u8.cpp @@ -27,7 +27,14 @@ const impl_list_map_t ®ular_u8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // u8 -> {{u8, data_type::undef, 0}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) REG_FAST_DIRECT_COPY(u8, f32) REG_FAST_DIRECT_COPY(u8, s32) @@ -35,23 +42,17 @@ const impl_list_map_t ®ular_u8_impl_list_map() { REG_FAST_DIRECT_COPY(u8, s8) REG_FAST_DIRECT_COPY(u8, u8) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, f32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, s32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, bf16, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, s8, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, u8, nChw16c)) - REG_SR(u8, any, f32, any, fmt_order::any, spec::reference) - REG_SR(u8, any, s32, any, fmt_order::any, spec::reference) - REG_SR(u8, any, bf16, any, fmt_order::any, spec::reference) - REG_SR(u8, any, u8, any, fmt_order::any, spec::reference) - REG_SR(u8, any, s8, any, fmt_order::any, spec::reference) + REG_SR(u8, any, f32, any, fmt_order_any, spec_reference) + REG_SR(u8, any, s32, any, fmt_order_any, spec_reference) + REG_SR(u8, any, bf16, any, fmt_order_any, spec_reference) + REG_SR(u8, any, u8, any, fmt_order_any, spec_reference) + REG_SR(u8, any, s8, any, fmt_order_any, spec_reference) nullptr, }}, diff --git a/src/cpu/reorder/simple_reorder.hpp b/src/cpu/reorder/simple_reorder.hpp index 63aee106f49..115c4419db9 100644 --- a/src/cpu/reorder/simple_reorder.hpp +++ b/src/cpu/reorder/simple_reorder.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,13 +43,13 @@ using bd = block_dim_t; using ib = inner_blk_t; template -using data_t = typename prec_traits::type; +using data_t = typename prec_traits_t::type; template -using _qz_a1b0 = q10n::qz_a1b0, data_t>; +using _qz_a1b0 = q10n::qz_a1b0_t, data_t>; template -using _qz = q10n::qz, data_t>; +using _qz = q10n::qz_t, data_t>; namespace fmt_order { const bool keep = true; @@ -79,6 +79,9 @@ struct conv_req_comp {}; // {s8, u8: asymmetric quantization} const auto output_d = ctx.memory_mdw(DNNL_ARG_TO, pd->dst_md()); \ DEFINE_ARG_SCALES_BUFFER_ATTR(pd->attr(), src_scales, DNNL_ARG_FROM); \ DEFINE_ARG_SCALES_BUFFER_ATTR(pd->attr(), dst_scales_, DNNL_ARG_TO); \ + const auto src_scales_d \ + = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | DNNL_ARG_FROM); \ + MAYBE_UNUSED(src_scales_d); \ int src_scales_mask, dst_scales_mask; \ CHECK(get_scales_mask(pd->attr(), &src_scales_mask, &dst_scales_mask)); \ int scales_mask = std::max(src_scales_mask, dst_scales_mask); \ @@ -88,7 +91,12 @@ struct conv_req_comp {}; // {s8, u8: asymmetric quantization} const float *dst_scales = pd->precompute_scales( \ scratchpad, pd->attr(), D_mask, dst_scales_); \ MAYBE_UNUSED(dst_scales); \ - DEFINE_ZERO_POINT_VALUE_ATTR(pd->attr(), src_zp, DNNL_ARG_FROM); \ + DEFINE_ZERO_POINTS_BUFFER_ATTR(pd->attr(), src_zero_points, DNNL_ARG_FROM) \ + const auto src_zps_d \ + = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_FROM); \ + MAYBE_UNUSED(src_zps_d); \ + int src_zp = src_zero_points ? src_zero_points[0] : 0; \ + MAYBE_UNUSED(src_zp); \ DEFINE_ZERO_POINT_VALUE_ATTR(pd->attr(), dst_zp, DNNL_ARG_TO); \ const float alpha = src_scales[0] * dst_scales[0]; \ MAYBE_UNUSED(alpha); \ @@ -125,12 +133,12 @@ inline status_t get_scales_mask( return status::invalid_arguments; *src_mask = 0; - if (!s.get(DNNL_ARG_SRC).has_default_values()) - *src_mask = s.get(DNNL_ARG_SRC).mask_; + if (!s.has_default_values(DNNL_ARG_SRC)) + *src_mask = s.get_mask(DNNL_ARG_SRC); *dst_mask = 0; - if (!s.get(DNNL_ARG_DST).has_default_values()) - *dst_mask = s.get(DNNL_ARG_DST).mask_; + if (!s.has_default_values(DNNL_ARG_DST)) + *dst_mask = s.get_mask(DNNL_ARG_DST); // This is used in a check function. if (*src_mask > 0 && *dst_mask > 0 && *dst_mask != *src_mask) @@ -140,15 +148,68 @@ inline status_t get_scales_mask( inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support, bool sum_support) { using smask_t = primitive_attr_t::skip_mask_t; - smask_t skip_mask = smask_t::scales_runtime; + smask_t skip_mask = smask_t::scales; if (sum_support) skip_mask = skip_mask | smask_t::post_ops; if (!attr->has_default_values(skip_mask)) return false; + for (int arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { + // Data type for scales is not generally supported. + if (!attr->scales_.has_default_data_type(arg)) return false; + // Groups are generally not supported. + if (!attr->scales_.get(arg).has_default_groups()) return false; + } if (many_scales_support) return true; int src_mask, dst_mask; if (get_scales_mask(attr, &src_mask, &dst_mask) != status::success) return false; return src_mask == 0 && dst_mask == 0; } + +// TODO: once re-factor for quantization happens, for each entry maintain a md +// in complaince with correspondent argument for easier offset computation. +inline status_t get_quant_md(memory_desc_t &md, const int ndims, + const dims_t in_dims, const int quant_mask, const dim_t g0, + const dim_t g1, const data_type_t dt) { + dims_t quant_dims {}; + // TODO: incorporate groups into `utils::copy_dims_with_mask` to simplify + // the logic. + utils::copy_dims_with_mask(quant_dims, in_dims, ndims, quant_mask, + /* fill_with_ones = */ true); + if (ndims >= 2) { + if (utils::one_of(0, g0, g1)) return status::runtime_error; + quant_dims[ndims - 1] /= g1; + quant_dims[ndims - 2] /= g0; + } + + CHECK(memory_desc_init_by_tag( + md, ndims, quant_dims, dt, get_abx_tag(ndims))); + return status::success; +} + +// Returns an offset of a quantization entry based on logical offset dimensions +// of the correspondent input - `input_idx`, `quant_mask`, groups `g0` and +// `g1` when they are supported (otherwise, pass `1`), and `quant_dims`. +// +// Offset is always concide with logical index because quantization entries +// don't have a notion of physical formats. +inline dim_t get_quant_off(const dims_t &input_idx, const int ndims, + const int quant_mask, const dim_t g0, const dim_t g1, + const memory_desc_t &quant_md) { + dims_t quant_idx {}; + utils::array_copy(quant_idx, input_idx, ndims); + utils::apply_mask_on_dims(quant_idx, ndims, quant_mask); + // Note: an `idx` must divide by a group value as grouped quantization + // applies to consecutive points. + // Using quant dimensions in `l_dims_by_l_offset` will lead to wrapping + // around dimensions instead of applying consecutively. + if (ndims >= 2) { + quant_idx[ndims - 1] /= g1; + quant_idx[ndims - 2] /= g0; + } + + const memory_desc_wrapper q_mdw(quant_md); + return q_mdw.off_v(quant_idx); +} + } // namespace /* specific reorders: implementation */ @@ -160,16 +221,16 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { using namespace data_type; using namespace utils; - if (input_d.has_runtime_dims_or_strides()) return false; + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); int src_scales_mask, dst_scales_mask; - auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); - if (status != status::success) return false; + CHECK(get_scales_mask(attr, &src_scales_mask, &dst_scales_mask)); int scales_mask = std::max(src_scales_mask, dst_scales_mask); static constexpr bool w_groups = one_of( @@ -184,16 +245,32 @@ struct simple_reorder_impl, data_t>()( + o = q10n::qz_b0_t, data_t>()( i, s * adj_scale * d); if (req_comp) cp[g * OC + oc] -= (int32_t)o; if (has_asymmetric_comp) zp[g * OC + oc] -= (int32_t)o; @@ -320,23 +397,23 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { using namespace format_tag; using namespace data_type; using namespace utils; - if (input_d.has_runtime_dims_or_strides()) return false; + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); int src_scales_mask, dst_scales_mask; - auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); - if (status != status::success) return false; + CHECK(get_scales_mask(attr, &src_scales_mask, &dst_scales_mask)); int scales_mask = std::max(src_scales_mask, dst_scales_mask); - const bool w_groups = !one_of(tag_o, OIw4i16o4i, OIw2i8o4i, OIw4o4i, - OIhw4i16o4i, OIhw2i8o4i, OIhw4o4i, OIdhw4i16o4i, OIdhw2i8o4i, - OIdhw4o4i, OI4i16o4i, OI4i32o4i, OI4i64o4i, OIw4i32o4i, - OIw4i64o4i, OIhw4i32o4i, OIhw4i64o4i, OIdhw4i32o4i, + static constexpr bool w_groups = !one_of(tag_o, OIw4i16o4i, OIw2i8o4i, + OIw4o4i, OIhw4i16o4i, OIhw2i8o4i, OIhw4o4i, OIdhw4i16o4i, + OIdhw2i8o4i, OIdhw4o4i, OI4i16o4i, OI4i32o4i, OI4i64o4i, + OIw4i32o4i, OIw4i64o4i, OIhw4i32o4i, OIhw4i64o4i, OIdhw4i32o4i, OIdhw4i64o4i); const bool req_comp = output_d.extra().flags @@ -348,16 +425,31 @@ struct simple_reorder_impl::inner_blks, - ib::_4a4b, ib::_4b4c) + constexpr dim_t icblksize + = utils::one_of( + tag_traits_t::inner_blks, ib::_4a4b, ib::_4b4c) ? 4 - : utils::one_of(tag_traits::inner_blks, ib::_2c8b4c, + : utils::one_of(tag_traits_t::inner_blks, ib::_2c8b4c, ib::_2b8a4b) ? 8 : 16; constexpr dim_t ocblksize - = tag_traits::inner_blks == ib::_4b32a4b ? 32 - : tag_traits::inner_blks == ib::_4b64a4b ? 64 - : icblksize; + = tag_traits_t::inner_blks == ib::_4b32a4b ? 32 + : tag_traits_t::inner_blks == ib::_4b64a4b ? 64 + : icblksize; const auto &plain_d = order_keep ? input_d : output_d; const auto &dims = input_d.dims(); @@ -444,7 +537,7 @@ struct simple_reorder_impl::inner_blks> +#define index AB_or_BC_blk_off::inner_blks> for_(dim_t ic = 0; ic < ic_block; ++ic) for (dim_t oc = 0; oc < oc_block; ++oc) { const auto plain_off @@ -454,7 +547,7 @@ struct simple_reorder_impl, data_t>()( + = q10n::qz_b0_t, data_t>()( inp[plain_off], src_scale * adj_scale * dst_scale); if (req_comp) c[oc] -= (128 * (int32_t)(out[index(oc, ic)])); @@ -536,37 +629,45 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { using namespace format_tag; using namespace data_type; using namespace utils; - if (input_d.has_runtime_dims_or_strides()) return false; + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); - const bool w_groups = !one_of(tag_o, Owi16o, Owhi16o); + static constexpr bool w_groups = !one_of(tag_o, Owi16o, Owhi16o); // Current formats are only used in jit kernels that natively // support s8 instructions, hence, there is no need for signed // compensation. const bool req_comp = output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8; - const bool req_asymmetric_comp = output_d.extra().flags & memory_extra_flags::compensation_conv_asymmetric_src; auto mask_ok = [&](bool check, int mask) { - const int c_mask = 0x1, - g_mask = 0x3; // mask for i/o-channel and ngroups - return IMPLICATION(check, mask == (w_groups ? g_mask : c_mask)); + return IMPLICATION(check, mask == (w_groups ? 0x3 : 0x1)); }; - return simple_attr_check(attr, true, false) - && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) - && mask_ok(req_asymmetric_comp, - output_d.extra().asymm_compensation_mask) - && one_of(input_d.data_type(), f32, s8, bf16) - && output_d.data_type() == s8 && !req_comp; + VDISPATCH_REORDER_IC(one_of(input_d.data_type(), f32, s8, bf16), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_REORDER_IC( + output_d.data_type() == s8, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_REORDER_IC( + simple_attr_check(attr, true, false), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC(input_d.matches_tag(tag_i), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC(output_d.matches_tag(tag_o), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); + VDISPATCH_REORDER_IC(!req_comp, "compensation is not supported"); + VDISPATCH_REORDER_IC(mask_ok(req_asymmetric_comp, + output_d.extra().asymm_compensation_mask), + "zero-points compensation configuration is not supported"); + + return status::success; } GET_SCRATCHPAD_SIZE_ZERO(); @@ -609,7 +710,7 @@ struct simple_reorder_impl, data_t>()( + out[oc] = q10n::qz_b0_t, data_t>()( inp[plain_off], s[oc] * adj_scale * d[oc]); if (has_asymmetric_comp) zp[oc] -= (int32_t)(out[oc]); } @@ -689,45 +790,55 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { using namespace format_tag; using namespace data_type; using namespace utils; - if (input_d.has_runtime_dims_or_strides()) return false; + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); int src_scales_mask, dst_scales_mask; - auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); - if (status != status::success) return false; + CHECK(get_scales_mask(attr, &src_scales_mask, &dst_scales_mask)); int scales_mask = std::max(src_scales_mask, dst_scales_mask); - const bool w_groups = !one_of(tag_o, OwI16o4i, OIw16i16o4i, OhwI16o4i, - OIhw16i16o4i, OdhwI16o4i, OIdhw16i16o4i); + static constexpr bool w_groups = !one_of(tag_o, OwI16o4i, OIw16i16o4i, + OhwI16o4i, OIhw16i16o4i, OdhwI16o4i, OIdhw16i16o4i); // Current formats are only used in jit kernels that natively // support s8 instructions, hence, there is no need for signed // compensation. const bool req_comp = output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8; - const bool req_asymmetric_comp = output_d.extra().flags & memory_extra_flags::compensation_conv_asymmetric_src; auto mask_ok = [&](bool check, int mask) { - const int c_mask = 0x1, - g_mask = 0x3; // mask for o-channel and ngroups - return IMPLICATION(check, mask == (w_groups ? g_mask : c_mask)); + return IMPLICATION(check, mask == (w_groups ? 0x3 : 0x1)); }; - return simple_attr_check(attr, true, false) - && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) - && mask_ok(req_asymmetric_comp, - output_d.extra().asymm_compensation_mask) - && one_of(input_d.data_type(), f32, s8, bf16) - && IMPLICATION(!w_groups, one_of(scales_mask, 0, 0x1)) - && IMPLICATION(w_groups, one_of(scales_mask, 0, 0x3)) - && output_d.data_type() == s8 && !req_comp; + VDISPATCH_REORDER_IC(one_of(input_d.data_type(), f32, s8, bf16), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_REORDER_IC( + output_d.data_type() == s8, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_REORDER_IC( + simple_attr_check(attr, true, false), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC(input_d.matches_tag(tag_i), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC(output_d.matches_tag(tag_o), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); + VDISPATCH_REORDER_IC(!req_comp, "compensation is not supported"); + VDISPATCH_REORDER_IC(mask_ok(req_asymmetric_comp, + output_d.extra().asymm_compensation_mask), + "zero-points compensation configuration is not supported"); + VDISPATCH_REORDER_IC( + IMPLICATION(!w_groups, one_of(scales_mask, 0, 0x1)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_REORDER_IC(IMPLICATION(w_groups, one_of(scales_mask, 0, 0x3)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + return status::success; } GET_SCRATCHPAD_SIZE_ZERO(); @@ -746,11 +857,11 @@ struct simple_reorder_impl::inner_blks, ib::_16b16a4b, + = utils::one_of(tag_traits_t::inner_blks, ib::_16b16a4b, ib::_16c16b4c) ? 64 - : utils::one_of( - tag_traits::inner_blks, ib::_16a4b, ib::_16b4c) + : utils::one_of(tag_traits_t::inner_blks, ib::_16a4b, + ib::_16b4c) ? 4 : 1; assert(ic_blksize != 1); @@ -791,9 +902,9 @@ struct simple_reorder_impl::inner_blks>( + auto index = AB_or_BC_blk_off::inner_blks>( oc, ic); - out[index] = q10n::qz_b0, data_t>()( + out[index] = q10n::qz_b0_t, data_t>()( inp[plain_off], s[oc] * adj_scale * d[oc]); if (has_asymmetric_comp) zp[oc] -= (int32_t)(out[index]); @@ -857,13 +968,20 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { using namespace format_tag; using namespace data_type; using namespace utils; - if (input_d.has_runtime_dims_or_strides()) return false; + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + int src_scales_mask, dst_scales_mask; + CHECK(get_scales_mask(attr, &src_scales_mask, &dst_scales_mask)); + int scales_mask = std::max(src_scales_mask, dst_scales_mask); + const size_t D_mask + = array_product(input_d.dims(), math::ilog2q(scales_mask + 1)); const bool req_comp = output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8; @@ -876,21 +994,26 @@ struct simple_reorder_impl a, d0 <-> b, d1 <-> c constexpr dim_t D0_blksize = 64; constexpr dim_t D1_blksize - = (utils::one_of(tag_traits::inner_blks, ib::_16a64b4a, + = (utils::one_of(tag_traits_t::inner_blks, ib::_16a64b4a, ib::_16b64c4b)) ? 64 - : (utils::one_of(tag_traits::inner_blks, ib::_16a48b4a, + : (utils::one_of(tag_traits_t::inner_blks, ib::_16a48b4a, ib::_16b48c4b)) ? 48 - : (utils::one_of(tag_traits::inner_blks, ib::_16a32b4a, + : (utils::one_of(tag_traits_t::inner_blks, ib::_16a32b4a, ib::_16b32c4b)) ? 32 - : (utils::one_of(tag_traits::inner_blks, ib::_16a16b4a, + : (utils::one_of(tag_traits_t::inner_blks, ib::_16a16b4a, ib::_16b16c4b)) ? 16 : 1; @@ -952,10 +1075,11 @@ struct simple_reorder_impl::inner_blks>( + = AB_or_BC_blk_off::inner_blks>( d0, d1); - out[index] = q10n::qz_b0, data_t>()( - inp[plain_off], s[0] * adj_scale * d[0]); + out[index] + = q10n::qz_b0_t, data_t>()( + inp[plain_off], s[0] * adj_scale * d[0]); auto o = static_cast(out[index]); if (req_comp) cp[d1] -= (128 * o); @@ -963,18 +1087,19 @@ struct simple_reorder_impl::inner_blks>( + = AB_or_BC_blk_off::inner_blks>( d0, d1); - out[index] = q10n::qz_b0, data_t>()( - 0, s[0] * adj_scale * d[0]); + out[index] + = q10n::qz_b0_t, data_t>()( + 0, s[0] * adj_scale * d[0]); } } for_(int d0 = d0_block; d0 < D0_blksize; ++d0) for (int d1 = 0; d1 < D1_blksize; ++d1) { - auto index = AB_or_BC_blk_off::inner_blks>( + auto index = AB_or_BC_blk_off::inner_blks>( d0, d1); - out[index] = q10n::qz_b0, data_t>()( + out[index] = q10n::qz_b0_t, data_t>()( 0, s[0] * adj_scale * d[0]); } }; @@ -1041,16 +1166,16 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { using namespace data_type; using namespace utils; - if (input_d.has_runtime_dims_or_strides()) return false; + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); int src_scales_mask, dst_scales_mask; - auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); - if (status != status::success) return false; + CHECK(get_scales_mask(attr, &src_scales_mask, &dst_scales_mask)); int scales_mask = std::max(src_scales_mask, dst_scales_mask); const dim_t g = input_d.dims()[0]; @@ -1064,22 +1189,36 @@ struct simple_reorder_impl, data_t>()( + out[g] = q10n::qz_b0_t, data_t>()( inp[i_off], src_scale * adj_scale * dst_scale); } }; @@ -1220,15 +1359,26 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { using namespace data_type; - if (input_d.has_runtime_dims_or_strides()) return false; + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + VDISPATCH_REORDER_IC(order_keep, "unsupported internal impl detail"); + VDISPATCH_REORDER_IC( + input_d.data_type() == f32, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_REORDER_IC( + output_d.data_type() == bf16, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_REORDER_IC( + attr->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC(input_d.matches_tag(tag_i), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC(output_d.matches_tag(tag_o), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); - return order_keep && input_d.matches_tag(tag_i) - && output_d.matches_tag(tag_o) && input_d.data_type() == f32 - && output_d.data_type() == bf16 && attr->has_default_values(); + return status::success; } static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, @@ -1322,25 +1472,39 @@ struct simple_reorder_impl struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { using namespace data_type; - if (input_d.has_runtime_dims_or_strides()) return false; + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + VDISPATCH_REORDER_IC(order_keep, "unsupported internal impl detail"); + VDISPATCH_REORDER_IC( + input_d.data_type() == f32, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_REORDER_IC( + output_d.data_type() == bf16, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_REORDER_IC( + attr->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC(input_d.matches_tag(tag_i), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC(output_d.matches_tag(tag_o), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); - return input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) - && input_d.data_type() == f32 && output_d.data_type() == bf16 - && attr->has_default_values(); + return status::success; } static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d) { + constexpr int ndims = tag_traits_t::ndims; const size_t blksize = 16; - const size_t W = input_d.dims()[3]; + const size_t W = input_d.dims()[ndims - 1]; return sizeof(float) * blksize * W * dnnl_get_max_threads(); } @@ -1348,14 +1512,15 @@ struct simple_reorder_impl::ndims; const auto &flat_d = input_d; const auto &dims = input_d.dims(); const auto &pdims = output_d.padded_dims(); const dim_t C = dims[1]; - const dim_t H = dims[2]; - const dim_t W = dims[3]; + const dim_t H = ndims == 3 ? 1 : dims[ndims - 2]; + const dim_t W = dims[ndims - 1]; const dim_t wsp_size = W * blksize; float *wspace = scratchpad.template get( @@ -1368,7 +1533,7 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) - && simple_attr_check(attr, false, true); + using namespace data_type; + + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + VDISPATCH_REORDER_IC( + simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d), + "unsupported configuration"); + VDISPATCH_REORDER_IC( + simple_attr_check(attr, false, true), VERBOSE_UNSUPPORTED_ATTR); + + return status::success; } GET_SCRATCHPAD_SIZE_ZERO(); @@ -1423,7 +1598,7 @@ struct simple_reorder_impl::inner_blks == ib::_4b ? 4 : 8; + = tag_traits_t::inner_blks == ib::_4b ? 4 : 8; constexpr dim_t blksize_16 = 16; constexpr dim_t ic_mult = order_keep ? blksize_16 / blksize_i : 1; @@ -1519,24 +1694,28 @@ struct simple_reorder_impl struct simple_reorder_impl::block_dims == bd::_A - || tag_traits::block_dims == bd::_B) - && tag_traits::ndims >= 3 - && tag_traits::ndims <= 6>::type> { + && (tag_traits_t::block_dims == bd::_A + || tag_traits_t::block_dims == bd::_B) + && tag_traits_t::ndims >= 3 + && tag_traits_t::ndims <= 6>::type> { PLAIN_TO_BLOCKED_IS_APPLICABLE(); GET_SCRATCHPAD_SIZE_ZERO(); @@ -1549,8 +1728,8 @@ struct simple_reorder_impl::ndims; - const int blk_idx = tag_traits::block_dims == bd::_A ? 0 : 1; + const int ndims = tag_traits_t::ndims; + const int blk_idx = tag_traits_t::block_dims == bd::_A ? 0 : 1; const dim_t H0 = dims[0]; const dim_t H1 = dims[1]; @@ -1565,7 +1744,7 @@ struct simple_reorder_impl::inner_blks) { + switch (tag_traits_t::inner_blks) { case ib::_4a: case ib::_4b: blksize = 4; break; case ib::_8a: @@ -1684,14 +1863,17 @@ struct simple_reorder_impl struct simple_reorder_impl::block_dims == bd::_AB - || tag_traits::block_dims == bd::_BC) - && IMPLICATION(tag_traits::block_dims == bd::_AB, - tag_traits::ndims >= 3 - && tag_traits::ndims <= 5) - && IMPLICATION(tag_traits::block_dims == bd::_BC, - tag_traits::ndims >= 4 - && tag_traits::ndims <= 6)>::type> { + && (tag_traits_t::block_dims == bd::_AB + || tag_traits_t::block_dims == bd::_BC) + && IMPLICATION(tag_traits_t::block_dims == bd::_AB, + tag_traits_t::ndims >= 3 + && tag_traits_t::ndims <= 5) + && IMPLICATION(tag_traits_t::block_dims == bd::_BC, + tag_traits_t::ndims >= 4 + && tag_traits_t::ndims <= 6) + && (type_i != dnnl_bin && type_o != dnnl_bin) + && (type_i != dnnl_nf4 && type_o != dnnl_nf4) + && (type_i != dnnl_f4_e2m1 && type_o != dnnl_f4_e2m1)>::type> { PLAIN_TO_BLOCKED_IS_APPLICABLE(); GET_SCRATCHPAD_SIZE_ZERO(); @@ -1704,9 +1886,10 @@ struct simple_reorder_impl::ndims; + constexpr int ndims = tag_traits_t::ndims; - static constexpr bool with_g = tag_traits::block_dims == bd::_BC; + static constexpr bool with_g + = tag_traits_t::block_dims == bd::_BC; const dim_t G = with_g ? dims[0] : 1; const dim_t H0 = dims[0 + with_g]; @@ -1723,7 +1906,7 @@ struct simple_reorder_impl::inner_blks) { + switch (tag_traits_t::inner_blks) { case ib::_4b4a: case ib::_4b4c: case ib::_4c4b: @@ -1777,7 +1960,7 @@ struct simple_reorder_impl *i, data_t *o, const int block_h0, const int block_h1) { -#define blk_off AB_or_BC_blk_off::inner_blks> +#define blk_off AB_or_BC_blk_off::inner_blks> if (alpha == 1.0 && beta == 0.0) { for (int h0 = 0; h0 < block_h0; ++h0) { for (int h1 = 0; h1 < block_h1; ++h1) { @@ -1872,6 +2055,322 @@ struct simple_reorder_impl +struct simple_reorder_impl::type> +{ + static status_t is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + VDISPATCH_REORDER_IC( + simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d), + "unsupported configuration"); + VDISPATCH_REORDER_IC( + simple_attr_check(attr, false, false), VERBOSE_UNSUPPORTED_ATTR); + return status::success; + } + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + + const auto &dims = input_d.dims(); + const int C = dims[1]; + const int H = dims[2]; + const int W = dims[3]; + + int nbits = 8; + const int CB = utils::div_up(C, nbits); + + auto ker = [&](const data_t *i, data_t *o) { + for (int cb = 0; cb < CB; ++cb) { + uint8_t bin_val = 0x00; + for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) { + const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[1]; + + auto bit = uint8_t((i[flat_off] > 0) ? 0x01 : 0x00); + bin_val |= (bit << shift); + } + + o[cb] = bin_val; + } + }; + + parallel_nd(dims[0], H, W, + [&](int n, int h, int w) { + auto iidx = input_d.blk_off(n, 0, h, w); + auto oidx = output_d.blk_off(n, 0, h, w); + + auto i = &input[iidx]; + auto o = &output[oidx / nbits]; + ker(i, o); + }); + + return status::success; + } +}; + +template +struct simple_reorder_impl::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = false; + constexpr int blksize_o = tag_o == format_tag::OIhw8o32i ? 8 : 16; + constexpr int blksize_i = 32; + + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int NB_OC = pdims[w_groups + 0] / blksize_o; + const int IC = dims[w_groups + 1]; + const int NB_IC = pdims[w_groups + 1] / blksize_i; + const int H = dims[w_groups + 2]; + const int W = dims[w_groups + 3]; + + constexpr int i_mult_o = blksize_o; + constexpr int i_mult_i = blksize_i; + constexpr int nbits = 8; + + auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t { + return (uint8_t) ((val >> bit) & 0x0001); + }; + + parallel_nd(G, NB_OC, NB_IC, H, W, + [&](int g, int nb_oc, int nb_ic, int h, int w) { + const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o); + const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i); + + for (int oc = 0; oc < oc_block; ++oc) { + for (int icb = 0; icb < utils::div_up(ic_block, nbits); ++icb) { + + uint8_t bin_val = 0x00; + for (int ic = icb*nbits, shift = 0; ic < std::min(IC, (icb + 1)*nbits); ic++, shift++) { + size_t iidx = (i_mult_o * nb_oc + oc) * input_d.blocking_desc().strides[0] + + (i_mult_i * nb_ic + ic) * input_d.blocking_desc().strides[1] + + h * input_d.blocking_desc().strides[2] + + w; + + uint8_t bit = extract_bit(input[iidx / nbits], (uint8_t)(iidx % nbits)); + bin_val |= (bit << shift); + } + + size_t oidx = output_d.blk_off(g, nb_oc, nb_ic, h, w) + oc * blksize_i + icb * nbits; + output[oidx / nbits] = bin_val; + + } + } + }); + + return status::success; + } +}; + +template +struct simple_reorder_impl::block_dims == bd::_AB && + utils::one_of(type_i, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) && + type_i == type_o>::type> +{ + static status_t is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + if (!(!input_d.has_runtime_dims_or_strides() && + simple_attr_check(attr, false, true) && + (order_keep ? output_d.matches_tag(tag_o) && input_d.is_plain() + : input_d.matches_tag(tag_o) && output_d.is_plain()))) + return status::invalid_arguments; + + if (output_d.blocking_desc().inner_nblks != 3 || + !utils::one_of(output_d.blocking_desc().inner_blks[2], 2, 4) || + output_d.blocking_desc().inner_idxs[2] != 1) + return status::invalid_arguments; + + return status::success; + } + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + + int blksize_o = 1; + int blksize_i = 1; + + for (int i = 0; i < output_d.blocking_desc().inner_nblks; i++) { + if (output_d.blocking_desc().inner_idxs[i] == 0) + blksize_o *= output_d.blocking_desc().inner_blks[i]; + else + blksize_i *= output_d.blocking_desc().inner_blks[i]; + } + + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + const int OC = dims[0]; + const int NB_OC = pdims[0] / blksize_o; + const int IC = dims[1]; + const int NB_IC = pdims[1] / blksize_i; + + int i_mult_o = blksize_o; + int i_mult_i = blksize_i; + + auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 4 : 0; + + return (uint8_t) ((val >> shift) & 0x000F); + }; + + auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return dst | (uint8_t) (val << shift); + }; + + if (output_d.blocking_desc().inner_blks[2] == 4) { + parallel_nd(NB_OC, NB_IC, + [&](int nb_oc, int nb_ic) { + const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o); + const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i); + + for (int icb = 0; icb < utils::div_up(ic_block, 8); ++icb) { + for (int oc = 0; oc < oc_block; ++oc) { + const int ic_int_block = nstl::min(8, ic_block - icb * 8); + for (int ic = 0; ic < ic_int_block; ++ic) { + size_t iidx = (i_mult_o * nb_oc + oc) * input_d.blocking_desc().strides[0] + + (i_mult_i * nb_ic + icb * 8 + ic) * input_d.blocking_desc().strides[1]; + size_t oidx = output_d.blk_off(nb_oc, nb_ic) + icb * blksize_o * 8 + oc * 8 + 2 * (ic % 4) + ic / 4; + const uint8_t* packed_val = reinterpret_cast(input); + auto src_val = extract_half_byte(packed_val[iidx / 2], (uint8_t)(iidx % 2)); + uint8_t* output_val = reinterpret_cast(output); + uint8_t dst_val = oidx % 2 == 0 ? 0 : output_val[oidx / 2]; + dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(oidx % 2)); + output_val[oidx / 2] = dst_val; + } + } + } + }); + } else { + parallel_nd(NB_OC, NB_IC, + [&](int nb_oc, int nb_ic) { + const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o); + const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i); + + for (int icb = 0; icb < utils::div_up(ic_block, 2); ++icb) { + for (int oc = 0; oc < oc_block; ++oc) { + for (int ic = 0; ic < 2; ++ic) { + size_t iidx = (i_mult_o * nb_oc + oc) * input_d.blocking_desc().strides[0] + + (i_mult_i * nb_ic + icb *2 + ic) * input_d.blocking_desc().strides[1]; + size_t oidx = output_d.blk_off(nb_oc, nb_ic) + icb * blksize_o * 2 + oc * 2 + ic; + const uint8_t* packed_val = reinterpret_cast(input); + auto src_val = extract_half_byte(packed_val[iidx / 2], (uint8_t)(iidx % 2)); + uint8_t* output_val = reinterpret_cast(output); + uint8_t dst_val = ic == 1 ? output_val[oidx / 2] : 0; + dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(oidx % 2)); + output_val[oidx / 2] = dst_val; + } + } + } + }); + } + + return status::success; + } +}; + +template +struct simple_reorder_impl::type> { + static status_t is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + if (!input_d.has_runtime_dims_or_strides() + && input_d.is_dense() && output_d.is_dense() + && simple_attr_check(attr, false, true)) { + return status::success; + } + return status::invalid_arguments; + } + + GET_SCRATCHPAD_SIZE_ZERO(); + + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + DECLARE_COMMON_PARAMS(); + using namespace utils; + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const dim_t work_amount = input_d.nelems(); + + auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 4 : 0; + + return (uint8_t)((val >> shift) & 0x000F); + }; + + parallel(0, [&](const int ithr, const int nthr) { + dim_t start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + if (utils::one_of(type_i, dnnl_s4, dnnl_u4)) { + PRAGMA_OMP_SIMD() + for (dim_t idx = start; idx < end; idx++) { + const auto i_off = input_d.off_l(idx); + const auto o_off = output_d.off_l(idx); + const int8_t src_val = extract_half_byte(input[i_off / 2], i_off % 2); + output[o_off] = _qz_a1b0()(src_val); + } + } else { + static const std::array lookup = {-1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; + + PRAGMA_OMP_SIMD() + for (dim_t idx = start; idx < end; idx++) { + const auto i_off = input_d.off_l(idx); + const auto o_off = output_d.off_l(idx); + const uint8_t idx_val = extract_half_byte(input[i_off / 2], i_off % 2); + output[o_off] = lookup[idx_val]; + } + } + }); + + return status::success; + } +}; + /* generic and direct-copy reorders */ template @@ -1880,12 +2379,21 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - return !input_d.has_runtime_dims_or_strides() - && input_d.similar_to(output_d, true, false, 0) - && input_d.is_dense() && output_d.is_dense() - && simple_attr_check(attr, false, true); + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + VDISPATCH_REORDER_IC( + simple_attr_check(attr, false, true), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC( + input_d.is_dense(), VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC( + output_d.is_dense(), VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); + VDISPATCH_REORDER_IC(input_d.similar_to(output_d, true, false, 0), + VERBOSE_TENSOR_FORMAT_MISMATCH, "src", "dst"); + + return status::success; } GET_SCRATCHPAD_SIZE_ZERO(); @@ -1911,25 +2419,26 @@ struct simple_reorder_impl, data_t>()( - input[e]); + output[e] + = q10n::qz_a1b0_t, data_t>()( + input[e]); } } else if (alpha == 1.0) { PRAGMA_OMP_SIMD() for (size_t e = start; e < end; ++e) { - output[e] = q10n::qz_a1, data_t>()( + output[e] = q10n::qz_a1_t, data_t>()( input[e], output[e], beta); } } else if (beta == 0.0) { PRAGMA_OMP_SIMD() for (size_t e = start; e < end; ++e) { - output[e] = q10n::qz_b0, data_t>()( + output[e] = q10n::qz_b0_t, data_t>()( input[e], alpha); } } else { PRAGMA_OMP_SIMD() for (size_t e = start; e < end; ++e) { - output[e] = q10n::qz, data_t>()( + output[e] = q10n::qz_t, data_t>()( input[e], output[e], alpha, beta); } } @@ -1938,28 +2447,27 @@ struct simple_reorder_impl, + output[e] = q10n::qz_a1b0_t, data_t>()(input[e]); } } else if (alpha == 1.0) { PRAGMA_OMP_SIMD() for (size_t e = nelems - rem_elems; e < nelems; ++e) { - output[e] - = q10n::qz_a1, data_t>()( - input[e], output[e], beta); + output[e] = q10n::qz_a1_t, + data_t>()(input[e], output[e], beta); } } else if (beta == 0.0) { PRAGMA_OMP_SIMD() for (size_t e = nelems - rem_elems; e < nelems; ++e) { - output[e] - = q10n::qz_b0, data_t>()( - input[e], alpha); + output[e] = q10n::qz_b0_t, + data_t>()(input[e], alpha); } } else { PRAGMA_OMP_SIMD() for (size_t e = nelems - rem_elems; e < nelems; ++e) { - output[e] = q10n::qz, data_t>()( - input[e], output[e], alpha, beta); + output[e] + = q10n::qz_t, data_t>()( + input[e], output[e], alpha, beta); } } } @@ -1971,13 +2479,23 @@ struct simple_reorder_impl struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - return !input_d.has_runtime_dims_or_strides() && input_d.is_dense() - && output_d.is_dense() && simple_attr_check(attr, false, true); + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + VDISPATCH_REORDER_IC( + simple_attr_check(attr, false, true), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC( + input_d.is_dense(), VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC( + output_d.is_dense(), VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); + + return status::success; } static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, @@ -2024,19 +2542,19 @@ struct simple_reorder_impl()(wspace[i0_off]); + + const auto i1_off + = need_transform ? idx + 1 : input_d.off_l(idx + 1); + auto val1 = _qz_a1b0()(wspace[i1_off]); + const auto o_off = need_transform ? idx : output_d.off_l(idx); - const auto shift = i % 2 ? int4_extract_t::high_half - : int4_extract_t::low_half; - auto src_val - = _qz_a1b0()(wspace[i_off]); - const uint8_t dst_val = i == 0 - ? 0 - : reinterpret_cast(output)[o_off / 2]; - output[o_off / 2] = src_val.insert(dst_val, shift); + nibble2_t o_val(val0.raw_bits_, val1.raw_bits_); + reinterpret_cast(output)[o_off / 2] = o_val.get(); } }); @@ -2048,13 +2566,33 @@ template struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - return !input_d.has_runtime_dims_or_strides() && input_d.is_dense() - && output_d.is_dense() && simple_attr_check(attr, false, true); + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + VDISPATCH_REORDER_IC( + input_d.nelems() % 2 == 0, "Unsupported dimensions"); + VDISPATCH_REORDER_IC( + input_d.is_dense(), VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC( + output_d.is_dense(), VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); + + using smask_t = primitive_attr_t::skip_mask_t; + smask_t skip_mask = smask_t::scales_data_type | smask_t::scales_groups + | smask_t::zero_points_data_type | smask_t::zero_points_groups; + VDISPATCH_REORDER_IC( + attr->has_default_values(skip_mask), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC(attr->scales_.has_default_values(DNNL_ARG_DST), + VERBOSE_UNSUPPORTED_SCALES_CFG); + return status::success; } static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, @@ -2069,53 +2607,112 @@ struct simple_reorder_impl *wspace = scratchpad.template get>( memory_tracking::names::key_reorder_space); - // When formats of the input and the output are not identical, the idea - // is to reorder the data from the input format to the output format - // but within the same data type, and after the format reorder apply - // the compression into int4 as on `abx` format. + // The implementation splits data conversion and format conversion in + // two passes for cases when it's not straightforward to perform both + // at once. The second pass applicability is determined by: + // * Transformation between incompatible formats is needed, especially + // when int4 source in not dense in the last dimension... const bool need_transform = input_d.strides()[input_d.ndims() - 1] != 1; - wspace = need_transform ? wspace : output; + // * Post-processing, including advanced dequantization parameters as + // groups. + const auto &scales = pd->attr()->scales_; + const bool has_src_scales = !scales.has_default_values(DNNL_ARG_SRC); + const auto &zps = pd->attr()->zero_points_; + const bool has_src_zps = !zps.has_default_values(DNNL_ARG_SRC); + + const bool need_second_pass + = need_transform || has_src_scales || has_src_zps; + wspace = need_second_pass ? wspace : output; // To avoid clashes between threads each byte (or 2 elements) // is handled by a single thread const dim_t work_amount = input_d.nelems() / 2; parallel(0, [&](const int ithr, const int nthr) { + auto u8_input = reinterpret_cast(input); dim_t start {0}, end {0}; balance211(work_amount, nthr, ithr, start, end); PRAGMA_OMP_SIMD() - for_(dim_t j = start; j < end; j++) - for (int i = 0; i < 2; ++i) { - const auto idx = 2 * j + i; - const auto i_off = need_transform ? idx : input_d.off_l(idx); - const auto o_off = need_transform ? idx : output_d.off_l(idx); - const auto shift = i % 2 ? int4_extract_t::high_half - : int4_extract_t::low_half; - auto src_val = data_t::extract( - reinterpret_cast(input)[i_off / 2], - shift); - reinterpret_cast *>(wspace)[o_off] - = static_cast(src_val); + for (dim_t j = start; j < end; j++) { + const auto idx = 2 * j; + const auto i_off = need_second_pass ? idx : input_d.off_l(idx); + const nibble2_t in_nibble(u8_input[i_off / 2]); + + for (int i = 0; i < 2; ++i) { + const auto o_off = need_second_pass + ? idx + i + : output_d.off_l(idx + i); + data_t src_val(in_nibble.get(i)); + reinterpret_cast *>(wspace)[o_off] + = static_cast(src_val); + } } }); - if (need_transform) { - const dim_t work_amount = output_d.nelems(); - parallel(0, [&](const int ithr, const int nthr) { - dim_t start {0}, end {0}; - balance211(work_amount, nthr, ithr, start, end); - PRAGMA_OMP_SIMD() - for (dim_t idx = start; idx < end; idx++) { - const auto i_off = input_d.off_l(idx); - const auto o_off = output_d.off_l(idx); - output[o_off] = wspace[i_off]; - } - }); + if (!need_second_pass) return status::success; + + const int ndims = input_d.ndims(); + // Applied to the pre-last dimension. + const auto src_scales_group0 = scales.get_group(DNNL_ARG_SRC, 0); + // Applied to the last dimension. + const auto src_scales_group1 = scales.get_group(DNNL_ARG_SRC, 1); + + memory_desc_t src_scales_md {}; + if (has_src_scales) { + get_quant_md(src_scales_md, ndims, input_d.dims(), src_scales_mask, + src_scales_group0, src_scales_group1, + src_scales_d.data_type()); } + int src_zps_mask = zps.get_mask(DNNL_ARG_SRC); + // Applied to the pre-last dimension. + const auto src_zps_group0 = zps.get_group(DNNL_ARG_SRC, 0); + // Applied to the last dimension. + const auto src_zps_group1 = zps.get_group(DNNL_ARG_SRC, 1); + memory_desc_t src_zps_md {}; + if (has_src_zps) { + get_quant_md(src_zps_md, ndims, input_d.dims(), src_zps_mask, + src_zps_group0, src_zps_group1, src_zps_d.data_type()); + } + + parallel_nd(input_d.nelems(), [&](dim_t idx) { + // Must be per thread; when shared, race condition happens. + dims_t input_idx {}; + float src_scale = 1.f; + if (has_src_scales || has_src_zps) { + utils::l_dims_by_l_offset( + input_idx, idx, input_d.dims(), ndims); + } + if (has_src_scales) { + const dim_t src_scales_off = get_quant_off(input_idx, ndims, + src_scales_mask, src_scales_group0, src_scales_group1, + src_scales_md); + // A single scale has already been pre-processed by the + // library-managed macros. + src_scale = src_scales_d.nelems() == 1 + ? src_scales[0] + : io::load_float_value(src_scales_d.data_type(), + src_scales, src_scales_off); + } + + int src_zp_val = 0; // Avoid clashing with the one defined for rest. + if (has_src_zps) { + const dim_t src_zps_off + = get_quant_off(input_idx, ndims, src_zps_mask, + src_zps_group0, src_zps_group1, src_zps_md); + src_zp_val = io::load_float_value( + src_zps_d.data_type(), src_zero_points, src_zps_off); + } + + const auto i_off = input_d.off_l(idx); + const auto o_off = output_d.off_l(idx); + output[o_off] = src_scale * (wspace[i_off] - src_zp_val); + }); + return status::success; } }; @@ -2126,15 +2723,24 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) { return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d); }; - return !input_d.has_runtime_dims_or_strides() - && input_d.similar_to(output_d, true, false, 1) - && is_dense_no_0(input_d) && is_dense_no_0(output_d) - && simple_attr_check(attr, false, true); + VDISPATCH_REORDER_IC( + simple_attr_check(attr, false, true), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC(is_dense_no_0(input_d), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC(is_dense_no_0(output_d), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); + VDISPATCH_REORDER_IC(input_d.similar_to(output_d, true, false, 1), + VERBOSE_TENSOR_FORMAT_MISMATCH, "src", "dst"); + + return status::success; } GET_SCRATCHPAD_SIZE_ZERO(); @@ -2205,11 +2811,7 @@ struct simple_reorder_impl::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { /* supported smask: 0x0...011..10...0, * i.e. 1 should be contiguous */ @@ -2241,17 +2845,31 @@ struct simple_reorder_impl 0 && smask & 0x1; smask >>= 1) ; - if (smask != 0) return false; + VDISPATCH_REORDER_IC(smask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } + + using smask_t = primitive_attr_t::skip_mask_t; + smask_t skip_mask = smask_t::scales_data_type | smask_t::scales_groups + | smask_t::zero_points_data_type | smask_t::zero_points_groups + | smask_t::post_ops; + VDISPATCH_REORDER_IC( + attr->has_default_values(skip_mask), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_REORDER_IC(simple_po_check(attr), VERBOSE_UNSUPPORTED_POSTOP); + const auto &scales = attr->scales_; + const bool has_dst_scales = !scales.has_default_values(DNNL_ARG_DST); + if (has_dst_scales) { + VDISPATCH_REORDER_IC(scales.has_default_data_type(DNNL_ARG_DST) + && scales.has_default_groups(DNNL_ARG_DST), + VERBOSE_UNSUPPORTED_SCALES_CFG); } + VDISPATCH_REORDER_IC( + input_d.is_blocking_desc() && !input_d.is_additional_buffer(), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); + VDISPATCH_REORDER_IC( + output_d.is_blocking_desc() && !output_d.is_additional_buffer(), + VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); - using skip_mask_t = dnnl_primitive_attr::skip_mask_t; - return input_d.is_blocking_desc() && output_d.is_blocking_desc() - && !output_d.is_additional_buffer() - && !input_d.is_additional_buffer() - && attr->has_default_values(skip_mask_t::scales_runtime - | skip_mask_t::zero_points_runtime - | skip_mask_t::post_ops) - && simple_po_check(attr); + return status::success; } GET_SCRATCHPAD_SIZE_ZERO(); @@ -2264,23 +2882,83 @@ struct simple_reorder_impl()(f); - }); + const int ndims = input_d.ndims(); + const auto &scales = pd->attr()->scales_; + const bool has_src_scales = !scales.has_default_values(DNNL_ARG_SRC); + // Applied to the pre-last dimension. + const auto src_scales_group0 = scales.get_group(DNNL_ARG_SRC, 0); + // Applied to the last dimension. + const auto src_scales_group1 = scales.get_group(DNNL_ARG_SRC, 1); + memory_desc_t src_scales_md {}; + if (has_src_scales) { + get_quant_md(src_scales_md, ndims, input_d.dims(), src_scales_mask, + src_scales_group0, src_scales_group1, + src_scales_d.data_type()); + } + const bool has_dst_scales = !scales.has_default_values(DNNL_ARG_DST); + memory_desc_t dst_scales_md {}; + if (has_dst_scales) { + get_quant_md(dst_scales_md, ndims, input_d.dims(), dst_scales_mask, + 1, 1, data_type::f32); + } + + const auto &zps = pd->attr()->zero_points_; + int src_zps_mask = zps.get_mask(DNNL_ARG_SRC); + const bool has_src_zps = !zps.has_default_values(DNNL_ARG_SRC); + // Applied to the pre-last dimension. + const auto src_zps_group0 = zps.get_group(DNNL_ARG_SRC, 0); + // Applied to the last dimension. + const auto src_zps_group1 = zps.get_group(DNNL_ARG_SRC, 1); + memory_desc_t src_zps_md {}; + if (has_src_zps) { + get_quant_md(src_zps_md, ndims, input_d.dims(), src_zps_mask, + src_zps_group0, src_zps_group1, src_zps_d.data_type()); + } + + parallel_nd(input_d.nelems(), [&](dim_t idx) { + // Must be per thread; when shared, race condition happens. + dims_t input_idx {}; + float src_scale = 1.f; + if (has_src_scales || has_dst_scales || has_src_zps) { + utils::l_dims_by_l_offset( + input_idx, idx, input_d.dims(), ndims); + } + if (has_src_scales) { + const dim_t src_scales_off = get_quant_off(input_idx, ndims, + src_scales_mask, src_scales_group0, src_scales_group1, + src_scales_md); + // A single scale has already been pre-processed by the + // library-managed macros. + src_scale = src_scales_d.nelems() == 1 + ? src_scales[0] + : io::load_float_value(src_scales_d.data_type(), + src_scales, src_scales_off); + } + + float dst_scale = 1.f; + if (has_dst_scales) { + const dim_t dst_scales_off = get_quant_off( + input_idx, ndims, dst_scales_mask, 1, 1, dst_scales_md); + dst_scale = dst_scales[dst_scales_off]; + } + + int src_zp_val = 0; // Avoid clashing with the one defined for rest. + if (has_src_zps) { + const dim_t src_zps_off + = get_quant_off(input_idx, ndims, src_zps_mask, + src_zps_group0, src_zps_group1, src_zps_md); + src_zp_val = io::load_float_value( + src_zps_d.data_type(), src_zero_points, src_zps_off); + } + + const auto i_off = input_d.off_l(idx); + const auto o_off = output_d.off_l(idx); + float d = src_scale * (input[i_off] - src_zp_val); + if (beta) d += beta * output[o_off]; + d = d * dst_scale + dst_zp; + output[o_off] = _qz_a1b0()(d); + }); return status::success; } }; @@ -2299,24 +2977,36 @@ struct simple_reorder_t : public primitive_t { const primitive_attr_t *attr, engine_t *src_engine, const memory_desc_t *src_md, engine_t *dst_engine, const memory_desc_t *dst_md) { - using skip_mask_t = dnnl_primitive_attr::skip_mask_t; - bool args_ok = impl::is_dense_format_kind({src_md, dst_md}) - && src_md->data_type == type_i - && dst_md->data_type == type_o - && attr->has_default_values(skip_mask_t::scales_runtime - | skip_mask_t::zero_points - | skip_mask_t::zero_points_runtime - | skip_mask_t::post_ops) - && simple_reorder_impl::is_applicable(src_md, dst_md, attr); - if (!args_ok) return status::invalid_arguments; + // Since `type_i` and `type_o` are templated arguments, no need + // to put them under verbose_dispatch logic. + bool ok = src_md->data_type == type_i + && dst_md->data_type == type_o; + if (!ok) return status::invalid_arguments; + + VDISPATCH_REORDER_IC(impl::is_dense_format_kind({src_md, dst_md}), + VERBOSE_UNSUPPORTED_SPARSE_CFG); + + using skip_mask_t = primitive_attr_t::skip_mask_t; + VDISPATCH_REORDER_IC( + attr->has_default_values(skip_mask_t::scales_data_type + | skip_mask_t::scales_groups + | skip_mask_t::zero_points_data_type + | skip_mask_t::zero_points_groups + | skip_mask_t::post_ops), + VERBOSE_UNSUPPORTED_ATTR); + + auto status = simple_reorder_impl::is_applicable(src_md, dst_md, attr); + if (status != status::success) return status; - int mask = -1; - bool is_set = false; - CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set)); const memory_desc_wrapper input_d(src_md); - if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0) - return status::unimplemented; + + int mask = -1; + if (!attr->scales_.has_default_values(DNNL_ARG_DST)) { + mask = attr->scales_.get_mask(DNNL_ARG_DST); + if (input_d.has_runtime_dims_or_strides() && mask > 0) + return status::unimplemented; + } auto _pd = make_unique_pd(attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); @@ -2330,7 +3020,7 @@ struct simple_reorder_t : public primitive_t { scratchpad.book(memory_tracking::names::key_reorder_space, scratchpad_sz_, 1, 16); - if (is_set && mask > 0) { + if (mask > 0) { dim_t D_mask; _pd->get_D_values(input_d, mask, nullptr, &D_mask, nullptr); scratchpad.template book( diff --git a/src/cpu/reorder/simple_sparse_reorder.hpp b/src/cpu/reorder/simple_sparse_reorder.hpp index 2eb0cd4c203..b65dbefa97c 100644 --- a/src/cpu/reorder/simple_sparse_reorder.hpp +++ b/src/cpu/reorder/simple_sparse_reorder.hpp @@ -1,37 +1,38 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 +* * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ + #ifndef CPU_REORDER_SIMPLE_SPARSE_REORDER_HPP #define CPU_REORDER_SIMPLE_SPARSE_REORDER_HPP -#include -#include - #include + +#include "simple_reorder.hpp" + #include "common/c_types_map.hpp" #include "common/dnnl_thread.hpp" #include "common/math_utils.hpp" #include "common/primitive.hpp" -#include "common/reorder.hpp" - #include "common/primitive_attr.hpp" -#include "common/stream.hpp" #include "common/tag_traits.hpp" #include "common/type_helpers.hpp" #include "common/utils.hpp" + #include "cpu/cpu_primitive.hpp" #include "cpu/reorder/cpu_reorder_pd.hpp" + #include "cpu/simple_q10n.hpp" namespace dnnl { @@ -40,9 +41,6 @@ namespace cpu { // The following cases can be covered: // -// Note: `sparse_tag` is a regular format tag describing -// a regular tensor with sparse data. -// // - sparse_tag -> sparse_tag // - encoding -> encoding // @@ -55,10 +53,23 @@ namespace cpu { // - dense_tag -> encoding // - encoding -> dense_tag #define SIMPLE_SPARSE_REORDER_TEMPL_DECL \ - impl::data_type_t type_i, typename fmt_i_t, fmt_i_t fmt_i, \ - impl::data_type_t type_o, typename fmt_o_t, fmt_o_t fmt_o + impl::data_type_t type_i, format_tag_t fmt_i, \ + impl::data_type_t type_o, format_tag_t fmt_o, \ + bool order_keep + #define SIMPLE_SPARSE_REORDER_TEMPL_CALL \ - type_i, fmt_i_t, fmt_i, type_o, fmt_o_t, fmt_o + type_i, fmt_i, type_o, fmt_o, order_keep + +// TODO: move common code to reorder_utils.hpp. +namespace sparse_spec { +struct reference {}; +} // namespace sparse_spec + +namespace sparse_inputs_order { +constexpr bool keep = true; +constexpr bool reverse = false; +constexpr bool any = keep; +} // namespace sparse_inputs_order template struct simple_sparse_reorder_impl {}; @@ -66,113 +77,132 @@ struct simple_sparse_reorder_impl {}; namespace { template constexpr bool is_format_tag(T) { - return std::is_same::value; + return std::is_same::value ? true : false; } } // namespace +using namespace data_type; + +// TODO: think about combining compression reorders with sparse reorders. +/* specific reorders: IP compression */ template struct simple_sparse_reorder_impl::type> { + && (fmt_i == format_tag::oi + || fmt_i == format_tag::io)) + && (is_format_tag(fmt_o) + && fmt_o == format_tag::OI16i64o4i), + sparse_spec::reference>::type> { - static bool is_applicable(const memory_desc_wrapper &input_d, + static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - // This reorder expects a non-plain format for destination. - return input_d.is_blocking_desc() && output_d.is_sparse_desc() - && output_d.sparse_desc().encoding == sparse_encoding::packed - && output_d.blocking_desc().inner_nblks > 0 - && output_d.blk_size() % 64 == 0; - } - - static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d) { - const auto nelems = output_d.nelems(true); - const auto tmp_output_sz = nelems * output_d.data_type_size(); - const auto nnz_per_blocks_sz - = nelems / output_d.blk_size() * sizeof(dim_t); - return tmp_output_sz + nnz_per_blocks_sz; - } - static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx, - const std::shared_ptr &reorder) { - auto output_values = CTX_OUT_MEM(data_t *, DNNL_ARG_TO, 0); - auto output_offsets = CTX_OUT_MEM(int64_t *, DNNL_ARG_TO, 1); - auto output_bitmask = CTX_OUT_MEM(uint64_t *, DNNL_ARG_TO, 2); + VDISPATCH_REORDER_IC(!input_d.has_runtime_dims_or_strides(), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + const size_t D_mask = utils::array_product( + input_d.dims(), math::ilog2q(attr->scales_.get_mask(DNNL_ARG_SRC) - INT_MIN + 1)); + const size_t oc = (input_d.dims()[0]); + VDISPATCH_REORDER_IC(output_d.matches_tag(fmt_o) && input_d.matches_tag(fmt_i), + VERBOSE_RUNTIMEDIM_UNSUPPORTED); + // VDISPATCH_REORDER_IC( + // input_d.is_blocking_desc(), VERBOSE_UNSUPPORTED_FORMAT_KIND); + VDISPATCH_REORDER_IC( + output_d.is_sparse_desc(), VERBOSE_UNSUPPORTED_FORMAT_KIND); + VDISPATCH_REORDER_IC( + output_d.sparse_desc().encoding == sparse_encoding::packed, + VERBOSE_UNSUPPORTED_FEATURE, + "only sparse_encoding::packed is supported for dst"); + // VDISPATCH_REORDER_IC(output_d.blocking_desc().inner_nblks > 0, + // VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); + // VDISPATCH_REORDER_IC(output_d.blk_size() % 64 == 0, + // VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "dst"); + VDISPATCH_REORDER_IC(utils::one_of(input_d.data_type(), data_type::f32, data_type::s8), + VERBOSE_UNSUPPORTED_DT, "src"); + VDISPATCH_REORDER_IC(utils::one_of(output_d.data_type(), data_type::s8) && (D_mask == 1 || D_mask == oc), + VERBOSE_UNSUPPORTED_DT, "dst"); - engine_t *engine = ctx.stream()->engine(); - const auto scratchpad = ctx.get_scratchpad_grantor(); - auto wspace_mem_storage = scratchpad.get_memory_storage( - memory_tracking::names::key_reorder_space); - memory_t wspace_mem( - engine, reorder->pd()->dst_md(), std::move(wspace_mem_storage)); - - exec_args_t r_args; - r_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_FROM); - r_args[DNNL_ARG_DST] = {&wspace_mem, false}; - exec_ctx_t r_ctx(ctx, std::move(r_args)); + return status::success; + } - nested_scratchpad_t ns( - ctx, memory_tracking::names::key_nested, reorder); - r_ctx.set_scratchpad_grantor(ns.grantor()); - reorder->execute(r_ctx); + GET_SCRATCHPAD_SIZE_ZERO(); - auto *wspace = scratchpad.template get>( - memory_tracking::names::key_reorder_space); + static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { + auto input = CTX_IN_MEM(const data_t *, DNNL_ARG_FROM); + auto output = CTX_OUT_MEM(data_t *, DNNL_ARG_TO); + const auto input_d = ctx.memory_mdw(DNNL_ARG_FROM, pd->src_md()); const auto output_d = ctx.memory_mdw(DNNL_ARG_TO, pd->dst_md()); - const auto nelems = output_d.nelems(true); - const auto blk_sz = output_d.blk_size(); - const auto nblks = nelems / blk_sz; - - dim_t *nnz_per_blocks - = reinterpret_cast(reinterpret_cast(wspace) - + nelems * output_d.data_type_size()); - - static constexpr int bitmask_step = sizeof(uint64_t) * CHAR_BIT; - // Fill output_bitmask and move non-zero elements to the begining of the - // blocks. Also, remember number of non-zero elements per-block to - // calculate output_offsets later. - parallel_nd(nblks, [&](dim_t b) { - dim_t nnz_per_blk = 0; - for (dim_t i = 0; i < blk_sz / bitmask_step; i++) { - uint64_t &bm = output_bitmask[b * blk_sz / bitmask_step + i]; - bm = 0; - for (dim_t j = 0; j < bitmask_step; j++) { - const auto v = wspace[b * blk_sz + bitmask_step * i + j]; - if (v != 0) { - wspace[b * blk_sz + nnz_per_blk++] = v; - bm |= (uint64_t(1) << j); + + const auto &input_dims = input_d.dims(); + const auto &padded_dims = output_d.padded_dims(); + constexpr int i_outer_blksize = 16; + constexpr int i_blksize = i_outer_blksize * 4; + constexpr int o_blksize = 64; + + const int OC = input_dims[0]; + const int NB_OC = padded_dims[0] / o_blksize; + const int IC = input_dims[1]; + const int NB_IC = padded_dims[1] / i_blksize; + const int plain_o_stride = input_d.blocking_desc().strides[0]; + const int plain_i_stride = input_d.blocking_desc().strides[1]; + size_t offset = padded_dims[0] * padded_dims[1]; + + int total_blocks = offset / 4096; + using comp_tile_len_type = int; + comp_tile_len_type *comp_tile_len_ptr = reinterpret_cast(output); + int comp_tile_len_index = 0; + int cl_length = 0; + // Wasting memory space due to allocation a buffer for the whole tensor? + int output_offset = ceil((float)total_blocks * sizeof(comp_tile_len_type) / 64.0) * 64; + uint64_t *bitmask_ptr = reinterpret_cast(output + output_offset + offset); + auto outp = &output[output_d.blk_off(0, 0, 0, 0) + output_offset]; + + // TODO: add threading. + for (int O = 0; O < NB_OC; O++) { + for (int I = 0; I < NB_IC; I++) { + auto inp + = &input[input_d.blk_off(o_blksize * O, i_blksize * I)]; + const int oc_block = nstl::min(o_blksize, OC - O * o_blksize); + const int ic_block = nstl::min(i_blksize, IC - I * i_blksize); + int non_zeros = 0; + int bitmask_idx = (O * NB_IC + I) * i_blksize; + comp_tile_len_ptr[comp_tile_len_index] = cl_length; + + for (int ic_base = 0; ic_base < ic_block; + ic_base += 4) { // 64, steps of 4 + bitmask_ptr[bitmask_idx] = 0; + int bit = 0; + int count = 0; + for (int oc = 0; oc < oc_block; oc++) { // 64 + if (count % 64 == 0) { + bitmask_ptr[bitmask_idx] = 0; + bit = 0; + } + int plain_off = oc * plain_o_stride + + ic_base * plain_i_stride; + int ic_block_here = nstl::min(4, ic_block - ic_base); + for (int ic = 0; ic < ic_block_here; ic++) { // 4 + data_t o = inp[plain_off]; + if (o != 0) { + *outp++ = o; + bitmask_ptr[bitmask_idx] |= (1UL << bit); + non_zeros++; + } + plain_off += plain_i_stride; + bit++; + count++; + } + if (count % 64 == 0) { bitmask_idx++; } } } + comp_tile_len_type cl = (comp_tile_len_type)ceil(non_zeros / 64.0); + comp_tile_len_index++; + cl_length = comp_tile_len_ptr[comp_tile_len_index - 1] + cl; + int unsed_bytes_in_cl = 64 - (non_zeros % 64); + if (unsed_bytes_in_cl == 64) { unsed_bytes_in_cl = 0; } + outp += unsed_bytes_in_cl; // 64: next output starts in new cacheline } - nnz_per_blocks[b] = nnz_per_blk; - }); - - // Calculate output_offsets using previously computed number of non-zero - // elements in each block. - parallel_nd(nblks, [&](dim_t b) { - dim_t off = 0; - if (b != 0) { - for (dim_t i = 0; i < b; i++) { - off += nnz_per_blocks[i]; - } - } - output_offsets[b] = off; - }); - - // Use the calculated output_offsets and number of non-zero elements - // per block to copy the non-zero elements that we moved to the - // begining of the blocks to output_values. - parallel_nd(nblks, [&](dim_t b) { - const auto nnz_per_blk = nnz_per_blocks[b]; - const auto blk_off = output_offsets[b]; - for (dim_t i = 0; i < nnz_per_blk; i++) { - output_values[blk_off + i] = wspace[b * blk_sz + i]; - } - }); - + } return status::success; } }; @@ -181,9 +211,8 @@ template struct simple_sparse_reorder_t : public primitive_t { struct pd_t : public cpu_reorder_pd_t { using cpu_reorder_pd_t::cpu_reorder_pd_t; - DECLARE_COMMON_PD_T("simple::any", simple_sparse_reorder_t); - std::shared_ptr reorder_pd_; + DECLARE_COMMON_PD_T("simple_sparse:any", simple_sparse_reorder_t); private: static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, @@ -191,39 +220,24 @@ struct simple_sparse_reorder_t : public primitive_t { const memory_desc_t *src_md, engine_t *dst_engine, const memory_desc_t *dst_md) { - const bool args_ok = src_md->data_type == type_i + const bool ok = src_md->data_type == type_i && dst_md->data_type == type_o + && attr->has_default_values() && simple_sparse_reorder_impl< - SIMPLE_SPARSE_REORDER_TEMPL_CALL>:: - is_applicable(src_md, dst_md, attr); - if (!args_ok) return status::invalid_arguments; + SIMPLE_SPARSE_REORDER_TEMPL_CALL, + spec>::is_applicable(src_md, dst_md, attr) == status::success; + if (!ok) return status::invalid_arguments; - auto _pd = make_unique_pd(attr, src_engine->kind(), src_md, + auto _pd = new pd_t(attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); if (_pd == nullptr) return status::out_of_memory; - CHECK(_pd->init(engine, src_engine, dst_engine)); - - CHECK(_pd->init_scratchpad_md()); - return safe_ptr_assign(*reorder_pd, _pd.release()); - } + if (_pd->init(engine, src_engine, dst_engine) != status::success) { + delete _pd; + return status::unimplemented; + } - status_t init( - engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { - // Convert sparse packed desc to blocking desc. - auto converted_dst_md = cvt_sparse_packed2blocked(*this->dst_md()); - - CHECK(reorder_primitive_desc_create( - reorder_pd_, engine, src_md(), &converted_dst_md, attr())); - - const size_t scratchpad_sz_ = simple_sparse_reorder_impl< - SIMPLE_SPARSE_REORDER_TEMPL_CALL>:: - get_scratchpad_size(src_md(), dst_md()); - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(memory_tracking::names::key_reorder_space, - scratchpad_sz_, 1, 16); - scratchpad.book(memory_tracking::names::key_nested, - reorder_pd_->scratchpad_registry()); - return status::success; + _pd->init_scratchpad_md(); + return safe_ptr_assign(*reorder_pd, _pd); } friend dnnl::impl::impl_list_item_t; @@ -231,18 +245,13 @@ struct simple_sparse_reorder_t : public primitive_t { simple_sparse_reorder_t(const pd_t *apd) : primitive_t(apd) {} - status_t init(engine_t *engine) override { - return pd()->reorder_pd_->create_primitive(reorder_, engine); - } - status_t execute(const exec_ctx_t &ctx) const override { - return simple_sparse_reorder_impl< - SIMPLE_SPARSE_REORDER_TEMPL_CALL>::execute(pd(), ctx, reorder_); + return simple_sparse_reorder_impl::execute(pd(), ctx); } private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - std::shared_ptr reorder_; }; #undef SIMPLE_SPARSE_REORDER_TEMPL_DECL @@ -251,4 +260,5 @@ struct simple_sparse_reorder_t : public primitive_t { } // namespace cpu } // namespace impl } // namespace dnnl + #endif diff --git a/src/cpu/rnn/postgemm_dispatcher.hpp b/src/cpu/rnn/postgemm_dispatcher.hpp index 1c38e44d2fa..be43824f7aa 100644 --- a/src/cpu/rnn/postgemm_dispatcher.hpp +++ b/src/cpu/rnn/postgemm_dispatcher.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,14 +52,14 @@ template struct rnn_postgemm_dispatcher { - typedef typename prec_traits::type src_layer_t; - typedef typename prec_traits::type src_iter_t; - typedef typename prec_traits::type dst_layer_t; - typedef typename prec_traits::type dst_iter_t; - typedef typename prec_traits::type gemm_acc_t; - typedef typename prec_traits::type scratch_t; - typedef typename prec_traits::type ht_t; - typedef typename prec_traits::type gates_t; + using src_layer_t = typename prec_traits_t::type; + using src_iter_t = typename prec_traits_t::type; + using dst_layer_t = typename prec_traits_t::type; + using dst_iter_t = typename prec_traits_t::type; + using gemm_acc_t = typename prec_traits_t::type; + using scratch_t = typename prec_traits_t::type; + using ht_t = typename prec_traits_t::type; + using gates_t = typename prec_traits_t::type; using class_name = rnn_postgemm_dispatcher; @@ -253,20 +253,25 @@ struct rnn_postgemm_dispatcher { && !mayiuse(avx512_core)) return status::success; +//NOLINTBEGIN(bugprone-macro-parentheses) +// Can't put types into `()`: +// error: expected type-specifier before ‘)’ token #define CREATE_WITH_DIR(k, ker_t) \ do { \ if (mayiuse(avx512_core)) \ - k.reset(new ker_t(rnn, pd_)); \ + (k).reset( \ + new ker_t(rnn, pd_)); \ else if (mayiuse(avx2)) \ - k.reset(new ker_t(rnn, pd_)); \ + (k).reset(new ker_t(rnn, pd_)); \ else \ - k.reset(new ker_t(rnn, pd_)); \ + (k).reset(new ker_t(rnn, pd_)); \ } while (0) #define CREATE(k, ker_t) \ do { \ - if (jit_fwd) CREATE_WITH_DIR(k, CONCAT2(ker_t, _fwd)); \ - if (jit_bwd) CREATE_WITH_DIR(k, CONCAT2(ker_t, _bwd)); \ + if (jit_fwd) CREATE_WITH_DIR((k), CONCAT2(ker_t, _fwd)); \ + if (jit_bwd) CREATE_WITH_DIR((k), CONCAT2(ker_t, _bwd)); \ } while (0) + //NOLINTEND(bugprone-macro-parentheses) if (pd_->cell_kind() == alg_kind::vanilla_lstm) { CREATE(rnn_postgemm_, jit_uni_lstm_cell_postgemm); diff --git a/src/cpu/rnn/ref_postgemm_lstm.cpp b/src/cpu/rnn/ref_postgemm_lstm.cpp index 8fef036b710..d452eb39fc2 100644 --- a/src/cpu/rnn/ref_postgemm_lstm.cpp +++ b/src/cpu/rnn/ref_postgemm_lstm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -188,7 +188,7 @@ rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::lstm_postgemm) { const auto quantize_f32_u8 = [&](float f) { float qf = f * data_scale + data_shift; - return q10n::qz_a1b0()(qf); + return q10n::qz_a1b0_t()(qf); }; const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) { @@ -229,7 +229,7 @@ rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::lstm_postgemm) { const auto quantize_f32_s8 = [&](float f) { float qf = f * data_scale + data_shift; - return q10n::qz_a1b0()(qf); + return q10n::qz_a1b0_t()(qf); }; const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) { diff --git a/src/cpu/rnn/ref_postgemm_lstm_projection.cpp b/src/cpu/rnn/ref_postgemm_lstm_projection.cpp index 153603ecafe..5a3c728d8db 100644 --- a/src/cpu/rnn/ref_postgemm_lstm_projection.cpp +++ b/src/cpu/rnn/ref_postgemm_lstm_projection.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,7 +104,7 @@ rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::lstm_projection_postgemm) { float qf = f * data_scale + data_shift; qf = nstl::min(qf, 255.0f); qf = nstl::max(qf, 0.0f); - return q10n::qz_a1b0()(qf); + return q10n::qz_a1b0_t()(qf); }; const auto dequantize_s32_f32 = [&](gemm_acc_t s, int j) { @@ -149,7 +149,7 @@ rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::lstm_projection_postgemm) { const auto quantize_f32_s8 = [&](float f) { const float qf = f * data_scale + data_shift; - return q10n::qz_a1b0()(qf); + return q10n::qz_a1b0_t()(qf); }; const auto dequantize_s32_f32 = [&](gemm_acc_t s, int j) { diff --git a/src/cpu/rnn/ref_rnn.cpp b/src/cpu/rnn/ref_rnn.cpp index 2df2652daea..21ccbd20d54 100644 --- a/src/cpu/rnn/ref_rnn.cpp +++ b/src/cpu/rnn/ref_rnn.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -320,6 +320,9 @@ _ref_rnn_common_t::pd_t::init_brgemm( bool allow_down_conversion_to_bf16 = is_f32 && is_fpmath_bf16 && is_impl_bf16; + // Initialized rnn_ early to get correct verbose output + rnn_ = zero(); + rnn_.is_brgemm = true; VDISPATCH_RNN( one_of(cell_kind, alg_kind::vanilla_rnn, alg_kind::vanilla_lstm, alg_kind::vanilla_gru, alg_kind::lbr_gru, @@ -352,8 +355,6 @@ _ref_rnn_common_t::pd_t::init_brgemm( VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_RNN(this->with_bias(), VERBOSE_UNSUPPORTED_BIAS_CFG); - rnn_ = zero(); - rnn_.is_brgemm = true; VDISPATCH_RNN(init_conf(rnn_, *this->desc(), *this->attr(), this->src_md(0), this->src_md(1), this->src_md(2), this->weights_md(0), this->weights_md(1), @@ -413,7 +414,7 @@ _ref_rnn_common_t::pd_t::init_brgemm( VDISPATCH_RNN( !(rnn_.is_signed_int8_conf() && !is_superset(isa, avx512_core_amx)), VERBOSE_ISA_DT_MISMATCH); - VDISPATCH_RNN(!(rnn_.is_int8_conf() && !is_superset(isa, avx512_core_vnni)), + VDISPATCH_RNN(!(rnn_.is_int8_conf() && !is_superset(isa, avx2)), VERBOSE_ISA_DT_MISMATCH); VDISPATCH_RNN(!(rnn_.is_f32_conf() && !is_superset(isa, avx2)), VERBOSE_ISA_DT_MISMATCH); @@ -829,20 +830,33 @@ template ::execute_matmul)) { - engine_t *engine = ctx.stream()->engine(); + // Service engine is just a global classic CPU engine that is used + // when it's required to create memory_t objects for classic CPU + // engine regardless of the CPU runtime. For example, SYCL CPU engine + // cannot be used to create such objects. + engine_t *service_engine = get_service_engine(); constexpr auto mem_flag = memory_flags_t::use_runtime_ptr; - memory_t src_mem( - engine, matmul_prim->pd()->src_md(), mem_flag, (void *)(a_)); - memory_t wei_mem( - engine, matmul_prim->pd()->weights_md(), mem_flag, (void *)(b_)); - memory_t dst_mem( - engine, matmul_prim->pd()->dst_md(), mem_flag, (void *)(c_)); + + // a_, b_ and c_ are regular, raw CPU pointers that can only be used with + // memory_t objects created for the classic CPU engine. + std::unique_ptr src_mem; + CHECK(safe_ptr_assign(src_mem, + new memory_t(service_engine, matmul_prim->pd()->src_md(), mem_flag, + (void *)(a_)))); + std::unique_ptr wei_mem; + CHECK(safe_ptr_assign(wei_mem, + new memory_t(service_engine, matmul_prim->pd()->weights_md(), + mem_flag, (void *)(b_)))); + std::unique_ptr dst_mem; + CHECK(safe_ptr_assign(dst_mem, + new memory_t(service_engine, matmul_prim->pd()->dst_md(), mem_flag, + (void *)(c_)))); exec_args_t matmul_args; // Note Matmul src and wei may not directly map to RNN primitive src and wei - matmul_args[DNNL_ARG_SRC] = {&wei_mem, true}; - matmul_args[DNNL_ARG_WEIGHTS] = {&src_mem, true}; - matmul_args[DNNL_ARG_DST] = {&dst_mem, false}; + matmul_args[DNNL_ARG_SRC] = {wei_mem.get(), true}; + matmul_args[DNNL_ARG_WEIGHTS] = {src_mem.get(), true}; + matmul_args[DNNL_ARG_DST] = {dst_mem.get(), false}; exec_ctx_t matmul_ctx(ctx, std::move(matmul_args)); nested_scratchpad_t ns(ctx, key_nested_multiple, matmul_prim); @@ -1409,7 +1423,7 @@ void copy_init_iter_fwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd, const auto maybe_q = [&](input_data_t f) { if (quantize) { float qf = f * data_scale + data_shift; - return q10n::qz_a1b0()(qf); + return q10n::qz_a1b0_t()(qf); } else return (src_data_t)f; }; @@ -1575,7 +1589,7 @@ void copy_res_layer_fwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd, PRAGMA_OMP_SIMD() for (int s = 0; s < rnn.dlc; s++) { float val = (float)ss[s] + dd[s]; - val = q10n::qz_a1b0()(val); + val = q10n::qz_a1b0_t()(val); dd[s] = (dst_layer_dt)((val - 2 * shift) / scale); } } else if (rnn_u8u8_case @@ -2132,11 +2146,13 @@ status_t _ref_rnn_common_t::execute( auto wei_iter_mem = scratchpad.get_memory_storage(key_rnn_bf32_wei_iter_trans); { - memory_t reorder_dst( - engine, &wei_layer_desc, std::move(wei_layer_mem)); + std::unique_ptr reorder_dst; + CHECK(safe_ptr_assign(reorder_dst, + new memory_t(engine, &wei_layer_desc, + std::move(wei_layer_mem)))); exec_args_t reorder_args; reorder_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_WEIGHTS_LAYER); - reorder_args[DNNL_ARG_DST] = {&reorder_dst, false}; + reorder_args[DNNL_ARG_DST] = {reorder_dst.get(), false}; exec_ctx_t reorder_ctx(ctx, std::move(reorder_args)); nested_scratchpad_t ns( ctx, key_nested_multiple, bf32_wei_layer_reorder_); @@ -2148,11 +2164,13 @@ status_t _ref_rnn_common_t::execute( } { - memory_t reorder_dst( - engine, &wei_iter_desc, std::move(wei_iter_mem)); + std::unique_ptr reorder_dst; + CHECK(safe_ptr_assign(reorder_dst, + new memory_t( + engine, &wei_iter_desc, std::move(wei_iter_mem)))); exec_args_t reorder_args; reorder_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_WEIGHTS_ITER); - reorder_args[DNNL_ARG_DST] = {&reorder_dst, false}; + reorder_args[DNNL_ARG_DST] = {reorder_dst.get(), false}; exec_ctx_t reorder_ctx(ctx, std::move(reorder_args)); nested_scratchpad_t ns( ctx, key_nested_multiple, bf32_wei_iter_reorder_); diff --git a/src/cpu/rnn/ref_rnn.hpp b/src/cpu/rnn/ref_rnn.hpp index a479867bd26..bb2262fa257 100644 --- a/src/cpu/rnn/ref_rnn.hpp +++ b/src/cpu/rnn/ref_rnn.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -62,7 +62,7 @@ void gates_reduction(const rnn_utils::rnn_conf_t &rnn, // @todo block k on simd-width to enable vectorization in // parallel_nd path #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP && _OPENMP >= 201307 \ - && (!defined(__INTEL_COMPILER) || __INTEL_COMPILER < 1910) + && defined __INTEL_COMPILER && __INTEL_COMPILER < 1910 #pragma omp parallel for simd collapse(2) for (int i = 0; i < rnn.n_gates; i++) for (int k = 0; k < rnn.dhc; k++) @@ -97,16 +97,16 @@ struct _ref_rnn_common_t : public primitive_t { rnn_postgemm_bwd_t>::type; /* These types are defined for each element in the cell execution */ - typedef typename prec_traits::type src_layer_t; - typedef typename prec_traits::type src_iter_t; - typedef typename prec_traits::type dst_layer_t; - typedef typename prec_traits::type dst_iter_t; - typedef typename prec_traits::type weights_t; - typedef typename prec_traits::type gemm_data_t; - typedef typename prec_traits::type gemm_acc_t; - typedef typename prec_traits::type scratch_t; - typedef typename prec_traits::type ht_t; - typedef typename prec_traits::type gates_t; + using src_layer_t = typename prec_traits_t::type; + using src_iter_t = typename prec_traits_t::type; + using dst_layer_t = typename prec_traits_t::type; + using dst_iter_t = typename prec_traits_t::type; + using weights_t = typename prec_traits_t::type; + using gemm_data_t = typename prec_traits_t::type; + using gemm_acc_t = typename prec_traits_t::type; + using scratch_t = typename prec_traits_t::type; + using ht_t = typename prec_traits_t::type; + using gates_t = typename prec_traits_t::type; using class_name = _ref_rnn_common_t; @@ -172,7 +172,7 @@ struct _ref_rnn_common_t : public primitive_t { : primitive_t(apd), rnn_postgemm_(nullptr) {} status_t init(engine_t *engine) override; - virtual ~_ref_rnn_common_t() { delete rnn_postgemm_; } + ~_ref_rnn_common_t() override { delete rnn_postgemm_; } status_t execute(const exec_ctx_t &ctx) const override; diff --git a/src/cpu/rnn/rnn_reorders.hpp b/src/cpu/rnn/rnn_reorders.hpp index 5156350d860..79b1ff21e93 100644 --- a/src/cpu/rnn/rnn_reorders.hpp +++ b/src/cpu/rnn/rnn_reorders.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,7 +61,7 @@ template static inline void quantize_igo(int8_t *scratch_quantized, const memory_desc_wrapper &src_d, const float *src, int mask, float *scales) { - typedef typename prec_traits::type in_data_t; + using in_data_t = typename prec_traits_t::type; // TODO: trivial strides assumes here. // Use proper strides where appropriate @@ -76,7 +76,7 @@ static inline void quantize_igo(int8_t *scratch_quantized, for (int go = 0; go < G * O; go++) { const float s = scales[(mask == 0) ? 0 : go]; scratch_quantized[ldi * G * O + go] - = q10n::qz_b0()( + = q10n::qz_b0_t()( src[ldi * G * O + go], s); } } @@ -87,7 +87,7 @@ template static inline void quantize_goi(int8_t *scratch_quantized, const memory_desc_wrapper &src_d, const float *src, int mask, float *scales) { - typedef typename prec_traits::type in_data_t; + using in_data_t = typename prec_traits_t::type; // TODO: trivial strides assumes here. // Use proper strides where appropriate @@ -100,7 +100,7 @@ static inline void quantize_goi(int8_t *scratch_quantized, PRAGMA_OMP_SIMD() for (dim_t i = 0; i < I; i++) { scratch_quantized[ld * I * G * O + i * G * O + go] - = q10n::qz_b0()( + = q10n::qz_b0_t()( src[ld * G * O * I + go * I + i], s); } }); @@ -232,8 +232,8 @@ struct rnn_data_reorder_t : public primitive_t { rnn_data_reorder_t(const pd_t *apd) : primitive_t(apd) {} private: - typedef typename prec_traits::type in_data_t; - typedef typename prec_traits::type out_data_t; + using in_data_t = typename prec_traits_t::type; + using out_data_t = typename prec_traits_t::type; bool is_dense() const { const memory_desc_wrapper &input_d = pd()->src_md(); @@ -271,7 +271,7 @@ struct rnn_data_reorder_t : public primitive_t { PRAGMA_OMP_SIMD() for (int j = 0; j < inner_dim; ++j) { const float in = (float)i_[j] * scale + shift; - o_[j] = q10n::qz_a1b0()(in); + o_[j] = q10n::qz_a1b0_t()(in); } } }); @@ -288,7 +288,8 @@ struct rnn_data_reorder_t : public primitive_t { const size_t nelems = input_d.nelems(); parallel_nd(nelems, [&](size_t i) { const float in = (float)input[input_d.off_l(i)] * scale + shift; - output[output_d.off_l(i)] = q10n::qz_a1b0()(in); + output[output_d.off_l(i)] + = q10n::qz_a1b0_t()(in); }); return status::success; } @@ -428,7 +429,7 @@ struct rnn_weights_reorder_s8_t : public primitive_t { rnn_weights_reorder_s8_t(const pd_t *apd) : primitive_t(apd) {} private: - typedef typename prec_traits::type in_data_t; + using in_data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { // TODO: trivial strides assumed here. @@ -615,8 +616,8 @@ struct rnn_weights_reorder_t : public primitive_t { rnn_weights_reorder_t(const pd_t *apd) : primitive_t(apd) {} private: - typedef typename prec_traits::type in_data_t; - typedef typename prec_traits::type out_data_t; + using in_data_t = typename prec_traits_t::type; + using out_data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { // TODO: trivial strides assumed here. @@ -779,12 +780,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t { return unimplemented; // Check the proper memory desc has been passed to u8s8 and s8s8 - // Note: currently rnn_u8s8_compensation and rnn_s8s8_compensation - // have common bit so we have to perform additional checks to - // separate these two cases const bool check_u8s8 = (od.extra().flags & rnn_u8s8_compensation) - && !types::extra_flag_rnn_s8s8_compensation_is_set( - od.extra().flags) && od.extra().compensation_mask == ((id.ndims() == 5) ? 27 /* 11011 */ : 13 /* 1101 */); @@ -802,7 +798,8 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t { format_tag_t otag, itag; itag = id.matches_one_of_tag(ldigo, ldio); - otag = od.matches_one_of_tag(ldgOI64o4i, ldgOI32o4i, ldOI32o4i); + otag = od.matches_one_of_tag( + ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i, ldOI16o4i); if (itag != format_tag::undef && otag != format_tag::undef) { _pd->itag_ = itag; _pd->otag_ = otag; @@ -842,8 +839,8 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t { rnn_brgemm_weights_reorder_s8_t(const pd_t *apd) : primitive_t(apd) {} private: - typedef typename prec_traits::type in_data_t; - typedef typename prec_traits::type out_data_t; + using in_data_t = typename prec_traits_t::type; + using out_data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { using namespace format_tag; @@ -860,15 +857,13 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t { return status::success; } - const auto &blocked_d = dst_d; - const auto &pdims = blocked_d.padded_dims(); - - const int o_block = pd()->otag_ == ldgOI64o4i ? 64 : 32; + const int o_block = dst_d.blocking_desc().inner_blks[0]; static constexpr int i_block = 4; dim_t L, D, I, G, O; init_dims(L, D, I, G, O, src_d); + const auto &pdims = dst_d.padded_dims(); const dim_t pI = pdims[2]; const dim_t pO = (src_d.ndims() == 5) ? pdims[4] : pdims[3]; const dim_t IB = pI / i_block; @@ -886,9 +881,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t { .template get(memory_tracking::names:: key_reorder_rnn_weights_reduction); float *comp = reinterpret_cast(dst + compensation_offset); - const bool req_s8s8_comp = (dst_d.extra().flags & rnn_u8s8_compensation) - && !types::extra_flag_rnn_s8s8_compensation_is_set( - dst_d.extra().flags); + const bool req_s8s8_comp = dst_d.extra().flags & rnn_u8s8_compensation; const auto mask_ok = [&](int mask) { return mask == ((src_d.ndims() == 5) ? 27 /* 11011 */ diff --git a/src/cpu/rnn/rnn_utils.cpp b/src/cpu/rnn/rnn_utils.cpp index e4342fe93eb..a4a51608ad1 100644 --- a/src/cpu/rnn/rnn_utils.cpp +++ b/src/cpu/rnn/rnn_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,8 +76,8 @@ bool rnn_utils::is_ldoi(const memory_desc_wrapper &mdw) { bool rnn_utils::is_ldigo_blocked(const memory_desc_wrapper &mdw) { format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldgOi32o, format_tag::ldgOI32o2i, format_tag::ldgOI32o4i, - format_tag::ldgOI64o2i, format_tag::ldgOI64o4i, - format_tag::ldgOi16o); + format_tag::ldgOI16o4i, format_tag::ldgOI64o2i, + format_tag::ldgOI64o4i, format_tag::ldgOi16o); return md_format_tag != format_tag::undef; } @@ -88,8 +88,8 @@ bool rnn_utils::is_ldgoi_blocked(const memory_desc_wrapper &mdw) { } bool rnn_utils::is_ldio_blocked(const memory_desc_wrapper &mdw) { - format_tag_t md_format_tag = mdw.matches_one_of_tag( - format_tag::ldOi32o, format_tag::ldOI32o4i, format_tag::ldOi16o); + format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldOi32o, + format_tag::ldOI32o4i, ldOI16o4i, format_tag::ldOi16o); return md_format_tag != format_tag::undef; } @@ -286,14 +286,16 @@ status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, if (weights_type == weights_type_t::projection) { if (rnn.is_int8_conf()) - tag = format_tag::ldOI32o4i; + tag = utils::map(n_block, format_tag::undef, 32, + format_tag::ldOI32o4i, 16, format_tag::ldOI16o4i); else tag = utils::map(n_block, format_tag::undef, 32, format_tag::ldOi32o, 16, format_tag::ldOi16o); } else if (rnn.is_fwd) { if (rnn.is_int8_conf()) tag = utils::map(n_block, format_tag::undef, 64, - format_tag::ldgOI64o4i, 32, ldgOI32o4i); + format_tag::ldgOI64o4i, 32, ldgOI32o4i, 16, + ldgOI16o4i); else if (rnn.is_xf16_conf()) tag = utils::map(n_block, format_tag::undef, 64, format_tag::ldgOI64o2i, 32, ldgOI32o2i); diff --git a/src/cpu/rnn/rnn_utils.hpp b/src/cpu/rnn/rnn_utils.hpp index 0bd61ba9365..f120e733cd3 100644 --- a/src/cpu/rnn/rnn_utils.hpp +++ b/src/cpu/rnn/rnn_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -316,7 +316,7 @@ struct rnn_conf_t { size_t weights_iter_comp_offset = 0, weights_iter_pack_size = 0; size_t weights_projection_comp_offset = 0, weights_projection_pack_size = 0; - bool copy_bias = 0; + bool copy_bias = false; int weights_layer_ld = 0, weights_layer_nld = 0; int diff_weights_layer_ld = 0, diff_weights_layer_nld = 0; int weights_iter_ld = 0, weights_iter_nld = 0; @@ -347,9 +347,10 @@ struct rnn_conf_t { int dst_iter_c_ld_ = 0, dst_iter_c_nld_ = 0; int weights_iter_compensation_size = 0, weights_layer_compensation_size = 0; - bool is_fwd = 0, is_training = 0, is_lbr = 0, is_lstm_peephole = 0, - is_lstm_projection = 0, is_augru = 0, is_orig_gru = 0; - bool use_workspace = 0; + bool is_fwd = false, is_training = false, is_lbr = false, + is_lstm_peephole = false, is_lstm_projection = false, is_augru = false, + is_orig_gru = false; + bool use_workspace = false; // Size of workspace for each tensor in bytes // Notes: @@ -630,7 +631,7 @@ struct rnn_conf_t { int dhc_block_peephole, dhc_tail_peephole, dhc_blocks_peephole; bool brgemm_fwd_iter_layer_fuse_possible = false; - dim_t nthr; + int nthr; #if DNNL_X64 x64::cpu_isa_t brgemm_isa; #endif @@ -683,7 +684,7 @@ bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, rnn.dst_iter_c_dt = dst_iter_c_d.is_zero() ? data_type::f32 : dst_iter_c_d.data_type(); - rnn.cell_dt = data_traits::data_type; + rnn.cell_dt = data_traits_t::data_type; switch (rd.direction) { case dnnl_unidirectional_left2right: rnn.exec_dir = l2r; break; case dnnl_unidirectional_right2left: rnn.exec_dir = r2l; break; diff --git a/src/cpu/rv64/CMakeLists.txt b/src/cpu/rv64/CMakeLists.txt index 6be072fd790..fe5f9bb37e0 100644 --- a/src/cpu/rv64/CMakeLists.txt +++ b/src/cpu/rv64/CMakeLists.txt @@ -32,7 +32,7 @@ if(NOT DNNL_RISCV_USE_RVV_INTRINSICS) endif() endif() -if(NOT DNNL_CPU_RUNTIME STREQUAL "SEQ") +if(NOT (DNNL_CPU_RUNTIME STREQUAL "SEQ" OR DNNL_CPU_RUNTIME STREQUAL "OMP")) message(FATAL_ERROR "Only sequential runtime is now supported for a RISC-V CPU") endif() diff --git a/src/cpu/rv64/rvv_nchw_pooling.cpp b/src/cpu/rv64/rvv_nchw_pooling.cpp index 5ded8584251..e4a1b566c7f 100644 --- a/src/cpu/rv64/rvv_nchw_pooling.cpp +++ b/src/cpu/rv64/rvv_nchw_pooling.cpp @@ -57,9 +57,9 @@ void MaxPooling(const float *src, float *dst, const dim_t batch, int ow_offset = ow * strideW - padLeft; size_t size = std::min(ow_offset + kerW, inW) - std::max(ow_offset, 0); - size_t cycleLength = vsetvl_e32m8(size); - vfloat32m8_t vmax - = vle32_v_f32m8(&arr_flt_min[0], cycleLength); + size_t cycleLength = __riscv_vsetvl_e32m8(size); + vfloat32m8_t vmax = __riscv_vle32_v_f32m8( + &arr_flt_min[0], cycleLength); for (int id = std::max(od_offset, 0); id < std::min(od_offset + kerD, inD); id++) @@ -73,34 +73,35 @@ void MaxPooling(const float *src, float *dst, const dim_t batch, size_t iw = 0; for (; iw < size - cycleLength; iw += cycleLength) { - vfloat32m8_t vsrc = vle32_v_f32m8( + vfloat32m8_t vsrc = __riscv_vle32_v_f32m8( &local_src[local_src_offset + iw], cycleLength); - vmax = vfmax_vv_f32m8( + vmax = __riscv_vfmax_vv_f32m8( vsrc, vmax, cycleLength); } - size_t tailLength = vsetvl_e32m8(size - iw); + size_t tailLength + = __riscv_vsetvl_e32m8(size - iw); { - vfloat32m8_t vsrc = vle32_v_f32m8( + vfloat32m8_t vsrc = __riscv_vle32_v_f32m8( &local_src[local_src_offset + iw], tailLength); - vmax = vfmax_vv_f32m8( + vmax = __riscv_vfmax_vv_f32m8( vsrc, vmax, tailLength); } } vfloat32m1_t min_scalar; float min = -__FLT_MAX__; - min_scalar = vle32_v_f32m1(&min, 1); + min_scalar = __riscv_vle32_v_f32m1(&min, 1); - cycleLength = vsetvl_e32m8(size); + cycleLength = __riscv_vsetvl_e32m8(size); vfloat32m1_t vred_res; - vred_res = vfredmax_vs_f32m8_f32m1( - vred_res, vmax, min_scalar, cycleLength); + vred_res = __riscv_vfredmax_vs_f32m8_f32m1( + vmax, min_scalar, cycleLength); float red_res; - vse32_v_f32m1(&red_res, vred_res, 1); + __riscv_vse32_v_f32m1(&red_res, vred_res, 1); dst[dst_offset] = red_res; } } diff --git a/src/cpu/rv64/rvv_nchw_pooling.hpp b/src/cpu/rv64/rvv_nchw_pooling.hpp index 4fc0d134b47..86df99c6ec0 100644 --- a/src/cpu/rv64/rvv_nchw_pooling.hpp +++ b/src/cpu/rv64/rvv_nchw_pooling.hpp @@ -1,5 +1,5 @@ /****************************************************************************** -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * Copyright 2023 KNS Group LLC (YADRO) * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -66,7 +66,7 @@ struct riscv_nchw_pooling_fwd_t : public primitive_t { riscv_nchw_pooling_fwd_t(const pd_t *apd); - using data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); diff --git a/src/cpu/scale_utils.cpp b/src/cpu/scale_utils.cpp index c6d92a33e2f..ad4e502c473 100644 --- a/src/cpu/scale_utils.cpp +++ b/src/cpu/scale_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2024 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,21 +32,18 @@ constexpr size_t scales_simd_w = 16; } void book_precomputed_scales(memory_tracking::registrar_t &scratchpad, - const arg_scales_t &attr_scales, size_t wei_scale_count, + const scales_t &attr_scales, size_t wei_scale_count, bool force_scales_book) { using namespace dnnl::impl::memory_tracking::names; - const bool with_src_scales - = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); + const bool with_src_scales = !attr_scales.has_default_values(DNNL_ARG_SRC); const bool with_wei_scales - = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); - const auto wei_scales_dt = attr_scales.get(DNNL_ARG_WEIGHTS).data_type_; - const auto wei_scale_groups_ndims - = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; + = !attr_scales.has_default_values(DNNL_ARG_WEIGHTS); + if ((with_src_scales && with_wei_scales) || force_scales_book - || (wei_scales_dt != data_type::f32 && with_wei_scales) - || (wei_scale_groups_ndims > 0 && with_wei_scales)) { - const int wei_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; + || !attr_scales.has_default_data_type(DNNL_ARG_WEIGHTS) + || !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_groups()) { + const int wei_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); const size_t precomputed_scales_size = wei_mask == 0 ? scales_simd_w : nstl::max( @@ -60,27 +57,26 @@ void book_precomputed_scales(memory_tracking::registrar_t &scratchpad, bool req_copy_scales( const primitive_attr_t *attr, const float scale_adjust_factor) { const auto &attr_scales = attr->scales_; - const bool with_src_scales - = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); + const bool with_src_scales = !attr_scales.has_default_values(DNNL_ARG_SRC); const bool with_wei_scales - = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); - const auto wei_scales_dt = attr_scales.get(DNNL_ARG_WEIGHTS).data_type_; - const auto wei_scale_groups_ndims - = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; + = !attr_scales.has_default_values(DNNL_ARG_WEIGHTS); return (with_src_scales && with_wei_scales) || scale_adjust_factor != 1.0f - || (wei_scales_dt != data_type::f32 && with_wei_scales) - || (wei_scale_groups_ndims > 0 && with_wei_scales); + || !attr_scales.has_default_data_type(DNNL_ARG_WEIGHTS) + || !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_groups(); } const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, const float *src_scales, const float *wei_scales, dim_t oc, const primitive_attr_t *attr, float scale_adjust_factor) { - // Note: per-ic-channel is no supported in default - const int wei_scale_mask = attr->scales_.get(DNNL_ARG_WEIGHTS).mask_; + // Note: per-ic-channel is no supported by default. + const int wei_scale_mask = attr->scales_.get_mask(DNNL_ARG_WEIGHTS); return precompute_scales(scratchpad, src_scales, wei_scales, 1, oc, false, - wei_scale_mask != 0, attr, scale_adjust_factor, false); + wei_scale_mask > 0, attr, scale_adjust_factor, false); } +// Note: `wei_scale_per_ic` and `wei_scale_per_oc` could be identified in this +// function unless different primitives have same definition of `per_ic` and +// `per_oc` masks. Mostly, matmul is different from anybody else. const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, const float *src_scales, const float *wei_scales, dim_t IC, dim_t OC, const bool wei_scale_per_ic, const bool wei_scale_per_oc, @@ -89,18 +85,16 @@ const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, using namespace dnnl::impl::memory_tracking::names; const auto &attr_scales = attr->scales_; - const bool with_src_scales - = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); + const bool with_src_scales = !attr_scales.has_default_values(DNNL_ARG_SRC); const auto wei_scale_count = (wei_scale_per_ic ? IC : 1) * (wei_scale_per_oc ? OC : 1); const float *scales = nullptr; if (req_copy_scales(attr, scale_adjust_factor)) { - const int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; size_t size = 0; auto loc_scales = scratchpad.template get(key_precomputed_scales, &size); - if (wei_scale_mask == 0 || wei_scale_count == 1) { + if (wei_scale_count == 1) { const size_t count = nstl::min(size / sizeof(float), scales_simd_w); utils::array_set(loc_scales, src_scales[0] * wei_scales[0] * scale_adjust_factor, count); @@ -108,12 +102,9 @@ const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, const dim_t count = nstl::min( static_cast(size / sizeof(float)), wei_scale_count); const auto wei_scale_dt - = attr_scales.get(DNNL_ARG_WEIGHTS).data_type_; - const auto wei_scale_groups_ndims - = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; - const auto wei_scale_groups_ic = wei_scale_groups_ndims > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[0] - : 1; + = attr_scales.get_data_type(DNNL_ARG_WEIGHTS); + const auto wei_scale_groups_ic + = attr_scales.get_group(DNNL_ARG_WEIGHTS, 0); // Note: per-ic-channel scales is only supported for // weights decompression for now if ((wei_scale_per_ic && wei_scale_groups_ic > 1) diff --git a/src/cpu/scale_utils.hpp b/src/cpu/scale_utils.hpp index 7c1ce535889..48164b776d4 100644 --- a/src/cpu/scale_utils.hpp +++ b/src/cpu/scale_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2024 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ namespace impl { namespace cpu { void book_precomputed_scales(memory_tracking::registrar_t &scratchpad, - const arg_scales_t &attr_scales, size_t wei_scales_count, + const scales_t &attr_scales, size_t wei_scales_count, bool force_scales_book = false); bool req_copy_scales( diff --git a/src/cpu/simple_concat.cpp b/src/cpu/simple_concat.cpp index 234f6bf6d7f..f4a49d264f7 100644 --- a/src/cpu/simple_concat.cpp +++ b/src/cpu/simple_concat.cpp @@ -74,6 +74,16 @@ status_t simple_concat_t::execute(const exec_ctx_t &ctx) const { // Applies when concat axis is the outermost dimension, e.g. concat_axis = 0 // or concat_axis = 1, and dims[0] = 1; if (!has_outer_loop) { + // @todo CPU_PLUGIN: + // the following implementation was used to fix some performace issues + // Now after original oneDNN re-designed this piece it seems to be not applicable + // anymore + // for (int a = 0; a < num_arrs; ++a) { + // const data_t *i = &iptrs[a][0]; + // data_t *o = &optrs[a][0]; + // parallel_nd_legacy(nelems_to_copy[a], [&](dim_t e) { o[e] = i[e]; }); + // } + int nthr = dnnl_get_max_threads(); parallel(nthr, [&](int ithr, int nthr) { for (int a = 0; a < num_arrs; ++a) { @@ -104,7 +114,7 @@ status_t simple_concat_t::execute(const exec_ctx_t &ctx) const { const auto L1_size = platform::get_per_core_cache_size(1); UNUSED(L1_size); // for Windows - parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], + parallel_nd_legacy(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], phys_dims[4], num_arrs, [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, dim_t a) { // check if zero memory diff --git a/src/cpu/simple_concat.hpp b/src/cpu/simple_concat.hpp index ff0c5e22deb..ece8014e483 100644 --- a/src/cpu/simple_concat.hpp +++ b/src/cpu/simple_concat.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -168,7 +168,7 @@ struct simple_concat_t : public primitive_t { status_t execute(const exec_ctx_t &ctx) const override; - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } diff --git a/src/cpu/simple_layer_normalization.cpp b/src/cpu/simple_layer_normalization.cpp index e80f8cbbf48..493115fab57 100644 --- a/src/cpu/simple_layer_normalization.cpp +++ b/src/cpu/simple_layer_normalization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,8 +52,8 @@ status_t simple_layer_normalization_fwd_t::pd_t::init(engine_t *engine) { VDISPATCH_LNORM(stat_md()->data_type == f32, VERBOSE_UNSUPPORTED_DT); VDISPATCH_LNORM(check_scale_shift_data_type(), VERBOSE_UNSUPPORTED_FEATURE, "unsupported scale or shift data type"); - VDISPATCH_LNORM(attr()->has_default_values(skip_mask_t::scales_runtime - | skip_mask_t::post_ops), + VDISPATCH_LNORM(attr()->has_default_values( + skip_mask_t::scales | skip_mask_t::post_ops), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_LNORM(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_LNORM(post_ops_ok(), VERBOSE_UNSUPPORTED_POSTOP); diff --git a/src/cpu/simple_layer_normalization.hpp b/src/cpu/simple_layer_normalization.hpp index 95b01ba1788..aacf6b127e6 100644 --- a/src/cpu/simple_layer_normalization.hpp +++ b/src/cpu/simple_layer_normalization.hpp @@ -105,24 +105,29 @@ struct simple_layer_normalization_fwd_t : public primitive_t { auto scratchpad = ctx.get_scratchpad_grantor(); auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean); auto variance_mem = scratchpad.get_memory_storage(key_lnorm_tmp_var); - memory_t mean(engine, &(pd()->reordered_stat_md_), std::move(mean_mem)); - memory_t variance( - engine, &(pd()->reordered_stat_md_), std::move(variance_mem)); + std::unique_ptr mean; + CHECK(safe_ptr_assign(mean, + new memory_t(engine, &(pd()->reordered_stat_md_), + std::move(mean_mem)))); + std::unique_ptr variance; + CHECK(safe_ptr_assign(variance, + new memory_t(engine, &(pd()->reordered_stat_md_), + std::move(variance_mem)))); // reorder input stats if (pd()->stats_are_src() && reorder_) { - reorder_stat( - ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false}); + reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_MEAN), + {mean.get(), false}); reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE), - {&variance, false}); + {variance.get(), false}); } status_t status = execute_forward(ctx); if (status != status::success) return status; // reorder output stats if (!pd()->stats_are_src() && reorder_) { - reorder_stat( - ctx, engine, {&mean, true}, ctx.args().at(DNNL_ARG_MEAN)); - reorder_stat(ctx, engine, {&variance, true}, + reorder_stat(ctx, engine, {mean.get(), true}, + ctx.args().at(DNNL_ARG_MEAN)); + reorder_stat(ctx, engine, {variance.get(), true}, ctx.args().at(DNNL_ARG_VARIANCE)); } @@ -208,14 +213,18 @@ struct simple_layer_normalization_bwd_t : public primitive_t { auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean); auto variance_mem = scratchpad.get_memory_storage(key_lnorm_tmp_var); - memory_t mean( - engine, &(pd()->reordered_stat_md_), std::move(mean_mem)); - memory_t variance(engine, &(pd()->reordered_stat_md_), - std::move(variance_mem)); - reorder_stat( - ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false}); + std::unique_ptr mean; + CHECK(safe_ptr_assign(mean, + new memory_t(engine, &(pd()->reordered_stat_md_), + std::move(mean_mem)))); + std::unique_ptr variance; + CHECK(safe_ptr_assign(variance, + new memory_t(engine, &(pd()->reordered_stat_md_), + std::move(variance_mem)))); + reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_MEAN), + {mean.get(), false}); reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE), - {&variance, false}); + {variance.get(), false}); } return execute_backward(ctx); diff --git a/src/cpu/simple_q10n.hpp b/src/cpu/simple_q10n.hpp index 9b31cb120c4..10f2ca62a06 100644 --- a/src/cpu/simple_q10n.hpp +++ b/src/cpu/simple_q10n.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ saturate(const acc_t &x) { acc_t v = x; acc_t lbound = (acc_t)nstl::numeric_limits::lowest(); // Pick up a modified version of max value when do f32 -> s32. - acc_t ubound = types::max_value(data_traits::data_type); + acc_t ubound = types::max_value(data_traits_t::data_type); if (v < lbound) v = lbound; if (v > ubound) v = ubound; return v; @@ -82,33 +82,33 @@ inline out_t saturate_and_round(acc_t f) { /* Quantization with alpha == 1 and beta == 0 */ template -struct qz_a1b0 { +struct qz_a1b0_t { out_t operator()(in_t in) { return saturate_and_round((float)in); } }; template -struct qz_a1b0::value && !is_subset::value>::type> { out_t operator()(in_t in) { return saturate(in); } }; template -struct qz_a1b0::value>::type> { out_t operator()(in_t in) { return (out_t)in; } }; /* Quantization with alpha == 1 */ template -struct qz_a1 { +struct qz_a1_t { out_t operator()(in_t in, out_t out, float beta) { return saturate_and_round((float)in + beta * out); } }; template -struct qz_a1 { +struct qz_a1_t { float operator()(in_t in, float out, float beta) { return (float)in + beta * out; } @@ -116,55 +116,55 @@ struct qz_a1 { /* Quantization with beta == 0 */ template -struct qz_b0 { +struct qz_b0_t { out_t operator()(in_t in, float alpha) { return saturate_and_round(alpha * in); } }; template -struct qz_b0 { +struct qz_b0_t { float operator()(in_t in, float alpha) { return alpha * in; } }; /* Quantization */ template -struct qz { +struct qz_t { out_t operator()(in_t in, out_t out, float alpha, float beta) { return saturate_and_round(alpha * in + (beta ? beta * out : 0)); } }; template -struct qz { +struct qz_t { float operator()(in_t in, float out, float alpha, float beta) { return alpha * in + (beta ? beta * out : 0); } }; template <> -struct qz { +struct qz_t { float operator()(bfloat16_t in, bfloat16_t out, float alpha, float beta) { return (bfloat16_t)(alpha * (float)in + (beta ? beta * (float)out : 0)); } }; template <> -struct qz { +struct qz_t { float operator()(float in, bfloat16_t out, float alpha, float beta) { return (bfloat16_t)(alpha * in + (beta ? beta * out : 0)); } }; template <> -struct qz { +struct qz_t { float operator()(float16_t in, float16_t out, float alpha, float beta) { return (float16_t)(alpha * (float)in + (beta ? beta * (float)out : 0)); } }; template <> -struct qz { +struct qz_t { float operator()(float in, float16_t out, float alpha, float beta) { return (float16_t)(alpha * in + (beta ? beta * out : 0)); } diff --git a/src/cpu/simple_resampling.cpp b/src/cpu/simple_resampling.cpp index 0babdbe8265..7838c01ef9b 100644 --- a/src/cpu/simple_resampling.cpp +++ b/src/cpu/simple_resampling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,8 +40,8 @@ template struct simple_resampling_kernel_t : public simple_resampling_base_t { simple_resampling_kernel_t(const resampling_pd_t *pd); - using src_data_t = typename prec_traits::type; - using dst_data_t = typename prec_traits::type; + using src_data_t = typename prec_traits_t::type; + using dst_data_t = typename prec_traits_t::type; status_t init() override; status_t execute(const exec_ctx_t &ctx) const override; @@ -179,25 +179,19 @@ void simple_resampling_kernel_t::fill_coeffs() { if (pd_->is_fwd()) { linear_coeffs_.reserve(pd_->OD() + pd_->OH() + pd_->OW()); for (dim_t od = 0; od < pd_->OD(); od++) - linear_coeffs_.emplace_back( - linear_coeffs_t(od, pd_->OD(), pd_->ID())); + linear_coeffs_.emplace_back(od, pd_->OD(), pd_->ID()); for (dim_t oh = 0; oh < pd_->OH(); oh++) - linear_coeffs_.emplace_back( - linear_coeffs_t(oh, pd_->OH(), pd_->IH())); + linear_coeffs_.emplace_back(oh, pd_->OH(), pd_->IH()); for (dim_t ow = 0; ow < pd_->OW(); ow++) - linear_coeffs_.emplace_back( - linear_coeffs_t(ow, pd_->OW(), pd_->IW())); + linear_coeffs_.emplace_back(ow, pd_->OW(), pd_->IW()); } else { bwd_linear_coeffs_.reserve(pd_->ID() + pd_->IH() + pd_->IW()); for (dim_t id = 0; id < pd_->ID(); id++) - bwd_linear_coeffs_.emplace_back( - bwd_linear_coeffs_t(id, pd_->OD(), pd_->ID())); + bwd_linear_coeffs_.emplace_back(id, pd_->OD(), pd_->ID()); for (dim_t ih = 0; ih < pd_->IH(); ih++) - bwd_linear_coeffs_.emplace_back( - bwd_linear_coeffs_t(ih, pd_->OH(), pd_->IH())); + bwd_linear_coeffs_.emplace_back(ih, pd_->OH(), pd_->IH()); for (dim_t iw = 0; iw < pd_->IW(); iw++) - bwd_linear_coeffs_.emplace_back( - bwd_linear_coeffs_t(iw, pd_->OW(), pd_->IW())); + bwd_linear_coeffs_.emplace_back(iw, pd_->OW(), pd_->IW()); } } diff --git a/src/cpu/simple_resampling.hpp b/src/cpu/simple_resampling.hpp index f632baa27a4..a9ccef95af2 100644 --- a/src/cpu/simple_resampling.hpp +++ b/src/cpu/simple_resampling.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -103,7 +103,8 @@ struct simple_resampling_fwd_t : public primitive_t { simple_resampling_fwd_t(const pd_t *apd); status_t init(engine_t *engine) override; - ~simple_resampling_fwd_t() = default; + + ~simple_resampling_fwd_t() override = default; status_t execute(const exec_ctx_t &ctx) const override; @@ -149,7 +150,8 @@ struct simple_resampling_bwd_t : public primitive_t { simple_resampling_bwd_t(const pd_t *apd); status_t init(engine_t *engine) override; - ~simple_resampling_bwd_t() = default; + + ~simple_resampling_bwd_t() override = default; status_t execute(const exec_ctx_t &ctx) const override; diff --git a/src/cpu/simple_sum.hpp b/src/cpu/simple_sum.hpp index e8b72a21910..db2d4b1ace4 100644 --- a/src/cpu/simple_sum.hpp +++ b/src/cpu/simple_sum.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -124,9 +124,9 @@ struct simple_sum_t : public primitive_t { status_t execute(const exec_ctx_t &ctx) const override; enum { max_num_arrs = 16 }; - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type dst_data_t; - typedef typename prec_traits::type acc_data_t; + using src_data_t = typename prec_traits_t::type; + using dst_data_t = typename prec_traits_t::type; + using acc_data_t = typename prec_traits_t::type; private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } diff --git a/src/cpu/sycl/engine.hpp b/src/cpu/sycl/engine.hpp index 9a2f8a67b4d..0563ab53dcd 100644 --- a/src/cpu/sycl/engine.hpp +++ b/src/cpu/sycl/engine.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,9 @@ class engine_t : public cpu::cpu_engine_t { status_t create_memory_storage(memory_storage_t **storage, unsigned flags, size_t size, void *handle) override { + assert(runtime_kind() == runtime_kind::sycl); + if (runtime_kind() != runtime_kind::sycl) return status::runtime_error; + return impl()->create_memory_storage( storage, this, flags, size, handle); } @@ -53,15 +56,12 @@ class engine_t : public cpu::cpu_engine_t { return cpu::sycl::stream_t::create_stream(stream, this, stream_impl); } - const ::sycl::device &device() const { return impl()->device(); } - const ::sycl::context &context() const { return impl()->context(); } - - xpu::sycl::backend_t backend() const { return impl()->backend(); } - bool mayiuse_system_memory_allocators() const override { return impl()->mayiuse_system_memory_allocators(); } + DECLARE_COMMON_SYCL_ENGINE_FUNCTIONS(); + protected: const xpu::sycl::engine_impl_t *impl() const { return (const xpu::sycl::engine_impl_t *)impl::engine_t::impl(); diff --git a/src/cpu/sycl/stream_cpu_thunk.cpp b/src/cpu/sycl/stream_cpu_thunk.cpp index c6fc0723758..fe6d5276936 100644 --- a/src/cpu/sycl/stream_cpu_thunk.cpp +++ b/src/cpu/sycl/stream_cpu_thunk.cpp @@ -41,6 +41,9 @@ void dnnl_impl_sycl_cpu_thunk(const thunk_params_t *params) { prim_iface->execute(submit_ctx->exec_ctx); + for (auto &m : submit_ctx->exec_ctx.args()) + m.second.mem->release(); + const_cast(prim_iface)->release(); delete submit_ctx; diff --git a/src/cpu/sycl/stream_submit_cpu_primitive.cpp b/src/cpu/sycl/stream_submit_cpu_primitive.cpp index 4fb94689fce..f2df6ca2d92 100644 --- a/src/cpu/sycl/stream_submit_cpu_primitive.cpp +++ b/src/cpu/sycl/stream_submit_cpu_primitive.cpp @@ -109,6 +109,7 @@ void submit_cpu_primitive(stream_t *stream, const primitive_iface_t *prim_iface, std::vector sycl_mem_storages; for (auto &a : exec_ctx.args()) { + a.second.mem->retain(); if (a.second.mem->engine()->runtime_kind() == runtime_kind::sycl) { auto *mem_storage = a.second.mem->memory_storage(); if (!mem_storage->is_null()) { diff --git a/src/cpu/ukernel/attr_params.cpp b/src/cpu/ukernel/attr_params.cpp new file mode 100644 index 00000000000..f56da41e789 --- /dev/null +++ b/src/cpu/ukernel/attr_params.cpp @@ -0,0 +1,83 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dnnl/dnnl_ukernel.h" + +#include "cpu/ukernel/c_types_map.hpp" + +#if DNNL_X64 +#include "cpu/x64/ukernel/attr_params.hpp" +#endif + +#ifdef DNNL_EXPERIMENTAL_UKERNEL + +using namespace dnnl::impl; +using namespace dnnl::impl::cpu; +using namespace dnnl::impl::cpu::ukernel; + +status_t dnnl_ukernel_attr_params_create(attr_params_t **attr_params) { +#if DNNL_X64 + return x64::ukernel::dnnl_ukernel_attr_params_create(attr_params); +#endif + return status::unimplemented; +} + +status_t dnnl_ukernel_attr_params_set_post_ops_args( + attr_params_t *attr_params, const void **post_ops_args) { +#if DNNL_X64 + return x64::ukernel::dnnl_ukernel_attr_params_set_post_ops_args( + attr_params, post_ops_args); +#endif + return status::unimplemented; +} + +status_t dnnl_ukernel_attr_params_set_A_scales( + attr_params_t *attr_params, const void *a_scales) { +#if DNNL_X64 + return x64::ukernel::dnnl_ukernel_attr_params_set_A_scales( + attr_params, a_scales); +#endif + return status::unimplemented; +} + +status_t dnnl_ukernel_attr_params_set_B_scales( + attr_params_t *attr_params, const void *b_scales) { +#if DNNL_X64 + return x64::ukernel::dnnl_ukernel_attr_params_set_B_scales( + attr_params, b_scales); +#endif + return status::unimplemented; +} + +status_t dnnl_ukernel_attr_params_set_D_scales( + attr_params_t *attr_params, const void *d_scales) { +#if DNNL_X64 + return x64::ukernel::dnnl_ukernel_attr_params_set_D_scales( + attr_params, d_scales); +#endif + return status::unimplemented; +} + +status_t dnnl_ukernel_attr_params_destroy(attr_params_t *attr_params) { +#if DNNL_X64 + return x64::ukernel::dnnl_ukernel_attr_params_destroy(attr_params); +#endif + return status::unimplemented; +} + +#endif + +//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/ukernel/brgemm.cpp b/src/cpu/ukernel/brgemm.cpp new file mode 100644 index 00000000000..bb3c27de1f9 --- /dev/null +++ b/src/cpu/ukernel/brgemm.cpp @@ -0,0 +1,157 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dnnl/dnnl_ukernel.h" + +#include "cpu/ukernel/c_types_map.hpp" + +#if DNNL_X64 +#include "cpu/x64/ukernel/brgemm.hpp" +#endif + +#ifdef DNNL_EXPERIMENTAL_UKERNEL + +using namespace dnnl::impl; +using namespace dnnl::impl::cpu; +using namespace dnnl::impl::cpu::ukernel; + +status_t dnnl_brgemm_create(brgemm_t **brgemm, dim_t M, dim_t N, dim_t K, + dim_t batch_size, dim_t lda, dim_t ldb, dim_t ldc, data_type_t a_dt, + data_type_t b_dt, data_type_t c_dt) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_create( + brgemm, M, N, K, batch_size, lda, ldb, ldc, a_dt, b_dt, c_dt); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_set_add_C(brgemm_t *brgemm, int add_C) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_set_add_C(brgemm, add_C); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_set_post_ops(brgemm_t *brgemm, dim_t ldd, data_type_t d_dt, + const post_ops_t *post_ops) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_set_post_ops(brgemm, ldd, d_dt, post_ops); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_set_A_scales(brgemm_t *brgemm, int a_scale_mask) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_set_A_scales(brgemm, a_scale_mask); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_set_B_scales(brgemm_t *brgemm, int b_scale_mask) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_set_B_scales(brgemm, b_scale_mask); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_set_D_scales(brgemm_t *brgemm, int d_scale_mask) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_set_D_scales(brgemm, d_scale_mask); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_finalize(brgemm_t *brgemm) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_finalize(brgemm); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_get_B_pack_type( + pack_type_t *pack_type, data_type_t dt_a, data_type_t dt_b) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_get_B_pack_type(pack_type, dt_a, dt_b); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_get_scratchpad_size(const brgemm_t *brgemm, size_t *size) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_get_scratchpad_size(brgemm, size); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_is_execute_postops_valid( + const brgemm_t *brgemm, int *valid) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_is_execute_postops_valid(brgemm, valid); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_set_hw_context(const brgemm_t *brgemm) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_set_hw_context(brgemm); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_release_hw_context() { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_release_hw_context(); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_generate(brgemm_t *brgemm) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_generate(brgemm); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_execute(const brgemm_t *brgemm, const void *A_ptr, + const void *B_ptr, const dim_t *A_B_offsets, void *C_ptr, + void *scratchpad_ptr) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_execute( + brgemm, A_ptr, B_ptr, A_B_offsets, C_ptr, scratchpad_ptr); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_execute_postops(const brgemm_t *brgemm, const void *A_ptr, + const void *B_ptr, const dim_t *A_B_offsets, const void *C_ptr, + void *D_ptr, void *scratchpad_ptr, const attr_params_t *attr_params) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_execute_postops(brgemm, A_ptr, B_ptr, + A_B_offsets, C_ptr, D_ptr, scratchpad_ptr, attr_params); +#endif + return status::unimplemented; +} + +status_t dnnl_brgemm_destroy(brgemm_t *brgemm) { +#if DNNL_X64 + return x64::ukernel::dnnl_brgemm_destroy(brgemm); +#endif + return status::unimplemented; +} + +#endif + +//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/ukernel/c_types_map.hpp b/src/cpu/ukernel/c_types_map.hpp new file mode 100644 index 00000000000..f4835779137 --- /dev/null +++ b/src/cpu/ukernel/c_types_map.hpp @@ -0,0 +1,53 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_UKERNEL_C_TYPES_MAP_HPP +#define CPU_UKERNEL_C_TYPES_MAP_HPP + +#include "oneapi/dnnl/dnnl_ukernel_types.h" + +#include "common/c_types_map.hpp" + +#ifdef DNNL_EXPERIMENTAL_UKERNEL + +// A section identical to c_map_types.hpp but just for brgemm ukernel so far. +namespace dnnl { +namespace impl { +namespace cpu { +namespace ukernel { + +using pack_type_t = dnnl_pack_type_t; +namespace pack_type { +const pack_type_t undef = dnnl_pack_type_undef; +const pack_type_t no_trans = dnnl_pack_type_no_trans; +const pack_type_t trans = dnnl_pack_type_trans; +const pack_type_t pack32 = dnnl_pack_type_pack32; +} // namespace pack_type + +using attr_params_t = dnnl_ukernel_attr_params; +using brgemm_t = dnnl_brgemm; +using transform_t = dnnl_transform; + +} // namespace ukernel +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif + +#endif + +//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/ukernel/transform.cpp b/src/cpu/ukernel/transform.cpp new file mode 100644 index 00000000000..d76fb5ece5a --- /dev/null +++ b/src/cpu/ukernel/transform.cpp @@ -0,0 +1,65 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dnnl/dnnl_ukernel.h" + +#include "cpu/ukernel/c_types_map.hpp" + +#if DNNL_X64 +#include "cpu/x64/ukernel/transform.hpp" +#endif + +#ifdef DNNL_EXPERIMENTAL_UKERNEL + +using namespace dnnl::impl; +using namespace dnnl::impl::cpu; +using namespace dnnl::impl::cpu::ukernel; + +status_t dnnl_transform_create(transform_t **transform, dim_t K, dim_t N, + pack_type_t in_pack_type, dim_t in_ld, dim_t out_ld, data_type_t in_dt, + data_type_t out_dt) { +#if DNNL_X64 + return x64::ukernel::dnnl_transform_create( + transform, K, N, in_pack_type, in_ld, out_ld, in_dt, out_dt); +#endif + return status::unimplemented; +} + +status_t dnnl_transform_generate(transform_t *transform) { +#if DNNL_X64 + return x64::ukernel::dnnl_transform_generate(transform); +#endif + return status::unimplemented; +} + +status_t dnnl_transform_execute( + const transform_t *transform, const void *in_ptr, void *out_ptr) { +#if DNNL_X64 + return x64::ukernel::dnnl_transform_execute(transform, in_ptr, out_ptr); +#endif + return status::unimplemented; +} + +status_t dnnl_transform_destroy(transform_t *transform) { +#if DNNL_X64 + return x64::ukernel::dnnl_transform_destroy(transform); +#endif + return status::unimplemented; +} + +#endif + +//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/x64/CMakeLists.txt b/src/cpu/x64/CMakeLists.txt index a03c573ecea..9f232c929f7 100644 --- a/src/cpu/x64/CMakeLists.txt +++ b/src/cpu/x64/CMakeLists.txt @@ -93,3 +93,4 @@ set(OBJ_LIB ${LIB_PACKAGE_NAME}_cpu_x64) add_library(${OBJ_LIB} OBJECT ${SOURCES}) set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS $) +enable_conditional_compilation4(${OBJ_LIB}) diff --git a/src/cpu/x64/amx_tile_configure.cpp b/src/cpu/x64/amx_tile_configure.cpp index 64bb3d80deb..9464c604617 100644 --- a/src/cpu/x64/amx_tile_configure.cpp +++ b/src/cpu/x64/amx_tile_configure.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,12 +22,12 @@ namespace impl { namespace cpu { namespace x64 { -struct jit_amx_tilecfg_t : public jit_generator { +struct jit_amx_tilecfg_t : public jit_generator_t { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_amx_tilecfg_t) // TODO: Need to check status jit_amx_tilecfg_t(bool lazy = false) - : jit_generator(jit_name(), avx512_core_amx), is_lazy_(lazy) { + : jit_generator_t(jit_name(), avx512_core_amx), is_lazy_(lazy) { create_kernel(); } @@ -54,10 +54,11 @@ struct jit_amx_tilecfg_t : public jit_generator { sttilecfg(ptr[abi_param2]); // Move tilecfg into Zmm for further comparison. vmovdqu64(Xbyak::Zmm(0), ptr[abi_param2]); - // Sets `1` per word if values are equal. + // Sets `1` per word, 32 words total for Zmms, if values are equal. vpcmpeqw(Xbyak::Opmask(0), Xbyak::Zmm(0), ptr[abi_param1]); - // `kortestw` will set CF=1 if all `1` in the mask. - kortestw(Xbyak::Opmask(0), Xbyak::Opmask(0)); + // `kortestd` will set CF=1 if all `1` in the mask. Double word + // takes 32 bits to compare. + kortestd(Xbyak::Opmask(0), Xbyak::Opmask(0)); // Checks if CF=1. If it is, everything matched, skipping config... jc(skip_tilecfg, T_NEAR); // ... otherwise, configure tile with user palette. @@ -71,11 +72,11 @@ struct jit_amx_tilecfg_t : public jit_generator { } }; -struct jit_amx_tilerelease_t : public jit_generator { +struct jit_amx_tilerelease_t : public jit_generator_t { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_amx_tilerelease_t) // TODO: Need to check status - jit_amx_tilerelease_t() : jit_generator(jit_name(), avx512_core_amx) { + jit_amx_tilerelease_t() : jit_generator_t(jit_name(), avx512_core_amx) { create_kernel(); } diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index 70c383978f2..4afa6f012c9 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,7 +81,9 @@ void brgemm_desc_t::cleanup_dst_md() { void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *scratch, - const brgemm_dynamic_values_t *dynamic_values) { + const brgemm_dynamic_values_t *dynamic_values, + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -101,6 +103,11 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.dynamic_LDC = dynamic_values->dynamic_LDC; brgemm_p.dynamic_LDD = dynamic_values->dynamic_LDD; } + brgemm_p.ptr_wei_scales = ptr_wei_scales; + brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; + brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; + brgemm_p.ic = ic; assert(brg_kernel); @@ -110,7 +117,9 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, const void *addr_A, const void *addr_B, const brgemm_batch_element_t *batch, void *ptr_C, void *scratch, - const brgemm_dynamic_values_t *dynamic_values) { + const brgemm_dynamic_values_t *dynamic_values, + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -124,13 +133,17 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.do_apply_comp = 0; brgemm_p.skip_accm = 0; brgemm_p.BS = bs; + brgemm_p.ptr_wei_scales = ptr_wei_scales; + brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; + brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; + brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; brgemm_p.dynamic_LDB = dynamic_values->dynamic_LDB; brgemm_p.dynamic_LDC = dynamic_values->dynamic_LDC; brgemm_p.dynamic_LDD = dynamic_values->dynamic_LDD; } - assert(brg_kernel); (*brg_kernel)(&brgemm_p); } @@ -138,7 +151,9 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, const brgemm_post_ops_data_t &post_ops_data, void *scratch, - const brgemm_dynamic_values_t *dynamic_values) { + const brgemm_dynamic_values_t *dynamic_values, + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -165,13 +180,17 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations; brgemm_p.c_zp_values = post_ops_data.c_zp_values; brgemm_p.ptr_dst_scales = post_ops_data.dst_scales; + brgemm_p.ptr_wei_scales = ptr_wei_scales; + brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; + brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; + brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; brgemm_p.dynamic_LDB = dynamic_values->dynamic_LDB; brgemm_p.dynamic_LDC = dynamic_values->dynamic_LDC; brgemm_p.dynamic_LDD = dynamic_values->dynamic_LDD; } - assert(brg_kernel); (*brg_kernel)(&brgemm_p); } @@ -180,7 +199,9 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, const void *addr_A, const void *addr_B, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, const brgemm_post_ops_data_t &post_ops_data, void *scratch, - const brgemm_dynamic_values_t *dynamic_values) { + const brgemm_dynamic_values_t *dynamic_values, + const void *ptr_wei_scales, const void *ptr_wei_zero_points, + const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = batch; @@ -205,8 +226,14 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off; brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations; brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations; + brgemm_p.a_zp_values = post_ops_data.a_zp_values; brgemm_p.c_zp_values = post_ops_data.c_zp_values; brgemm_p.ptr_dst_scales = post_ops_data.dst_scales; + brgemm_p.ptr_wei_scales = ptr_wei_scales; + brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points; + brgemm_p.ptr_src_scales = ptr_src_scales; + brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum; + brgemm_p.ic = ic; if (dynamic_values) { brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA; brgemm_p.dynamic_LDB = dynamic_values->dynamic_LDB; @@ -218,11 +245,13 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, (*brg_kernel)(&brgemm_p); } +// from ov dyn_quant status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, impl::data_type_t dt_a, impl::data_type_t dt_b, bool transA, bool transB, brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB, - dim_t LDC, dim_t M, dim_t N, dim_t K, const brgemm_strides_t *strides) { + dim_t LDC, dim_t M, dim_t N, dim_t K, const brgemm_strides_t *strides, + bool is_weights_decompression, bool is_src_dynamic_quantization, const memory_desc_t *wei_md, const primitive_attr_t *attr) { /* m - number of rows of the matrix op(A) and number of rows of the matrix C n - number of columns of the matrix op(B) and number of columns of the matrix C @@ -230,37 +259,95 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa, Matrices are in row-major layouts: A: lda * m, LDA - lda must be at least max(1, k) - B: ldb * k, LDB - ldb must be at least max(1, n) - C: ldc * m, LDC - ldc must be at least max(1, n) + B: ldb * k, LDB - ldb must be at least max(1, n) + C: ldc * m, LDC - ldc must be at least max(1, n) - Matrices are in column-major layouts: + Matrices are in column-major layouts: A: lda * k, LDA - lda must be at least max(1, m) - B: ldb * n, LDB - ldb must be at least max(1, k) - C: ldc * n, LDC - ldc must be at least max(1, m) - */ - if (brg == nullptr) return status::invalid_arguments; + B: ldb * n, LDB - ldb must be at least max(1, k) + C: ldc * n, LDC - ldc must be at least max(1, m) + */ + if (brg == nullptr) return status::invalid_arguments; if (transA || transB) return status::unimplemented; - if (type == brgemm_batch_kind_t::brgemm_batch_kind_undef) - return status::invalid_arguments; + + brg->with_wei_decomp = is_weights_decompression; + brg->with_src_dyn_quant = is_src_dynamic_quantization; brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, alpha, beta, LDA, LDB, LDC, M, N, K, strides); - if (utils::one_of(true, brg->is_runtime_lda, brg->is_runtime_ldb)) - return status::unimplemented; - if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments; + // Upper bound, this can likely be improved by accounting for blocking + dim_t max_a_stride = brg->LDA * types::data_type_size(brg->dt_a) + * (brg->layout == brgemm_col_major ? K : M); + dim_t max_b_stride = brg->LDB * types::data_type_size(brg->dt_b) + * (brg->layout == brgemm_col_major ? N : K); + dim_t max_c_stride = brg->LDC * types::data_type_size(brg->dt_c) + * (brg->layout == brgemm_col_major ? N : M); + + // Required for EVEX encoding for offsets + const dim_t max_stride = std::numeric_limits::max(); + if ((max_a_stride > max_stride && !brg->is_runtime_lda) + || (max_b_stride > max_stride && !brg->is_runtime_ldb) + || (max_c_stride >= max_stride && !brg->is_runtime_ldc)) + return status::unimplemented; + if (utils::everyone_is(false, brg->is_int8, brg->is_bf16, brg->is_f32, - brg->is_f16, brg->is_fp8)) + brg->is_f16/*, brg->is_fp8*/)) return status::unimplemented; - // Only amx_int8 kernel supports u8 weights. + // Only avx512_core_amx kernel supports u8 weights. if (!IMPLICATION( - brg->dt_b == u8, is_superset(brg->isa_impl, avx512_core_amx))) + brg->dt_b == u8, is_superset(brg->isa_impl, avx512_core_amx)) && !brg->with_wei_decomp) return status::unimplemented; - CHECK(brgemm_blocking(brg)); + const memory_desc_wrapper wei_d(wei_md); + if (brg->with_wei_decomp) { + brg->with_grouped_wei_decomp = false; + + auto wei_scales = attr->scales_.get(DNNL_ARG_WEIGHTS); + brg->with_wei_decomp_scales = !wei_scales.has_default_values(); + brg->wei_decomp_scales_group_size = wei_d.dims()[1]; + if (brg->with_wei_decomp_scales) { + brg->wei_decomp_scales_dt = wei_scales.get_data_type(); + if (!one_of(brg->wei_decomp_scales_dt, f32, e8m0)) + return status::unimplemented; + + auto ld_dim = wei_scales.get_dims()[0]; + brg->wei_decomp_scales_stride = ld_dim > 1 ? ld_dim : 0; + brg->wei_decomp_scales_group_size = wei_d.dims()[1] / wei_scales.get_dims()[1]; + brg->with_grouped_wei_decomp |= wei_scales.get_dims()[1] != 1; + } + + brg->with_wei_decomp_zero_points = !attr->zero_points_.has_default_values(DNNL_ARG_WEIGHTS); + brg->wei_decomp_zero_points_group_size = wei_d.dims()[1]; + if (brg->with_wei_decomp_zero_points) { + brg->wei_decomp_zero_points_dt = attr->zero_points_.get_data_type(DNNL_ARG_WEIGHTS); + if (!one_of(brg->wei_decomp_zero_points_dt, f32, u8)) + return status::unimplemented; + + auto ld_dim = attr->zero_points_.get_dims(DNNL_ARG_WEIGHTS)[0]; + brg->wei_decomp_zero_points_stride = ld_dim > 1 ? ld_dim : 0; + brg->wei_decomp_zero_points_group_size = wei_d.dims()[1] / attr->zero_points_.get_dims(DNNL_ARG_WEIGHTS)[1]; + brg->with_grouped_wei_decomp |= attr->zero_points_.get_dims(DNNL_ARG_WEIGHTS)[1] != 1; + } + } + + brg->src_scales_group_size = wei_d.dims()[1]; + if (brg->with_src_dyn_quant) { + brg->src_scales_group_size = attr->src_dyn_quant_params_.get(); + brg->with_grouped_wei_decomp = true; + brg->src_scales_stride = div_up(wei_d.dims()[1], brg->src_scales_group_size); + } + + CHECK(brgemm_desc_finalize(brg)); + + brg->src_sum_group_size = wei_d.dims()[1]; + if (brg->with_src_dyn_quant) { + brg->src_sum_group_size = brg->rd_block; + brg->src_grouped_sum_stride = div_up(wei_d.dims()[1], brg->src_sum_group_size); + } // avx2_vnni_2 kernel with xf16 data type requires blocked weights. if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16() @@ -290,14 +377,13 @@ status_t brdgmm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa, false, brg->is_int8, brg->is_bf16, brg->is_f32, brg->is_f16)) return status::unimplemented; - CHECK(brdgmm_blocking(brg)); - return status::success; } status_t brgemm_desc_set_postops(brgemm_desc_t *brg, const primitive_attr_t *attr, const memory_desc_t *dst_md, dim_t LDD, - impl::data_type_t dt_bias) { + impl::data_type_t dt_bias, + bool is_weights_decompression) { if (!brg || !dst_md) return status::invalid_arguments; brg->set_attr(attr); @@ -348,13 +434,15 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg, data_type::f16))) return status::unimplemented; const auto bias_f8_e5m2_compatible - = one_of(dt_d, data_type::f32, data_type::f16, data_type::f8_e5m2) + = one_of(dt_d, data_type::f32, data_type::f16, data_type::bf16, + data_type::f8_e5m2) && one_of(dt_bias, data_type::undef, data_type::f32, data_type::f16, - data_type::f8_e5m2, data_type::f8_e4m3); + data_type::bf16, data_type::f8_e5m2, data_type::f8_e4m3); const auto bias_f8_e4m3_compatible - = one_of(dt_d, data_type::f32, data_type::f16, data_type::f8_e4m3) + = one_of(dt_d, data_type::f32, data_type::f16, data_type::bf16, + data_type::f8_e4m3) && one_of(dt_bias, data_type::undef, data_type::f32, data_type::f16, - data_type::f8_e4m3, data_type::f8_e5m2); + data_type::bf16, data_type::f8_e4m3, data_type::f8_e5m2); if (!IMPLICATION(brg->is_fp8, bias_f8_e5m2_compatible || bias_f8_e4m3_compatible)) return status::unimplemented; @@ -371,9 +459,6 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg, brg->is_bf16_emu = !(mayiuse(avx512_core_bf16) || brg->isa_impl == avx2_vnni_2); - // Rerun blocking heuristic due to reduced zmm register count - if (brg->is_bf16_emu && brg->is_dgmm) CHECK(brdgmm_blocking(brg)); - if (!brg->attr()) return status::success; using namespace injector; @@ -400,6 +485,7 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg, false /*sum_requires_zp_zero*/, true /*sum_requires_same_params*/, {broadcasting_strategy_t::per_oc, + broadcasting_strategy_t::per_oc_d, broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_mb, broadcasting_strategy_t::per_mb_spatial, @@ -426,55 +512,60 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg, const auto &wei_scales = attr->scales_.get(DNNL_ARG_WEIGHTS); brg->with_scales = !brg->skip_scales && (!src_scales.has_default_values() - || !wei_scales.has_default_values() + || (!wei_scales.has_default_values() && !is_weights_decompression) || brg->with_weights_scale_adjust); if (brg->with_scales) { // Note. the current version supports only two different output scale // types: - // 1) common (mask_ = 0) + // 1) common (mask = 0) // 2) per_n_dim_scale - broadcast across n dimension; // for convolution and inner product promitives it corresponds - // to "per_oc" mask_ = 1 << 1; for matmul - to - // mask_ = (1 << (ndims - 1))), where ndims is number of + // to "per_oc" mask = 1 << 1; for matmul - to + // mask = (1 << (ndims - 1))), where ndims is number of // dimensions for original matmul problem - // So if wei_scales.mask_ != 0 (not common) it's assumed here that scale - // type is per_n_dim_scale and driver which calls brgemm kernel checked - // that mask has correct value for this case - brg->is_oc_scale = wei_scales.mask_ != 0; + // So if wei_scales.get_mask() > 0 (not common) it's assumed here that + // scale type is per_n_dim_scale and driver which calls brgemm kernel + // checked that mask has correct value for this case + brg->is_oc_scale = wei_scales.get_mask() > 0; } const auto &dst_scales = attr->scales_.get(DNNL_ARG_DST); brg->with_dst_scales = !dst_scales.has_default_values(); - const bool scales_ok = src_scales.mask_ == 0 && dst_scales.mask_ == 0 - && attr->scales_.has_default_values( - {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}); + const bool scales_ok = attr->scales_.has_default_values({DNNL_ARG_SRC, + DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) + && IMPLICATION(!src_scales.has_default_values(), + src_scales.get_mask() == 0) + && IMPLICATION(!dst_scales.has_default_values(), + dst_scales.get_mask() == 0); if (!scales_ok) return status::unimplemented; auto init_zp_type = [&](brgemm_broadcast_t &zp_type, int mem_arg) -> status_t { - auto zero_points = attr->zero_points_; - - // common zero point type is supported for now - if (!zero_points.common(mem_arg)) return status::unimplemented; + const auto &zp = attr->zero_points_; + // Always init a default value; + zp_type = brgemm_broadcast_t::none; const bool skip_zero_point - = mem_arg == DNNL_ARG_WEIGHTS && brg->skip_zp_b_compensation; - zp_type = zero_points.has_default_values(mem_arg) || skip_zero_point - ? brgemm_broadcast_t::none - : brgemm_broadcast_t::per_tensor; + = (mem_arg == DNNL_ARG_WEIGHTS && brg->skip_zp_b_compensation); + if (skip_zero_point) return status::success; + if (!zp.has_default_values(mem_arg)) { + int mask = zp.get_mask(mem_arg); + if (mask == 0) { + zp_type = brgemm_broadcast_t::per_tensor; + } else if (mask == (1 << 1)) { + zp_type = brgemm_broadcast_t::per_n; + } else if (mask == 1 && mem_arg == DNNL_ARG_WEIGHTS ) { + zp_type = brgemm_broadcast_t::none; + } else { + return status::unimplemented; + } + } return status::success; }; - init_zp_type(brg->zp_type_a, DNNL_ARG_SRC); - init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS); - init_zp_type(brg->zp_type_c, DNNL_ARG_DST); - - // Post-ops may use vector registers so brgemm/brdgmm blocking may need to - // be updated - if (brg->is_dgmm) - CHECK(brdgmm_blocking(brg)); - else - CHECK(brgemm_blocking(brg)); + CHECK(init_zp_type(brg->zp_type_a, DNNL_ARG_SRC)); + CHECK(init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS)); + CHECK(init_zp_type(brg->zp_type_c, DNNL_ARG_DST)); return status::success; } @@ -505,42 +596,6 @@ status_t brgemm_desc_set_attr( if (brgattr.fpmath_mode != fpmath_mode::strict) maybe_try_bf32(brg); - const int max_vpad = nstl::max(brgattr.max_top_vpad, - brgattr.max_bottom_vpad); // these should be equal - bool hint_blocking_set - = (brgattr.hint_bd_block != 0 || brgattr.hint_bd_block2 != 0 - || brgattr.hint_ld_block != 0 || brgattr.hint_ld_block2 != 0 - || brgattr.hint_load_nt_A != brgemm_hint_nt_undef - || brgattr.hint_load_nt_B != brgemm_hint_nt_undef - || brgattr.hint_bs_group > 1); - if (brgattr.use_uker || brg->is_bf16_tmm || hint_blocking_set - || brgattr.bd_mask_level - || brgattr.fpmath_mode != fpmath_mode::strict || max_vpad > 0) { - if (brg->is_dgmm) - CHECK(brdgmm_blocking(brg)); - else - CHECK(brgemm_blocking(brg)); - } - - if (!brg->is_dgmm) { - // virtual padding is restricted by bd_block size due to - // brgemm_kernel implementation. TODO: remove this restriction - const int min_bd_block - = brg->bdb_tail > 0 ? brg->bdb_tail : brg->bd_block; - if ((max_vpad > min_bd_block)) return status::unimplemented; - } - - brg->LDA2 = (brgattr.LDA2 != 0) ? brgattr.LDA2 : brg->LDA; - brg->LDB2 = (brgattr.LDB2 != 0) ? brgattr.LDB2 : brg->LDB; - brg->LDC2_M = (brgattr.LDC2_M != 0) ? brgattr.LDC2_M : brg->LDC; - brg->LDC2_N = (brgattr.LDC2_N != 0) ? brgattr.LDC2_N : brg->ld_block; - - brg->is_blocked = (brg->LDA2 != brg->LDA || brg->LDB2 != brg->LDB - || brg->LDC2_M != brg->LDC || brg->LDC2_N != brg->ld_block); - - if (!IMPLICATION(brg->is_blocked, brg->layout == brgemm_row_major)) - return status::invalid_arguments; - // virtual padding is not supported for "amx" if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0) && (brg->is_tmm)) @@ -568,6 +623,28 @@ status_t brgemm_desc_set_attr( return status::success; } +status_t brgemm_desc_finalize(brgemm_desc_t *brg) { + if (brg == nullptr) return status::invalid_arguments; + + const int max_vpad = nstl::max( + brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad); + + if (brg->is_dgmm) + CHECK(brdgmm_blocking(brg)); + else + CHECK(brgemm_blocking(brg)); + + if (!brg->is_dgmm) { + // virtual padding is restricted by bd_block size due to + // brgemm_kernel implementation. TODO: remove this restriction + const int min_bd_block + = brg->bdb_tail > 0 ? brg->bdb_tail : brg->bd_block; + if ((max_vpad > min_bd_block)) return status::unimplemented; + } + + return status::success; +} + status_t brgemm_kernel_create( brgemm_kernel_t **brg_kernel, const brgemm_desc_t &brg) { if (!brg_kernel) return status::invalid_arguments; @@ -617,10 +694,11 @@ status_t brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel) { status_t brgemm_init_tiles(const brgemm_desc_t &brg, char palette[64]) { if (!brg.is_tmm) return status::unimplemented; - //TODO: Add support of tail processing by reduction dimension auto rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block; if (brg.is_input_convert()) rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/); + else + rd_block = utils::rnd_up(rd_block, brg.rd_step); palette_config_t *buff = (palette_config_t *)(palette); @@ -762,11 +840,13 @@ int brgemm_cmp(const brgemm_desc_t &lhs, const brgemm_desc_t &rhs) { CMP_BRGEMM_FIELD(brgattr.hint_prfB.dist2); CMP_BRGEMM_FIELD(brgattr.hint_prfC.dist1); CMP_BRGEMM_FIELD(brgattr.hint_prfC.dist2); - CMP_BRGEMM_FIELD(brgattr.wary_tail_read); + CMP_BRGEMM_FIELD(brgattr.wary_A_k_tail_read); + CMP_BRGEMM_FIELD(brgattr.extendable_k); CMP_BRGEMM_FIELD(brgattr.generate_skip_accumulation); CMP_BRGEMM_FIELD(brgattr.bd_mask_level); CMP_BRGEMM_FIELD(brgattr.use_uker); CMP_BRGEMM_FIELD(brgattr.use_interleave_stores); + CMP_BRGEMM_FIELD(brgattr.b_is_vnni); CMP_BRGEMM_FIELD(brgattr.fpmath_mode); CMP_BRGEMM_FIELD(brgattr.LDA2); CMP_BRGEMM_FIELD(brgattr.LDB2); diff --git a/src/cpu/x64/brgemm/brgemm.hpp b/src/cpu/x64/brgemm/brgemm.hpp index 084013e13c5..1f1b7771606 100644 --- a/src/cpu/x64/brgemm/brgemm.hpp +++ b/src/cpu/x64/brgemm/brgemm.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,12 +62,15 @@ namespace x64 { /// @param strides Strides between the matrices in the batch. Can be nullptr. /// TODO: what does "Can be nullptr" mean? /// + status_t DNNL_API brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, impl::data_type_t dt_a, impl::data_type_t dt_b, bool transA, bool transB, brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K, - const brgemm_strides_t *strides = nullptr); + const brgemm_strides_t *strides = nullptr, + bool is_weights_decompression = false, bool is_src_dynamic_quantization = false, + const memory_desc_t *wei_md = nullptr, const primitive_attr_t *attr = nullptr); /// Initializes a BRGEMM descriptor with B matrix as a diagonal matrix /// represented in packed vector format. @@ -119,7 +122,8 @@ status_t DNNL_API brdgmm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa, /// status_t DNNL_API brgemm_desc_set_postops(brgemm_desc_t *brg, const primitive_attr_t *attr, const memory_desc_t *dst_md, dim_t LDD, - impl::data_type_t dt_bias = impl::data_type::undef); + impl::data_type_t dt_bias = impl::data_type::undef, + bool is_weights_decompression = false); /// Adds BRGEMM attributes to BRGEMM descriptor /// @@ -130,6 +134,15 @@ status_t DNNL_API brgemm_desc_set_postops(brgemm_desc_t *brg, status_t DNNL_API brgemm_desc_set_attr( brgemm_desc_t *brg, const brgemm_attr_t &brgattr); +/// Finalize BRGEMM descriptor. +/// +/// @param brg Output BRGEMM descriptor +/// This function must be called after all the fields of the descriptor are set. +/// It finalizes the descriptor including internal blocking parameters to +/// prepare it for the kernel creation. +/// +status_t DNNL_API brgemm_desc_finalize(brgemm_desc_t *brg); + /// Generates a BRGEMM kernel based on descriptor /// /// @param brg_kernel Output BRGEMM kernel @@ -169,7 +182,9 @@ status_t DNNL_API brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel); void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *scratch = nullptr, - const brgemm_dynamic_values_t *dynamic_values = nullptr); + const brgemm_dynamic_values_t *dynamic_values = nullptr, + const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version) /// @@ -197,7 +212,9 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, const void *addr_A, const void *addr_B, const brgemm_batch_element_t *batch, void *ptr_C, void *scratch = nullptr, - const brgemm_dynamic_values_t *dynamic_values = nullptr); + const brgemm_dynamic_values_t *dynamic_values = nullptr, + const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_addr version) /// @@ -224,7 +241,9 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr, - const brgemm_dynamic_values_t *dynamic_values = nullptr); + const brgemm_dynamic_values_t *dynamic_values = nullptr, + const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version) /// @@ -255,7 +274,9 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, const void *addr_A, const void *addr_B, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr, - const brgemm_dynamic_values_t *dynamic_values = nullptr); + const brgemm_dynamic_values_t *dynamic_values = nullptr, + const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr, + const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0); /// AMX utilities: Creates a palette based on BRGEMM descriptor /// diff --git a/src/cpu/x64/brgemm/brgemm_containers.cpp b/src/cpu/x64/brgemm/brgemm_containers.cpp index b0a3d047d43..4fc7df86066 100644 --- a/src/cpu/x64/brgemm/brgemm_containers.cpp +++ b/src/cpu/x64/brgemm/brgemm_containers.cpp @@ -49,6 +49,10 @@ bool brgemm_desc_container_t::insert(int idx, brgemm_desc_t &brg, brg.brgattr.static_offsets = static_offsets_list_.back().data(); const auto ret = map_.insert({brg, idx}); + const int ref_size = refs_.size(); + if (idx > ref_size - 1) { + refs_.resize(idx + 1); + } refs_[idx] = &(ret.first->first); // if there was no insertion then clean bd_mask and static_offsets if (!ret.second) { diff --git a/src/cpu/x64/brgemm/brgemm_containers.hpp b/src/cpu/x64/brgemm/brgemm_containers.hpp index 5a2eabe4b4b..5f5a7c67177 100644 --- a/src/cpu/x64/brgemm/brgemm_containers.hpp +++ b/src/cpu/x64/brgemm/brgemm_containers.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace brgemm_containers { struct brgemm_desc_container_t { public: - brgemm_desc_container_t() {} + brgemm_desc_container_t() = default; brgemm_desc_container_t(size_t ns) { resize(ns); } void resize(size_t ns) { refs_.resize(ns); } inline const brgemm_desc_t *operator[](int idx) const { return refs_[idx]; } @@ -71,7 +71,7 @@ struct brgemm_desc_container_t { // #define BRGEMM_KERNEL_GLOBAL_STORAGE struct brgemm_kernel_container_t { - brgemm_kernel_container_t() {} + brgemm_kernel_container_t() = default; brgemm_kernel_container_t(size_t ns) { resize(ns); } void resize(size_t ns) { refs_.resize(ns); } inline const brgemm_kernel_t *operator[](int idx) const { @@ -111,9 +111,9 @@ struct brgemm_kernel_container_t { }; struct brgemm_palette_container_t { - typedef std::array S_t; + using S_t = std::array; - brgemm_palette_container_t() {} + brgemm_palette_container_t() = default; brgemm_palette_container_t(size_t ns) { resize(ns); } void resize(size_t ns) { refs_.resize(ns); } diff --git a/src/cpu/x64/brgemm/brgemm_types.hpp b/src/cpu/x64/brgemm/brgemm_types.hpp index 9054081576e..624777de8ac 100644 --- a/src/cpu/x64/brgemm/brgemm_types.hpp +++ b/src/cpu/x64/brgemm/brgemm_types.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include "common/primitive_attr.hpp" #include "cpu/platform.hpp" #include "cpu/x64/cpu_isa_traits.hpp" +#include "cpu/x64/jit_generator.hpp" namespace dnnl { namespace impl { @@ -96,11 +97,7 @@ struct brgemm_prf_t { }; struct brgemm_batch_element_t { - brgemm_batch_element_t() { - ptr.A = ptr.B = nullptr; - vvpad.top = vvpad.bottom = 0; - has_s8s8_comp_batch_pad = 0; - } + brgemm_batch_element_t() { ptr.A = ptr.B = nullptr; } union { struct { const void *A; @@ -112,14 +109,14 @@ struct brgemm_batch_element_t { } offset; }; struct { - dim_t top; - dim_t bottom; + dim_t top = 0; + dim_t bottom = 0; } vvpad; // w.r.t. M dimension // Used to calculate compensation when batch padding is present. // Note: batch_pad represent the overlap between weights and the height // dimension w.r.t. convolution dimensions. - dim_t has_s8s8_comp_batch_pad; + dim_t has_s8s8_comp_batch_pad = 0; }; struct DNNL_API brgemm_attr_t { @@ -138,7 +135,22 @@ struct DNNL_API brgemm_attr_t { = brgemm_kernel_prefetching_t::brgemm_prf_default; brgemm_prf_t hint_prfA, hint_prfB, hint_prfC; - bool wary_tail_read; + // This parameter determines how we will read the tail by K dimension from + // matrix A. For AMX if the parameter is true then the brgemm will first + // copy the data to the intermediate buffer and only then use the tileload. + // For non-AMX the A data are loaded byte by byte if flag is set + bool wary_A_k_tail_read {false}; + // For AMX the K dimension given to the brgemm is required to be divisible + // by vnni granularity. In addition blocking by K dimension may not be + // optimal if K greater than tile size and divisible by it. + // The parameter 'extendable_k' enables the brgemm to use the optimal K + // block size assuming that the following requirements for the matrix B are + // fulfilled: + // - It is ​​properly blocked (64 bytes block by K dimension). + // - The dimension K is padded by zeros. + // For K tail handling in this case the brgemm behavior is determined by the + // 'wary_A_k_tail_read' parameter. + bool extendable_k {false}; bool generate_skip_accumulation; // Value of bd_mask_level specifies how bd_mask is used in brgemm kernel // 0 – bd_mask is not used @@ -152,6 +164,7 @@ struct DNNL_API brgemm_attr_t { // interleave stores or not bool use_interleave_stores; impl::fpmath_mode_t fpmath_mode = fpmath_mode::strict; + bool b_is_vnni {false}; // Second level leading dimension describing distance between 16-line // blocks in case of blocked layout. Used to calculate address of next // bd block. By default are equal to regular leading dimension parameters @@ -189,7 +202,7 @@ struct DNNL_API brgemm_attr_t { }; struct brgemm_desc_t { - brgemm_desc_t() {} + brgemm_desc_t() = default; brgemm_desc_t(const brgemm_desc_t &other); DNNL_API ~brgemm_desc_t(); @@ -235,6 +248,7 @@ struct brgemm_desc_t { bool with_scales = false; bool skip_zp_b_compensation = false; bool skip_scales = false; + bool n_bcast_1_load = false; brgemm_broadcast_t zp_type_a = brgemm_broadcast_t::none; brgemm_broadcast_t zp_type_b = brgemm_broadcast_t::none; @@ -294,6 +308,7 @@ struct brgemm_desc_t { static constexpr int MAX_VPAD = 100; static constexpr int AMX_TILES_NUM = 8; + static constexpr int tilesize = 1024; void set_attr(const primitive_attr_t *ppdattr); void set_dst_md(const memory_desc_t *pdst_md); @@ -307,6 +322,22 @@ struct brgemm_desc_t { bool is_input_convert() const { return is_bf32 || is_fp8_via_convert(); } + bool with_wei_decomp = false; + bool with_grouped_wei_decomp = false; + bool with_wei_decomp_scales = false; + bool with_wei_decomp_zero_points = false; + int wei_decomp_scales_stride = 0; + int wei_decomp_zero_points_stride = 0; + int wei_decomp_scales_group_size = 0; + int wei_decomp_zero_points_group_size = 0; + impl::data_type_t wei_decomp_scales_dt = data_type::undef; + impl::data_type_t wei_decomp_zero_points_dt = data_type::undef; + bool with_src_dyn_quant = false; + int src_scales_group_size = 0; + int src_scales_stride = 0; + int src_sum_group_size = 0; + int src_grouped_sum_stride = 0; + bool is_row_major() const { assert(layout != brgemm_layout_undef); return layout == brgemm_row_major; @@ -364,43 +395,113 @@ struct brgemm_desc_t { return (get_num_C_tiles() + get_num_A_tiles() + N); } + int get_convert_wsp_buffer_size() const noexcept { + if (!is_input_convert()) return 0; + const int n_bdb = bd_block2; + const int n_rdb = rdb + (rdb_tail != 0); + const int n_ldb = ldb + (ldb_tail != 0); + const int downcvt_tiles = brgattr.max_bs * n_rdb * (n_bdb + n_ldb); + return downcvt_tiles * tilesize; + } + int get_wsp_buffer_size() const noexcept { int sz = 0; if (is_tmm) { - constexpr int tilesize = 1024; sz = get_num_C_tiles() * tilesize; // postops buffer - if (is_input_convert()) { - const int n_bdb = bd_block2; - const int n_rdb = rdb + (rdb_tail != 0); - const int n_ldb = ldb + (ldb_tail != 0); - const int downcvt_tiles - = brgattr.max_bs * n_rdb * (n_bdb + n_ldb); - sz += downcvt_tiles * tilesize; - } + sz += get_convert_wsp_buffer_size(); + if (amx_wary_k_tail()) sz += tilesize; } return sz; } + // A class version of the `static` version of the function. + // Note: used in benchdnn only, not used inside the library. + bool is_b_data_layout_vnni() const { + return is_b_data_layout_vnni(dt_a, dt_b, brgattr.b_is_vnni, isa_impl); + } + // This function indicates when VNNI granularity packing is expected by the // kernel. // - // Note: used in benchdnn only, not used inside the library. + // Note: used as the `static` function in ukernel only, not anywhere else. + // `static`-ness is required to identify if the transform routine must be + // used for the ukernel to work properly. This information is critical + // because the transform routine accepts only 4 `ldb` values which affects + // ukernel creation. Otherwise, the user must create the ukernel object, + // query the packing info, and if it's required, likely re-create the + // object with a different `ldb` value, which may not work because + // creation stage for user's application may not provide all the info to + // create a ukernel object. // Note: for `bf32` (or brgattr.fpmath_mode_ == bf16) the function returns // `true` because the data transformation to vnni layout is internal and // transparent to the user. - bool is_b_data_layout_vnni() const { + // Note: the library MUST NOT break the ability to provide this information + // without brgemm_desc_t object creation. + static bool is_b_data_layout_vnni(data_type_t dt_a, data_type_t dt_b, + bool attr_b_is_vnni, cpu_isa_t isa) { using namespace data_type; switch (dt_b) { case f32: return false; // Note: `dt_a == f32` means implicit up-conversion of B to f32. - case f16: return (isa_impl != avx512_core_fp16) && (dt_a != f32); + case f16: + return dt_a != f32 + && (is_f16_b_non_amx_vnni(dt_b, attr_b_is_vnni, isa) + || is_superset(isa, avx512_core_amx_fp16) + || is_superset(isa, avx2_vnni_2)); // Note: `dt_a == f32` means implicit up-conversion of B to f32. case bf16: return dt_a != f32; default: return true; } } + + // This function indicates when the kernel would operate with the D pointer + // (`true`) and when not (`false`). It's important to distinguish these two + // cases due to the fact that kernel would ignore D pointer completely if + // no post-accumulation work is identified. + // + // Correspondent decisions are done in `store_accumulators` function. + // The function is used inside kernel generation and ukernel API. + // TODO: extend usage to primitives (each of them utilize their own copy + // of this definition). + bool are_post_ops_applicable() const { + const bool has_zero_points = !utils::everyone_is( + brgemm_broadcast_t::none, zp_type_a, zp_type_b, zp_type_c); + return dt_c != dt_d || with_eltwise || with_binary || with_scales + || with_bias || with_sum || req_s8s8_compensation + || has_zero_points || with_dst_scales; + } + bool is_xf16() const noexcept { return is_bf16 || is_f16; } + bool is_f16_b_non_amx_vnni() const { + return is_f16_b_non_amx_vnni(dt_b, brgattr.b_is_vnni, isa_impl); + } + + // Note: `static` version appears because of `static is_b_data_layout_vnni`. + static bool is_f16_b_non_amx_vnni( + data_type_t dt_b, bool attr_b_is_vnni, cpu_isa_t isa) { + // This function controls the code section which relies on + // `avx512_core_fp16` instructions directly. + return dt_b == data_type::f16 && attr_b_is_vnni + && isa == avx512_core_fp16; + } + + bool reduce_by_words() const { + return is_bf16_tmm || is_f16_tmm || is_input_convert(); + } + int max_rd_block() const { return reduce_by_words() ? 32 : 64; } + int rd_block_step() const { return (reduce_by_words() && !is_fp8) ? 2 : 4; } + + bool amx_may_extend_k() const { + return (is_superset(isa_impl, avx512_core_amx) && brgattr.extendable_k + && (reduce_dim % data_type_vnni_granularity(dt_a) + || (reduce_dim > max_rd_block() + && reduce_dim % max_rd_block()))); + } + bool amx_wary_k_tail() const { + return amx_may_extend_k() && brgattr.wary_A_k_tail_read; + } + bool operator==(const brgemm_desc_t &rhs) const; bool operator<(const brgemm_desc_t &rhs) const; @@ -430,6 +531,12 @@ struct brgemm_dynamic_values_t { , dynamic_LDD(LDD) {} }; +struct brgemm_decomp_kernel_params_t { + const void *ptr_B; + const void *scratch_buf; + const void *bitmask_ptr; +}; + struct brgemm_kernel_params_t { const void *ptr_A; const void *ptr_B; @@ -464,10 +571,17 @@ struct brgemm_kernel_params_t { const void *a_zp_compensations = nullptr; const void *b_zp_compensations = nullptr; + const void *a_zp_values = nullptr; const void *c_zp_values = nullptr; size_t skip_accm = 0; int32_t zp_a_val = 1; const void *ptr_dst_scales = nullptr; + + const void *ptr_wei_scales = nullptr; + const void *ptr_wei_zero_points = nullptr; + const void *ptr_src_scales = nullptr; + const void *ptr_src_grouped_sum = nullptr; + size_t ic; dim_t dynamic_LDA = 0; dim_t dynamic_LDB = 0; dim_t dynamic_LDC = 0; @@ -479,24 +593,34 @@ struct jit_brgemm_kernel_t; struct jit_brgemm_amx_uker_base_t; template struct jit_brdgmm_kernel_base_t; -class jit_generator; +class jit_generator_t; struct brgemm_kernel_t { - brgemm_kernel_t() {}; - virtual ~brgemm_kernel_t() {}; + brgemm_kernel_t() = default; + virtual ~brgemm_kernel_t() = default; virtual status_t create_kernel() = 0; virtual void operator()(brgemm_kernel_params_t *) const = 0; - virtual const jit_generator *get_jit_generator() const = 0; + virtual const jit_generator_t *get_jit_generator() const = 0; + virtual const brgemm_desc_t &get_brg() const = 0; +}; + +struct jit_base_brgemm_kernel_t : public jit_generator_t { + jit_base_brgemm_kernel_t(const char *impl_name, cpu_isa_t isa_impl) + : jit_generator_t(impl_name, isa_impl) {} + virtual const brgemm_desc_t &get_brg() const = 0; }; template struct brgemm_kernel_common_t : public brgemm_kernel_t { brgemm_kernel_common_t(const brgemm_desc_t &abrd); - ~brgemm_kernel_common_t(); + ~brgemm_kernel_common_t() override; - status_t create_kernel(); - void operator()(brgemm_kernel_params_t *) const; - virtual const jit_generator *get_jit_generator() const; + status_t create_kernel() override; + void operator()(brgemm_kernel_params_t *) const override; + const jit_generator_t *get_jit_generator() const override; + const brgemm_desc_t &get_brg() const override { + return ((jit_base_brgemm_kernel_t *)brgemm_kernel_)->get_brg(); + } private: jit_brgemm_kernel_t *brgemm_kernel_ = nullptr; @@ -506,11 +630,14 @@ struct brgemm_kernel_common_t : public brgemm_kernel_t { struct brgemm_amx_uker_t : public brgemm_kernel_t { brgemm_amx_uker_t(const brgemm_desc_t &abrd); - ~brgemm_amx_uker_t(); + ~brgemm_amx_uker_t() override; - status_t create_kernel(); - void operator()(brgemm_kernel_params_t *) const; - virtual const jit_generator *get_jit_generator() const; + status_t create_kernel() override; + void operator()(brgemm_kernel_params_t *) const override; + const jit_generator_t *get_jit_generator() const override; + const brgemm_desc_t &get_brg() const override { + return ((jit_base_brgemm_kernel_t *)brgemm_kernel_)->get_brg(); + } private: jit_brgemm_amx_uker_base_t *brgemm_kernel_ = nullptr; @@ -521,11 +648,14 @@ struct brgemm_amx_uker_t : public brgemm_kernel_t { template struct brdgmm_kernel_t : public brgemm_kernel_t { brdgmm_kernel_t(const brgemm_desc_t &abrd); - ~brdgmm_kernel_t(); + ~brdgmm_kernel_t() override; - status_t create_kernel(); - void operator()(brgemm_kernel_params_t *) const; - virtual const jit_generator *get_jit_generator() const; + status_t create_kernel() override; + void operator()(brgemm_kernel_params_t *) const override; + const jit_generator_t *get_jit_generator() const override; + const brgemm_desc_t &get_brg() const override { + return ((jit_base_brgemm_kernel_t *)brgemm_kernel_)->get_brg(); + } private: jit_brdgmm_kernel_base_t *brgemm_kernel_ = nullptr; @@ -574,7 +704,8 @@ struct brgemm_post_ops_data_t { const void *b_zp_compensations = nullptr, const void *c_zp_values = nullptr, bool skip_accumulation = false, int32_t zp_a_val = 1, bool do_only_comp = false, - bool do_only_zp_a_val = false, const float *dst_scales = nullptr) + bool do_only_zp_a_val = false, const float *dst_scales = nullptr, + const void *a_zp_values = nullptr) : bias(bias) , scales(scales) , binary_post_ops_rhs(binary_post_ops_rhs) @@ -589,7 +720,8 @@ struct brgemm_post_ops_data_t { , zp_a_val {zp_a_val} , do_only_comp {do_only_comp} , do_only_zp_a_val {do_only_zp_a_val} - , dst_scales(dst_scales) {} + , dst_scales(dst_scales) + , a_zp_values(a_zp_values) {} const void *bias = nullptr; const float *scales = nullptr; @@ -606,6 +738,7 @@ struct brgemm_post_ops_data_t { const bool do_only_comp = false; const bool do_only_zp_a_val = false; const float *dst_scales = nullptr; + const void *a_zp_values = nullptr; }; } // namespace x64 diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index d62d908f101..48760f79b6c 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2024 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,10 +51,14 @@ void init_kernel_datatype( brgemm_desc_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) { assert(dt_a != data_type::undef && dt_b != data_type::undef); brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8) - && utils::one_of(dt_b, data_type::u8, data_type::s8); - brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16); - brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32); - brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b); + && utils::one_of(dt_b, data_type::u8, data_type::s8, data_type::u4); + brg->is_bf16 = (dt_a == data_type::bf16) && utils::one_of(dt_b, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1); + // Note: f32:bf16 is treated as f32 case while f32:f16 has already been + // treated as f16. Probably, need a common ground here. + brg->is_f32 = (dt_a == data_type::f32) + && utils::one_of( + dt_b, data_type::f32, data_type::f16, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1); + brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b) && !brg->is_f32; brg->is_fp8 = one_of(dt_a, data_type::f8_e5m2, data_type::f8_e4m3) && one_of(dt_b, data_type::f8_e5m2, data_type::f8_e4m3); assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16 @@ -131,24 +135,45 @@ void set_isa_impl(brgemm_desc_t *brg) { is_isa_ok(avx512_core_fp16), avx512_core_fp16, is_isa_ok(avx2), avx2); } else if (brg->is_bf16) { - brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx), - avx512_core_amx, is_isa_ok(avx512_core_bf16), avx512_core_bf16, - is_isa_ok(avx2_vnni_2), avx2_vnni_2); + if (brg->dt_a == data_type::f32 && brg->dt_b == data_type::bf16) { + // Distinguish f32:bf16 case upconversion for bf16 on AVX512_CORE + // and AVX2. + brg->isa_impl = utils::map(true, isa_undef, + is_isa_ok(avx512_core_amx), avx512_core_amx, + is_isa_ok(avx512_core_bf16), avx512_core_bf16, + is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2), + avx2_vnni_2, is_isa_ok(avx2), avx2); + } else { + brg->isa_impl = utils::map(true, isa_undef, + is_isa_ok(avx512_core_amx), avx512_core_amx, + is_isa_ok(avx512_core_bf16), avx512_core_bf16, + is_isa_ok(avx2_vnni_2), avx2_vnni_2); + } } else if (brg->is_f16) { if (everyone_is(data_type::f16, brg->dt_a, brg->dt_b)) { brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx_fp16), avx512_core_amx_fp16, is_isa_ok(avx512_core_fp16), avx512_core_fp16, is_isa_ok(avx2_vnni_2), avx2_vnni_2); + } else if (brg->dt_a == data_type::f32 && brg->dt_b == data_type::f16) { + // Distinguish f32:f16 case upconversion for f16 on AVX512_CORE and + // AVX2. + brg->isa_impl = utils::map(true, isa_undef, + is_isa_ok(avx512_core_fp16), avx512_core_fp16, + is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2), avx2); } else { brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_fp16), avx512_core_fp16); } } else if (brg->is_int8) { - brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx), - avx512_core_amx, is_isa_ok(avx512_core_vnni), avx512_core_vnni, - is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2), - avx2_vnni_2, is_isa_ok(avx2_vnni), avx2_vnni); + brg->isa_impl + = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx_fp16), + avx512_core_amx_fp16, is_isa_ok(avx512_core_amx), + avx512_core_amx, is_isa_ok(avx512_core_fp16), + avx512_core_fp16, is_isa_ok(avx512_core_vnni), + avx512_core_vnni, is_isa_ok(avx512_core), avx512_core, + is_isa_ok(avx2_vnni_2), avx2_vnni_2, + is_isa_ok(avx2_vnni), avx2_vnni, is_isa_ok(avx2), avx2); } else if (brg->is_fp8) { brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx10_1_512_amx_fp16), avx10_1_512_amx_fp16); @@ -190,17 +215,19 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { || brg->brgattr.max_bottom_vpad > 0) && brg->zp_type_a != brgemm_broadcast_t::none; const int beta_regs = !one_of(brg->beta, 1.f, 0.f); + const int b_vnni_regs = brg->is_f16_b_non_amx_vnni() ? 2 : 0; const int max_isa_regs = isa_num_vregs(brg->isa_impl); // note: the 'adj_ld_block2' already removes the necessary registers // for 'embd_bcst' auto max_reg_count = max_isa_regs - max_bcst_regs - beta_regs - - req_compensation - req_zp_a_comp_pads; + - req_compensation - req_zp_a_comp_pads - b_vnni_regs; if (req_zp_a_comp_pads) max_reg_count = nstl::min(max_reg_count, max_isa_regs - max_bcst_regs - 5); - const int postops_regs = brg->attr() + // For dynamic quantization case it is more performant to maximize the amount of accumulators + const int postops_regs = brg->attr() && !brg->with_src_dyn_quant ? injector::aux_vec_count( brg->attr()->post_ops_, brg->isa_impl, true) : 0; @@ -218,238 +245,175 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { // non-VNNI INT8 dot product required 2 temp vectors if (brg->is_int8 && !brg->has_int8_vnni) max_bcast_block -= 2; + if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5; + if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2; + if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1; + if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0 && !brg->with_src_dyn_quant) max_bcast_block -= 1; + if (brg->with_src_dyn_quant) max_bcast_block -= 1; + max_bcast_block /= adj_ld_block2; return max_bcast_block; } -status_t brgemm_blocking(brgemm_desc_t *brg) { - - set_isa_impl(brg); - if (brg->isa_impl == isa_undef) return status::unimplemented; - assert(!brg->is_dgmm); // should not be called from brdgmm - if (brg->is_dgmm) return status::unimplemented; - set_brg_vmm(brg); - if (!(brg->is_tmm || brg->is_zmm || brg->is_ymm)) - return status::unimplemented; - - if (!brg->is_tmm) { - const int simd_w = is_superset(brg->isa_impl, avx512_core) ? 16 : 8; - brg->ld_block = simd_w; - brg->ldb = brg->load_dim / brg->ld_block; - brg->ldb_tail = brg->load_dim % brg->ld_block; - - int adj_ld_block2 = calculate_ldb_params(brg, 4); - int max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2); - - // reduce 'ld_block2' to allow a larger 'bd_block' - const int max_vpad = nstl::max( - brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad); - if (is_superset(brg->isa_impl, avx2) && max_bcast_block < max_vpad) { - adj_ld_block2 = calculate_ldb_params(brg, 2); - max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2); +status_t brgemm_blocking_tmm(brgemm_desc_t *brg) { + const auto L1 = platform::get_per_core_cache_size(1); + + // Blocking configuration for AMX + const auto BD = brg->bcast_dim; + const auto BD_R16 = rnd_up(BD, 16); + const auto LD = brg->load_dim; + const auto LD_R16 = rnd_up(LD, 16); + + const int max_width = 16, min_width = 1; + brg->ld_block = 16; + brg->ldb = LD / brg->ld_block; + brg->ldb_tail = LD % brg->ld_block; + + auto find_bdb_bd_mask = [&](int bd_block, int &bdb, int &bdb_tail) { + if (brg->brgattr.bd_mask_level != 2 || BD == 0) { + bdb = div_up(BD, bd_block); + bdb_tail = BD % bd_block; + return; } - const int min_block = 1; - float best_bd_block_eff = 0.f; - brg->bd_block = 1; - for (int bd_block = max_bcast_block; bd_block >= min_block; - bd_block--) { - const auto bd_block_disb = static_cast(brg->bcast_dim) - / rnd_up(brg->bcast_dim, bd_block); - const auto brgemm_microkernel_eff - = (static_cast(adj_ld_block2) * bd_block) - / (((adj_ld_block2) + bd_block) * max_bcast_block); - const auto bd_block_eff = bd_block_disb * brgemm_microkernel_eff; - - float block_foot_print = static_cast(brg->typesize_A) - * (bd_block * brg->reduce_dim); - if (block_foot_print <= static_cast( - platform::get_per_core_cache_size(1)) - && (bd_block_eff > best_bd_block_eff)) { - brg->bd_block = bd_block; - best_bd_block_eff = bd_block_eff; + bdb = 0; + bdb_tail = 0; + for (int i = 0; i < BD;) { + if (brg->brgattr.bd_mask_level == 2 + && brg->brgattr.bd_mask[i] == 0) { + i++; + } else { + i += bd_block; + if (i > BD) { + bdb_tail = BD - i + bd_block; + if (brg->brgattr.use_uker) bdb++; + } else + bdb++; } } - brg->bdb = brg->bcast_dim / brg->bd_block; - brg->bdb_tail = brg->bcast_dim % brg->bd_block; - - const int rd_unroll = 4; - const data_type_t rd_block_dt = get_mac_emu_data_type( - brg->dt_a, brg->isa_impl, brg->isa_impl != avx2_vnni_2); - if (rd_block_dt == dnnl_data_type_undef) return status::unimplemented; - const int vnni_granularity = data_type_vnni_granularity(rd_block_dt); - brg->rd_block = rd_unroll * vnni_granularity; - brg->rdb = brg->reduce_dim / brg->rd_block; - brg->rdb_tail = brg->reduce_dim % brg->rd_block; + }; - brg->is_M_tail = false; - } else { - // Blocking configuration for AMX - const int max_width = 16, min_width = 1; - brg->ld_block = 16; - brg->ldb = brg->load_dim / brg->ld_block; - brg->ldb_tail = brg->load_dim % brg->ld_block; - - auto find_bdb_bd_mask = [&](int bd_block, int &bdb, int &bdb_tail) { - if (brg->brgattr.bd_mask_level != 2 || brg->bcast_dim == 0) { - bdb = div_up(brg->bcast_dim, bd_block); - bdb_tail = brg->bcast_dim % bd_block; - return; + auto find_bd_block_for_bd_mask = [&]() { + if (brg->brgattr.bd_mask_level != 2 || BD == 0) return false; + + auto min_bdb = INT_MAX; + const auto start_bd_block = nstl::min(max_width, BD); + auto best_bd_block = start_bd_block; + for (auto bd_block = start_bd_block; bd_block > 0; bd_block--) { + int bdb = 0; + int bdb_tail = 0; + find_bdb_bd_mask(bd_block, bdb, bdb_tail); + // bcast_dim should be divided by bd_block + if (bdb < min_bdb && bdb_tail == 0) { + min_bdb = bdb; + best_bd_block = bd_block; } + } + brg->bd_block = best_bd_block; + brg->bdb_tail = 0; + brg->bdb = min_bdb; + return true; + }; - bdb = 0; - bdb_tail = 0; - for (int i = 0; i < brg->bcast_dim;) { - if (brg->brgattr.bd_mask_level == 2 - && brg->brgattr.bd_mask[i] == 0) { - i++; - } else { - i += bd_block; - if (i > brg->bcast_dim) { - bdb_tail = brg->bcast_dim - i + bd_block; - if (brg->brgattr.use_uker) bdb++; - } else - bdb++; - } - } - }; + auto set_decomposition_by_ld = [&]() { + if (brg->bd_block2 == 1 && brg->ldb > 0 && brg->ldb_tail == 0) { + if (brg->ldb % 3 == 0) + brg->ld_block2 = 3; + else if (brg->ldb % 2 == 0) + brg->ld_block2 = 2; + else + brg->ld_block2 = 1; + } else { + brg->ld_block2 + = (brg->ldb > 0 && brg->ldb % 2 == 0 && brg->ldb_tail == 0 + && brg->bd_block2 < 3) + ? 2 + : 1; + } + brg->ldb2 = brg->ldb / brg->ld_block2; + brg->ldb2_tail = brg->ldb % brg->ld_block2; - auto find_bd_block_for_bd_mask = [&]() { - if (brg->brgattr.bd_mask_level != 2 || brg->bcast_dim == 0) - return false; - - auto min_bdb = INT_MAX; - const auto start_bd_block = nstl::min(max_width, brg->bcast_dim); - auto best_bd_block = start_bd_block; - for (auto bd_block = start_bd_block; bd_block > 0; bd_block--) { - int bdb = 0; - int bdb_tail = 0; - find_bdb_bd_mask(bd_block, bdb, bdb_tail); - // bcast_dim should be divided by bd_block - if (bdb < min_bdb && bdb_tail == 0) { - min_bdb = bdb; - best_bd_block = bd_block; - } - } - brg->bd_block = best_bd_block; - brg->bdb_tail = 0; - brg->bdb = min_bdb; - return true; - }; + // Re-adjust the bd_block2 if possible + if (brg->ld_block2 == 1 && !brg->is_M_tail && brg->ldb_tail == 0) { + brg->bd_block2 = (brg->bdb >= 3) ? 3 : (brg->bdb >= 2) ? 2 : 1; + brg->bdb2 = brg->bdb / brg->bd_block2; + brg->bdb2_tail = (brg->bd_block2 == 1) ? brg->bdb + : brg->bdb % brg->bd_block2; + } + }; - auto set_decomposition_by_ld = [&]() { - if (brg->bd_block2 == 1 && brg->ldb > 0 && brg->ldb_tail == 0) { - if (brg->ldb % 3 == 0) - brg->ld_block2 = 3; - else if (brg->ldb % 2 == 0) - brg->ld_block2 = 2; - else - brg->ld_block2 = 1; - } else { - brg->ld_block2 - = (brg->ldb > 0 && brg->ldb % 2 == 0 - && brg->ldb_tail == 0 && brg->bd_block2 < 3) - ? 2 - : 1; - } - brg->ldb2 = brg->ldb / brg->ld_block2; - brg->ldb2_tail = brg->ldb % brg->ld_block2; - - // Re-adjust the bd_block2 if possible - if (brg->ld_block2 == 1 && !brg->is_M_tail && brg->ldb_tail == 0) { - brg->bd_block2 = (brg->bdb >= 3) ? 3 : (brg->bdb >= 2) ? 2 : 1; - brg->bdb2 = brg->bdb / brg->bd_block2; - brg->bdb2_tail = (brg->bd_block2 == 1) - ? brg->bdb - : brg->bdb % brg->bd_block2; + auto try_3x1_decomposition = [&](int width_step) { + brg->is_M_tail = false; + if (BD > (width_step - 1) * max_width && BD < width_step * max_width + && brg->ldb_tail == 0) { + if (!find_bd_block_for_bd_mask()) { + brg->bd_block = max_width; + brg->bdb = div_up(BD, brg->bd_block); + brg->bdb_tail = BD % brg->bd_block; + brg->is_M_tail = true; } - }; + brg->bd_block2 = width_step; + brg->bdb2 = brg->bdb / brg->bd_block2; + brg->bdb2_tail = brg->bdb % brg->bd_block2; + set_decomposition_by_ld(); + return true; + } + return false; + }; - auto try_3x1_decomposition = [&](int width_step) { - brg->is_M_tail = false; - if (brg->bcast_dim > (width_step - 1) * max_width - && brg->bcast_dim < width_step * max_width - && brg->ldb_tail == 0) { - if (!find_bd_block_for_bd_mask()) { - brg->bd_block = max_width; - brg->bdb = div_up(brg->bcast_dim, brg->bd_block); - brg->bdb_tail = brg->bcast_dim % brg->bd_block; - brg->is_M_tail = true; + auto try_2x2_decomposition = [&]() { + if (!find_bd_block_for_bd_mask()) { + for (int m_block = max_width; m_block >= min_width; m_block--) { + if (BD % m_block == 0) { + brg->bd_block = m_block; + break; } - brg->bd_block2 = width_step; - brg->bdb2 = brg->bdb / brg->bd_block2; - brg->bdb2_tail = brg->bdb % brg->bd_block2; - set_decomposition_by_ld(); - return true; } - return false; - }; - - auto try_2x2_decomposition = [&]() { - if (!find_bd_block_for_bd_mask()) { - for (int m_block = max_width; m_block >= min_width; m_block--) { - if (brg->bcast_dim % m_block == 0) { - brg->bd_block = m_block; - break; + if (brg->bd_block == 1) { + brg->bd_block = nstl::min(max_width, BD); + brg->bdb_tail = BD % max_width; + for (int i = max_width; i >= min_width; i--) { + const auto i_tail = BD % i; + if (i_tail > brg->bdb_tail || i_tail == 0) { + brg->bd_block = i; + brg->bdb_tail = i_tail; + if (i_tail == 0) break; } } - if (brg->bd_block == 1) { - brg->bd_block = nstl::min(max_width, brg->bcast_dim); - brg->bdb_tail = brg->bcast_dim % max_width; - for (int i = max_width; i >= min_width; i--) { - const auto i_tail = brg->bcast_dim % i; - if (i_tail > brg->bdb_tail || i_tail == 0) { - brg->bd_block = i; - brg->bdb_tail = i_tail; - if (i_tail == 0) break; - } - } - } - brg->bdb = brg->bcast_dim / brg->bd_block; - brg->bdb_tail = brg->bcast_dim % brg->bd_block; } + brg->bdb = BD / brg->bd_block; + brg->bdb_tail = BD % brg->bd_block; + } - brg->bd_block2 = (brg->bdb >= 2) ? 2 : 1; - brg->bdb2 = brg->bdb / brg->bd_block2; - brg->bdb2_tail = (brg->bd_block2 == 1) ? brg->bdb - : brg->bdb % brg->bd_block2; - - brg->is_M_tail = false; + brg->bd_block2 = (brg->bdb >= 2) ? 2 : 1; + brg->bdb2 = brg->bdb / brg->bd_block2; + brg->bdb2_tail + = (brg->bd_block2 == 1) ? brg->bdb : brg->bdb % brg->bd_block2; - set_decomposition_by_ld(); + brg->is_M_tail = false; - return !(brg->ld_block2 == 1 || brg->bd_block2 == 1 - || brg->bd_block < 8); - }; + set_decomposition_by_ld(); - bool is_decomposition_defined = false; - for (int i = decomposition_2x2; i != undefined; i++) { - switch (i) { - case decomposition_2x2: - is_decomposition_defined = try_2x2_decomposition(); - break; - case decomposition_3x1_3: - is_decomposition_defined = try_3x1_decomposition(3); - break; - case decomposition_3x1_2: - is_decomposition_defined = try_3x1_decomposition(2); - break; - default: assert(!"invalid value"); break; - }; - if (is_decomposition_defined) break; - } - if (!is_decomposition_defined) try_2x2_decomposition(); + return !(brg->ld_block2 == 1 || brg->bd_block2 == 1 + || brg->bd_block < 8); + }; - auto recalc_bd_block = [&](int new_bd_block) { - if (new_bd_block == 0) return; + auto recalc_blocking = [&](int new_bd_block, int new_ld_block, + int new_bd_block2, int new_ld_block2) { + if (new_bd_block != 0) { brg->bd_block = new_bd_block; find_bdb_bd_mask(brg->bd_block, brg->bdb, brg->bdb_tail); brg->is_M_tail = (brg->bdb_tail != 0); - }; + } - auto recalc_bd_block2 = [&](int new_bd_block2) { - if (new_bd_block2 == 0) return; + if (new_ld_block != 0) { + brg->ld_block = new_ld_block; + brg->ldb = div_up(LD, brg->ld_block); + brg->ldb_tail = LD % brg->ld_block; + } + + if (new_bd_block2 != 0) { brg->bd_block2 = new_bd_block2; if (can_dispatch_uker(brg)) { brg->bdb2 = div_up(brg->bdb, brg->bd_block2); @@ -460,17 +424,9 @@ status_t brgemm_blocking(brgemm_desc_t *brg) { brg->bdb2 = full_bd_blocks / brg->bd_block2; brg->bdb2_tail = full_bd_blocks % brg->bd_block2; } - }; - - auto recalc_ld_block = [&](int new_ld_block) { - if (new_ld_block == 0) return; - brg->ld_block = new_ld_block; - brg->ldb = div_up(brg->load_dim, brg->ld_block); - brg->ldb_tail = brg->load_dim % brg->ld_block; - }; + } - auto recalc_ld_block2 = [&](int new_ld_block2) { - if (new_ld_block2 == 0) return; + if (new_ld_block2 != 0) { brg->ld_block2 = new_ld_block2; if (can_dispatch_uker(brg)) { brg->ldb2 = div_up(brg->ldb, brg->ld_block2); @@ -481,217 +437,184 @@ status_t brgemm_blocking(brgemm_desc_t *brg) { brg->ldb2 = full_ld_blocks / brg->ld_block2; brg->ldb2_tail = full_ld_blocks % brg->ld_block2; } - }; + } + }; - const bool try_load_nt_A - = (brg->innermost_loop == brgemm_bd_loop_innermost); - const bool try_load_nt_B - = (brg->innermost_loop == brgemm_ld_loop_innermost); - const bool try_load_nt - = (static_cast(brg->typesize_A) - * brg->brgattr.hint_expected_A_size - + static_cast(brg->typesize_B) - * brg->brgattr.hint_expected_B_size - + static_cast(brg->typesize_C) - * brg->brgattr.hint_expected_C_size) - >= platform::get_per_core_cache_size(1); - brg->load_nt_A = try_load_nt_A && try_load_nt; - brg->load_nt_B = try_load_nt_B && try_load_nt; - - recalc_bd_block(brg->bd_block); - recalc_bd_block2(brg->bd_block2); - recalc_ld_block(brg->ld_block); - recalc_ld_block2(brg->ld_block2); - - if (can_dispatch_uker(brg)) { - // Blocking heuristics for some shapes - // TODO: Review these criterias - size_t eff_K - = brg->reduce_dim * brg->typesize_A * brg->brgattr.K_koef; - auto L1 = platform::get_per_core_cache_size(1); - auto low_K = (L1 - 4 * 1024) / (6 * 16); - - // TODO: if rdb_tail != 0 then we should limit - // blocking because we need extra tiles for A and B to load rdb_tail - // if bd_mask_level != 0 it means it aligned to 16 - - bool bdb_block_tail = !(brg->bd_block > 12 - && (brg->bcast_dim % brg->bd_block == 0 - && brg->brgattr.bd_mask_level == 0)); - bool ldb_tail_16 = (brg->load_dim % 16 != 0); - if (everyone_is(false, bdb_block_tail, ldb_tail_16)) { - // try to use 1x(4|5) or (4|5)x1 decomposition for specific - // range of K - auto upper_K5 = (L1 - 5 * 1024) / (5 * 16); - auto upper_K4 = (L1 - 4 * 1024) / (4 * 16); - bool K5_fit_L1 = (low_K <= eff_K && eff_K < upper_K5); - bool K4_fit_L1 = (low_K <= eff_K && eff_K < upper_K4); - bool bd_big = (brg->bcast_dim > 32); - bool ld_big = (brg->load_dim > 32); - if (brg->load_dim % 80 == 0 && K5_fit_L1 && bd_big) { - - recalc_ld_block(16); - recalc_bd_block2(1); - recalc_ld_block2(5); - brg->load_nt_A = true; - brg->load_nt_B = false; - brg->innermost_loop = brgemm_bd_loop_innermost; - } else if (brg->load_dim % 64 == 0 && K4_fit_L1 && bd_big) { - - recalc_ld_block(16); - recalc_bd_block2(1); - recalc_ld_block2(4); - brg->load_nt_A = true; - brg->load_nt_B = false; - brg->innermost_loop = brgemm_bd_loop_innermost; - } else if ((brg->bcast_dim % 80 == 0 - || (brg->brgattr.bd_mask_level != 0 - && brg->bdb % 4 == 0)) - && K5_fit_L1 && ld_big) { - - recalc_ld_block(16); - recalc_bd_block2(5); - recalc_ld_block2(1); - brg->load_nt_A = false; - brg->load_nt_B = true; - brg->innermost_loop = brgemm_ld_loop_innermost; - } else if ((brg->bcast_dim % 64 == 0 - || (brg->brgattr.bd_mask_level != 0 - && brg->bdb % 4 == 0)) - && K4_fit_L1 && ld_big) { - - recalc_bd_block(16); - recalc_ld_block(16); - recalc_bd_block2(4); - recalc_ld_block2(1); - brg->load_nt_A = false; - brg->load_nt_B = true; - brg->innermost_loop = brgemm_ld_loop_innermost; - } - } - // Tile decomposition for shapes with small dimensions - // or dimensions with tails - if (ldb_tail_16 && !bdb_block_tail && brg->load_dim > 64 - && brg->ld_block < 8) { - recalc_ld_block(16); - recalc_bd_block2(2); - recalc_ld_block2(1); - } else if (ldb_tail_16 && !bdb_block_tail - && rnd_up(brg->load_dim, 16) == 64 - && (brg->ld_block < 8 || brg->ldb_tail > 0)) { - recalc_ld_block(16); - recalc_bd_block2(1); - recalc_ld_block2(4); - } else if (ldb_tail_16 && !bdb_block_tail - && rnd_up(brg->load_dim, 16) == 48 - && (brg->ld_block < 8 || brg->ldb_tail > 0)) { - recalc_ld_block(16); - recalc_bd_block2(1); - recalc_ld_block2(3); - } else if (ldb_tail_16 && !bdb_block_tail - && rnd_up(brg->load_dim, 16) == 32 - && (brg->ld_block < 8 || brg->ldb_tail > 0)) { - recalc_ld_block(16); - recalc_bd_block2(2); - recalc_ld_block2(2); - } else if (brg->bcast_dim <= 16) { - recalc_bd_block(brg->bcast_dim); - recalc_ld_block(16); - recalc_bd_block2(1); - recalc_ld_block2( - nstl::min(ldb_tail_16 ? ((brg->ldb > 4) ? 3 : 4) : 5, - div_up(brg->load_dim, 16))); - } else if (bdb_block_tail && !ldb_tail_16 && brg->bcast_dim > 64 - && (brg->bd_block < 8 || brg->bdb_tail > 0)) { - - recalc_bd_block(16); - recalc_ld_block(16); - recalc_bd_block2(1); - recalc_ld_block2(2); - } else if (bdb_block_tail && !ldb_tail_16 - && rnd_up(brg->bcast_dim, 16) == 64 - && (brg->bd_block < 8 || brg->bdb_tail > 0)) { - recalc_bd_block(16); - recalc_ld_block(16); - recalc_bd_block2(4); - recalc_ld_block2(1); - } else if (bdb_block_tail && !ldb_tail_16 - && rnd_up(brg->bcast_dim, 16) == 48 - && (brg->bd_block < 8 || brg->bdb_tail > 0)) { - recalc_bd_block(16); - recalc_ld_block(16); - recalc_bd_block2(3); - recalc_ld_block2(1); - } else if (bdb_block_tail && !ldb_tail_16 - && rnd_up(brg->bcast_dim, 16) == 32 - && (brg->bd_block < 8 || brg->bdb_tail > 0) - && (brg->load_dim % 32 == 0)) { - - recalc_bd_block(16); - recalc_ld_block(16); - recalc_bd_block2(2); - recalc_ld_block2(2); - } else if (brg->load_dim <= 16) { - recalc_bd_block(16); - recalc_ld_block(16); // we can't use ld_block other than 16 - recalc_bd_block2( - nstl::min(brg->bdb_tail ? (brg->bdb > 4 ? 3 : 4) : 5, - div_up(brg->bcast_dim, 16))); - recalc_ld_block2(1); - } else if (bdb_block_tail && ldb_tail_16 - && rnd_up(brg->bcast_dim, 16) == 32 - && rnd_up(brg->load_dim, 16) == 32 - && (brg->ld_block < 8 || brg->ldb_tail > 0 - || brg->bd_block < 8 || brg->bdb_tail > 0)) { - recalc_bd_block(16); - recalc_ld_block(16); - recalc_bd_block2(2); - recalc_ld_block2(2); - } - // if interleave stores and small number of iterations then - // try to increase them - auto n_iterations = brg->bdb2 * brg->bdb2; - if (false && brg->brgattr.use_interleave_stores - && n_iterations < 4) { - int k_it = div_up(4, n_iterations); - if (brg->bdb2 > brg->ldb2) - recalc_bd_block2(div_up(brg->bdb2, k_it)); - else - recalc_ld_block2(div_up(brg->ldb2, k_it)); + auto recalc_blocking_ext + = [&](int new_bd_block, int new_ld_block, int new_bd_block2, + int new_ld_block2, bool load_nt_A, bool load_nt_B, + brgemm_kernel_innermost_loop_t innermost_loop) { + recalc_blocking(new_bd_block, new_ld_block, new_bd_block2, + new_ld_block2); + brg->load_nt_A = load_nt_A; + brg->load_nt_B = load_nt_B; + brg->innermost_loop = innermost_loop; + }; + + bool is_decomposition_defined = false; + for (int i = decomposition_2x2; i != undefined; i++) { + switch (i) { + case decomposition_2x2: + is_decomposition_defined = try_2x2_decomposition(); + break; + case decomposition_3x1_3: + is_decomposition_defined = try_3x1_decomposition(3); + break; + case decomposition_3x1_2: + is_decomposition_defined = try_3x1_decomposition(2); + break; + default: assert(!"invalid value"); break; + }; + if (is_decomposition_defined) break; + } + if (!is_decomposition_defined) try_2x2_decomposition(); + + const bool try_load_nt_A + = (brg->innermost_loop == brgemm_bd_loop_innermost); + const bool try_load_nt_B + = (brg->innermost_loop == brgemm_ld_loop_innermost); + const bool try_load_nt + = (static_cast(brg->typesize_A) + * brg->brgattr.hint_expected_A_size + + static_cast(brg->typesize_B) + * brg->brgattr.hint_expected_B_size + + static_cast(brg->typesize_C) + * brg->brgattr.hint_expected_C_size) + >= L1; + brg->load_nt_A = try_load_nt_A && try_load_nt; + brg->load_nt_B = try_load_nt_B && try_load_nt; + + recalc_blocking( + brg->bd_block, brg->ld_block, brg->bd_block2, brg->ld_block2); + + if (can_dispatch_uker(brg)) { + // Blocking heuristics for some shapes + // TODO: Review these criteria + const size_t eff_K + = brg->reduce_dim * brg->typesize_A * brg->brgattr.K_koef; + const auto low_K = (L1 - 4 * 1024) / (6 * 16); + + // TODO: if rdb_tail != 0 then we should limit + // blocking because we need extra tiles for A and B to load rdb_tail + // if bd_mask_level != 0 it means it aligned to 16 + + const bool bdb_block_tail = !(brg->bd_block > 12 + && (BD % brg->bd_block == 0 + && brg->brgattr.bd_mask_level == 0)); + const bool ldb_tail_16 = (LD % 16 != 0); + if (everyone_is(false, bdb_block_tail, ldb_tail_16)) { + // try to use 1x(4|5) or (4|5)x1 decomposition for specific + // range of K + const auto upper_K5 = (L1 - 5 * 1024) / (5 * 16); + const auto upper_K4 = (L1 - 4 * 1024) / (4 * 16); + const bool K5_fit_L1 = (low_K <= eff_K && eff_K < upper_K5); + const bool K4_fit_L1 = (low_K <= eff_K && eff_K < upper_K4); + const bool bd_big = (BD > 32); + const bool ld_big = (LD > 32); + const bool aligned_bd_mask + = brg->brgattr.bd_mask_level != 0 && brg->bdb % 4 == 0; + if (LD % 80 == 0 && K5_fit_L1 && bd_big) { + recalc_blocking_ext( + 0, 16, 1, 5, true, false, brgemm_bd_loop_innermost); + } else if (LD % 64 == 0 && K4_fit_L1 && bd_big) { + recalc_blocking_ext( + 0, 16, 1, 4, true, false, brgemm_bd_loop_innermost); + } else if ((BD % 80 == 0 || aligned_bd_mask) && K5_fit_L1 + && ld_big) { + + recalc_blocking_ext( + 0, 16, 5, 1, false, true, brgemm_ld_loop_innermost); + } else if ((BD % 64 == 0 || aligned_bd_mask) && K4_fit_L1 + && ld_big) { + recalc_blocking_ext( + 16, 16, 4, 1, false, true, brgemm_ld_loop_innermost); } } + // Tile decomposition for shapes with small dimensions + // or dimensions with tails + const bool weak_ldb = brg->ld_block < 8 || brg->ldb_tail > 0; + const bool weak_bdb = brg->bd_block < 8 || brg->bdb_tail > 0; + const bool ldb_tail_only = ldb_tail_16 && !bdb_block_tail; + const bool bdb_tail_only = bdb_block_tail && !ldb_tail_16; + if (ldb_tail_only && LD > 64 && brg->ld_block < 8) { + recalc_blocking(0, 16, 2, 1); + } else if (ldb_tail_only && weak_ldb && LD_R16 == 64) { + recalc_blocking(0, 16, 1, 4); + } else if (ldb_tail_only && weak_ldb && LD_R16 == 48) { + recalc_blocking(0, 16, 1, 3); + } else if (ldb_tail_only && weak_ldb && LD_R16 == 32) { + recalc_blocking(0, 16, 2, 2); + } else if (BD <= 16) { + // Have to call recalc_blocking twice to calculate ldb + recalc_blocking(BD, 16, 0, 0); + const auto ld_block2 = nstl::min( + ldb_tail_16 ? ((brg->ldb > 4) ? 3 : 4) : 5, div_up(LD, 16)); + recalc_blocking(0, 0, 1, ld_block2); + } else if (bdb_tail_only && weak_bdb && BD > 64) { + recalc_blocking(16, 16, 1, 2); + } else if (bdb_tail_only && weak_bdb && BD_R16 == 64) { + recalc_blocking(16, 16, 4, 1); + } else if (bdb_tail_only && weak_bdb && BD_R16 == 48) { + recalc_blocking(16, 16, 3, 1); + } else if (bdb_tail_only && weak_bdb && BD_R16 == 32 + && (LD % 32 == 0)) { + recalc_blocking(16, 16, 2, 2); + } else if (LD <= 16) { + // Have to call recalc_blocking twice to calculate bdb + // we can't use ld_block other than 16 + recalc_blocking(16, 16, 0, 0); + const auto bd_block2 = nstl::min( + brg->bdb_tail ? (brg->bdb > 4 ? 3 : 4) : 5, div_up(BD, 16)); + recalc_blocking(0, 0, bd_block2, 1); + } else if (bdb_block_tail && ldb_tail_16 && BD_R16 == 32 && LD_R16 == 32 + && (weak_ldb || weak_bdb)) { + recalc_blocking(16, 16, 2, 2); + } - if (brg->get_num_A_tiles() + brg->get_num_B_tiles() - + brg->get_num_C_tiles() - > brgemm_desc_t::AMX_TILES_NUM) { - assert(!"brgemm internal error: invalid blocking"); - return status::runtime_error; + // The code below is a draft for the future optimization of interleave + // stores and small number of iterations. + // TODO: review and enable if needed +#if 0 + // if interleave stores and small number of iterations then + // try to increase them + const auto n_iterations = brg->bdb2 * brg->bdb2; + if (brg->brgattr.use_interleave_stores && n_iterations < 4) { + int k_it = div_up(4, n_iterations); + if (brg->bdb2 > brg->ldb2) + recalc_blocking(0, 0, div_up(brg->bdb2, k_it), 0); + else + recalc_blocking(0, 0, 0, div_up(brg->ldb2, k_it)); } +#endif + } - // check hints for blocking parameters - recalc_bd_block(brg->brgattr.hint_bd_block); - recalc_bd_block2(brg->brgattr.hint_bd_block2 - ? brg->brgattr.hint_bd_block2 - : brg->bd_block2); - recalc_ld_block(brg->brgattr.hint_ld_block); - recalc_ld_block2(brg->brgattr.hint_ld_block2 - ? brg->brgattr.hint_ld_block2 - : brg->ld_block2); - - if (brg->brgattr.hint_load_nt_A != brgemm_hint_nt_undef) - brg->load_nt_A - = (brg->brgattr.hint_load_nt_A == brgemm_hint_nt_true); - if (brg->brgattr.hint_load_nt_B != brgemm_hint_nt_undef) - brg->load_nt_B - = (brg->brgattr.hint_load_nt_B == brgemm_hint_nt_true); - - const bool reduce_by_words = brg->is_bf16_tmm || brg->is_f16_tmm - || brg->is_input_convert(); - const auto max_rd_block = reduce_by_words ? 32 : 64; - const auto rd_block_step = (reduce_by_words && !brg->is_fp8) ? 2 : 4; - // TODO: if rd_block calculated is very small then maybe it makes - // sense to use 1x2 or 2x1 blocking with supporting rd_block - // and rdb_tail + if (brg->get_num_A_tiles() + brg->get_num_B_tiles() + brg->get_num_C_tiles() + > brgemm_desc_t::AMX_TILES_NUM) { + assert(!"brgemm internal error: invalid blocking"); + return status::runtime_error; + } + + // check hints for blocking parameters + recalc_blocking(brg->brgattr.hint_bd_block, brg->brgattr.hint_ld_block, + brg->brgattr.hint_bd_block2 ? brg->brgattr.hint_bd_block2 + : brg->bd_block2, + brg->brgattr.hint_ld_block2 ? brg->brgattr.hint_ld_block2 + : brg->ld_block2); + + if (brg->brgattr.hint_load_nt_A != brgemm_hint_nt_undef) + brg->load_nt_A = (brg->brgattr.hint_load_nt_A == brgemm_hint_nt_true); + if (brg->brgattr.hint_load_nt_B != brgemm_hint_nt_undef) + brg->load_nt_B = (brg->brgattr.hint_load_nt_B == brgemm_hint_nt_true); + + // TODO: if rd_block calculated is very small then maybe it makes + // sense to use 1x2 or 2x1 blocking with supporting rd_block + // and rdb_tail + const auto rd_block_step = brg->rd_block_step(); + const auto max_rd_block = brg->max_rd_block(); + if (brg->amx_may_extend_k()) { + brg->rd_block = nstl::min( + rnd_up(brg->reduce_dim, brg->rd_step), max_rd_block); + } else { brg->rd_block = rd_block_step; for (int i = max_rd_block; i > 0; i -= rd_block_step) { if (brg->reduce_dim % i == 0) { @@ -699,33 +622,186 @@ status_t brgemm_blocking(brgemm_desc_t *brg) { break; } } - brg->rdb = brg->reduce_dim / brg->rd_block; - brg->rdb_tail = brg->reduce_dim % brg->rd_block; - - // Remove these guards in the future (add tail processing by reduction - // dimension) - // TODO: these checks do not work for fp8-f16 and f16-fp8 cfgs - if (!IMPLICATION( - brg->rdb > 0 && brg->rdb_tail, brg->is_input_convert())) { - return status::unimplemented; + } + + brg->rdb = brg->reduce_dim / brg->rd_block; + brg->rdb_tail = brg->reduce_dim % brg->rd_block; + + // Remove these guards in the future (add tail processing by reduction + // dimension) + // TODO: these checks do not work for fp8-f16 and f16-fp8 cfgs + if (!IMPLICATION(brg->rdb > 0 && brg->rdb_tail, + brg->is_input_convert() || brg->amx_wary_k_tail())) { + return status::unimplemented; + } + + if (!IMPLICATION((brg->rdb_tail + % ((brg->is_bf16_tmm || brg->is_f16_tmm) ? 2 : 4)) + != 0, + brg->is_input_convert() || brg->amx_wary_k_tail())) { + return status::unimplemented; + } + + //TODO: check this condition + brg->interleave_tilestores_ = brg->beta == 0 + && (brg->brgattr.use_interleave_stores + && (brg->bd_block2 * brg->ld_block2 == 4) + && !brg->brgattr.var_bs) + ? true + : false; + return status::success; +} + +status_t brgemm_blocking_vmm(brgemm_desc_t *brg) { + const auto L1 = platform::get_per_core_cache_size(1); + + const int simd_w = is_superset(brg->isa_impl, avx512_core) ? 16 : 8; + brg->ld_block = simd_w; + brg->ldb = brg->load_dim / brg->ld_block; + brg->ldb_tail = brg->load_dim % brg->ld_block; + + const int max_vpad = nstl::max( + brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad); + + int max_bcast_block {0}, min_bcast_block {0}, adj_ld_block2 {0}; + if (brg->with_src_dyn_quant) { + adj_ld_block2 = calculate_ldb_params(brg, 4); + max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2); + // reduce 'ld_block2' to allow a larger 'bd_block' + if (is_superset(brg->isa_impl, avx2) && max_bcast_block < max_vpad) { + for (int try_ld_block2 = 2; try_ld_block2 > 0; --try_ld_block2) { + adj_ld_block2 = calculate_ldb_params(brg, try_ld_block2); + max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2); + if (max_bcast_block >= max_vpad) break; + } + // bcast block in brgemm kernel should be greater than virtual + // padding to avoid possible functional issues + if (max_bcast_block < max_vpad) return status::unimplemented; + } + } else { + // iterate ld_block2 starting from 4 to allow bd_block larger than + // virtual padding + bool few_regs = utils::one_of(brg->isa_impl, avx2, avx2_vnni, avx2_vnni_2); + bool hint_n_bcast_1_load + = brg->brgattr.hint_loop_order == brgemm_lo_bl_1load; + for (int try_ld_block2 = 4; try_ld_block2 > 0; --try_ld_block2) { + adj_ld_block2 = calculate_ldb_params(brg, try_ld_block2); + brg->n_bcast_1_load + = (few_regs && adj_ld_block2 == 4) || hint_n_bcast_1_load; + max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2); + const auto bdb_tail = brg->bcast_dim % max_bcast_block; + min_bcast_block = bdb_tail > 0 ? bdb_tail : max_bcast_block; + if (min_bcast_block >= max_vpad) break; } - if (!IMPLICATION( - (brg->rdb_tail - % ((brg->is_bf16_tmm || brg->is_f16_tmm) ? 2 : 4)) - != 0, - brg->is_input_convert())) { - return status::unimplemented; + // bcast block in brgemm kernel should be greater than virtual + // padding to avoid possible functional issues + if (min_bcast_block < max_vpad) return status::unimplemented; + } + + const int min_block = nstl::max(1, max_vpad); + + float best_bd_block_eff = 0.f; + if (max_bcast_block == 0) max_bcast_block = 1; + brg->bd_block = max_bcast_block; + for (int bd_block = max_bcast_block; bd_block >= min_block; bd_block--) { + const auto bd_block_disb = static_cast(brg->bcast_dim) + / rnd_up(brg->bcast_dim, bd_block); + const auto brgemm_microkernel_eff + = (static_cast(adj_ld_block2) * bd_block) + / (((adj_ld_block2) + bd_block) * max_bcast_block); + const auto bd_block_eff = bd_block_disb * brgemm_microkernel_eff; + + float block_foot_print = static_cast(brg->typesize_A) + * (bd_block * brg->reduce_dim); + if (block_foot_print <= static_cast(L1) + && (bd_block_eff > best_bd_block_eff)) { + brg->bd_block = bd_block; + best_bd_block_eff = bd_block_eff; } + } + brg->bdb = brg->bcast_dim / brg->bd_block; + brg->bdb_tail = brg->bcast_dim % brg->bd_block; + + const data_type_t rd_block_dt = get_mac_emu_data_type( + brg->dt_a, brg->isa_impl, brg->isa_impl != avx2_vnni_2); + if (rd_block_dt == dnnl_data_type_undef) return status::unimplemented; + const int vnni_granularity + = (brg->is_f16 && brg->isa_impl == avx512_core_fp16) + ? 1 + : data_type_vnni_granularity(brg->dt_a); + int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4, data_type::f4_e2m1) ? 32 : 4; + if (brg->with_grouped_wei_decomp && !brg->with_src_dyn_quant) { + auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size); + min_group_size = nstl::min(min_group_size, brg->src_scales_group_size); + rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity); + rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity); + brg->rd_block = rd_unroll * vnni_granularity; + } else if (brg->with_src_dyn_quant) { + brg->rd_block = brg->src_scales_group_size; + auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size); + brg->rd_block = nstl::min(brg->rd_block, min_group_size); + } else { + brg->rd_block = rd_unroll * vnni_granularity; + } + + brg->rdb = brg->reduce_dim / brg->rd_block; + brg->rdb_tail = brg->reduce_dim % brg->rd_block; + + brg->is_M_tail = false; + // avx2_vnni_2 kernel with xf16 data type requires blocked weights. + if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16() + && brg->LDB % brg->ld_block > 0) + return status::unimplemented; + + return status::success; +} - //TODO: check this condition - brg->interleave_tilestores_ = brg->beta == 0 - && (brg->brgattr.use_interleave_stores - && (brg->bd_block2 * brg->ld_block2 == 4) - && !brg->brgattr.var_bs) - ? true - : false; +status_t brgemm_blocking(brgemm_desc_t *brg) { + const bool is_b_in_vnni_format = !(brg->dt_b == data_type::f16 && brg->isa_impl == avx512_core_fp16) && + !(one_of(brg->dt_a, data_type::f32, data_type::bf16) && + one_of(brg->dt_b, data_type::u8, data_type::s8)) && + !(one_of(brg->dt_a, data_type::f32) && one_of(brg->dt_b, data_type::bf16, data_type::f16)); + brg->ld_step = is_b_in_vnni_format ? data_type_vnni_granularity(brg->dt_b) : 1; + const bool has_no_vnni_compute_instruction + = (brg->is_f16 && one_of(brg->isa_impl, avx2_vnni_2, avx512_core_fp16)) + || (brg->is_bf16 && brg->isa_impl == avx2_vnni_2) + || (one_of(brg->dt_a, data_type::f32, data_type::bf16) && one_of(brg->dt_b, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) + || (one_of(brg->dt_a, data_type::f32) && one_of(brg->dt_b, data_type::bf16, data_type::f16)); + brg->rd_step = has_no_vnni_compute_instruction ? 1 : data_type_vnni_granularity(brg->dt_b); + + if (brg->with_src_dyn_quant && one_of(brg->dt_b, data_type::u4)) { + brg->ld_step = 8; + brg->rd_step = 4; } + set_isa_impl(brg); + if (brg->isa_impl == isa_undef) return status::unimplemented; + assert(!brg->is_dgmm); // should not be called from brdgmm + if (brg->is_dgmm) return status::unimplemented; + set_brg_vmm(brg); + if (!(brg->is_tmm || brg->is_zmm || brg->is_ymm)) + return status::unimplemented; + + if (brg->is_tmm) + CHECK(brgemm_blocking_tmm(brg)); + else + CHECK(brgemm_blocking_vmm(brg)); + + if (!IMPLICATION(brg->brgattr.LDB2 == 0, brg->load_dim <= brg->LDB)) + return status::invalid_arguments; + + brg->LDA2 = (brg->brgattr.LDA2 != 0) ? brg->brgattr.LDA2 : brg->LDA; + brg->LDB2 = (brg->brgattr.LDB2 != 0) ? brg->brgattr.LDB2 : brg->LDB; + brg->LDC2_M = (brg->brgattr.LDC2_M != 0) ? brg->brgattr.LDC2_M : brg->LDC; + brg->LDC2_N + = (brg->brgattr.LDC2_N != 0) ? brg->brgattr.LDC2_N : brg->ld_block; + + brg->is_blocked = (brg->LDA2 != brg->LDA || brg->LDB2 != brg->LDB + || brg->LDC2_M != brg->LDC || brg->LDC2_N != brg->ld_block); + + if (!IMPLICATION(brg->is_blocked, brg->layout == brgemm_row_major)) + return status::invalid_arguments; + return status::success; } @@ -839,7 +915,8 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa, brg->isa_user = isa; set_isa_impl(brg); - brg->is_int8_tmm = brg->is_int8 && brg->isa_impl == avx512_core_amx; + brg->is_int8_tmm + = brg->is_int8 && is_superset(brg->isa_impl, avx512_core_amx); brg->is_bf16_tmm = brg->is_bf16 && brg->isa_impl == avx512_core_amx; brg->is_f16_tmm = brg->is_f16 && brg->isa_impl == avx512_core_amx_fp16; brg->is_bf32 = is_bf32 @@ -851,8 +928,9 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa, brg->has_int8_vnni = isa_has_int8_vnni(brg->isa_impl); set_brg_vmm(brg); // TODO: Investigate if it is really needed here. - brg->req_s8s8_compensation = brg->is_int8 && brg->dt_a == data_type::s8 - && !isa_has_s8s8(brg->isa_impl); + brg->req_s8s8_compensation = brg->is_int8 && !brg->is_int8_tmm + && !isa_has_s8s8(brg->isa_impl) && brg->dt_a == data_type::s8 + && !brg->with_src_dyn_quant; brg->LDA = (brg->is_row_major()) ? static_cast(LDA) : static_cast(LDB); @@ -875,15 +953,6 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa, brg->bd_block2 = 0; brg->bdb2 = 0; brg->bdb2_tail = 0; - - const data_type_t ld_step_compute_dt - = get_mac_emu_data_type(brg->dt_b, brg->isa_impl, - brg->isa_impl != avx2_vnni_2 && !brg->is_fp8_via_convert()); - brg->ld_step = data_type_vnni_granularity(ld_step_compute_dt); - - const data_type_t rd_step_compute_dt = get_mac_emu_data_type( - brg->dt_b, brg->isa_impl, !brg->is_fp8_via_convert()); - brg->rd_step = data_type_vnni_granularity(rd_step_compute_dt); } void init_brdgmm_conf(brgemm_desc_t *brg, cpu_isa_t isa, diff --git a/src/cpu/x64/brgemm/brgemm_utils.hpp b/src/cpu/x64/brgemm/brgemm_utils.hpp index db2fc9a2a8d..2ff100d2351 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.hpp +++ b/src/cpu/x64/brgemm/brgemm_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2024 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,12 +28,17 @@ namespace impl { namespace cpu { namespace x64 { +void init_kernel_datatype( + brgemm_desc_t *brg, data_type_t dt_a, data_type_t dt_b); + namespace brgemm_utils { bool can_dispatch_uker(const brgemm_desc_t *brg); void maybe_try_bf32(brgemm_desc_t *brg); +void set_isa_impl(brgemm_desc_t *brg); + status_t brgemm_blocking(brgemm_desc_t *brg); status_t brdgmm_blocking(brgemm_desc_t *brg); diff --git a/src/cpu/x64/brgemm/capi/brgemm_api.cpp b/src/cpu/x64/brgemm/capi/brgemm_api.cpp deleted file mode 100644 index 79d5dcc73b8..00000000000 --- a/src/cpu/x64/brgemm/capi/brgemm_api.cpp +++ /dev/null @@ -1,688 +0,0 @@ -/******************************************************************************* -* Copyright 2024 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "oneapi/dnnl/dnnl_ukernel.h" - -#include "common/c_types_map.hpp" -#include "common/memory_desc_wrapper.hpp" -#include "common/verbose.hpp" - -#include "cpu/ref_io_helper.hpp" - -#include "cpu/x64/amx_tile_configure.hpp" - -#include "cpu/x64/brgemm/brgemm.hpp" - -#include "cpu/x64/brgemm/capi/brgemm_api.hpp" - -#ifdef DNNL_EXPERIMENTAL_UKERNEL - -using namespace dnnl::impl; -using namespace dnnl::impl::format_tag; -using namespace dnnl::impl::status; -using namespace dnnl::impl::cpu::x64; - -using brgemm_t = dnnl_brgemm; -using transform_t = dnnl_transform; - -#define VCHECK_BRGEMM(cond, msg, ...) \ - VCONDCHECK(ukernel, create, check, brgemm, (cond), \ - status::invalid_arguments, msg, ##__VA_ARGS__) - -#define VCHECK_BRGEMM_STATUS(status, cond, msg, ...) \ - VCONDCHECK(ukernel, create, check, brgemm, (cond), (status), msg, \ - ##__VA_ARGS__) - -status_t attr_params_t::set_post_ops_args(const void **post_ops_args) { - post_ops_args_ = post_ops_args; - return status::success; -} - -status_t attr_params_t::set_scales(const void *scales, int arg) { - switch (arg) { - case DNNL_ARG_SRC: a_scales_ = scales; break; - case DNNL_ARG_WEIGHTS: b_scales_ = scales; break; - case DNNL_ARG_DST: d_scales_ = scales; break; - default: assert(!"unsupported arg"); - } - return status::success; -} - -const void *attr_params_t::get_scales(int arg) const { - switch (arg) { - case DNNL_ARG_SRC: return a_scales_; - case DNNL_ARG_WEIGHTS: return b_scales_; - case DNNL_ARG_DST: return d_scales_; - default: assert(!"unsupported arg"); - } - return nullptr; -} - -dnnl_brgemm::~dnnl_brgemm() { - brgemm_kernel_destroy(brgemm_kernel_); -} - -// Typical usage is either `1.f` to append to previous result, or `0.f` to write -// C from scratch. -status_t brgemm_t::set_add_C(int add_C) { - if (add_C == 0) - beta_ = 0.f; - else if (add_C == 1) - beta_ = 1.f; - return status::success; -} - -status_t brgemm_t::set_post_ops( - dim_t ldd, data_type_t d_dt, const post_ops_t *post_ops) { - ldd_ = ldd; - d_dt_ = d_dt; - CHECK(attr_.set_post_ops(*post_ops)); - return status::success; -} - -status_t brgemm_t::set_scales(int mask, int arg) { - if (mask < 0) return status::invalid_arguments; - CHECK(attr_.scales_.set(arg, mask)); - return status::success; -} - -status_t brgemm_t::finalize() { - brgemm_batch_kind_t batch_kind = brgemm_batch_kind_t::brgemm_offs; - - auto status = brgemm_desc_init(&brgemm_desc_, cpu_isa_t::isa_undef, - batch_kind, a_dt_, b_dt_, /* transA = */ false, - /* trans_B = */ false, brgemm_row_major, /* alpha = */ 1.f, beta_, - lda_, ldb_, ldc_, M_, N_, K_, - /* strides = */ nullptr); - if (status != status::success) { - VCHECK_BRGEMM_STATUS(status, false, "brgemm_desc_init failed"); - } - - memory_desc_t D_md; - dims_t dims {M_, N_}; - dims_t strides {ldc_, 1}; - status = memory_desc_init_by_strides( - D_md, /* ndims = */ 2, dims, d_dt_, strides); - if (status != status::success) { - VCHECK_BRGEMM_STATUS(status, false, "D_md creation failed"); - } - - status = brgemm_desc_set_postops( - &brgemm_desc_, &attr_, &D_md, ldd_, data_type::undef); - if (status != status::success) { - VCHECK_BRGEMM_STATUS(status, false, "brgemm_desc_set_postops failed"); - } - - brgemm_attr_t brgemm_attr; - brgemm_attr.max_bs = batch_size_; - if (mayiuse(avx512_core_amx)) { - brgemm_attr.use_uker = true; - brgemm_attr.use_interleave_stores = true; - brgemm_attr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0; - } - - status = brgemm_desc_set_attr(&brgemm_desc_, brgemm_attr); - if (status != status::success) { - VCHECK_BRGEMM_STATUS(status, false, "brgemm_desc_set_attr failed"); - } - - // Note: API can't take a compensation buffer externally. Users must add - // compensation on their own as a binary post-op. - brgemm_desc_.req_s8s8_compensation = false; - - return status::success; -} - -pack_type_t brgemm_t::get_B_pack_type() const { - if (brgemm_desc_.is_b_data_layout_vnni()) return pack_type::pack32; - return pack_type::no_trans; -} - -size_t brgemm_t::get_scratchpad_size() const { - return brgemm_desc_.get_wsp_buffer_size(); -} - -status_t brgemm_t::set_hw_context() const { - char palette[AMX_PALETTE_SIZE] = {}; - auto status = brgemm_init_tiles(brgemm_desc_, palette); - // If status isn't successful, it means tiles configuration is not required. - if (status == status::success) { - status = amx_tile_lazy_configure(palette); - VCHECK_BRGEMM_STATUS( - status, status == status::success, "amx_tile_configure failed"); - } - return status::success; -} - -status_t brgemm_t::generate() { - // Re-generation won't take any effect. - if (brgemm_kernel_ != nullptr) return status::success; - - auto status = brgemm_kernel_create(&brgemm_kernel_, brgemm_desc_); - VCHECK_BRGEMM_STATUS( - status, status == status::success, "brgemm_kernel_create failed"); - - // Generate a verbose info string at the point where configuration is done. - if (get_verbose(verbose_t::exec_profile, component_t::ukernel)) { - create_verbose_info(); - } - return status::success; -} - -status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr, - const dim_t *A_B_offsets, void *C_ptr, void *scratchpad_ptr) const { - const auto batch_size = brgemm_desc_.brgattr.max_bs; - std::vector v_batch_element(batch_size); - for (int i = 0; i < batch_size; i++) { - v_batch_element[i].offset.A = A_B_offsets[2 * i]; - v_batch_element[i].offset.B = A_B_offsets[2 * i + 1]; - } - - if (get_verbose(verbose_t::exec_profile, component_t::ukernel)) { - double start_ms = get_msec(); - brgemm_kernel_execute(brgemm_kernel_, batch_size, A_ptr, B_ptr, - v_batch_element.data(), C_ptr, scratchpad_ptr, - /* dynamic_values = */ nullptr); - double duration_ms = get_msec() - start_ms; - - std::stringstream ss; - ss << "cpu,brgemm,,undef," << verbose_info_; - VPROF(start_ms, ukernel, exec, VERBOSE_profile, ss.str().c_str(), - duration_ms); - } else { - brgemm_kernel_execute(brgemm_kernel_, batch_size, A_ptr, B_ptr, - v_batch_element.data(), C_ptr, scratchpad_ptr, - /* dynamic_values = */ nullptr); - } - return status::success; -} - -status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr, - const dim_t *A_B_offsets, const void *C_ptr, void *D_ptr, - void *scratchpad_ptr, const attr_params_t *attr_params) const { - if (attr_params == nullptr) return status::invalid_arguments; - - const auto batch_size = brgemm_desc_.brgattr.max_bs; - std::vector v_batch_element(batch_size); - for (int i = 0; i < batch_size; i++) { - v_batch_element[i].offset.A = A_B_offsets[2 * i]; - v_batch_element[i].offset.B = A_B_offsets[2 * i + 1]; - } - - brgemm_post_ops_data_t post_ops_data; - // Note: this member is used to compute an offset from the base DST address. - // Thus, it's not a C buffer that should be passed, but D buffer. - post_ops_data.data_C_ptr_ = reinterpret_cast(D_ptr); - // This member expects a pointer to a vector of pointers to binary_po args. - // It's exactly what `attr_params` stores when gets a pointer from the user. - post_ops_data.binary_post_ops_rhs = attr_params->get_post_ops_args(); - - // Scales (quantization case, happens after accumulation). Require manual - // combining when both are present, and extending to full simd broadcast, - // when single values are provided. - // Note: this piece is pretty close to what `precompute_scales` does. - // TODO: switch to `precompute_scales` directly. - alignas(64) float scales_buf[16] = {0}; - // TODO: delegate extra memory to scratchpad? - std::vector wei_scales_v(N_); - - const bool has_src_scales - = !attr_.scales_.get(DNNL_ARG_SRC).has_default_values(); - const bool has_wei_scales - = !attr_.scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); - - // Save src scale value to re-use it. - float src_scale_val = 1.f; - if (has_src_scales) { - const void *src_scales_ptr = attr_params->get_scales(DNNL_ARG_SRC); - if (src_scales_ptr == nullptr) return status::invalid_arguments; - - src_scale_val - = cpu::io::load_float_value(data_type::f32, src_scales_ptr, 0); - } - if (has_wei_scales) { - // Handle weights entirely here to avoid duplicating the logic. - - const void *wei_scales_ptr = attr_params->get_scales(DNNL_ARG_WEIGHTS); - if (wei_scales_ptr == nullptr) return status::invalid_arguments; - - int wei_mask = attr_.scales_.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_mask > 0) { - for (dim_t i = 0; i < N_; i++) { - const float wei_scale_val = cpu::io::load_float_value( - data_type::f32, wei_scales_ptr, i); - wei_scales_v[i] = wei_scale_val * src_scale_val; - } - post_ops_data.scales = wei_scales_v.data(); - } else { - const float s = cpu::io::load_float_value( - data_type::f32, wei_scales_ptr, 0); - utils::array_set(scales_buf, s * src_scale_val, 16); - post_ops_data.scales = scales_buf; - } - } else if (has_src_scales) { - utils::array_set(scales_buf, src_scale_val, 16); - post_ops_data.scales = scales_buf; - } - - // Destination scales. Require manual extending to full simd broadcast. - alignas(64) float dst_scales_buf[16] = {0}; - if (!attr_.scales_.get(DNNL_ARG_DST).has_default_values()) { - const void *dst_scales_ptr = attr_params->get_scales(DNNL_ARG_DST); - if (dst_scales_ptr == nullptr) return status::invalid_arguments; - - const float s - = cpu::io::load_float_value(data_type::f32, dst_scales_ptr, 0); - utils::array_set(dst_scales_buf, 1.f / s, 16); - post_ops_data.dst_scales = dst_scales_buf; - } - - if (D_ptr && c_dt_ == d_dt_ - && attr_.has_default_values( - primitive_attr_t::skip_mask_t::fpmath_mode)) { - C_ptr = D_ptr; - } - - if (get_verbose(verbose_t::exec_profile, component_t::ukernel)) { - double start_ms = get_msec(); - brgemm_kernel_execute_postops(brgemm_kernel_, batch_size, A_ptr, B_ptr, - v_batch_element.data(), const_cast(C_ptr), D_ptr, - post_ops_data, scratchpad_ptr, - /* dynamic_values = */ nullptr); - double duration_ms = get_msec() - start_ms; - - std::stringstream ss; - ss << "cpu,brgemm,,undef," << verbose_info_; - VPROF(start_ms, ukernel, exec, VERBOSE_profile, ss.str().c_str(), - duration_ms); - } else { - brgemm_kernel_execute_postops(brgemm_kernel_, batch_size, A_ptr, B_ptr, - v_batch_element.data(), const_cast(C_ptr), D_ptr, - post_ops_data, scratchpad_ptr, - /* dynamic_values = */ nullptr); - } - return status::success; -} - -status_t brgemm_t::create_verbose_info() { -#if defined(DISABLE_VERBOSE) - return status::success; -#else - const auto &d = brgemm_desc_; - std::stringstream ss; - - memory_desc_t src_md; - const dims_t src_dims = {M_, K_}; - const dims_t src_strides = {lda_, 1}; - CHECK(memory_desc_init_by_strides(src_md, 2, src_dims, a_dt_, src_strides)); - - memory_desc_t wei_md; - const dims_t wei_dims = {K_, N_}; - const dims_t wei_strides = {ldb_, 1}; - CHECK(memory_desc_init_by_strides(wei_md, 2, wei_dims, b_dt_, wei_strides)); - - memory_desc_t dst_md; - const dims_t dst_dims = {M_, N_}; - const dims_t dst_strides = {ldd_, 1}; - CHECK(memory_desc_init_by_strides(dst_md, 2, dst_dims, d_dt_, dst_strides)); - - ss << md2fmt_str("src", &src_md, format_kind::undef) << " "; - ss << md2fmt_str("wei", &wei_md, format_kind::undef) << " "; - ss << md2fmt_str("dst", &dst_md, format_kind::undef); - ss << "," << attr2str(&attr_) << ","; - ss << "bs:" << d.brgattr.max_bs << " beta:" << beta_; - ss << "," << md2dim_str(&src_md) << ":" << md2dim_str(&wei_md); - - verbose_info_ = ss.str(); - return status::success; -#endif -} - -dnnl_transform::dnnl_transform(dim_t K, dim_t N, pack_type_t in_pack_type, - dim_t in_ld, dim_t out_ld, data_type_t in_dt, data_type_t out_dt) - : K_(K) - , N_(N) - , in_ld_(in_ld) - , out_ld_(out_ld) - , in_dt_(in_dt) - , out_dt_(out_dt) { - // Check for a valid in_ld depending on a pack type. - assert(in_pack_type == pack_type::no_trans ? in_ld_ >= N_ : in_ld_ >= K_); - // Only special N_blk sizes are supported by matmul copy routines. Rest - // will crash. - assert(utils::one_of(out_ld_, 16, 32, 48, 64)); - - const auto in_tag = in_pack_type == pack_type::trans ? format_tag::ba - : format_tag::ab; - auto status = matmul::init_conf(bmc_, /* batch = */ 1, K_, N_, in_ld_, - out_ld_, in_dt_, out_dt_, in_tag); - assert(status == status::success); - if (status != status::success) return; - - if (in_pack_type == pack_type::trans) { - strides_[0] = 1; - strides_[1] = in_ld_; - } else if (in_pack_type == pack_type::no_trans) { - strides_[0] = in_ld_; - strides_[1] = 1; - } else { - assert(!"Unsupported pack type"); - } -} - -status_t transform_t::generate() { - // Re-generation won't take any effect. - if (pack_B_kernel_ != nullptr) return status::success; - - CHECK(matmul::create_brgemm_matmul_copy_b(pack_B_kernel_, &bmc_)); - - // Generate a verbose info string at the point where configuration is done. - if (get_verbose(verbose_t::exec_profile, component_t::ukernel)) { - CHECK(create_verbose_info()); - } - return status::success; -} - -status_t transform_t::execute(const void *src, void *dst) const { - double start_ms = 0; - if (get_verbose(verbose_t::exec_profile, component_t::ukernel)) - start_ms = get_msec(); - - const uint8_t *src_ptr = reinterpret_cast(src); - uint8_t *dst_ptr = reinterpret_cast(dst); - - const auto &kernel_conf = bmc_; - const dim_t n_blks = utils::div_up(kernel_conf.N, kernel_conf.N_blk); - const dim_t k_blks = utils::div_up(kernel_conf.K, kernel_conf.K_blk); - const auto blk_size = kernel_conf.K_blk * kernel_conf.N_blk; - - const auto i_dt_sz = kernel_conf.b_dt_sz; - const auto o_dt_sz = kernel_conf.a_dt_sz; - - for (dim_t n_blk_idx = 0; n_blk_idx < n_blks; n_blk_idx++) { - const auto n = n_blk_idx * kernel_conf.N_blk; - const bool is_N_tail = (kernel_conf.N - n) < kernel_conf.N_blk; - auto ker_exec_ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); - ker_exec_ctx.current_N_blk - = is_N_tail ? kernel_conf.N_tail : kernel_conf.N_blk; - - int k_blk_idx = 0; - for (; k_blk_idx < kernel_conf.K / kernel_conf.K_blk; k_blk_idx++) { - const auto k = k_blk_idx * kernel_conf.K_blk; - const auto src_offset - = i_dt_sz * (k * strides_[0] + n * strides_[1]); - const auto dst_offset - = o_dt_sz * (k_blk_idx * blk_size + n_blk_idx * k_blks); - ker_exec_ctx.src = &src_ptr[src_offset]; - ker_exec_ctx.tr_src = &dst_ptr[dst_offset]; - ker_exec_ctx.current_K_start = k; - ker_exec_ctx.current_K_iters = kernel_conf.K_blk; - (*pack_B_kernel_)(&ker_exec_ctx); - } - if (kernel_conf.K_tail > 0) { - const auto k = k_blk_idx * kernel_conf.K_blk; - const auto src_offset - = i_dt_sz * (k * strides_[0] + n * strides_[1]); - const auto dst_offset - = o_dt_sz * (k_blk_idx * blk_size + n_blk_idx * k_blks); - ker_exec_ctx.src = &src_ptr[src_offset]; - ker_exec_ctx.tr_src = &dst_ptr[dst_offset]; - ker_exec_ctx.current_K_start = k; - ker_exec_ctx.current_K_iters = kernel_conf.K_tail; - (*pack_B_kernel_)(&ker_exec_ctx); - } - } - - if (get_verbose(verbose_t::exec_profile, component_t::ukernel)) { - double duration_ms = get_msec() - start_ms; - - std::stringstream ss; - ss << "cpu,transform,pack_B,undef," << verbose_info_; - VPROF(start_ms, ukernel, exec, VERBOSE_profile, ss.str().c_str(), - duration_ms); - } - return status::success; -} - -status_t transform_t::create_verbose_info() { -#if defined(DISABLE_VERBOSE) - return status::success; -#else - std::stringstream ss; - - memory_desc_t src_md; - const dims_t dims = {K_, N_}; - CHECK(memory_desc_init_by_strides(src_md, 2, dims, in_dt_, strides_)); - - memory_desc_t dst_md; - const dims_t dst_strides = {out_ld_, 1}; - CHECK(memory_desc_init_by_strides(dst_md, 2, dims, out_dt_, dst_strides)); - - ss << md2fmt_str("src", &src_md, format_kind::undef) << " "; - ss << md2fmt_str("dst", &dst_md, format_kind::undef); - ss << ",,," << md2dim_str(&src_md); - - verbose_info_ = ss.str(); - return status::success; -#endif -} - -//////////////// -// Public API // -//////////////// - -///////////////////////// -// Attribute arguments // -///////////////////////// - -status_t dnnl_ukernel_attr_params_create(attr_params_t **attr_params) { - *attr_params = new attr_params_t(); - return status::success; -} - -status_t dnnl_ukernel_attr_params_set_post_ops_args( - attr_params_t *attr_params, const void **post_ops_args) { - if (attr_params == nullptr) return status::invalid_arguments; - - CHECK(attr_params->set_post_ops_args(post_ops_args)); - return status::success; -} - -status_t dnnl_ukernel_attr_params_set_A_scales( - attr_params_t *attr_params, const void *a_scales) { - if (attr_params == nullptr) return status::invalid_arguments; - - CHECK(attr_params->set_scales(a_scales, DNNL_ARG_SRC)); - return status::success; -} - -status_t dnnl_ukernel_attr_params_set_B_scales( - attr_params_t *attr_params, const void *b_scales) { - if (attr_params == nullptr) return status::invalid_arguments; - - CHECK(attr_params->set_scales(b_scales, DNNL_ARG_WEIGHTS)); - return status::success; -} - -status_t dnnl_ukernel_attr_params_set_D_scales( - attr_params_t *attr_params, const void *d_scales) { - if (attr_params == nullptr) return status::invalid_arguments; - - CHECK(attr_params->set_scales(d_scales, DNNL_ARG_DST)); - return status::success; -} - -status_t dnnl_ukernel_attr_params_destroy(attr_params_t *attr_params) { - delete attr_params; - return status::success; -} - -//////////// -// BRGeMM // -//////////// - -status_t dnnl_brgemm_create(brgemm_t **brgemm, dim_t M, dim_t N, dim_t K, - dim_t batch_size, dim_t lda, dim_t ldb, dim_t ldc, data_type_t a_dt, - data_type_t b_dt, data_type_t c_dt) { - if (batch_size <= 0) { - VCHECK_BRGEMM_STATUS( - status::invalid_arguments, false, "batch size is non-positive"); - } - - *brgemm = new brgemm_t( - M, N, K, batch_size, lda, ldb, ldc, a_dt, b_dt, c_dt); - return status::success; -} - -status_t dnnl_brgemm_set_add_C(brgemm_t *brgemm, int add_C) { - if (brgemm == nullptr) return invalid_arguments; - - CHECK(brgemm->set_add_C(add_C)); - return status::success; -} - -status_t dnnl_brgemm_set_post_ops(brgemm_t *brgemm, dim_t ldd, data_type_t d_dt, - const post_ops_t *post_ops) { - if (brgemm == nullptr) return invalid_arguments; - - CHECK(brgemm->set_post_ops(ldd, d_dt, post_ops)); - return status::success; -} - -status_t dnnl_brgemm_set_A_scales(brgemm_t *brgemm, int a_scale_mask) { - if (brgemm == nullptr) return invalid_arguments; - - CHECK(brgemm->set_scales(a_scale_mask, DNNL_ARG_SRC)); - return status::success; -} - -status_t dnnl_brgemm_set_B_scales(brgemm_t *brgemm, int b_scale_mask) { - if (brgemm == nullptr) return invalid_arguments; - - CHECK(brgemm->set_scales(b_scale_mask, DNNL_ARG_WEIGHTS)); - return status::success; -} - -status_t dnnl_brgemm_set_D_scales(brgemm_t *brgemm, int d_scale_mask) { - if (brgemm == nullptr) return invalid_arguments; - - CHECK(brgemm->set_scales(d_scale_mask, DNNL_ARG_DST)); - return status::success; -} - -status_t dnnl_brgemm_finalize(brgemm_t *brgemm) { - if (brgemm == nullptr) return invalid_arguments; - - CHECK(brgemm->finalize()); - return status::success; -} - -status_t dnnl_brgemm_get_B_pack_type( - const brgemm_t *brgemm, dnnl_pack_type_t *pack_type) { - if (brgemm == nullptr) return invalid_arguments; - - if (pack_type) *pack_type = brgemm->get_B_pack_type(); - return status::success; -} - -status_t dnnl_brgemm_get_scratchpad_size(const brgemm_t *brgemm, size_t *size) { - if (brgemm == nullptr) return invalid_arguments; - - if (size) *size = brgemm->get_scratchpad_size(); - return status::success; -} - -status_t dnnl_brgemm_set_hw_context(const brgemm_t *brgemm) { - if (brgemm == nullptr) return invalid_arguments; - - CHECK(brgemm->set_hw_context()); - return status::success; -} - -status_t dnnl_brgemm_release_hw_context() { - if (mayiuse(avx512_core_amx)) { - VCHECK_BRGEMM(amx_tile_release() == status::success, - "amx_tile_release failed"); - } - - return status::success; -} - -status_t dnnl_brgemm_generate(brgemm_t *brgemm) { - if (brgemm == nullptr) return invalid_arguments; - - CHECK(brgemm->generate()); - return status::success; -} - -status_t dnnl_brgemm_execute(const brgemm_t *brgemm, const void *A_ptr, - const void *B_ptr, const dim_t *A_B_offsets, void *C_ptr, - void *scratchpad_ptr) { - CHECK(brgemm->execute(A_ptr, B_ptr, A_B_offsets, C_ptr, scratchpad_ptr)); - return status::success; -} - -status_t dnnl_brgemm_execute_postops(const brgemm_t *brgemm, const void *A_ptr, - const void *B_ptr, const dim_t *A_B_offsets, const void *C_ptr, - void *D_ptr, void *scratchpad_ptr, const attr_params_t *attr_params) { - CHECK(brgemm->execute(A_ptr, B_ptr, A_B_offsets, C_ptr, D_ptr, - scratchpad_ptr, attr_params)); - return status::success; -} - -status_t dnnl_brgemm_destroy(brgemm_t *brgemm) { - delete brgemm; - return status::success; -} - -/////////////// -// Transform // -/////////////// - -status_t dnnl_transform_create(transform_t **transform, dim_t K, dim_t N, - pack_type_t in_pack_type, dim_t in_ld, dim_t out_ld, data_type_t in_dt, - data_type_t out_dt) { - if (transform == nullptr) return status::invalid_arguments; - - *transform - = new transform_t(K, N, in_pack_type, in_ld, out_ld, in_dt, out_dt); - return status::success; -} - -status_t dnnl_transform_generate(transform_t *transform) { - if (transform == nullptr) return status::invalid_arguments; - - CHECK(transform->generate()); - return status::success; -} - -status_t dnnl_transform_execute( - const transform_t *transform, const void *in_ptr, void *out_ptr) { - if (utils::any_null(transform, in_ptr, out_ptr)) - return status::invalid_arguments; - - CHECK(transform->execute(in_ptr, out_ptr)); - return status::success; -} - -status_t dnnl_transform_destroy(transform_t *transform) { - delete transform; - return status::success; -} - -#endif - -//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/x64/brgemm/capi/brgemm_api.hpp b/src/cpu/x64/brgemm/capi/brgemm_api.hpp deleted file mode 100644 index 0a2604e9520..00000000000 --- a/src/cpu/x64/brgemm/capi/brgemm_api.hpp +++ /dev/null @@ -1,173 +0,0 @@ -/******************************************************************************* -* Copyright 2024 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_X64_BRGEMM_CAPI_BRGEMM_API_HPP -#define CPU_X64_BRGEMM_CAPI_BRGEMM_API_HPP - -#include - -#include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp" -#include "cpu/x64/matmul/brgemm_matmul_utils.hpp" - -#include "cpu/x64/brgemm/brgemm_types.hpp" - -#ifdef DNNL_EXPERIMENTAL_UKERNEL - -// A section identical to c_map_types.hpp but just for brgemm ukernel so far. -namespace dnnl { -namespace impl { -namespace cpu { -namespace x64 { - -using pack_type_t = dnnl_pack_type_t; -namespace pack_type { -const pack_type_t undef = dnnl_pack_type_undef; -const pack_type_t no_trans = dnnl_pack_type_no_trans; -const pack_type_t trans = dnnl_pack_type_trans; -const pack_type_t pack32 = dnnl_pack_type_pack32; -} // namespace pack_type - -using attr_params_t = dnnl_ukernel_attr_params; - -} // namespace x64 -} // namespace cpu -} // namespace impl -} // namespace dnnl - -struct dnnl_ukernel_attr_params : public dnnl::impl::c_compatible { - dnnl_ukernel_attr_params() = default; - - dnnl::impl::status_t set_post_ops_args(const void **post_ops_args); - const void *get_post_ops_args() const { return post_ops_args_; } - - dnnl::impl::status_t set_scales(const void *scales, int arg); - const void *get_scales(int arg) const; - -private: - const void *post_ops_args_; - const void *a_scales_; - const void *b_scales_; - const void *d_scales_; -}; - -struct dnnl_brgemm : public dnnl::impl::c_compatible { - dnnl_brgemm(dnnl::impl::dim_t M, dnnl::impl::dim_t N, dnnl::impl::dim_t K, - dnnl::impl::dim_t batch_size, dnnl::impl::dim_t lda, - dnnl::impl::dim_t ldb, dnnl::impl::dim_t ldc, - dnnl::impl::data_type_t a_dt, dnnl::impl::data_type_t b_dt, - dnnl::impl::data_type_t c_dt) - : M_(M) - , N_(N) - , K_(K) - , batch_size_(batch_size) - , lda_(lda) - , ldb_(ldb) - , ldc_(ldc) - , ldd_(ldc) // User may overwrite with set_post_ops(). - , a_dt_(a_dt) - , b_dt_(b_dt) - , c_dt_(c_dt) - , d_dt_(c_dt) // User may overwrite with set_post_ops(). - , beta_(0.f) // User may overwrite with set_add_C(). - , brgemm_kernel_(nullptr) {} - - ~dnnl_brgemm(); - - dnnl::impl::status_t set_add_C(int add_C); - - dnnl::impl::status_t set_post_ops(dnnl::impl::dim_t ldd, - dnnl::impl::data_type_t d_dt, - const dnnl::impl::post_ops_t *post_ops); - - dnnl::impl::status_t set_scales(int mask, int arg); - - dnnl::impl::status_t finalize(); - - dnnl::impl::cpu::x64::pack_type_t get_B_pack_type() const; - - size_t get_scratchpad_size() const; - - dnnl::impl::status_t set_hw_context() const; - - dnnl::impl::status_t generate(); - - dnnl::impl::status_t execute(const void *A_ptr, const void *B_ptr, - const dnnl::impl::dim_t *A_B_offsets, void *C_ptr, - void *scratchpad_ptr) const; - dnnl::impl::status_t execute(const void *A_ptr, const void *B_ptr, - const dnnl::impl::dim_t *A_B_offsets, const void *C_ptr, - void *D_ptr, void *scratchpad_ptr, - const dnnl::impl::cpu::x64::attr_params_t *attr_params) const; - -private: - // User's inputs. - dnnl::impl::dim_t M_, N_, K_, batch_size_; - dnnl::impl::dim_t lda_, ldb_, ldc_, ldd_; - dnnl::impl::data_type_t a_dt_, b_dt_, c_dt_, d_dt_; - float beta_; - // A copy of attributes to avoid dependency on user's attributes lifetime. - dnnl::impl::primitive_attr_t attr_; - - // A main kernel. - dnnl::impl::cpu::x64::brgemm_desc_t brgemm_desc_; - dnnl::impl::cpu::x64::brgemm_kernel_t *brgemm_kernel_; - - // Creates a `verbose_info_` string once during `generate()` call, and calls - // it during execute(). This is done to avoid string re-creation. - dnnl::impl::status_t create_verbose_info(); - std::string verbose_info_; -}; - -struct dnnl_transform : public dnnl::impl::c_compatible { - // Ctor that follows a call to initialize matmul conf struct. - dnnl_transform(dnnl::impl::dim_t K, dnnl::impl::dim_t N, - dnnl::impl::cpu::x64::pack_type_t in_pack_type, - dnnl::impl::dim_t in_ld, dnnl::impl::dim_t out_ld, - dnnl::impl::data_type_t in_dt, dnnl::impl::data_type_t out_dt); - - // Generates a transform kernel. - dnnl::impl::status_t generate(); - - // Executes a transform kernel. - dnnl::impl::status_t execute(const void *src, void *dst) const; - -private: - // User's inputs. - dnnl::impl::dim_t K_, N_; - dnnl::impl::dim_t in_ld_, out_ld_; - dnnl::impl::data_type_t in_dt_, out_dt_; - // Save `strides_` for `execute` to get proper source offset. - dnnl::impl::dims_t strides_; - - // A transform kernel. - // Note: though it's a generic class for any kind of transformation, so far - // it's only matmul's copy_B. - dnnl::impl::cpu::x64::matmul::brgemm_matmul_conf_t bmc_; - // `unique_ptr` is required by API that generates a kernel. - std::unique_ptr - pack_B_kernel_; - - // Creates a `verbose_info_` string once during `generate()` call, and calls - // it during execute(). This is done to avoid string re-creation. - dnnl::impl::status_t create_verbose_info(); - std::string verbose_info_; -}; - -#endif - -#endif - -//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp b/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp index 39c5a990b05..7a8aaf66445 100644 --- a/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ #include "cpu/x64/brgemm/jit_brdgmm_kernel.hpp" #include "cpu/x64/cpu_barrier.hpp" #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" -#include "cpu/x64/jit_generator.hpp" #define GET_OFF(field) offsetof(brgemm_kernel_params_t, field) #define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field) @@ -39,12 +38,13 @@ using namespace Xbyak; template jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t( const brgemm_desc_t &abrd) - : jit_generator(jit_name(), abrd.isa_impl) + : jit_base_brgemm_kernel_t(jit_name(), abrd.isa_impl) , brg(abrd) - , simd_w_(vreg_traits::vlen / brg.typesize_C) + , simd_w_(vreg_traits_t::vlen / brg.typesize_C) , max_vmms_(isa_num_vregs(brg.isa_impl)) , compute_dst_zp_(brg.zp_type_c != brgemm_broadcast_t::none) , compute_src_zp_(brg.zp_type_a != brgemm_broadcast_t::none) + , is_src_zp_bcast_(brg.zp_type_a == brgemm_broadcast_t::per_tensor) , compute_compensation_(compute_src_zp_ || brg.req_s8s8_compensation) , has_vpad_(brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0) , has_bpad_(brg.brgattr.max_top_bpad > 0 || brg.brgattr.max_bottom_bpad > 0) @@ -147,7 +147,7 @@ void jit_brdgmm_kernel_base_t::read_params() { } if (compute_src_zp_) { - mov(reg_tmp, ptr[param1 + GET_OFF(zp_a_val)]); + mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_values)]); mov(ptr[rsp + src_zp_value_], reg_tmp); mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_compensations)]); @@ -238,8 +238,8 @@ void jit_brdgmm_kernel_base_t::cvt2ps(data_type_t type_in, bool store) { const int tail_size = tail_length(); const bool is_load_tail = op.isMEM() && mask_flag && tail_size > 0 - && (tail_size - < static_cast(vreg_traits::vlen / sizeof(float))); + && (tail_size < static_cast( + vreg_traits_t::vlen / sizeof(float))); if (IMPLICATION(is_load_tail, isa_has_masks(brg.isa_impl))) { const Vmm vmm = maybe_mask(vmm_in, is_load_tail, store); switch (type_in) { @@ -473,9 +473,9 @@ void jit_brdgmm_kernel_base_t::store_accumulators_apply_post_ops( const bool dt_requires_saturation = one_of(brg.dt_d, data_type::u8, data_type::s8, data_type::s32); - auto vmm_lbound = vmm_tmp(0); - auto vmm_ubound = vmm_tmp(1); if (dt_requires_saturation) { + auto vmm_lbound = vmm_tmp(0); + auto vmm_ubound = vmm_tmp(1); init_saturate_f32( vmm_lbound, vmm_ubound, reg_tmp, data_type::f32, brg.dt_d); } @@ -484,6 +484,8 @@ void jit_brdgmm_kernel_base_t::store_accumulators_apply_post_ops( for (int m = 0; m < m_blocks; m++) { if (dt_requires_saturation) { + auto vmm_lbound = vmm_tmp(0); + auto vmm_ubound = vmm_tmp(1); for_(int n = 0; n < n_blocks; n++) for (int v_i = 0; v_i < v_substep; ++v_i) { if (get_substep_simd(n, v_i, has_n_tail) <= 0) continue; @@ -511,10 +513,7 @@ void jit_brdgmm_kernel_base_t::store_accumulators_apply_post_ops( if (brg.is_bf16_emu) bf16_emu_->vcvtneps2bf16(vmm_low, vmm); else - vcvtneps2bf16(vmm_low, vmm, - brg.isa_impl == avx2_vnni_2 - ? Xbyak::VexEncoding - : Xbyak::EvexEncoding); + vcvtneps2bf16(vmm_low, vmm, get_encoding()); if (mask_flag) vmovdqu16(addr, r_vmm_low); else @@ -553,9 +552,9 @@ void jit_brdgmm_kernel_base_t::store_accumulators_without_post_ops( const bool dt_requires_saturation = brg.is_int8 && brg.dt_c != data_type::s32; - auto vmm_lbound = vmm_tmp(0); - auto vmm_ubound = vmm_tmp(1); if (dt_requires_saturation) { + auto vmm_lbound = vmm_tmp(0); + auto vmm_ubound = vmm_tmp(1); init_saturate_f32( vmm_lbound, vmm_ubound, reg_tmp, data_type::f32, brg.dt_d); } @@ -567,8 +566,11 @@ void jit_brdgmm_kernel_base_t::store_accumulators_without_post_ops( if (substep_simd <= 0) continue; const bool mask_flag = substep_simd < simd_w_; auto vmm_acc = accm(m_blocks, n_blocks, m, n, v_i); - if (dt_requires_saturation) + if (dt_requires_saturation) { + auto vmm_lbound = vmm_tmp(0); + auto vmm_ubound = vmm_tmp(1); saturate_cvt_f32(vmm_acc, vmm_lbound, vmm_ubound, brg.dt_d); + } const auto offset = C_offset(m, n, v_i); if (IMPLICATION(mask_flag, isa_has_masks(brg.isa_impl))) { auto vmm_acc_masked = maybe_mask(vmm_acc, mask_flag, true); @@ -604,6 +606,17 @@ void jit_brdgmm_kernel_base_t::maybe_transpose_interleaved_vnni_to_plain( } } +template +void jit_brdgmm_kernel_base_t::load_src_zp() { + mov(reg_src_zero_point, ptr[rsp + src_zp_value_]); + lea(reg_src_zero_point, + is_src_zp_bcast_ + ? ptr_b[reg_src_zero_point] + : ptr[reg_src_zero_point + reg_aux_N * sizeof(int32_t)]); + if (!is_superset(brg.isa_impl, avx512_core) && is_src_zp_bcast_) + uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]); +} + template void jit_brdgmm_kernel_base_t::compute_int8_compensation( int m_blocks, int n_blocks, bool has_n_tail) { @@ -615,12 +628,10 @@ void jit_brdgmm_kernel_base_t::compute_int8_compensation( lea(reg_s8s8_comp, ptr[reg_s8s8_comp + reg_aux_N * sizeof(int32_t)]); } if (compute_src_zp_) { - lea(reg_src_zero_point, ptr[rsp + src_zp_value_]); + load_src_zp(); mov(reg_zp_compensation, ptr[rsp + zp_compensation_]); lea(reg_zp_compensation, ptr[reg_zp_compensation + reg_aux_N * sizeof(int32_t)]); - if (!is_superset(brg.isa_impl, avx512_core)) - uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]); } for_(int v_i = 0; v_i < v_substep; ++v_i) @@ -635,16 +646,35 @@ void jit_brdgmm_kernel_base_t::compute_int8_compensation( } if (compute_src_zp_) { // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32 - const Vmm vmm_zp = vmm_zp_comp(); - vmovups(vmm_zp, - maybe_EVEX_compress_addr(reg_zp_compensation, offset)); - if (is_superset(brg.isa_impl, avx512_core)) { - const bool src_zp_is_common = true; - vpmulld(vmm_zp, vmm_zp, - maybe_EVEX_compress_addr( - reg_src_zero_point, 0, src_zp_is_common)); + const bool is_tail + = n + 1 == n_blocks && has_n_tail && substep_simd < simd_w_; + const Vmm vmm_zp = isa_has_masks(brg.isa_impl) + ? maybe_mask(vmm_zp_comp(), is_tail, false) + : vmm_zp_comp(); + if (IMPLICATION(is_tail, isa_has_masks(brg.isa_impl))) { + vmovups(vmm_zp, + maybe_EVEX_compress_addr(reg_zp_compensation, offset)); + if (is_src_zp_bcast_) { + if (is_superset(brg.isa_impl, avx512_core)) + vpmulld(vmm_zp, vmm_zp, + maybe_EVEX_compress_addr( + reg_src_zero_point, 0, true)); + else + vpmulld(vmm_zp, vmm_zp, vmm_bcast()); + } else + vpmulld(vmm_zp, vmm_zp, + maybe_EVEX_compress_addr( + reg_src_zero_point, offset)); } else { - vpmulld(vmm_zp, vmm_zp, vmm_bcast()); + const int tail_size = tail_length(); + const Vmm ymm_tmp + = vmm_bcast(); // used for bcast or tail processing in avx2 + load_data(data_type::s32, vmm_zp, + ptr[reg_zp_compensation + offset], tail_size); + if (!is_src_zp_bcast_) + load_data(data_type::s32, ymm_tmp, + ptr[reg_src_zero_point + offset], tail_size); + vpmulld(vmm_zp, vmm_zp, ymm_tmp); } } for (int m = 0; m < m_blocks; m++) { @@ -696,9 +726,9 @@ void jit_brdgmm_kernel_base_t::load_a( + is_tail_block * v_i * simd_w_ * brg.typesize_A]; if (IMPLICATION(mask_flag, isa_has_masks(brg.isa_impl))) { vmma = maybe_mask(vmma, mask_flag, false); - if (brg.is_f32) { + if (brg.dt_a == data_type::f32) { vmovups(vmma, addr); - } else if (brg.is_bf16) { + } else if (brg.dt_a == data_type::bf16) { if (brg.isa_impl == avx2_vnni_2) { if (is_tail_block) { vpmovzxwd(vmma, addr); @@ -711,7 +741,7 @@ void jit_brdgmm_kernel_base_t::load_a( vpmovzxwd(vmma, addr); if (is_slow_bf16_vnni()) vpslld(vmma, vmma, 16); } - } else if (brg.is_f16) { + } else if (brg.dt_b == data_type::f16) { if (brg.isa_impl == avx2_vnni_2) { if (is_tail_block) vcvtph2ps(vmma, addr); @@ -721,7 +751,7 @@ void jit_brdgmm_kernel_base_t::load_a( vcvtneoph2ps(vmma, addr); } else vcvtph2ps(vmma, addr); - } else if (brg.is_int8) { + } else if (utils::one_of(brg.dt_a, data_type::s8, data_type::u8)) { if (is_fast_vnni_int8()) { assert(!mask_flag); vbroadcasti32x4(vmma, addr); @@ -747,9 +777,9 @@ void jit_brdgmm_kernel_base_t::load_b( const bool is_tail_block = has_n_tail && (n_i + 1 == n_blocks); const auto addr = ptr[reg_aux_B + B_offset(n_i) + is_tail_block * v_i * simd_w_ * brg.typesize_B]; - if (brg.is_f32) { + if (brg.dt_b == data_type::f32) { vmovups(vmmb, addr); - } else if (brg.is_int8) { + } else if (brg.dt_b == data_type::s8) { if (wei_zp) { // load weights for zero-point computation vpmovsxbd(vmmb, addr); if (is_fast_vnni_int8()) vpermd(vmmb, vmm_permute(), vmmb); @@ -762,7 +792,7 @@ void jit_brdgmm_kernel_base_t::load_b( vpmovsxbd(vmmb, addr); } } - } else if (brg.is_f16) { + } else if (brg.dt_b == data_type::f16) { if (brg.isa_impl == avx2_vnni_2) { if (is_tail_block) vcvtph2ps(vmmb, addr); @@ -772,7 +802,7 @@ void jit_brdgmm_kernel_base_t::load_b( vcvtneoph2ps(vmmb, addr); } else vcvtph2ps(vmmb, addr); - } else if (brg.is_bf16) { + } else if (brg.dt_b == data_type::bf16) { if (brg.isa_impl == avx2_vnni_2) { if (is_tail_block) { vpmovzxwd(vmmb, addr); @@ -783,31 +813,52 @@ void jit_brdgmm_kernel_base_t::load_b( vcvtneobf162ps(vmmb, addr); } else { vpmovzxwd(vmmb, addr); - if (is_slow_bf16_vnni()) vpslld(vmmb, vmmb, 16); + if (is_slow_bf16_vnni() || brg.is_f32) vpslld(vmmb, vmmb, 16); } } } template void jit_brdgmm_kernel_base_t::comp_dot_product( - compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb) { + compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb, int n, + bool is_tail_block) { switch (kernel_type) { case compute_pad_kernel_t::s8s8_kernel: - vpdpbusd(vmm_acc, vmm_shift(), vmmb, - is_superset(brg.isa_impl, avx512_core) - ? Xbyak::EvexEncoding - : Xbyak::VexEncoding); + vpdpbusd(vmm_acc, vmm_shift(), vmmb, get_encoding()); break; - case compute_pad_kernel_t::zero_point_kernel: - if (is_superset(brg.isa_impl, avx512_core)) { - vpmulld(vmm_zp_comp(), vmmb, - maybe_EVEX_compress_addr(reg_src_zero_point, 0, true)); + case compute_pad_kernel_t::zero_point_kernel: { + const Vmm vmm_zp = isa_has_masks(brg.isa_impl) + ? maybe_mask(vmm_zp_comp(), is_tail_block, false) + : vmm_zp_comp(); + const size_t offset = comp_offset(n); + if (IMPLICATION(is_tail_block, isa_has_masks(brg.isa_impl))) { + if (is_src_zp_bcast_) { + if (is_superset(brg.isa_impl, avx512_core)) + vpmulld(vmm_zp, vmmb, + maybe_EVEX_compress_addr( + reg_src_zero_point, 0, true)); + else + vpmulld(vmm_zp, vmmb, vmm_bcast()); + } else { + const Xbyak::Address src_zp_addr = maybe_EVEX_compress_addr( + reg_src_zero_point, offset); + if (is_fast_vnni_int8()) { + vmovups(vmm_zp, src_zp_addr); + vpermd(vmm_zp, vmm_permute(), vmm_zp); + vpmulld(vmm_zp, vmmb, vmm_zp); + } else + vpmulld(vmm_zp, vmmb, src_zp_addr); + } } else { - uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]); - vpmulld(vmm_zp_comp(), vmmb, vmm_bcast()); + const Vmm ymm_tmp + = vmm_bcast(); // used for bcast or tail processing in avx2 + if (!is_src_zp_bcast_) + load_data(data_type::s32, ymm_tmp, + ptr[reg_src_zero_point + offset], tail_length()); + vpmulld(vmm_zp, vmmb, ymm_tmp); } vpaddd(vmm_acc, vmm_acc, vmm_zp_comp()); - break; + } break; default: assert(!"unsupported comp_kernel type"); } } @@ -848,21 +899,25 @@ void jit_brdgmm_kernel_base_t::pad_comp_kernel( for (int pad_i = max_m_unroll; pad_i > 0; --pad_i) { L(jmp_table_labels[pad_i]); - if (is_zero_point_kernel) - lea(reg_src_zero_point, ptr[rsp + src_zp_value_]); + if (is_zero_point_kernel) load_src_zp(); if (pad_i > m_blocks) continue; const int m_i = get_mi(pad_i); int p_b_i = 0; for (int n_i = 0; n_i < n_blocks; ++n_i, ++p_b_i) { - if (get_substep_simd(n_i, 0, has_tail) <= 0) continue; + const int substep_simd = get_substep_simd(n_i, 0, has_tail); + if (substep_simd <= 0) continue; const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0); + const bool is_tail_block + = n_i + 1 == n_blocks && has_tail && substep_simd < simd_w_; if (p_b_i < n_preload_b_vmms) { - comp_dot_product(kernel_type, vmm_acc, vmm_b(p_b_i)); + comp_dot_product( + kernel_type, vmm_acc, vmm_b(p_b_i), n_i, is_tail_block); } else { // preloaded vmm_b not available const Vmm vmm_wei = vmm_b(max_bvmms - 1); load_b(vmm_wei, n_i, 0, has_tail, load_broadcast_wei); - comp_dot_product(kernel_type, vmm_acc, vmm_wei); + comp_dot_product( + kernel_type, vmm_acc, vmm_wei, n_i, is_tail_block); } } } @@ -880,8 +935,7 @@ void jit_brdgmm_kernel_base_t::batch_pad_kernel( auto kernel_body = [&](compute_pad_kernel_t kernel_type) { const bool is_zero_point_kernel = kernel_type == compute_pad_kernel_t::zero_point_kernel; - if (is_zero_point_kernel) - lea(reg_src_zero_point, ptr[rsp + src_zp_value_]); + if (is_zero_point_kernel) load_src_zp(); for (int nb_i = 0; nb_i < n_blocks; nb_i += max_bvmms) { const int n_e = nstl::min(nb_i + max_bvmms, n_blocks) - nb_i; for (int i = 0; i < n_e; ++i) { @@ -893,9 +947,13 @@ void jit_brdgmm_kernel_base_t::batch_pad_kernel( for_(int m_i = 0; m_i < m_blocks; ++m_i) for (int i = 0; i < n_e; ++i) { const int n_i = nb_i + i; - if (get_substep_simd(n_i, 0, has_tail) <= 0) continue; + const int substep_simd = get_substep_simd(n_i, 0, has_tail); + if (substep_simd <= 0) continue; const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0); - comp_dot_product(kernel_type, vmm_acc, vmm_b(i)); + const bool is_tail_block + = n_i + 1 == n_e && has_tail && substep_simd < simd_w_; + comp_dot_product( + kernel_type, vmm_acc, vmm_b(i), n_i, is_tail_block); } } }; @@ -938,10 +996,7 @@ void jit_brdgmm_kernel_base_t::brdgmm_microkernel(int m_blocks, if (brg.dt_a == data_type::s8 && isa_has_s8s8(brg.isa_impl)) vpdpbssd(vmm_acc, vmma, vmmb); else - vpdpbusd(vmm_acc, vmma, vmmb, - is_superset(brg.isa_impl, avx512_core) - ? Xbyak::EvexEncoding - : Xbyak::VexEncoding); + vpdpbusd(vmm_acc, vmma, vmmb, get_encoding()); } }; @@ -1007,8 +1062,8 @@ void jit_brdgmm_kernel_base_t::brdgmm_microkernel(int m_blocks, align(64); L(jmp_table_base); - for (int m_i = 0; m_i < m_blocks; ++m_i) { - putL(jmp_table_labels[m_i]); + for (const auto &label : jmp_table_labels) { + putL(label); } } @@ -1384,7 +1439,7 @@ void brdgmm_kernel_t::operator()(brgemm_kernel_params_t *params) const { } template -const jit_generator *brdgmm_kernel_t::get_jit_generator() const { +const jit_generator_t *brdgmm_kernel_t::get_jit_generator() const { return brgemm_kernel_; } diff --git a/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp b/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp index e3d6138dd5e..236d027de56 100644 --- a/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp +++ b/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace cpu { namespace x64 { template -struct jit_brdgmm_kernel_base_t : public jit_generator { +struct jit_brdgmm_kernel_base_t : public jit_base_brgemm_kernel_t { jit_brdgmm_kernel_base_t(const brgemm_desc_t &abrd); DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brdgmm_kernel_base_t) @@ -160,13 +160,15 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { return vmm_alloc.get_compute_vmm_count(); } + const brgemm_desc_t &get_brg() const override { return brg; } + private: // note: this kernel doesn't yet support TMM's. We differentiate Wmm and Vmm // just to follow same template style as brgemm_kernel. using Vmm = typename utils::conditional::value, Xbyak::Zmm, Wmm>::type; - using Vmm_low_t = typename vreg_traits::Vmm_lower_t; + using Vmm_low_t = typename vreg_traits_t::Vmm_lower_t; using po_injector_t = injector::jit_uni_postops_injector_base_t; std::unique_ptr postops_injector_; std::unique_ptr bf16_emu_; @@ -230,6 +232,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { const int simd_w_; const int max_vmms_; const bool compute_dst_zp_, compute_src_zp_; + const bool is_src_zp_bcast_; const bool compute_compensation_; // code-path for either s8s8 or src_zp const bool has_vpad_; // vertical padding w.r.t. M dimension const bool has_bpad_; // batch pad is computed for the overlap between the @@ -341,7 +344,8 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { void load_b( Vmm vmmb, int n_i, int v_i, bool has_n_tail, bool wei_zp = false); void comp_dot_product(compute_pad_kernel_t kernel_type, Vmm vmm_acc, - Vmm vmmb); // int8 compensation dot_product (zp and s8s8) + Vmm vmmb, int n, + bool is_tail_block); // int8 compensation dot_product (zp and s8s8) void pad_comp_kernel(compute_pad_kernel_t kernel_type, int m_blocks, int n_blocks, int padding, const Xbyak::Reg64 reg_pad, const std::function &get_mi, bool has_tail = false); @@ -360,6 +364,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { void apply_post_ops(int m_blocks, int n_blocks, bool has_n_tail); void maybe_transpose_interleaved_vnni_to_plain( int m_blocks, int n_blocks, bool has_n_tail); + void load_src_zp(); void compute_int8_compensation(int m_blocks, int n_blocks, bool has_n_tail); void store_accumulators(int m_blocks, int n_blocks, bool has_n_tail); void store_accumulators_without_post_ops( diff --git a/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp b/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp index 388e8c01742..fef520a0929 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,6 @@ #include "cpu/x64/cpu_isa_traits.hpp" #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" #include "cpu/x64/jit_avx512_core_fp8cvt.hpp" -#include "cpu/x64/jit_generator.hpp" #define GET_OFF(field) offsetof(brgemm_kernel_params_t, field) #define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field) @@ -39,9 +38,9 @@ namespace x64 { using namespace dnnl::impl::utils; using namespace Xbyak; -struct jit_brgemm_amx_uker_base_t : public jit_generator { +struct jit_brgemm_amx_uker_base_t : public jit_base_brgemm_kernel_t { jit_brgemm_amx_uker_base_t(const brgemm_desc_t &abrg) - : jit_generator(jit_name(), abrg.isa_impl) + : jit_base_brgemm_kernel_t(jit_name(), abrg.isa_impl) , brg(abrg) , postops_injector_(nullptr) { @@ -135,6 +134,8 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator { brgemm_desc_t brg; + const brgemm_desc_t &get_brg() const override { return brg; } + private: using po_injector_t = injector::jit_uni_postops_injector_base_t; std::unique_ptr postops_injector_; @@ -145,8 +146,7 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator { using reg64_t = const Xbyak::Reg64; enum { simd_w = 16, - zmm_width_in_bytes = cpu_isa_traits::vlen, - tile_size = 1024 + zmm_width_in_bytes = cpu_isa_traits_t::vlen, }; // Register decomposition @@ -259,10 +259,13 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator { struct dim_iteration_t { size_t idx = 0; - std ::vector blocks; + std::vector blocks; virtual bool operator==(const dim_iteration_t &rhs) const { return blocks == rhs.blocks; } + virtual bool operator!=(const dim_iteration_t &rhs) const { + return !operator==(rhs); + } size_t pos(size_t b) const { assert(b < blocks.size()); @@ -279,12 +282,12 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator { return blocks[b].block; } - int is_tail(size_t b) const { + bool is_tail(size_t b) const { assert(b < blocks.size()); return blocks[b].is_tail; } - int block2() const { return blocks.size(); } + int block2() const { return static_cast(blocks.size()); } int length() const { if (blocks.empty()) return 0; @@ -307,13 +310,20 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator { bd_iteration_t *similar {nullptr}; Label lstart; - virtual bool operator==(const bd_iteration_t &rhs) const { + bool operator==(const dim_iteration_t &_rhs) const override { + // `downcast` will catch a type mismatch in debug mode. + // Note: it supports only a pointer type so far. + const bd_iteration_t &rhs + = *utils::downcast(&_rhs); bool res = dim_iteration_t::operator==(rhs) && A_shift == rhs.A_shift && C_shift == rhs.C_shift && D_shift == rhs.D_shift && bd_mask == rhs.bd_mask && zp_comp_pad_a_shift == rhs.zp_comp_pad_a_shift; return res; } + bool operator!=(const dim_iteration_t &_rhs) const override { + return !operator==(_rhs); + } }; struct bs_iteration_t { @@ -398,6 +408,7 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator { Xbyak::Opmask ld_full_mask = Xbyak::Opmask(2); Xbyak::Opmask ld_tail_mask = Xbyak::Opmask(3); Xbyak::Opmask fp_col_mask = Xbyak::Opmask(4); + Xbyak::Opmask rd_tail_mask = Xbyak::Opmask(5); // Zmm map below const Xbyak::Zmm &zmm_tmp_1() const noexcept { return this->zmm0; } @@ -518,6 +529,10 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator { reg64_t reg_base, size_t offset, reg64_t reg_stride, matrix_kind_t mk); + bool maybe_pre_process_k_tail(brgemm_iteration_t &bi, int bdb, + const Tmm &t1, reg64_t reg_base, size_t offset, reg64_t reg_stride, + matrix_kind_t mk); + void maybe_tileloadd_nt( brgemm_iteration_t &bi, matrix_kind_t mk, int xdb, size_t offset); @@ -709,11 +724,12 @@ size_t jit_brgemm_amx_uker_base_t::B_offset( const auto rdb_B_offset = bi.rdi->pos(0) * brg.rd_block * LDB_size_; - const auto ldb_B_offset = bi.ldi->pos(0) * ld_block_B_size_ * brg.ld_step; + const auto ldb_offs = bi.ldi->pos(ldb) * brg.ld_block; + const auto ldb_B_offset = brg.typesize_B + * ((ldb_offs / brg.LDB) * brg.brgattr.LDB2 + + (ldb_offs % brg.LDB) * brg.rd_step); - return rdb_B_offset + ldb_B_offset - + (brg.is_blocked ? 1 : brg.rd_step) * ldb * ld_block_B_size_ - + bs_offs; + return rdb_B_offset + ldb_B_offset + bs_offs; } size_t jit_brgemm_amx_uker_base_t::C_offset(const brgemm_iteration_t &bi, @@ -721,7 +737,12 @@ size_t jit_brgemm_amx_uker_base_t::C_offset(const brgemm_iteration_t &bi, const auto bi_bd_start = get_out_bd(bi.bdi, 0, 0); const auto bd = get_out_bd(bi.bdi, bdb, inp_bd); const auto bd_shift = bd - (ununroll_bd_loop ? bi_bd_start : 0); - return (size_t)bd_shift * LDC2_size_M_ + (size_t)ldb * LDC2_size_N_; + size_t ldc_elem = (size_t)ldb * brg.ld_block; + size_t bloc_idx = ldc_elem / brg.LDC; + size_t in_block = ldc_elem % brg.LDC; + + return (size_t)bd_shift * LDC2_size_M_ + (size_t)bloc_idx * LDC2_size_N_ + + in_block * brg.typesize_C; } size_t jit_brgemm_amx_uker_base_t::D_offset(const brgemm_iteration_t &bi, @@ -1103,9 +1124,14 @@ void jit_brgemm_amx_uker_base_t::prefetch_CD_range(brgemm_iteration_t &bi, auto ptr_D = EVEX_compress_addr(reg_D, d_offset); uni_prefetch(ptr_D, pft, true); } else if (are_post_ops_applicable_) { - const auto c_offset = C_offset(bi, bdb, bd, ldb_pos); - auto ptr_C = EVEX_compress_addr(reg_C, c_offset); - uni_prefetch(ptr_C, pft, true); + // TODO: split hints C and D hints + // Using prefetchw for the C matrix is generally harmful + // because the C matrix is frequently reused and remains in the cache. + // However, it is very necessary for the D matrix + + // const auto c_offset = C_offset(bi, bdb, bd, ldb_pos); + // auto ptr_C = EVEX_compress_addr(reg_C, c_offset); + // uni_prefetch(ptr_C, pft, true); } else { const auto d_offset = D_offset(bi, bdb, bd, ldb_pos); auto ptr_D = EVEX_compress_addr(reg_D, d_offset); @@ -1666,11 +1692,17 @@ void jit_brgemm_amx_uker_base_t::maybe_tileloadd_nt( auto reg_base = is_A ? reg_A : reg_B; auto reg_stride = is_A ? reg_stride_lda : reg_stride_ldb; - if (brg.is_input_convert()) + if (brg.is_input_convert()) { // try_load_nt is not supported in maybe_pre_process_data as there is // no guarantee that the data is cache line aligned. maybe_pre_process_data(bi, t1, reg_base, offset, reg_stride, mk); - else if (load_nt) + return; + } + + if (maybe_pre_process_k_tail(bi, xdb, t1, reg_base, offset, reg_stride, mk)) + return; + + if (load_nt) tileloaddt1(t1, ptr[reg_base + offset + reg_stride]); else tileloadd(t1, ptr[reg_base + offset + reg_stride]); @@ -1771,10 +1803,9 @@ void jit_brgemm_amx_uker_base_t::fp8_to_f16_upconvert(brgemm_iteration_t &bi, assert(max_num_cols > 0); if (col_tail) { - const int tail_mask = (1 << col_tail) - 1; - auto reg_tmp_32 = reg_tmp_gpr.cvt32(); - mov(reg_tmp_32, tail_mask); - kmovd(fp_col_mask, reg_tmp_32); + const auto tail_mask = (static_cast(1) << col_tail) - 1; + mov(reg_tmp_gpr, tail_mask); + kmovq(fp_col_mask, reg_tmp_gpr); } // Note: using the same register used in col_tail, so order is important @@ -1810,10 +1841,9 @@ void jit_brgemm_amx_uker_base_t::bf32_downconvert(brgemm_iteration_t &bi, assert(max_num_cols > 0); if (col_tail) { - const int tail_mask = (1 << col_tail) - 1; - auto reg_tmp_32 = reg_tmp_gpr.cvt32(); - mov(reg_tmp_32, tail_mask); - kmovw(fp_col_mask, reg_tmp_32); + const auto tail_mask = (static_cast(1) << col_tail) - 1; + mov(reg_tmp_gpr, tail_mask); + kmovq(fp_col_mask, reg_tmp_gpr); } // Note: using the same register used in col_tail, so order is important @@ -1898,10 +1928,9 @@ void jit_brgemm_amx_uker_base_t::bf32_downconvert_to_vnni( }; if (col_tail) { - const int tail_mask = (1 << col_tail) - 1; - auto reg_tmp_32 = reg_tmp_gpr.cvt32(); - mov(reg_tmp_32, tail_mask); - kmovw(fp_col_mask, reg_tmp_32); + const auto tail_mask = (static_cast(1) << col_tail) - 1; + mov(reg_tmp_gpr, tail_mask); + kmovq(fp_col_mask, reg_tmp_gpr); } // Note: using the same register used in col_tail, so order is important @@ -1962,12 +1991,12 @@ void jit_brgemm_amx_uker_base_t::maybe_pre_process_data(brgemm_iteration_t &bi, auto &transform_buf = is_A ? transform_buf_map_A_ : transform_buf_map_B_; const auto transform_offset - = use_ils_ ? brg.get_num_C_tiles() * tile_size : 0; + = use_ils_ ? brg.get_num_C_tiles() * brgemm_desc_t::tilesize : 0; const auto max_bdb2 = tloop.bdis[0].block2(); const auto max_rdb = tloop.rdis.size(); const auto matrix_a_offset = transform_offset; const auto matrix_b_offset = transform_offset - + tile_size + + brgemm_desc_t::tilesize * (nstl::max(should_save_transform(mk), should_save_transform(matrix_A) * brg.brgattr.max_bs * max_bdb2 * max_rdb)); @@ -1977,7 +2006,7 @@ void jit_brgemm_amx_uker_base_t::maybe_pre_process_data(brgemm_iteration_t &bi, if (transform_buf.find(key) != transform_buf.end()) { auto buf_idx = transform_buf[key]; - auto offt = matrix_offset + buf_idx * tile_size; + auto offt = matrix_offset + buf_idx * brgemm_desc_t::tilesize; tileloadd(t1, ptr[reg_buf + reg_converted_stride + offt]); return; } @@ -1986,7 +2015,7 @@ void jit_brgemm_amx_uker_base_t::maybe_pre_process_data(brgemm_iteration_t &bi, // save offset of the transformation if required. if (should_save_transform(mk)) { auto buf_idx = transform_buf.size(); - buf_offt = matrix_offset + buf_idx * tile_size; + buf_offt = matrix_offset + buf_idx * brgemm_desc_t::tilesize; transform_buf[key] = buf_idx; } @@ -2020,6 +2049,72 @@ void jit_brgemm_amx_uker_base_t::maybe_pre_process_data(brgemm_iteration_t &bi, if (buf_offt) sub(reg_buf, buf_offt); } +bool jit_brgemm_amx_uker_base_t::maybe_pre_process_k_tail( + brgemm_iteration_t &bi, int bdb, const Tmm &t1, reg64_t reg_base, + size_t offset, reg64_t reg_stride, matrix_kind_t mk) { + const auto &tloop = imap_[bi.apply_postops]; + + const auto need_k_tail_processing = mk == matrix_A && brg.amx_wary_k_tail() + && brg.rdb_tail != 0 && bi.bdi->idx == tloop.bdis.size() - 1 + && bdb == bi.bdi->block2() - 1 && bi.last_bsi + && tloop.is_last_rdi(bi.rdi); + + if (!need_k_tail_processing) return false; + + auto transform_offset = brg.get_num_C_tiles() * brgemm_desc_t::tilesize + + brg.get_convert_wsp_buffer_size(); + + if (transform_offset) add(reg_buf, transform_offset); + mov(reg_converted_stride, zmm_width_in_bytes); + + // reuse transformed data from matrix A for ldi > 0 + if (bi.ldi->idx == 0) { + const auto num_rows = palette_.rows[t1.getIdx()]; + const auto num_col_bytes = palette_.cols[t1.getIdx()]; + + const auto max_num_cols + = nstl::min(num_col_bytes / brg.typesize_A, brg.rdb_tail); + const size_t col_tail + = max_num_cols % (zmm_width_in_bytes / brg.typesize_A); + if (col_tail) { + const auto tail_mask = (static_cast(1) << col_tail) - 1; + mov(reg_tmp_gpr, tail_mask); + kmovq(rd_tail_mask, reg_tmp_gpr); + } + auto zmm_1 = zmm_tmp_1(); + auto zmm_1_masked = col_tail ? zmm_1 | rd_tail_mask | T_z : zmm_1; + + assert(max_num_cols > 0); + + const auto reg_data_aux = reg_tmp_gpr; + lea(reg_data_aux, ptr[reg_base + offset]); + + for (int r = 0; r < num_rows; ++r) { + switch (brg.dt_a) { + case data_type::bf16: + case data_type::f16: + vmovdqu16(zmm_1_masked, ptr[reg_data_aux]); + break; + case data_type::f8_e5m2: + case data_type::f8_e4m3: + case data_type::s8: + case data_type::u8: + vmovdqu8(zmm_1_masked, ptr[reg_data_aux]); + break; + default: assert(!"unsupported data type"); + } + vmovups(ptr[reg_buf + r * zmm_width_in_bytes], zmm_1); + add(reg_data_aux, reg_stride); + } + } + // load into tmm from the transformed data. + tileloadd(t1, ptr[reg_buf + reg_converted_stride]); + + // reset buf pointer + if (transform_offset) sub(reg_buf, transform_offset); + return true; +} + void jit_brgemm_amx_uker_base_t::gemm_microkernel_amx(brgemm_iteration_t &bi) { prf0A.reset(); prf1A.reset(); @@ -2064,8 +2159,8 @@ void jit_brgemm_amx_uker_base_t::gemm_microkernel_amx(brgemm_iteration_t &bi) { void jit_brgemm_amx_uker_base_t::rdb_loop(brgemm_iteration_t &bi) { const auto &tloop = imap_[bi.apply_postops]; - for (size_t irdi = 0; irdi < tloop.rdis.size(); irdi++) { - bi.rdi = &(tloop.rdis[irdi]); + for (auto &rdi : tloop.rdis) { + bi.rdi = &rdi; gemm_microkernel_amx(bi); } } @@ -2195,8 +2290,8 @@ void jit_brgemm_amx_uker_base_t::ldb_loop(brgemm_iteration_t &bi) { // we move to next bdb2 block. const auto &tloop = imap_[bi.apply_postops]; transform_buf_map_A_.clear(); - for (size_t ildi = 0; ildi < tloop.ldis.size(); ildi++) { - bi.ldi = &(tloop.ldis[ildi]); + for (auto &ldi : tloop.ldis) { + bi.ldi = &ldi; ldb_loop_body(bi); } } @@ -2206,6 +2301,9 @@ jit_brgemm_amx_uker_base_t::find_similar( const bd_iteration_t *bdi, bool apply_postops) { auto &tloop = imap_[apply_postops]; const auto cidx = bdi->idx; + // if wary_k_tail is true then last iteration is unique + if (brg.amx_wary_k_tail() && cidx == tloop.bdis.size() - 1) return nullptr; + for (size_t i = (actual_ils(apply_postops) ? 1 : 0); i < cidx; i++) { if (*bdi == tloop.bdis[i] && IMPLICATION(actual_ils(apply_postops), @@ -2253,8 +2351,8 @@ void jit_brgemm_amx_uker_base_t::bdb_loop(brgemm_iteration_t &bi) { mov(ptr[rsp + reg_iter_labels_list_offs_], reg_iter_labels_list); } - for (size_t ibdi = 0; ibdi < tloop.bdis.size(); ibdi++) { - bi.bdi = &(tloop.bdis[ibdi]); + for (auto &bdi : tloop.bdis) { + bi.bdi = &bdi; bdb_loop_body(bi); } if (ununroll_bd_loop) { @@ -2263,8 +2361,8 @@ void jit_brgemm_amx_uker_base_t::bdb_loop(brgemm_iteration_t &bi) { align(64); L(iteration_pointers); - for (size_t ibdi = 0; ibdi < tloop.bdis.size(); ibdi++) { - putL(tloop.bdis[ibdi].lstart); + for (const auto &bdi : tloop.bdis) { + putL(bdi.lstart); } putL(loop_end); L(loop_end); @@ -2326,11 +2424,9 @@ void jit_brgemm_amx_uker_base_t::fill_imap() { auto abdb = bdb + ibdb; if (abdb >= brg.bdb) break; if (brg.bdb_tail && abdb == brg.bdb - 1) - bdi.blocks.emplace_back( - iteration_block_t(bdi_pos, brg.bdb_tail, true)); + bdi.blocks.emplace_back(bdi_pos, brg.bdb_tail, true); else - bdi.blocks.emplace_back( - iteration_block_t(bdi_pos, brg.bd_block, false)); + bdi.blocks.emplace_back(bdi_pos, brg.bd_block, false); bdi_pos += brg.bd_block; if (bdi_pos >= brg.bcast_dim) break; bdi_pos = skipped_bd_mask(bdi_pos); @@ -2371,11 +2467,9 @@ void jit_brgemm_amx_uker_base_t::fill_imap() { auto aldb = ldb + ildb; if (aldb >= brg.ldb) break; if (brg.ldb_tail && aldb == brg.ldb - 1) - ldi.blocks.emplace_back( - iteration_block_t(ldi_pos, brg.ldb_tail, true)); + ldi.blocks.emplace_back(ldi_pos, brg.ldb_tail, true); else - ldi.blocks.emplace_back( - iteration_block_t(ldi_pos, brg.ld_block, false)); + ldi.blocks.emplace_back(ldi_pos, brg.ld_block, false); ldi_pos++; } ldi.idx = tloop.ldis.size(); @@ -2387,15 +2481,14 @@ void jit_brgemm_amx_uker_base_t::fill_imap() { rdi.blocks.reserve(1); for (int rdb = 0; rdb < brg.rdb; rdb++) { rdi.blocks.clear(); - rdi.blocks.emplace_back(iteration_block_t(rdi_pos, brg.rd_block)); + rdi.blocks.emplace_back(rdi_pos, brg.rd_block); rdi.idx = tloop.rdis.size(); tloop.rdis.push_back(rdi); rdi_pos++; } if (brg.rdb_tail > 0) { rdi.blocks.clear(); - rdi.blocks.emplace_back( - iteration_block_t(rdi_pos, brg.rdb_tail, true)); + rdi.blocks.emplace_back(rdi_pos, brg.rdb_tail, true); rdi.idx = tloop.rdis.size(); tloop.rdis.push_back(rdi); } @@ -2571,23 +2664,17 @@ void jit_brgemm_amx_uker_base_t::generate() { && brg.brgattr.bd_mask_level == 0; need_to_apply_alpha_beta_ = (brg.beta != 0.f && !may_load_accumulators_) || brg.alpha != 1.f; - const bool has_zero_points = !everyone_is(brgemm_broadcast_t::none, - brg.zp_type_a, brg.zp_type_b, brg.zp_type_c); - are_post_ops_applicable_ = one_of(true, brg.with_eltwise, brg.with_binary, - brg.with_scales, brg.with_bias, brg.with_sum, brg.dt_d != brg.dt_c, - has_zero_points, brg.with_dst_scales); - - // second level blocking eligible only if we don't use store by vectors for now - assert(IMPLICATION(are_post_ops_applicable_ || need_to_apply_alpha_beta_ - || brg.brgattr.bd_mask_level, - !brg.is_blocked && !brg.brgattr.var_bs)); + are_post_ops_applicable_ = brg.are_post_ops_applicable(); + + assert(IMPLICATION(brg.brgattr.LDB2 == 0, brg.load_dim <= brg.LDB)); + assert(IMPLICATION(brg.brgattr.var_bs, IMPLICATION(brg.is_input_convert(), brg.is_fp8_via_convert()))); read_params(); prepare_bd_mask(); Label permute_index_table; - if (brg.is_input_convert()) { + if (brg.is_input_convert() || brg.amx_wary_k_tail()) { // save tiles description for later use brgemm_init_tiles(brg, (char *)(&palette_)); // load permute indices @@ -2669,7 +2756,7 @@ void brgemm_amx_uker_t::operator()(brgemm_kernel_params_t *params) const { (*brgemm_kernel_)(params); } -const jit_generator *brgemm_amx_uker_t::get_jit_generator() const { +const jit_generator_t *brgemm_amx_uker_t::get_jit_generator() const { return brgemm_kernel_; } diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index 0d26602aafd..d81e3f959db 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "common/utils.hpp" #include "cpu/platform.hpp" +#include "cpu/x64/brgemm/brgemm.hpp" #include "cpu/x64/brgemm/brgemm_types.hpp" #include "cpu/x64/cpu_barrier.hpp" #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" @@ -40,21 +41,18 @@ namespace x64 { using namespace dnnl::impl::utils; using namespace Xbyak; template -struct jit_brgemm_kernel_t : public jit_generator { +struct jit_brgemm_kernel_t : public jit_base_brgemm_kernel_t { jit_brgemm_kernel_t(const brgemm_desc_t &abrg) - : jit_generator(jit_name(), abrg.isa_impl) + : jit_base_brgemm_kernel_t(jit_name(), abrg.isa_impl) , brg(abrg) , postops_injector_(nullptr) - , max_effective_vregs(isa_num_vregs(brg.isa_impl) - - (brg.is_int8 && !brg.has_int8_vnni - ? 2 - : (brg.is_fp8_via_convert() ? 5 : 0))) { + , max_effective_vregs(get_max_effective_vregs(brg)) { // The implementation uses is_superset(), is_subset() utilities. // So avoid isa_all, isa_undef in these comparisions. assert(!utils::one_of(brg.isa_impl, isa_all, isa_undef)); - const int is_ldb2_tail = brg.ldb2_tail ? 1 : 0; - const int is_ldb_tail = brg.ldb_tail ? 1 : 0; + const dim_t is_ldb2_tail = brg.ldb2_tail ? 1 : 0; + const dim_t is_ldb_tail = brg.ldb_tail ? 1 : 0; is_ldb_loop_ = brg.ldb2 + is_ldb2_tail + is_ldb_tail > 1; bool has_f8_e5m2_binary_postops = false; @@ -82,15 +80,15 @@ struct jit_brgemm_kernel_t : public jit_generator { // 'fp8_to_f16_upconvert()' param and would collision with these // emulation vmms f8_e5m2_emulator_ = utils::make_unique( - this, xmm_fp8_emu_aux1, xmm_fp8_emu_aux2, - xmm_fp8_emu_aux3, kmask_fp8_aux, reg64_fp8_aux); + this, vmm_fp8_emu_aux1(), vmm_fp8_emu_aux2(), + vmm_fp8_emu_aux3(), kmask_fp8_aux, reg64_fp8_aux); if (one_of(data_type::f8_e4m3, brg.dt_a, brg.dt_b, brg.dt_c, brg.dt_d) || has_f8_e4m3_binary_postops) f8_e4m3_emulator_ = utils::make_unique( - this, xmm_fp8_emu_aux1, xmm_fp8_emu_aux2, - xmm_fp8_emu_aux3, xmm_fp8_emu_aux4, xmm_fp8_emu_aux5, - reg64_fp8_aux); + this, vmm_fp8_emu_aux1(), vmm_fp8_emu_aux2(), + vmm_fp8_emu_aux3(), vmm_fp8_emu_aux4(), + vmm_fp8_emu_aux5(), reg64_fp8_aux); } if (brg.with_eltwise || brg.with_binary || brg.with_sum) { @@ -131,16 +129,18 @@ struct jit_brgemm_kernel_t : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_kernel_t) - brgemm_desc_t brg; + const brgemm_desc_t &get_brg() const override { return brg; } private: + brgemm_desc_t brg; + enum matrix_kind_t { matrix_A, matrix_B }; static constexpr int zmm_width_in_bytes_ - = cpu_isa_traits::vlen; + = cpu_isa_traits_t::vlen; using Vmm = typename utils::conditional::value, Xbyak::Zmm, Wmm>::type; - using Vmm_lower_t = typename vreg_traits::Vmm_lower_t; + using Vmm_lower_t = typename vreg_traits_t::Vmm_lower_t; using po_injector_t = injector::jit_uni_postops_injector_base_t; std::unique_ptr postops_injector_; std::unique_ptr bf16_emu_; @@ -149,6 +149,8 @@ struct jit_brgemm_kernel_t : public jit_generator { Xbyak::Label avx_tail_mask_; Xbyak::Label sum_zp_scale_data_; + Xbyak::Label f16_perm_even_table_; + Xbyak::Label f16_perm_odd_table_; using reg64_t = const Xbyak::Reg64; // Register decomposition @@ -198,6 +200,13 @@ struct jit_brgemm_kernel_t : public jit_generator { const reg64_t reg_aux_zp_comp_b = reg_rdb_loop; const reg64_t reg_zp_c_values = reg_rdb_loop; const reg64_t reg_aux_zp_c_values = reg_rdb_loop; + const reg64_t reg_wei_scales = reg_rdb_loop; + const reg64_t reg_aux_wei_scales = reg_rdb_loop; + const reg64_t reg_wei_zp = reg_rdb_loop; + const reg64_t reg_aux_wei_zp = reg_rdb_loop; + const reg64_t reg_ic = reg_rdb_loop; + const reg64_t reg_src_scales = reg_rdb_loop; + const reg64_t reg_src_grouped_sum = reg_rdb_loop; const reg64_t reg_tmp_read_values = reg_rdb_loop; const reg64_t reg_aux_scales = reg_aux_B; @@ -262,10 +271,27 @@ struct jit_brgemm_kernel_t : public jit_generator { constexpr static int reg_aux_D_backup_offs_ = 232; constexpr static int reg_aux_D_bdb_loop_backup_offs_ = 240; constexpr static int reg_aux_D_bdb_loop_shift_offs_ = 248; + constexpr static int reg_wei_scales_offs_ = 256; + constexpr static int reg_aux_wei_scales_offs_ = 264; + constexpr static int reg_wei_zero_points_offs_ = 272; + constexpr static int reg_aux_wei_zero_points_offs_ = 280; + constexpr static int reg_ic_offs_ = 288; + constexpr static int reg_aux2_D_offs_ = 296; + constexpr static int reg_aux2_wei_scales_offs_ = 304; + constexpr static int reg_aux2_wei_zero_points_offs_ = 312; + constexpr static int reg_aux_ic_offs_ = 320; + constexpr static int reg_reg_a_offset_offs_ = 328; + constexpr static int reg_src_scales_offs_ = 336; + constexpr static int reg_aux_src_scales_offs_ = 344; + constexpr static int reg_aux2_src_scales_offs_ = 352; + constexpr static int reg_src_grouped_sum_offs_ = 360; + constexpr static int reg_aux_src_grouped_sum_offs_ = 368; + constexpr static int reg_aux2_src_grouped_sum_offs_ = 376; // these are used for FP8 as temporary push/pop spaces - constexpr static int reg_val_tmp_1_ = 256; - constexpr static int reg_val_tmp_2_ = 264; - constexpr static int stack_space_needed_ = 272; + constexpr static int reg_val_tmp_1_ = 384; + constexpr static int reg_val_tmp_2_ = 392; + constexpr static int stack_space_needed_ = 400; + bool is_ldb_loop_ = false; bool with_binary_non_scalar_bcast_ = false; @@ -275,14 +301,46 @@ struct jit_brgemm_kernel_t : public jit_generator { Xbyak::Opmask ld_tail_mask = Xbyak::Opmask(3); Xbyak::Opmask fp8_col_mask = Xbyak::Opmask(4); Xbyak::Opmask kmask_fp8_aux = Xbyak::Opmask(5); + Xbyak::Opmask rd_tail_mask = Xbyak::Opmask(6); + + static int get_max_effective_vregs(const brgemm_desc_t &brg) { + auto used_vregs = 0; + if (brg.is_int8 && !brg.has_int8_vnni) + used_vregs = 2; + else if (brg.is_fp8_via_convert()) + used_vregs = 5; + else if (brg.is_f16_b_non_amx_vnni()) + used_vregs = 2; + + if (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl == avx2) { + used_vregs += 5; + } + + if (one_of(brg.dt_b, data_type::f4_e2m1) && brg.isa_impl == avx2) { + used_vregs += 2; + } + + if (one_of(brg.dt_b, data_type::nf4, data_type::f4_e2m1) && brg.isa_impl != avx2) { + used_vregs += 1; + } + + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 && !brg.with_src_dyn_quant) { + used_vregs += 1; + } - Vmm accm(int ld_block, int bd, int ld) { + if (brg.with_src_dyn_quant) { + used_vregs += 1; + } + return isa_num_vregs(brg.isa_impl) - used_vregs; + } + + Vmm accm(dim_t ld_block, dim_t bd, dim_t ld) { return Vmm(max_effective_vregs - 1 - (bd * ld_block + ld)); } - Vmm bcst(int bd = 0) { - if (n_bcast_1_load) { - int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) + Vmm bcst(dim_t bd = 0) { + if (brg.n_bcast_1_load) { + dim_t idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - bd; assert(idx > 0); return Vmm(idx); @@ -290,18 +348,18 @@ struct jit_brgemm_kernel_t : public jit_generator { return Vmm(0); } - Vmm load(int ld = 0) { - if (n_bcast_1_load) { + Vmm load(dim_t ld = 0) { + if (brg.n_bcast_1_load) { return Vmm(0); } else { - int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) + dim_t idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld; assert(idx > 0); return Vmm(idx); } } - Vmm vmm_tmp(int i) { + Vmm vmm_tmp(dim_t i) { assert(IMPLICATION(!brg.is_tmm, i >= 0 && i < max_effective_vregs @@ -310,6 +368,10 @@ struct jit_brgemm_kernel_t : public jit_generator { } Vmm vmm_tail_mask() { return vmm_tmp(1); } + Vmm vmm_beta() { return vmm_tmp(1); } + Vmm vmm_lbound() { return vmm_tmp(1); } + Vmm vmm_ubound() { return vmm_tmp(0); } + Vmm vmm_one_bytes() const noexcept { return Vmm(3); } Vmm vmm_zp_a_shift() const noexcept { return Vmm(2); } Vmm vmm_inp_shift() const noexcept { return Vmm(1); } @@ -322,11 +384,13 @@ struct jit_brgemm_kernel_t : public jit_generator { // note: zmm reserv_5 is not necessary since it's only used for 'vdpbf16ps' // fp8 emulation convert - Vmm xmm_fp8_emu_aux1 = Vmm(1); - Vmm xmm_fp8_emu_aux2 = Vmm(2); - Vmm xmm_fp8_emu_aux3 = Vmm(3); - Vmm xmm_fp8_emu_aux4 = Vmm(4); - Vmm xmm_fp8_emu_aux5 = Vmm(5); + Vmm vmm_fp8_emu_aux1() const noexcept { return Vmm(1); } + Vmm vmm_fp8_emu_aux2() const noexcept { return Vmm(2); } + Vmm vmm_fp8_emu_aux3() const noexcept { return Vmm(3); } + Vmm vmm_fp8_emu_aux4() const noexcept { return Vmm(4); } + Vmm vmm_fp8_emu_aux5() const noexcept { return Vmm(5); } + + Zmm zmm_tmp_1() const noexcept { return Zmm(1); } // Required in every dot product for INT8 non-VNNI computation. Vmm int8_ones_words() const noexcept { @@ -336,6 +400,13 @@ struct jit_brgemm_kernel_t : public jit_generator { return Vmm(isa_num_vregs(brg.isa_impl) - 2); } + Vmm f16_perm_even_vreg() const noexcept { + return Vmm(isa_num_vregs(brg.isa_impl) - 1); + } + Vmm f16_perm_odd_vreg() const noexcept { + return Vmm(isa_num_vregs(brg.isa_impl) - 2); + } + Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const; Vmm_lower_t vmm_lower_mask(const Vmm_lower_t vmm_lower_in, bool mask_flag, @@ -344,263 +415,287 @@ struct jit_brgemm_kernel_t : public jit_generator { void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op, bool mask_flag, bool store, Xbyak::Opmask ktail_mask, - int tail_size); + dim_t tail_size); void advance_ldb_post_op_regs(); - void restore_ldb_post_op_regs(int ld_block2); - void advance_bdb_post_op_regs(int adj_bd_block); - void restore_bdb_post_op_regs(int bd_block2); - void ldb_regs_shift(int ld_block2, bool is_tail = false); - void advance_bd_block2_post_op_regs(int bd_block2); + void restore_ldb_post_op_regs(dim_t ld_block2); + void advance_bdb_post_op_regs(dim_t adj_bd_block); + void restore_bdb_post_op_regs(dim_t bd_block2); + void ldb_regs_shift(dim_t ld_block2, bool is_tail = false); + void advance_bd_block2_post_op_regs(dim_t bd_block2); void copy_post_ops_stack_values_to_aux(bool is_reg_tail); void read_params(); - void zero_accumulators(int bd_block2, bool is_bdb_tail, int ld_block, + void zero_accumulators(dim_t bd_block2, bool is_bdb_tail, dim_t ld_block, bool is_ld_tail, bool skip_accumulation); - void fp8_to_f16_upconvert(int num_rows, int tile_num_col_bytes, - reg64_t reg_base, int offset, reg64_t reg_data_stride, + void fp8_to_f16_upconvert(dim_t num_rows, dim_t tile_num_col_bytes, + reg64_t reg_base, dim_t offset, reg64_t reg_data_stride, data_type_t dt, bool is_rd_tail); - void fp8_to_f16_upconvert_to_vnni(int num_rows, int tile_num_col_bytes, - reg64_t reg_base, int offset, reg64_t reg_data_stride, + void fp8_to_f16_upconvert_to_vnni(dim_t num_rows, dim_t tile_num_col_bytes, + reg64_t reg_base, dim_t offset, reg64_t reg_data_stride, data_type_t dt, bool is_rd_tail); - void store_accumulators(int bd_block2, bool is_bdb_tail, int ld_block, + void store_accumulators(dim_t bd_block2, bool is_bdb_tail, dim_t ld_block, bool is_ld_tail, bool skip_accumulation); void store_accumulators_without_post_ops( - int bd_block, int ld_block, bool is_ld_tail); - void store_accumulators_apply_post_ops(int bd_block, int ld_block, - int ldb_and_bdb_offset, bool is_ld_tail); - void apply_compensation(int bd_block, int ld_block, bool is_ld_tail); - void apply_alpha_beta(int bd_block, int ld_block, bool is_ld_tail); - void apply_post_ops(int bd_block, int ld_block2, int ldb_and_bdb_offset, - bool is_ld_tail); + dim_t bd_block, dim_t ld_block, bool is_ld_tail); + void store_accumulators_apply_post_ops(dim_t bd_block, dim_t ld_block, + dim_t ldb_and_bdb_offset, bool is_ld_tail); + void apply_compensation(dim_t bd_block, dim_t ld_block, bool is_ld_tail); + void apply_alpha_beta(dim_t bd_block, dim_t ld_block, bool is_ld_tail); + void apply_post_ops(dim_t bd_block, dim_t ld_block2, + dim_t ldb_and_bdb_offset, bool is_ld_tail); void restore_A_B_matrices(); void set_A_B_matrices(); - void compute_int8_compensation(int rd_loop, int bd_b, int bd_e, - int bd_block, int ld_block2, bool is_ld_tail, int vpad); + void compute_int8_compensation(dim_t rd_loop, dim_t bd_b, dim_t bd_e, + dim_t bd_block, dim_t ld_block2, bool is_ld_tail, dim_t vpad); void maybe_pre_process_data(matrix_kind_t matrix_kind, const Tmm &t1, - reg64_t reg_base, size_t offset, reg64_t reg_stride, int num_rows, - int num_col_bytes, bool is_rd_tail); - void maybe_tileloadd_nt(matrix_kind_t matrix_kind, int idx, int offset, - bool is_rd_tail, bool is_tail); + reg64_t reg_base, dim_t offset, reg64_t reg_stride, dim_t num_rows, + dim_t num_col_bytes, bool is_rd_tail); + bool maybe_pre_process_k_tail(bool last_bdb, bool is_rd_tail, const Tmm &t1, + reg64_t reg_base, dim_t offset, reg64_t reg_stride, + matrix_kind_t mk); + void maybe_tileloadd_nt(matrix_kind_t matrix_kind, dim_t idx, dim_t offset, + bool is_rd_tail, bool is_tail, bool last_bdb); void dot_product(Vmm v1, Vmm v2, Vmm v3); - void gemm_microkernel(int bd_block2, bool is_bdb_tail, int ld_block, - bool is_rd_tail, bool is_ld_tail, int vpad, int rows_for_rd_tail); - void gemm_microkernel_amx(int bd_block2, bool is_bdb_tail, int ld_block, - bool is_rd_tail, bool is_ld_tail); - - void ldb_loop(int bd_block2, bool is_bdb_tail, int ld_block, - int ldb_loop_length, bool is_reg_tail, bool is_ld_tail, - bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail, + void gemm_microkernel(dim_t bd_block2, bool is_bdb_tail, dim_t ld_block, + bool is_rd_tail, bool is_ld_tail, dim_t vpad, + dim_t rows_for_rd_tail); + void gemm_microkernel_amx(dim_t bd_block2, bool is_bdb_tail, + dim_t ld_block2, bool is_rd_tail, bool is_ld_tail, bool last_bdb); + void gemm_microkernel_dyn_quant(dim_t bd_block2, bool is_bdb_tail, dim_t ld_block, + bool is_rd_tail, bool is_ld_tail, dim_t vpad, dim_t rows_for_rd_tail); + + void ldb_loop(dim_t bd_block2, bool is_bdb_tail, dim_t ld_block, + dim_t ldb_loop_length, bool is_reg_tail, bool is_ld_tail, + bool first_bdb, bool last_bdb, dim_t rows_for_rd_tail, bool skip_accumulation); void bdb_loop(); void generate() override; - int A_offset(int bd, int rd, bool is_amx = false) const noexcept; - int B_offset(int ld, int rd, bool is_amx = false) const noexcept; - int C_offset(int bd, int ld) const noexcept; - int D_offset(int bd, int ld) const noexcept; - - int rdb_A_offset() const noexcept; - int rdb_B_offset() const noexcept; - - int ldb_B_offset(int ld_block2, bool is_tail = false) const noexcept; - int ldb_C_offset(int ld_block2, bool is_tail = false) const noexcept; - int ldb_D_offset(int ld_block2, bool is_tail = false) const noexcept; - int ldb_po_offset(int ld_block2, bool is_tail = false) const noexcept; - - int bdb_A_offset(int bd_block2) const noexcept; - int bdb_C_offset(int bd_block2) const noexcept; - int bdb_D_offset(int bd_block2) const noexcept; - int bdb_po_offset(int bd_block2) const noexcept; - - int bias_offset(int ld, bool is_tail = false) const noexcept; - int oc_logical_offset(int ld, bool is_tail = false) const noexcept; - - int compensations_offset(int ld, bool is_tail = false) const noexcept; - int bdb_compensation_offset(int bd_block2) const noexcept; - int bd_compensation_offset(int ld, int bd) const noexcept; - int scales_offset(int ld, bool is_tail = false) const noexcept; - int zp_comp_a_offset(int ld, bool is_tail = false) const noexcept; - int bd_zp_comp_a_offset(int ld, int bd) const noexcept; - int bdb_zp_comp_a_offset(int bd_block2) const noexcept; - int zp_comp_b_offset(int bd) const noexcept; - int bdb_zp_comp_b_offset(int bd_block2) const noexcept; - int zp_c_values_offset(int ld, bool is_tail = false) const noexcept; - - bool n_bcast_1_load = false; + dim_t A_offset(dim_t bd, dim_t rd, bool is_amx = false) const noexcept; + dim_t B_offset(dim_t ld, dim_t rd, bool is_amx = false) const noexcept; + dim_t C_offset(dim_t bd, dim_t ld) const noexcept; + dim_t D_offset(dim_t bd, dim_t ld) const noexcept; + + dim_t rdb_A_offset() const noexcept; + dim_t rdb_B_offset() const noexcept; + + dim_t ldb_B_offset(dim_t ld_block2, bool is_tail = false) const noexcept; + dim_t ldb_C_offset(dim_t ld_block2, bool is_tail = false) const noexcept; + dim_t ldb_D_offset(dim_t ld_block2, bool is_tail = false) const noexcept; + dim_t ldb_po_offset(dim_t ld_block2, bool is_tail = false) const noexcept; + + dim_t bdb_A_offset(dim_t bd_block2) const noexcept; + dim_t bdb_C_offset(dim_t bd_block2) const noexcept; + dim_t bdb_D_offset(dim_t bd_block2) const noexcept; + dim_t bdb_po_offset(dim_t bd_block2) const noexcept; + + dim_t bias_offset(dim_t ld, bool is_tail = false) const noexcept; + dim_t oc_logical_offset(dim_t ld, bool is_tail = false) const noexcept; + + dim_t compensations_offset(dim_t ld, bool is_tail = false) const noexcept; + dim_t bdb_compensation_offset(dim_t bd_block2) const noexcept; + dim_t bd_compensation_offset(dim_t ld, dim_t bd) const noexcept; + dim_t scales_offset(dim_t ld, bool is_tail = false) const noexcept; + dim_t zp_comp_a_offset(dim_t ld, bool is_tail = false) const noexcept; + dim_t bd_zp_comp_a_offset(dim_t ld, dim_t bd) const noexcept; + dim_t bdb_zp_comp_a_offset(dim_t bd_block2) const noexcept; + dim_t zp_comp_b_offset(dim_t bd) const noexcept; + dim_t bdb_zp_comp_b_offset(dim_t bd_block2) const noexcept; + dim_t zp_c_values_offset(dim_t ld, bool is_tail = false) const noexcept; + dim_t wei_scales_offset(dim_t ld, bool is_tail = false) const noexcept; + dim_t wei_zp_offset(dim_t ld, bool is_tail = false) const noexcept; bool vpad_exist = false; bool need_comp_pads = false; + palette_config_t palette_; }; template -int jit_brgemm_kernel_t::A_offset( - int bd, int rd, bool is_amx) const noexcept { +dim_t jit_brgemm_kernel_t::A_offset( + dim_t bd, dim_t rd, bool is_amx) const noexcept { return (is_amx) ? brg.typesize_A * (bd * brg.bd_block * brg.LDA) : brg.typesize_A * (bd * brg.LDA + rd); } template -int jit_brgemm_kernel_t::B_offset( - int ld, int rd, bool is_amx) const noexcept { +dim_t jit_brgemm_kernel_t::B_offset( + dim_t ld, dim_t rd, bool is_amx) const noexcept { + int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; if (is_amx) { - return brg.typesize_B * (brg.rd_step * ld * brg.ld_block); + return brg.typesize_B * (brg.rd_step * ld * brg.ld_block) / typesize_scale; } else { - const int data_vnni_granularity = brg.ld_step; - const int rdb0 = rd / data_vnni_granularity; + const dim_t rdb0 = rd / brg.ld_step; // Note: Offsets for elements within vnni_granularity are expected to be // handled within gemm_microkernel (for ex: odd-even converts). - // hence no `rd % data_vnni_granularity` + // hence no `rd % brg.ld_step` return brg.typesize_B - * (rdb0 * data_vnni_granularity * brg.LDB - + data_vnni_granularity * ld * brg.ld_block); + * (rdb0 * brg.ld_step * brg.LDB + + brg.ld_step * ld * brg.ld_block) / typesize_scale; } } template -int jit_brgemm_kernel_t::C_offset(int bd, int ld) const noexcept { +dim_t jit_brgemm_kernel_t::C_offset(dim_t bd, dim_t ld) const noexcept { const auto bd_shift = brg.is_runtime_ldc ? 0 : bd * brg.LDC; return brg.typesize_C * (bd_shift + ld * brg.ld_block); } template -int jit_brgemm_kernel_t::D_offset(int bd, int ld) const noexcept { +dim_t jit_brgemm_kernel_t::D_offset(dim_t bd, dim_t ld) const noexcept { const auto bd_shift = brg.is_runtime_ldd ? 0 : bd * brg.LDD; return brg.typesize_D * (bd_shift + ld * brg.ld_block); } template -int jit_brgemm_kernel_t::rdb_A_offset() const noexcept { +dim_t jit_brgemm_kernel_t::rdb_A_offset() const noexcept { return brg.typesize_A * brg.rd_block; } template -int jit_brgemm_kernel_t::rdb_B_offset() const noexcept { - return brg.typesize_B * brg.rd_block * brg.LDB; +dim_t jit_brgemm_kernel_t::rdb_B_offset() const noexcept { + int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; + return brg.typesize_B * brg.rd_block * brg.LDB / typesize_scale; } template -int jit_brgemm_kernel_t::ldb_B_offset( - int ld_block2, bool is_tail) const noexcept { - return (is_tail) ? brg.typesize_B * brg.ldb_tail * brg.ld_step - : brg.typesize_B * ld_block2 * brg.ld_block * brg.ld_step; +dim_t jit_brgemm_kernel_t::ldb_B_offset( + dim_t ld_block2, bool is_tail) const noexcept { + int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; + return (is_tail) ? brg.typesize_B * brg.ldb_tail * brg.ld_step / typesize_scale + : brg.typesize_B * ld_block2 * brg.ld_block * brg.ld_step / typesize_scale; } template -int jit_brgemm_kernel_t::ldb_C_offset( - int ld_block2, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::ldb_C_offset( + dim_t ld_block2, bool is_tail) const noexcept { return (is_tail) ? brg.typesize_C * brg.ldb_tail : brg.typesize_C * ld_block2 * brg.ld_block; } template -int jit_brgemm_kernel_t::ldb_D_offset( - int ld_block2, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::ldb_D_offset( + dim_t ld_block2, bool is_tail) const noexcept { return (is_tail) ? brg.typesize_D * brg.ldb_tail : brg.typesize_D * ld_block2 * brg.ld_block; } template -int jit_brgemm_kernel_t::ldb_po_offset( - int ld_block2, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::ldb_po_offset( + dim_t ld_block2, bool is_tail) const noexcept { return (is_tail) ? brg.ldb_tail : ld_block2 * brg.ld_block; } template -int jit_brgemm_kernel_t::bdb_A_offset(int bd_block2) const noexcept { +dim_t jit_brgemm_kernel_t::bdb_A_offset(dim_t bd_block2) const noexcept { return brg.typesize_A * bd_block2 * brg.bd_block * brg.LDA; } template -int jit_brgemm_kernel_t::bdb_C_offset(int bd_block2) const noexcept { +dim_t jit_brgemm_kernel_t::bdb_C_offset(dim_t bd_block2) const noexcept { return bd_block2 * brg.bd_block * (brg.is_runtime_ldc ? 1 : brg.typesize_C * brg.LDC); } template -int jit_brgemm_kernel_t::bdb_D_offset(int bd_block2) const noexcept { +dim_t jit_brgemm_kernel_t::bdb_D_offset(dim_t bd_block2) const noexcept { return bd_block2 * brg.bd_block * (brg.is_runtime_ldd ? 1 : brg.typesize_D * brg.LDD); } template -int jit_brgemm_kernel_t::bdb_po_offset(int bd_block2) const noexcept { +dim_t jit_brgemm_kernel_t::bdb_po_offset(dim_t bd_block2) const noexcept { return bd_block2 * brg.bd_block * brg.LDD; } template -int jit_brgemm_kernel_t::bias_offset(int ld, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::bias_offset( + dim_t ld, bool is_tail) const noexcept { return (is_tail) ? brg.typesize_bias * brg.ldb_tail : brg.typesize_bias * ld * brg.ld_block; } template -int jit_brgemm_kernel_t::oc_logical_offset( - int ld, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::oc_logical_offset( + dim_t ld, bool is_tail) const noexcept { return (is_tail) ? brg.ldb_tail : ld * brg.ld_block; } template -int jit_brgemm_kernel_t::compensations_offset( - int ld, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::compensations_offset( + dim_t ld, bool is_tail) const noexcept { return (is_tail) ? sizeof(int32_t) * brg.ldb_tail : sizeof(int32_t) * ld * brg.ld_block; } template -int jit_brgemm_kernel_t::bdb_compensation_offset( - int bd_block2) const noexcept { +dim_t jit_brgemm_kernel_t::bdb_compensation_offset( + dim_t bd_block2) const noexcept { return sizeof(int32_t) * bd_block2 * brg.bd_block * brg.LDB; } template -int jit_brgemm_kernel_t::bd_compensation_offset( - int ld, int bd) const noexcept { +dim_t jit_brgemm_kernel_t::bd_compensation_offset( + dim_t ld, dim_t bd) const noexcept { return sizeof(int32_t) * (ld * brg.ld_block + bd * brg.LDB); } template -int jit_brgemm_kernel_t::scales_offset( - int ld, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::scales_offset( + dim_t ld, bool is_tail) const noexcept { return (is_tail) ? brg.is_oc_scale * sizeof(float) * brg.ldb_tail : brg.is_oc_scale * sizeof(float) * ld * brg.ld_block; } template -int jit_brgemm_kernel_t::zp_comp_a_offset( - int ld, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::zp_comp_a_offset( + dim_t ld, bool is_tail) const noexcept { return (is_tail) ? sizeof(int32_t) * brg.ldb_tail : sizeof(int32_t) * ld * brg.ld_block; } template -int jit_brgemm_kernel_t::bdb_zp_comp_a_offset( - int bd_block2) const noexcept { +dim_t jit_brgemm_kernel_t::wei_scales_offset( + dim_t ld, bool is_tail) const noexcept { + return (is_tail) ? types::data_type_size(brg.wei_decomp_scales_dt) * brg.ldb_tail + : types::data_type_size(brg.wei_decomp_scales_dt) * ld * brg.ld_block; +} + +template +dim_t jit_brgemm_kernel_t::wei_zp_offset( + dim_t ld, bool is_tail) const noexcept { + return (is_tail) ? types::data_type_size(brg.wei_decomp_zero_points_dt) * brg.ldb_tail + : types::data_type_size(brg.wei_decomp_zero_points_dt) * ld * brg.ld_block; +} + +template +dim_t jit_brgemm_kernel_t::bdb_zp_comp_a_offset( + dim_t bd_block2) const noexcept { return sizeof(int32_t) * bd_block2 * brg.bd_block * brg.LDB; } template -int jit_brgemm_kernel_t::bd_zp_comp_a_offset( - int ld, int bd) const noexcept { +dim_t jit_brgemm_kernel_t::bd_zp_comp_a_offset( + dim_t ld, dim_t bd) const noexcept { return sizeof(int32_t) * (ld * brg.ld_block + bd * brg.LDB); } template -int jit_brgemm_kernel_t::zp_comp_b_offset(int bd) const noexcept { +dim_t jit_brgemm_kernel_t::zp_comp_b_offset(dim_t bd) const noexcept { return sizeof(int32_t) * bd; } template -int jit_brgemm_kernel_t::bdb_zp_comp_b_offset( - int bd_block2) const noexcept { +dim_t jit_brgemm_kernel_t::bdb_zp_comp_b_offset( + dim_t bd_block2) const noexcept { return zp_comp_b_offset(bd_block2 * brg.bd_block); } template -int jit_brgemm_kernel_t::zp_c_values_offset( - int ld, bool is_tail) const noexcept { +dim_t jit_brgemm_kernel_t::zp_c_values_offset( + dim_t ld, bool is_tail) const noexcept { if (brg.zp_type_c == brgemm_broadcast_t::per_n) { return (is_tail) ? sizeof(int32_t) * brg.ldb_tail : sizeof(int32_t) * ld * brg.ld_block; @@ -636,10 +731,10 @@ void jit_brgemm_kernel_t::maybe_set_avx_mask(bool is_ld_tail) { template void jit_brgemm_kernel_t::cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op, bool mask_flag, bool store, - Xbyak::Opmask ktail_mask, int tail_size) { + Xbyak::Opmask ktail_mask, dim_t tail_size) { Vmm vmm = vmm_in; - const bool has_tail - = op.isMEM() && tail_size != vreg_traits::vlen / sizeof(float); + const bool has_tail = op.isMEM() + && tail_size != vreg_traits_t::vlen / sizeof(float); if (IMPLICATION(has_tail, is_superset(brg.isa_impl, avx512_core))) { vmm = vmm_mask(vmm_in, mask_flag, store, ktail_mask); } else { @@ -700,7 +795,7 @@ void jit_brgemm_kernel_t::advance_ldb_post_op_regs() { } template -void jit_brgemm_kernel_t::restore_ldb_post_op_regs(int ld_block2) { +void jit_brgemm_kernel_t::restore_ldb_post_op_regs(dim_t ld_block2) { if (brg.with_bias) { mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); sub(reg_aux_bias, bias_offset(ld_block2 - 1)); @@ -724,7 +819,7 @@ void jit_brgemm_kernel_t::restore_ldb_post_op_regs(int ld_block2) { } template -void jit_brgemm_kernel_t::advance_bdb_post_op_regs(int adj_bd_block) { +void jit_brgemm_kernel_t::advance_bdb_post_op_regs(dim_t adj_bd_block) { if (brg.zp_type_b != brgemm_broadcast_t::none) { mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]); add(reg_aux_zp_comp_b, bdb_zp_comp_b_offset(1)); @@ -739,7 +834,7 @@ void jit_brgemm_kernel_t::advance_bdb_post_op_regs(int adj_bd_block) { } template -void jit_brgemm_kernel_t::restore_bdb_post_op_regs(int bd_block2) { +void jit_brgemm_kernel_t::restore_bdb_post_op_regs(dim_t bd_block2) { bool post_processed = false; if (bd_block2 > 1) { if (brg.zp_type_b != brgemm_broadcast_t::none) { @@ -759,14 +854,16 @@ void jit_brgemm_kernel_t::restore_bdb_post_op_regs(int bd_block2) { } template -void jit_brgemm_kernel_t::ldb_regs_shift(int ld_block2, bool is_tail) { - int C_offset = (is_tail) ? ldb_C_offset(1, true) : ldb_C_offset(ld_block2); - int D_offset = (is_tail) ? ldb_D_offset(1, true) : ldb_D_offset(ld_block2); +void jit_brgemm_kernel_t::ldb_regs_shift(dim_t ld_block2, bool is_tail) { + dim_t C_offset + = (is_tail) ? ldb_C_offset(1, true) : ldb_C_offset(ld_block2); + dim_t D_offset + = (is_tail) ? ldb_D_offset(1, true) : ldb_D_offset(ld_block2); add(reg_aux_C, C_offset); add(reg_aux_D, D_offset); add(reg_b_offset, - (is_tail) ? ldb_B_offset(1, true) : ldb_B_offset(ld_block2)); + (is_tail) ? ldb_B_offset(0, true) : ldb_B_offset(ld_block2)); if (brg.with_bias) { mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); @@ -787,6 +884,23 @@ void jit_brgemm_kernel_t::ldb_regs_shift(int ld_block2, bool is_tail) { (is_tail) ? scales_offset(1, true) : scales_offset(ld_block2)); mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales); } + + if (brg.with_wei_decomp) { + if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0 ) { + mov(reg_aux_wei_scales, ptr[rsp + reg_aux_wei_scales_offs_]); + add(reg_aux_wei_scales, (is_tail) ? wei_scales_offset(1, true) : wei_scales_offset(ld_block2)); + mov(ptr[rsp + reg_aux_wei_scales_offs_], reg_aux_wei_scales); + mov(ptr[rsp + reg_aux2_wei_scales_offs_], reg_aux_wei_scales); + } + + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0 ) { + mov(reg_aux_wei_zp, ptr[rsp + reg_aux_wei_zero_points_offs_]); + add(reg_aux_wei_zp, (is_tail) ? wei_zp_offset(1, true) : wei_zp_offset(ld_block2)); + mov(ptr[rsp + reg_aux_wei_zero_points_offs_], reg_aux_wei_zp); + mov(ptr[rsp + reg_aux2_wei_zero_points_offs_], reg_aux_wei_zp); + } + } + if (brg.zp_type_a != brgemm_broadcast_t::none) { mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]); add(reg_aux_zp_comp_a, @@ -804,7 +918,7 @@ void jit_brgemm_kernel_t::ldb_regs_shift(int ld_block2, bool is_tail) { } template -void jit_brgemm_kernel_t::advance_bd_block2_post_op_regs(int bd_block2) { +void jit_brgemm_kernel_t::advance_bd_block2_post_op_regs(dim_t bd_block2) { if (brg.req_comp_pads_with_bcast && brg.req_s8s8_compensation) { mov(reg_compensation, ptr[rsp + reg_comp_offs_]); add(reg_compensation, bdb_compensation_offset(bd_block2)); @@ -854,6 +968,29 @@ void jit_brgemm_kernel_t::copy_post_ops_stack_values_to_aux( mov(reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]); mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_zp_c_values); } + + if (brg.with_wei_decomp_scales) { + mov(reg_wei_scales, ptr[rsp + reg_wei_scales_offs_]); + mov(ptr[rsp + reg_aux_wei_scales_offs_], reg_wei_scales); + mov(ptr[rsp + reg_aux2_wei_scales_offs_], reg_wei_scales); + } + if (brg.with_wei_decomp_zero_points) { + mov(reg_wei_zp, ptr[rsp + reg_wei_zero_points_offs_]); + mov(ptr[rsp + reg_aux_wei_zero_points_offs_], reg_wei_zp); + mov(ptr[rsp + reg_aux2_wei_zero_points_offs_], reg_wei_zp); + } + + } + if (brg.with_src_dyn_quant) { + mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]); + mov(ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales); + mov(ptr[rsp + reg_aux2_src_scales_offs_], reg_src_scales); + + if (brg.with_wei_decomp_zero_points) { + mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]); + mov(ptr[rsp + reg_aux_src_grouped_sum_offs_], reg_src_grouped_sum); + mov(ptr[rsp + reg_aux2_src_grouped_sum_offs_], reg_src_grouped_sum); + } } if (brg.zp_type_b != brgemm_broadcast_t::none) { mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]); @@ -917,6 +1054,25 @@ void jit_brgemm_kernel_t::read_params() { mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b); } + if (brg.with_wei_decomp) { + mov(reg_wei_scales, ptr[param1 + GET_OFF(ptr_wei_scales)]); + mov(ptr[rsp + reg_wei_scales_offs_], reg_wei_scales); + + mov(reg_wei_zp, ptr[param1 + GET_OFF(ptr_wei_zero_points)]); + mov(ptr[rsp + reg_wei_zero_points_offs_], reg_wei_zp); + + mov(reg_ic, ptr[param1 + GET_OFF(ic)]); + mov(ptr[rsp + reg_ic_offs_], reg_ic); + } + + if (brg.with_src_dyn_quant) { + mov(reg_src_scales, ptr[param1 + GET_OFF(ptr_src_scales)]); + mov(ptr[rsp + reg_src_scales_offs_], reg_src_scales); + + mov(reg_src_grouped_sum, ptr[param1 + GET_OFF(ptr_src_grouped_sum)]); + mov(ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum); + } + if (brg.zp_type_c != brgemm_broadcast_t::none) { mov(reg_zp_c_values, ptr[param1 + GET_OFF(c_zp_values)]); mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values); @@ -953,21 +1109,21 @@ void jit_brgemm_kernel_t::read_params() { } template -void jit_brgemm_kernel_t::zero_accumulators(int bd_block2, - bool is_bdb_tail, int ld_block2, bool is_ld_tail, +void jit_brgemm_kernel_t::zero_accumulators(dim_t bd_block2, + bool is_bdb_tail, dim_t ld_block2, bool is_ld_tail, bool skip_accumulation) { if (brg.is_tmm) { // avoid usage of tile registers if there is no accumulation if (skip_accumulation) return; - for_(int bdb = 0; bdb < bd_block2; bdb++) - for (int ldb = 0; ldb < ld_block2; ldb++) { - int idx = (is_ld_tail) ? brg.ld_block2 : ldb; + for_(dim_t bdb = 0; bdb < bd_block2; bdb++) + for (dim_t ldb = 0; ldb < ld_block2; ldb++) { + dim_t idx = (is_ld_tail) ? brg.ld_block2 : ldb; tilezero(Tmm(brg.get_C_tensor(bdb, idx, is_bdb_tail, is_ld_tail))); } } else { - int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; - for_(int bd = 0; bd < bd_block; bd++) - for (int ld = 0; ld < ld_block2; ld++) { + dim_t bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; + for_(dim_t bd = 0; bd < bd_block; bd++) + for (dim_t ld = 0; ld < ld_block2; ld++) { auto vmm = accm(ld_block2, bd, ld); uni_vpxor(vmm, vmm, vmm); } @@ -977,30 +1133,30 @@ void jit_brgemm_kernel_t::zero_accumulators(int bd_block2, // This method up-converts the data from bf8 to f16 and saves at reg_buf. // Generally used by matrix_A, where no vnni transformation of data is needed. template -void jit_brgemm_kernel_t::fp8_to_f16_upconvert(int num_rows, - int tile_num_col_bytes, reg64_t reg_base, int offset, +void jit_brgemm_kernel_t::fp8_to_f16_upconvert(dim_t num_rows, + dim_t tile_num_col_bytes, reg64_t reg_base, dim_t offset, reg64_t reg_data_stride, data_type_t dt, bool is_rd_tail) { - int rd_block = is_rd_tail ? brg.rdb_tail : brg.rd_block; + dim_t rd_block = is_rd_tail ? brg.rdb_tail : brg.rd_block; - const int max_num_cols = rd_block; //tile_num_col_bytes / sizeof(float16_t); - const int col_tail = max_num_cols % 32; + const dim_t max_num_cols + = rd_block; //tile_num_col_bytes / sizeof(float16_t); + const dim_t col_tail = max_num_cols % 32; auto zmm_1 = vmm_tmp(0); auto zmm_1_masked = col_tail ? zmm_1 | fp8_col_mask | T_z : zmm_1; assert(max_num_cols > 0); if (col_tail) { - const int tail_mask = (1 << col_tail) - 1; - auto reg_tmp_32 = reg_tmp_gpr.cvt32(); - mov(reg_tmp_32, tail_mask); - kmovd(fp8_col_mask, reg_tmp_32); + const auto tail_mask = (static_cast(1) << col_tail) - 1; + mov(reg_tmp_gpr, tail_mask); + kmovq(fp8_col_mask, reg_tmp_gpr); } // Note: using the same register used in col_tail, so order is important const auto reg_data_aux = reg_tmp_gpr; lea(reg_data_aux, ptr[reg_base + offset]); - for (int r = 0; r < num_rows; ++r) { + for (dim_t r = 0; r < num_rows; ++r) { if (dt == data_type::f8_e5m2) f8_e5m2_emulator_->vcvt_f8_to_f16(zmm_1_masked, ptr[reg_data_aux]); else if (dt == data_type::f8_e4m3) @@ -1016,11 +1172,11 @@ void jit_brgemm_kernel_t::fp8_to_f16_upconvert(int num_rows, // This method up-converts and transforms the data from fp8_vnni to f16_vnni // format. Generally used by matrix_B. template -void jit_brgemm_kernel_t::fp8_to_f16_upconvert_to_vnni(int num_rows, - int tile_num_col_bytes, reg64_t reg_base, int offset, +void jit_brgemm_kernel_t::fp8_to_f16_upconvert_to_vnni(dim_t num_rows, + dim_t tile_num_col_bytes, reg64_t reg_base, dim_t offset, reg64_t reg_data_stride, data_type_t dt, bool is_rd_tail) { - const int num_cols_ele = tile_num_col_bytes / 2; // 32 for full tile - const int num_N = num_cols_ele / 2; // 16 for full tile + const dim_t num_cols_ele = tile_num_col_bytes / 2; // 32 for full tile + const dim_t num_N = num_cols_ele / 2; // 16 for full tile const auto zmm_2 = vmm_tmp(2); assert(num_N > 0 && "bad tile parameters"); @@ -1029,9 +1185,9 @@ void jit_brgemm_kernel_t::fp8_to_f16_upconvert_to_vnni(int num_rows, const auto reg_data_aux = reg_tmp_gpr; lea(reg_data_aux, ptr[reg_base + offset]); - int rd_block = is_rd_tail ? brg.rdb_tail : brg.rd_block; - const int vnni_granularity = data_type_vnni_granularity(data_type::f16); - const int r_end = utils::div_up(rd_block, vnni_granularity); + dim_t rd_block = is_rd_tail ? brg.rdb_tail : brg.rd_block; + const dim_t vnni_granularity = data_type_vnni_granularity(data_type::f16); + const dim_t r_end = utils::div_up(rd_block, vnni_granularity); assert(r_end <= num_rows && "bad tile parameters"); if (dt == data_type::f8_e5m2) @@ -1046,16 +1202,16 @@ void jit_brgemm_kernel_t::fp8_to_f16_upconvert_to_vnni(int num_rows, // zero rest of the tile data if (r_end < num_rows) { vpxord(zmm_2, zmm_2, zmm_2); - for (int r = r_end; r < num_rows; ++r) + for (dim_t r = r_end; r < num_rows; ++r) vmovups(ptr[reg_buf_aux + r * zmm_width_in_bytes_], zmm_2); } } template void jit_brgemm_kernel_t::apply_alpha_beta( - int bd_block, int ld_block2, bool is_ld_tail) { + dim_t bd_block, dim_t ld_block2, bool is_ld_tail) { const bool apply_alpha = brg.alpha != 1.f; - const bool dq2ps_required = brg.is_int8 && (apply_alpha || brg.beta != 1.f); + const bool dq2ps_required = brg.is_int8 && (apply_alpha || brg.beta != 1.f) && !brg.with_src_dyn_quant; auto vmm_alpha = vmm_tmp(0); if (apply_alpha) { @@ -1063,8 +1219,8 @@ void jit_brgemm_kernel_t::apply_alpha_beta( uni_vmovq(Xmm(vmm_alpha.getIdx()), reg_tmp_gpr); uni_vbroadcastss(vmm_alpha, Xmm(vmm_alpha.getIdx())); } - for_(int bd = 0; bd < bd_block; bd++) - for (int ld = 0; ld < ld_block2; ld++) { + for_(dim_t bd = 0; bd < bd_block; bd++) + for (dim_t ld = 0; ld < ld_block2; ld++) { auto vmm = accm(ld_block2, bd, ld); if (dq2ps_required) uni_vcvtdq2ps(vmm, vmm); if (apply_alpha) uni_vmulps(vmm, vmm, vmm_alpha); @@ -1074,11 +1230,10 @@ void jit_brgemm_kernel_t::apply_alpha_beta( const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required; const bool need_init_beta_vmm = brg.beta != 1.f; auto vmm_prev_dst = vmm_tmp(0); - auto vmm_beta = vmm_tail_mask(); if (need_init_beta_vmm) { mov(reg_tmp_gpr, float2int(static_cast(brg.beta))); - uni_vmovq(Xmm(vmm_beta.getIdx()), reg_tmp_gpr); - uni_vbroadcastss(vmm_beta, Xmm(vmm_beta.getIdx())); + uni_vmovq(Xmm(vmm_beta().getIdx()), reg_tmp_gpr); + uni_vbroadcastss(vmm_beta(), Xmm(vmm_beta().getIdx())); } if (brg.is_runtime_ldc && bd_block > 1) @@ -1086,8 +1241,8 @@ void jit_brgemm_kernel_t::apply_alpha_beta( if (brg.is_fp8_via_convert()) mov(ptr[rsp + reg_val_tmp_1_], reg64_fp8_aux); - for_(int bd = 0; bd < bd_block; bd++) - for (int ld = 0; ld < ld_block2; ld++) { + for_(dim_t bd = 0; bd < bd_block; bd++) + for (dim_t ld = 0; ld < ld_block2; ld++) { const bool is_tail = is_ld_tail && ld + 1 == ld_block2; const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask; auto vmm = accm(ld_block2, bd, ld); @@ -1095,25 +1250,25 @@ void jit_brgemm_kernel_t::apply_alpha_beta( if (use_vadd_for_beta) { if (IMPLICATION(is_tail, is_superset(brg.isa_impl, avx512_core))) { auto vmm_masked = vmm_mask(vmm, is_tail, false, k_mask); - if (brg.is_int8) + if (brg.is_int8 && !brg.with_src_dyn_quant) uni_vpaddd(vmm_masked, vmm, ptr_C); else uni_vaddps(vmm_masked, vmm, ptr_C); } else { vmaskmovps(vmm_prev_dst, vmm_tail_mask(), ptr_C); - if (brg.is_int8) + if (brg.is_int8 && !brg.with_src_dyn_quant) uni_vpaddd(vmm, vmm, vmm_prev_dst); else uni_vaddps(vmm, vmm, vmm_prev_dst); } } else { - const int ld_size = is_tail ? brg.ldb_tail : brg.ld_block; + const dim_t ld_size = is_tail ? brg.ldb_tail : brg.ld_block; cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, is_tail, false, k_mask, ld_size); if (brg.beta == 1.f) uni_vaddps(vmm, vmm, vmm_prev_dst); else - uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_beta); + uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_beta()); } if (brg.is_runtime_ldc && bd_block > 1 && ld == ld_block2 - 1) add(reg_aux_C, ptr[rsp + reg_C_shift_bytes_offs_]); @@ -1128,8 +1283,8 @@ void jit_brgemm_kernel_t::apply_alpha_beta( } template -void jit_brgemm_kernel_t::apply_post_ops( - int bd_block, int ld_block2, int ldb_and_bdb_offset, bool is_ld_tail) { +void jit_brgemm_kernel_t::apply_post_ops(dim_t bd_block, dim_t ld_block2, + dim_t ldb_and_bdb_offset, bool is_ld_tail) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -1143,16 +1298,16 @@ void jit_brgemm_kernel_t::apply_post_ops( if (brg.is_runtime_ldd && bd_block > 1) mov(ptr[rsp + reg_aux_D_backup_offs_], reg_aux_D); - const int bd_block_shift = brg.is_runtime_ldd ? 1 : bd_block; - for (int bd_block_idx = 0; bd_block_idx < bd_block; + const dim_t bd_block_shift = brg.is_runtime_ldd ? 1 : bd_block; + for (dim_t bd_block_idx = 0; bd_block_idx < bd_block; bd_block_idx += bd_block_shift) { - int bd_start = bd_block_idx; - int bd_end = bd_start + bd_block_shift; + dim_t bd_start = bd_block_idx; + dim_t bd_end = bd_start + bd_block_shift; const auto set_binary_injecotr_params = [&] { if (!brg.with_binary || !with_binary_non_scalar_bcast_) return; - for_(int bd = bd_start; bd < bd_end; bd++) - for (int ld = 0; ld < ld_block2; ld++) { + for_(dim_t bd = bd_start; bd < bd_end; bd++) + for (dim_t ld = 0; ld < ld_block2; ld++) { const auto vmm_idx = accm(ld_block2, bd, ld).getIdx(); rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_aux_D); @@ -1205,14 +1360,14 @@ void jit_brgemm_kernel_t::apply_post_ops( // objects above that use push/pop if (brg.is_fp8_via_convert()) push(reg64_fp8_aux); - for_(int bd = bd_start; bd < bd_end; bd++) - for (int ld = 0; ld < ld_block2; ld++) { + for_(dim_t bd = bd_start; bd < bd_end; bd++) + for (dim_t ld = 0; ld < ld_block2; ld++) { const auto vmm = accm(ld_block2, bd, ld); const auto addr = ptr[reg_aux_D + D_offset(bd, ld)]; const auto vmm_prev_dst = vmm_tmp(0); const bool is_tail = is_ld_tail && ld + 1 == ld_block2; const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask; - const int ld_size = is_tail ? brg.ldb_tail : brg.ld_block; + const dim_t ld_size = is_tail ? brg.ldb_tail : brg.ld_block; cvt2ps(brg.sum_dt, vmm_prev_dst, addr, is_tail, false, k_mask, ld_size); if (p_sum_zp_reg_set) @@ -1253,8 +1408,8 @@ void jit_brgemm_kernel_t::apply_post_ops( } template -void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( - int bd_block, int ld_block2, int ldb_and_bdb_offset, bool is_ld_tail) { +void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(dim_t bd_block, + dim_t ld_block2, dim_t ldb_and_bdb_offset, bool is_ld_tail) { auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) -> @@ -1263,11 +1418,12 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( const bool beta_uses_vadd = brg.beta == 1.f && IMPLICATION(brg.is_int8, brg.alpha == 1.0f); const bool dq2ps_required = brg.is_int8 - && IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd); + && IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd) + && !brg.with_src_dyn_quant; if (brg.with_scales) { mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]); - for (int ld = 0; ld < ld_block2; ld++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { const auto addr = ptr[reg_aux_scales + scales_offset(ld)]; const bool is_tail = is_ld_tail && ld + 1 == ld_block2; auto vmm_scales = vmm_tmp(0); @@ -1279,7 +1435,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( auto vmm_scales = vmm_tmp(0); vmaskmovps(vmm_scales, vmm_tail_mask(), addr); } - for (int bd = 0; bd < bd_block; bd++) { + for (dim_t bd = 0; bd < bd_block; bd++) { auto vmm = accm(ld_block2, bd, ld); if (dq2ps_required) uni_vcvtdq2ps(vmm, vmm); uni_vmulps(vmm, vmm, vmm_scales); @@ -1290,7 +1446,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( if (brg.with_bias) { mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); } if (brg.is_fp8_via_convert()) mov(ptr[rsp + reg_val_tmp_1_], reg64_fp8_aux); - for (int ld = 0; ld < ld_block2; ld++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { auto vmm_bias = vmm_tmp(0); if (brg.with_bias) { auto ptr_bias = ptr[reg_aux_bias + bias_offset(ld)]; @@ -1298,7 +1454,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, is_tail, false, k_mask, is_tail ? brg.ldb_tail : brg.ld_block); } - for (int bd = 0; bd < bd_block; bd++) { + for (dim_t bd = 0; bd < bd_block; bd++) { auto vmm = accm(ld_block2, bd, ld); if (dq2ps_required && !brg.with_scales) uni_vcvtdq2ps(vmm, vmm); if (brg.with_bias) uni_vaddps(vmm, vmm, vmm_bias); @@ -1314,8 +1470,8 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( auto vmm_dst_scales = vmm_tmp(0); vbroadcastss(vmm_dst_scales, ptr[reg_aux_dst_scales]); - for (int ld = 0; ld < ld_block2; ld++) { - for (int bd = 0; bd < bd_block; bd++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { + for (dim_t bd = 0; bd < bd_block; bd++) { auto vmm = accm(ld_block2, bd, ld); vmulps(vmm, vmm, vmm_dst_scales); } @@ -1337,10 +1493,10 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( if (brg.is_fp8_via_convert()) mov(ptr[rsp + reg_val_tmp_1_], reg64_fp8_aux); - for (int ld = 0; ld < ld_block2; ld++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { const bool is_tail = is_ld_tail && ld + 1 == ld_block2; if (brg.zp_type_c == brgemm_broadcast_t::per_n) { - int zp_c_off = zp_c_values_offset(ld); + dim_t zp_c_off = zp_c_values_offset(ld); if (is_superset(brg.isa_impl, avx512_core)) { auto zp_c_addr = EVEX_compress_addr(reg_aux_zp_c_values, zp_c_off); @@ -1352,7 +1508,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( k_mask, is_tail ? brg.ldb_tail : brg.ld_block); } } - for (int bd = 0; bd < bd_block; bd++) { + for (dim_t bd = 0; bd < bd_block; bd++) { auto vmm = accm(ld_block2, bd, ld); uni_vaddps(vmm, vmm, vmm_zp_c); } @@ -1363,16 +1519,14 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( const bool dt_requires_saturation = one_of(brg.dt_d, data_type::u8, data_type::s8, data_type::s32); - auto vmm_lbound = vmm_tail_mask(); - auto vmm_ubound = vmm_tmp(0); - assert(vmm_lbound.getIdx() != vmm_ubound.getIdx()); + assert(vmm_lbound().getIdx() != vmm_ubound().getIdx()); if (dt_requires_saturation) { - init_saturate_f32( - vmm_lbound, vmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d); - for (int bd = 0; bd < bd_block; bd++) { - for (int ld = 0; ld < ld_block2; ld++) { + init_saturate_f32(vmm_lbound(), vmm_ubound(), reg_tmp_gpr, + data_type::f32, brg.dt_d); + for (dim_t bd = 0; bd < bd_block; bd++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { auto vmm = accm(ld_block2, bd, ld); - saturate_cvt_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d); + saturate_cvt_f32(vmm, vmm_lbound(), vmm_ubound(), brg.dt_d); } } // below call is not required as s32 doesn't use vmm_lbound @@ -1385,8 +1539,8 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( mov(ptr[rsp + reg_aux_D_backup_offs_], reg_aux_D); if (brg.is_fp8_via_convert()) mov(ptr[rsp + reg_val_tmp_1_], reg64_fp8_aux); - for_(int bd = 0; bd < bd_block; bd++) - for (int ld = 0; ld < ld_block2; ld++) { + for_(dim_t bd = 0; bd < bd_block; bd++) + for (dim_t ld = 0; ld < ld_block2; ld++) { auto addr = ptr[reg_aux_D + D_offset(bd, ld)]; auto vmm = accm(ld_block2, bd, ld); auto vmm_lower = Vmm_lower_t(vmm.getIdx()); @@ -1431,7 +1585,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( default: assert(!"unknown dst_dt"); } } else { - const int ld_block = is_tail ? brg.ldb_tail : brg.ld_block; + const dim_t ld_block = is_tail ? brg.ldb_tail : brg.ld_block; if (is_tail && types::data_type_size(brg.dt_b) == sizeof(float)) vmaskmovps(addr, vmm_tail_mask(), vmm); else @@ -1449,7 +1603,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( template void jit_brgemm_kernel_t::apply_compensation( - int bd_block, int ld_block2, bool is_ld_tail) { + dim_t bd_block, dim_t ld_block2, bool is_ld_tail) { // apply compensation to accumulated values // to avoid the loss of accuracy when converting s32 to f32 auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; @@ -1461,9 +1615,9 @@ void jit_brgemm_kernel_t::apply_compensation( mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]); const auto vmm_zp_comp_a = vmm_tmp(0); - for (int ld = 0; ld < ld_block2; ld++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { const bool is_tail = is_ld_tail && ld + 1 == ld_block2; - for (int bd = 0; bd < bd_block; bd++) { + for (dim_t bd = 0; bd < bd_block; bd++) { if (IMPLICATION(!brg.req_comp_pads_with_bcast, bd == 0)) { const auto zp_comp_a_addr = ptr[reg_aux_zp_comp_a + bd_zp_comp_a_offset(ld, bd)]; @@ -1488,9 +1642,9 @@ void jit_brgemm_kernel_t::apply_compensation( if (brg.zp_type_b != brgemm_broadcast_t::none) { mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]); - for (int bd = 0; bd < bd_block; bd++) { - int zp_comp_b_off = zp_comp_b_offset(bd); - for (int ld = 0; ld < ld_block2; ld++) { + for (dim_t bd = 0; bd < bd_block; bd++) { + dim_t zp_comp_b_off = zp_comp_b_offset(bd); + for (dim_t ld = 0; ld < ld_block2; ld++) { auto vmm = accm(ld_block2, bd, ld); if (is_superset(brg.isa_impl, avx512_core)) { const auto zp_comp_b_addr = EVEX_compress_addr( @@ -1509,9 +1663,9 @@ void jit_brgemm_kernel_t::apply_compensation( if (!brg.req_cal_comp_pads && brg.req_s8s8_compensation) { mov(reg_aux_compensation, ptr[rsp + reg_aux_comp_offs_]); auto vmm_comp = vmm_tmp(0); - for (int ld = 0; ld < ld_block2; ld++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { const bool is_tail = is_ld_tail && ld + 1 == ld_block2; - for (int bd = 0; bd < bd_block; bd++) { + for (dim_t bd = 0; bd < bd_block; bd++) { if (IMPLICATION(!brg.req_comp_pads_with_bcast, bd == 0)) { const auto comp_addr = ptr[reg_aux_compensation + bd_compensation_offset(ld, bd)]; @@ -1532,7 +1686,7 @@ void jit_brgemm_kernel_t::apply_compensation( template void jit_brgemm_kernel_t::store_accumulators_without_post_ops( - int bd_block, int ld_block2, bool is_ld_tail) { + dim_t bd_block, dim_t ld_block2, bool is_ld_tail) { // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) -> // accumulated values are converted to ps in apply_alpha_beta() @@ -1543,14 +1697,12 @@ void jit_brgemm_kernel_t::store_accumulators_without_post_ops( && !IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd); if (dt_requires_saturation) { - auto vmm_ubound = vmm_tmp(0); - auto vmm_lbound = vmm_tmp(1); - init_saturate_f32( - vmm_lbound, vmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d); - for (int bd = 0; bd < bd_block; bd++) { - for (int ld = 0; ld < ld_block2; ld++) { + init_saturate_f32(vmm_lbound(), vmm_ubound(), reg_tmp_gpr, + data_type::f32, brg.dt_d); + for (dim_t bd = 0; bd < bd_block; bd++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { auto vmm = accm(ld_block2, bd, ld); - saturate_cvt_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d); + saturate_cvt_f32(vmm, vmm_lbound(), vmm_ubound(), brg.dt_d); } } // below call is not required as s32 doesn't use vmm_lbound @@ -1560,8 +1712,8 @@ void jit_brgemm_kernel_t::store_accumulators_without_post_ops( if (brg.is_runtime_ldc && bd_block > 1) mov(ptr[rsp + reg_aux_C_backup_offs_], reg_aux_C); - for_(int bd = 0; bd < bd_block; bd++) - for (int ld = 0; ld < ld_block2; ld++) { + for_(dim_t bd = 0; bd < bd_block; bd++) + for (dim_t ld = 0; ld < ld_block2; ld++) { auto vmm = accm(ld_block2, bd, ld); const auto addr_c = ptr[reg_aux_C + C_offset(bd, ld)]; const bool is_tail = is_ld_tail && ld + 1 == ld_block2; @@ -1581,15 +1733,12 @@ void jit_brgemm_kernel_t::store_accumulators_without_post_ops( } template -void jit_brgemm_kernel_t::store_accumulators(int bd_block2, - bool is_bdb_tail, int ld_block2, bool is_ld_tail, +void jit_brgemm_kernel_t::store_accumulators(dim_t bd_block2, + bool is_bdb_tail, dim_t ld_block2, bool is_ld_tail, bool skip_accumulation) { const bool has_zero_points = !everyone_is(brgemm_broadcast_t::none, brg.zp_type_a, brg.zp_type_b, brg.zp_type_c); - const bool are_post_ops_applicable = one_of(true, brg.with_eltwise, - brg.with_binary, brg.with_scales, brg.with_bias, brg.with_sum, - brg.dt_d != brg.dt_c, brg.req_s8s8_compensation, has_zero_points, - brg.with_dst_scales); + const bool are_post_ops_applicable = brg.are_post_ops_applicable(); const bool need_to_apply_alpha_beta = brg.beta != 0.f || brg.alpha != 1.f; const bool need_generate_zp_a_compensation = brg.is_int8 && (brg.req_s8s8_compensation || has_zero_points); @@ -1627,16 +1776,16 @@ void jit_brgemm_kernel_t::store_accumulators(int bd_block2, } mov(reg_buf, ptr[rsp + reg_buf_offs_]); - for (int bdb = 0; bdb < bd_block2; bdb++) { - int adj_bd_block = (brg.is_M_tail && is_bdb_tail) + for (dim_t bdb = 0; bdb < bd_block2; bdb++) { + dim_t adj_bd_block = (brg.is_M_tail && is_bdb_tail) ? brg.bdb_tail : brg.bd_block; - for (int ldb = 0; ldb < ld_block2; ldb++) { - int idx = (is_ld_tail) ? brg.ld_block2 : ldb; + for (dim_t ldb = 0; ldb < ld_block2; ldb++) { + dim_t idx = (is_ld_tail) ? brg.ld_block2 : ldb; if (need_to_apply_alpha_beta || are_post_ops_applicable || apply_zp_a_compensation) { if (skip_accumulation) { - for (int bd = 0; bd < adj_bd_block; bd++) { + for (dim_t bd = 0; bd < adj_bd_block; bd++) { auto vreg_acc = accm(1, bd, 0); uni_vpxor(vreg_acc, vreg_acc, vreg_acc); } @@ -1644,7 +1793,7 @@ void jit_brgemm_kernel_t::store_accumulators(int bd_block2, tilestored(ptr[reg_buf + reg_stride_ld_block], Tmm(brg.get_C_tensor(bdb, idx, is_bdb_tail, is_ld_tail))); - for (int bd = 0; bd < adj_bd_block; bd++) { + for (dim_t bd = 0; bd < adj_bd_block; bd++) { size_t buf_offset = (bd * brg.ld_block) * brg.typesize_C; auto vreg_acc = is_ld_tail @@ -1765,7 +1914,7 @@ void jit_brgemm_kernel_t::store_accumulators(int bd_block2, store_accumulators_amx(false); L_aligned(label_done); } else { - int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; + dim_t bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; if (need_generate_zp_a_compensation) { Label label_store_without_comp; @@ -1870,11 +2019,10 @@ void jit_brgemm_kernel_t::set_A_B_matrices() { template void jit_brgemm_kernel_t::maybe_pre_process_data(matrix_kind_t matrix_kind, - const Tmm &t1, reg64_t reg_base, size_t offset, reg64_t reg_stride, - int num_rows, int num_col_bytes, bool is_rd_tail) { - constexpr int tile_size = 1024; + const Tmm &t1, reg64_t reg_base, dim_t offset, reg64_t reg_stride, + dim_t num_rows, dim_t num_col_bytes, bool is_rd_tail) { const auto transform_offset = brg.brgattr.use_interleave_stores - ? brg.get_num_C_tiles() * tile_size + ? brg.get_num_C_tiles() * brgemm_desc_t::tilesize : 0; add(reg_buf_aux, transform_offset); @@ -1897,12 +2045,12 @@ void jit_brgemm_kernel_t::maybe_pre_process_data(matrix_kind_t matrix_kind, template void jit_brgemm_kernel_t::maybe_tileloadd_nt(matrix_kind_t matrix_kind, - int idx, int offset, bool is_rd_tail, bool is_tail) { + dim_t idx, dim_t offset, bool is_rd_tail, bool is_tail, bool last_bdb) { const bool is_A = matrix_kind == matrix_kind_t::matrix_A; - const int tmm_idx = is_A ? brg.get_A_tensor(idx, is_tail) - : brg.get_B_tensor(idx, is_tail); + const dim_t tmm_idx = is_A ? brg.get_A_tensor(idx, is_tail) + : brg.get_B_tensor(idx, is_tail); auto t1 = Tmm(tmm_idx); auto reg_base = is_A ? reg_aux_A : reg_aux_B; @@ -1912,24 +2060,25 @@ void jit_brgemm_kernel_t::maybe_tileloadd_nt(matrix_kind_t matrix_kind, == (is_A ? brgemm_bd_loop_innermost : brgemm_ld_loop_innermost); if (brg.is_fp8_via_convert()) { - const int typesize_A + const dim_t typesize_A = brg.is_input_convert() ? sizeof(int16_t) : brg.typesize_A; - const int typesize_B + const dim_t typesize_B = brg.is_input_convert() ? sizeof(int16_t) : brg.typesize_B; - int rd_step = 4 / typesize_A; - int rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block; + dim_t rd_step = 4 / typesize_A; + dim_t rd_block + = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block; if (brg.is_input_convert()) { const int vnni_granularity = data_type_vnni_granularity(data_type::f16); rd_block = utils::rnd_up(rd_block, vnni_granularity); } - int A_col = typesize_A * rd_block; - int A_row = is_tail ? brg.bdb_tail : brg.bd_block; + dim_t A_col = typesize_A * rd_block; + dim_t A_row = is_tail ? brg.bdb_tail : brg.bd_block; - int B_col = (is_tail ? brg.ldb_tail : brg.ld_block) * typesize_B + dim_t B_col = (is_tail ? brg.ldb_tail : brg.ld_block) * typesize_B * rd_step; - int B_row = brg.typesize_C != 0 ? A_col / brg.typesize_C : 0; + dim_t B_row = brg.typesize_C != 0 ? A_col / brg.typesize_C : 0; mov(ptr[rsp + reg_val_tmp_1_], reg64_fp8_aux); mov(ptr[rsp + reg_val_tmp_2_], reg_buf_aux); @@ -1940,6 +2089,10 @@ void jit_brgemm_kernel_t::maybe_tileloadd_nt(matrix_kind_t matrix_kind, mov(reg64_fp8_aux, ptr[rsp + reg_val_tmp_1_]); mov(reg_buf_aux, ptr[rsp + reg_val_tmp_2_]); } else { + if (maybe_pre_process_k_tail(last_bdb || is_tail, is_rd_tail, t1, + reg_base, offset, reg_stride, matrix_kind)) + return; + const size_t cache_footprint = static_cast(brg.typesize_A) * brg.brgattr.hint_expected_A_size + static_cast(brg.typesize_B) @@ -1955,8 +2108,72 @@ void jit_brgemm_kernel_t::maybe_tileloadd_nt(matrix_kind_t matrix_kind, } template -void jit_brgemm_kernel_t::gemm_microkernel_amx(int bd_block2, - bool is_bdb_tail, int ld_block2, bool is_rd_tail, bool is_ld_tail) { +bool jit_brgemm_kernel_t::maybe_pre_process_k_tail(bool last_bdb, + bool is_rd_tail, const Tmm &t1, reg64_t reg_base, dim_t offset, + reg64_t reg_stride, matrix_kind_t mk) { + + // TODO: check is it last bs to calculate need_k_tail_processing + const auto need_k_tail_processing = mk == matrix_A && brg.amx_wary_k_tail() + && brg.rdb_tail != 0 && last_bdb && is_rd_tail; + if (!need_k_tail_processing) return false; + + const auto zmm_width_in_bytes = cpu_isa_traits_t::vlen; + + auto transform_offset = brg.get_num_C_tiles() * brgemm_desc_t::tilesize + + brg.get_convert_wsp_buffer_size(); + + //TODO: reuse transformed data from matrix A for ldi > 0 + const dim_t num_rows = palette_.rows[t1.getIdx()]; + const dim_t num_col_bytes = palette_.cols[t1.getIdx()]; + + const auto max_num_cols + = nstl::min(num_col_bytes / brg.typesize_A, brg.rdb_tail); + const size_t col_tail + = max_num_cols % (zmm_width_in_bytes / brg.typesize_A); + if (col_tail) { + const auto tail_mask = (static_cast(1) << col_tail) - 1; + mov(reg_tmp_gpr, tail_mask); + kmovq(rd_tail_mask, reg_tmp_gpr); + } + auto zmm_1 = zmm_tmp_1(); + auto zmm_1_masked = col_tail ? zmm_1 | rd_tail_mask | T_z : zmm_1; + + assert(max_num_cols > 0); + + mov(ptr[rsp + reg_val_tmp_2_], reg_buf_aux); + + mov(reg_buf_aux, ptr[rsp + reg_buf_offs_]); + if (transform_offset) add(reg_buf_aux, transform_offset); + + for (dim_t r = 0; r < num_rows; ++r) { + const auto row_offset = offset + r * brg.typesize_A * brg.LDA; + switch (brg.dt_a) { + case data_type::bf16: + case data_type::f16: + vmovdqu16(zmm_1_masked, ptr[reg_base + row_offset]); + break; + case data_type::f8_e5m2: + case data_type::f8_e4m3: + case data_type::s8: + case data_type::u8: + vmovdqu8(zmm_1_masked, ptr[reg_base + row_offset]); + break; + default: assert(!"unsupported data type"); + } + vmovups(ptr[reg_buf_aux + r * zmm_width_in_bytes], zmm_1); + } + // load into tmm from the transformed data. + mov(reg_converted_stride, zmm_width_in_bytes); + tileloadd(t1, ptr[reg_buf_aux + reg_converted_stride]); + mov(reg_buf_aux, ptr[rsp + reg_val_tmp_2_]); + + return true; +} + +template +void jit_brgemm_kernel_t::gemm_microkernel_amx(dim_t bd_block2, + bool is_bdb_tail, dim_t ld_block2, bool is_rd_tail, bool is_ld_tail, + bool last_bdb) { auto tdpbxxd = [this](const Tmm &x1, const Tmm &x2, const Tmm &x3) { if (brg.is_fp8) { if (brg.is_fp8_via_convert()) @@ -1979,20 +2196,20 @@ void jit_brgemm_kernel_t::gemm_microkernel_amx(int bd_block2, assert(!"unsupported combination"); } }; - int rbd_block = (is_rd_tail) ? 1 : brg.rdb; - for (int rdb = 0; rdb < rbd_block; rdb++) { - for (int bdb = 0; bdb < bd_block2; bdb++) { + dim_t rbd_block = (is_rd_tail) ? 1 : brg.rdb; + for (dim_t rdb = 0; rdb < rbd_block; rdb++) { + for (dim_t bdb = 0; bdb < bd_block2; bdb++) { maybe_tileloadd_nt(matrix_kind_t::matrix_A, bdb, rdb * rdb_A_offset() + A_offset(bdb, 0, true), is_rd_tail, - is_bdb_tail); + is_bdb_tail, last_bdb && bdb == bd_block2 - 1); } - for (int ldb = 0; ldb < ld_block2; ldb++) { + for (dim_t ldb = 0; ldb < ld_block2; ldb++) { - const int idx = (is_ld_tail) ? brg.ld_block2 : ldb; + const dim_t idx = (is_ld_tail) ? brg.ld_block2 : ldb; maybe_tileloadd_nt(matrix_kind_t::matrix_B, idx, rdb * rdb_B_offset() + B_offset(ldb, 0, true), is_rd_tail, - is_ld_tail); - for (int bdb = 0; bdb < bd_block2; bdb++) { + is_ld_tail, false); + for (dim_t bdb = 0; bdb < bd_block2; bdb++) { tdpbxxd(Tmm(brg.get_C_tensor( bdb, idx, is_bdb_tail, is_ld_tail)), Tmm(brg.get_A_tensor(bdb, is_bdb_tail)), @@ -2017,9 +2234,7 @@ void jit_brgemm_kernel_t::dot_product(Vmm v1, Vmm v2, Vmm v3) { if (brg.dt_a == data_type::s8 && isa_has_s8s8(brg.isa_impl)) vpdpbssd(v1, v3, v2); else if (brg.has_int8_vnni) - vpdpbusd(v1, v3, v2, - is_superset(brg.isa_impl, avx512_core) ? EvexEncoding - : VexEncoding); + vpdpbusd(v1, v3, v2, get_encoding()); else { vpmaddubsw(int8_dot_product_temp(), v3, v2); vpmaddwd(int8_dot_product_temp(), int8_dot_product_temp(), @@ -2030,12 +2245,13 @@ void jit_brgemm_kernel_t::dot_product(Vmm v1, Vmm v2, Vmm v3) { } template -void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, - int bd_e, int bd_block, int ld_block2, bool is_ld_tail, int vpad) { +void jit_brgemm_kernel_t::compute_int8_compensation(dim_t rd_loop, + dim_t bd_b, dim_t bd_e, dim_t bd_block, dim_t ld_block2, + bool is_ld_tail, dim_t vpad) { assert(brg.is_int8); auto compensation_padding = [this, ld_block2](Vmm vmm_load, Vmm vmm_tmp, - int ld, int bd_b, int bd_e) { + dim_t ld, dim_t bd_b, dim_t bd_e) { // req_cal_comp_pads -> only calculate compensation along with // computation and do not use pre-calculated compensation. // Calculate comp padding as: @@ -2046,7 +2262,7 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, dot_product(vmm_tmp, vmm_load, vmm_inp_shift()); } - for (int bd = bd_b; bd < bd_e; bd++) { + for (dim_t bd = bd_b; bd < bd_e; bd++) { auto vmm = accm(ld_block2, bd, ld); if (brg.req_cal_comp_pads) { uni_vpsubd(vmm, vmm, vmm_tmp); @@ -2061,7 +2277,7 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, dot_product(vmm_tmp, vmm_load, vmm_one_bytes()); uni_vpmulld(vmm_tmp, vmm_tmp, vmm_zp_a_shift()); - for (int bd = bd_b; bd < bd_e; bd++) { + for (dim_t bd = bd_b; bd < bd_e; bd++) { auto vmm = accm(ld_block2, bd, ld); if (brg.req_cal_comp_pads) { uni_vpsubd(vmm, vmm, vmm_tmp); @@ -2072,7 +2288,7 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, } }; - if (n_bcast_1_load && brg.zp_type_a != brgemm_broadcast_t::none) { + if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) { mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); const auto reg32_scratch = reg_zp_a_input_shift.cvt32(); mov(reg32_scratch, 0x1010101); @@ -2082,16 +2298,15 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); } - for_(int rd = 0; rd < rd_loop; rd += brg.rd_step) - for (int ld = 0; ld < ld_block2; ++ld) { + for_(dim_t rd = 0; rd < rd_loop; rd += brg.rd_step) + for (dim_t ld = 0; ld < ld_block2; ++ld) { const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; const bool is_tail = is_ld_tail && ld + 1 == ld_block2; if (IMPLICATION(is_tail, is_superset(brg.isa_impl, avx512_core))) { auto vmm_store = vmm_mask(load(), is_tail, false, ld_tail_mask); uni_vmovups(vmm_store, addr); } else { - load_bytes( - load(), addr, brg.typesize_B * brg.ldb_tail * brg.ld_step); + load_bytes(load(), addr, ldb_B_offset(0, true)); } if (brg.req_cal_comp_pads) { @@ -2105,14 +2320,11 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, } template -void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, - int ld_block2, bool is_rd_tail, bool is_ld_tail, int vpad, - int rows_for_rd_tail) { - assert(!brg.is_fp8_via_convert() && "No non-AMX path for fp8"); - - MAYBE_UNUSED(bd_block2); - int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; - const auto bd_b = nstl::max(0, vpad); +void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(dim_t bd_block2, + bool is_bdb_tail, dim_t ld_block2, bool is_rd_tail, bool is_ld_tail, + dim_t vpad, dim_t rows_for_rd_tail) { + dim_t bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; + const auto bd_b = nstl::max((dim_t)0, vpad); const auto bd_e = nstl::min(bd_block, bd_block + vpad); const auto is_valid_bd = need_comp_pads && vpad != 0 ? bd_b <= bd_e : bd_b < bd_e; @@ -2132,9 +2344,13 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, } else rd_loop = brg.rd_block; + bool maybe_load_bytes = (rows_for_rd_tail > 0 || brg.brgattr.wary_A_k_tail_read) + && is_rd_tail && rd_tail_size != 0 && (brg.is_bf16 || brg.is_int8); + auto broadcast = [this, rd_tail_size](Vmm v1, size_t offset, bool is_tail, data_type_t dt) { if (is_tail) { + uni_vpxor(v1, v1, v1); Xmm xmm_tmp = Xmm(v1.getIdx()); load_bytes( xmm_tmp, reg_aux_A, offset, rd_tail_size * brg.typesize_A); @@ -2160,6 +2376,285 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift()); }; + static const int8_t mask_low_half[64] = { + 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, + 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, + 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, + 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F + }; + + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); + + auto reg_ptr = reg_bdb_loop; + auto vmm_mask_low_half = Vmm(isa_num_vregs(brg.isa_impl) - 1); + mov(reg_ptr, (size_t)mask_low_half); + uni_vmovups(vmm_mask_low_half, ptr[reg_ptr]); + + const int vec_size = vreg_traits_t::vlen; + auto accums_stack_space = bd_e * ld_block2 * vec_size; + sub(rsp, accums_stack_space); + for (int bd = bd_b; bd < bd_e; bd++) { + for (int ld = 0; ld < ld_block2; ld++) { + auto vmm_accm = accm(ld_block2, bd, ld); + vmovups(ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm); + + uni_vxorps(vmm_accm, vmm_accm, vmm_accm); + } + } + + for (int rd = 0; rd < rd_loop; rd += brg.rd_step) { + int prefetch_count_B = 0; + for (int ld = 0; ld < ld_block2; ld++) { + const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; + const Vmm vmm_load = vmm_mask(load(ld), is_ld_tail, false, ld_tail_mask); + if (brg.dt_b == data_type::u8) { + uni_vmovups(vmm_load, addr); + } else if (brg.dt_b == data_type::u4) { + uni_vmovups(vmm_load, addr); + if (rd % 8 == 0) + uni_vpsrld(vmm_load, vmm_load, 4); + uni_vandps(vmm_load, vmm_load, vmm_mask_low_half); + } else { + assert(!"unsupported combination"); + } + } + + bool have_to_load_bytes + = maybe_load_bytes && (rd == rd_loop - brg.rd_step); + + auto rows_by_load_bytes = have_to_load_bytes ? rows_for_rd_tail : 0; + for (int bd = bd_b; bd < bd_e; bd++) { + if (!is_emdbd) { + const auto bd_by_load_bytes + = (bd >= bd_e - rows_by_load_bytes + || brg.brgattr.wary_A_k_tail_read); + broadcast(bcst(), A_offset(bd, rd), + have_to_load_bytes && bd_by_load_bytes, brg.dt_a); + } + if (prefetch_count_B < ld_block2) { + int typesize_scale = brg.dt_b == data_type::u4 ? 2 : 1; + prefetcht0(ptr[reg_aux_B + B_offset(prefetch_count_B++, rd) + + brg.LDB * brg.rd_block * brg.typesize_B / typesize_scale]); + } + for (int ld = 0; ld < ld_block2; ld++) { + auto vmm = accm(ld_block2, bd, ld); + vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding); + } + } + } + + auto vmm_zero_point = [&](int ld) { + return load(ld); + }; + + auto reg_local_wei_zp = reg_ldb_loop; + auto reg_local_src_grouped_sum = reg_bdb_loop; + auto vmm_tmp = Vmm(isa_num_vregs(brg.isa_impl) - 1); + auto vmm_src_grouped_sum = bcst(); + + if (brg.with_wei_decomp_zero_points) { + mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_ + accums_stack_space]); + if (brg.wei_decomp_zero_points_stride == 0) { + Vmm vmm_zp = vmm_zero_point(0); + auto reg_ptr_32 = Reg32(reg_ptr.getIdx()); + movzx(reg_ptr_32, ptr[reg_local_wei_zp]); + uni_vmovq(Xmm(vmm_zp.getIdx()), reg_ptr); + uni_vbroadcastss(vmm_zp, Xmm(vmm_zp.getIdx())); + } + + mov(reg_local_src_grouped_sum, ptr[rsp + reg_aux2_src_grouped_sum_offs_ + accums_stack_space]); + for (int bd = bd_b; bd < bd_e; bd++) { + uni_vbroadcastss(vmm_src_grouped_sum, ptr[reg_local_src_grouped_sum + bd * brg.src_grouped_sum_stride * sizeof(int32_t)]); + for (int ld = 0; ld < ld_block2; ld++) { + Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld); + if (bd == bd_b && brg.wei_decomp_zero_points_stride != 0) { + uni_vpmovzxbd(vmm_zp, ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]); + } + + auto vmm_accm = accm(ld_block2, bd, ld); + uni_vpmulld(vmm_tmp, vmm_src_grouped_sum, vmm_zp); + uni_vpsubd(vmm_accm, vmm_accm, vmm_tmp); + } + } + } + + auto wei_scale = [&](int ld) { + return load(ld); + }; + + auto reg_local_src_scales = reg_ldb_loop; + auto reg_local_wei_scales = reg_bdb_loop; + auto vmm_src_scales = bcst(); + + mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_ + accums_stack_space]); + mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]); + if (brg.wei_decomp_scales_stride == 0) { + uni_vbroadcastss(wei_scale(0), ptr[reg_local_wei_scales]); + } + + for (int bd = bd_b; bd < bd_e; bd++) { + uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]); + for (int ld = 0; ld < ld_block2; ld++) { + auto vmm_wei_scale = brg.wei_decomp_scales_stride == 0 ? wei_scale(0) : wei_scale(ld); + if (bd == bd_b && brg.wei_decomp_scales_stride != 0) { + uni_vmovups(vmm_wei_scale, ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]); + } + + auto vmm_accm = accm(ld_block2, bd, ld); + uni_vcvtdq2ps(vmm_accm, vmm_accm); + uni_vmulps(vmm_tmp, vmm_accm, vmm_src_scales); + uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]); + uni_vfmadd231ps(vmm_accm, vmm_tmp, vmm_wei_scale); + } + } + + add(rsp, accums_stack_space); + mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); + mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + + return; +} + +template +void jit_brgemm_kernel_t::gemm_microkernel(dim_t bd_block2, + bool is_bdb_tail, dim_t ld_block2, bool is_rd_tail, bool is_ld_tail, + dim_t vpad, dim_t rows_for_rd_tail) { + assert(!brg.is_fp8_via_convert() && "No non-AMX path for fp8"); + MAYBE_UNUSED(bd_block2); + + if (brg.with_src_dyn_quant) { + gemm_microkernel_dyn_quant(bd_block2, is_bdb_tail, ld_block2, is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); + return; + } + + dim_t bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; + const auto bd_b = nstl::max(dim_t(0), vpad); + const auto bd_e = nstl::min(bd_block, bd_block + vpad); + const auto is_valid_bd + = need_comp_pads && vpad != 0 ? bd_b <= bd_e : bd_b < bd_e; + if (!is_valid_bd) return; + + bool is_emdbd = brg.embd_bcst; + + dim_t rd_loop = 0, rd_tail_size = 0; + if (is_rd_tail) { + if (brg.is_bf16 || brg.is_int8) { + rd_tail_size = brg.rdb_tail % brg.rd_step; + rd_loop = (rd_tail_size != 0) + ? ((brg.rdb_tail / brg.rd_step) + 1) * brg.rd_step + : brg.rdb_tail; + } else + rd_loop = brg.rdb_tail; + } else + rd_loop = brg.rd_block; + + if (brg.req_s8s8_compensation) { + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + mov(reg_s8_input_shift, 128); + uni_vpbroadcastb(vmm_inp_shift(), reg_s8_input_shift.cvt8()); + mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + } + + auto broadcast_A = [this, rd_tail_size, is_rd_tail, rd_loop, + rows_for_rd_tail, + bd_e](Vmm vmm_bcast, dim_t bd, dim_t rd) { + const auto offset = A_offset(bd, rd); + const auto dt = brg.dt_a; + const bool maybe_load_bytes + = (rows_for_rd_tail > 0 || brg.brgattr.wary_A_k_tail_read) + && is_rd_tail && rd_tail_size != 0 + && (brg.is_bf16 || brg.is_int8); + const bool have_to_load_bytes + = maybe_load_bytes && (rd == rd_loop - brg.rd_step); + const auto rows_by_load_bytes + = have_to_load_bytes ? rows_for_rd_tail : 0; + const auto bd_by_load_bytes = (bd >= bd_e - rows_by_load_bytes + || brg.brgattr.wary_A_k_tail_read); + const auto is_tail = have_to_load_bytes && bd_by_load_bytes; + if (is_tail) { + Xmm xmm_tmp = Xmm(vmm_bcast.getIdx()); + load_bytes( + xmm_tmp, reg_aux_A, offset, rd_tail_size * brg.typesize_A); + uni_vpbroadcastd(vmm_bcast, xmm_tmp); + } else { + if (dt == data_type::f32) { + uni_vbroadcastss(vmm_bcast, ptr[reg_aux_A + offset]); + } else if (dt == data_type::bf16) { + if (brg.isa_impl == avx2_vnni_2) + vbcstnebf162ps(vmm_bcast, ptr[reg_aux_A + offset]); + else + uni_vpbroadcastd(vmm_bcast, ptr[reg_aux_A + offset]); + } else if (one_of(dt, data_type::s8, data_type::u8)) { + uni_vpbroadcastd(vmm_bcast, ptr[reg_aux_A + offset]); + } else if (dt == data_type::f16) { + if (brg.isa_impl == avx2_vnni_2) { + vbcstnesh2ps(vmm_bcast, ptr[reg_aux_A + offset]); + } else if (is_superset(brg.isa_impl, avx512_core_fp16)) { + // Broadcast is not supported for legacy f16-conversions. + vcvtph2psx(vmm_bcast, ptr_b[reg_aux_A + offset]); + } + } + } + + if (brg.req_s8s8_compensation) + uni_vpaddb(vmm_bcast, vmm_bcast, vmm_inp_shift()); + }; + + auto load_B = [this, is_ld_tail](dim_t vmm_load_idx, dim_t rd, dim_t ld) { + const Vmm vmm_load + = vmm_mask(load(vmm_load_idx), is_ld_tail, false, ld_tail_mask); + const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; + // Note: Assuming the tails are properly padded/blocked for + // avx2_vnni_2 with xf16 data type, as the B matrix is generally + // at least double-blocked. + if (brg.dt_b == data_type::f16) { + if (brg.isa_impl == avx2_vnni_2) { + if (rd % 2 == 0) + vcvtneeph2ps(vmm_load, addr); + else + vcvtneoph2ps(vmm_load, addr); + } else if (brg.is_f16_b_non_amx_vnni()) { + const auto actual_B_offset = B_offset(ld, utils::rnd_dn(rd, 2)); + const auto vnni_addr = ptr[reg_aux_B + actual_B_offset]; + vmovups(vmm_load, vnni_addr); + if (rd % 2 == 0) + vpermw(vmm_load, f16_perm_even_vreg(), vmm_load); + else + vpermw(vmm_load, f16_perm_odd_vreg(), vmm_load); + vcvtph2psx(vmm_load, Vmm_lower_t(vmm_load.getIdx())); + } else if (is_ld_tail && !is_superset(brg.isa_impl, avx512_core)) { + load_bytes(vmm_load, addr, ldb_B_offset(0, true)); + vcvtph2ps(vmm_load, Xmm(vmm_load.getIdx())); + } else { + uni_vcvtph2psx(vmm_load, addr); + } + } else if (brg.dt_b == data_type::bf16) { + if (brg.isa_impl == avx2_vnni_2) { + if (rd % 2 == 0) + vcvtneebf162ps(vmm_load, addr); + else + vcvtneobf162ps(vmm_load, addr); + } else if (utils::one_of(brg.isa_impl, avx512_core, avx2) && brg.is_f32) { + // Upconvert: load 16 bits and move them 16 bits left. + uni_vpmovzxwd(vmm_load, addr); + uni_vpslld(vmm_load, vmm_load, 16); + } else if (is_ld_tail && !is_superset(brg.isa_impl, avx512_core)) { + load_bytes(vmm_load, addr, ldb_B_offset(0, true)); + } else { + uni_vmovups(vmm_load, addr); + } + } else if (is_ld_tail) { + if (is_superset(brg.isa_impl, avx512_core)) { + uni_vmovups(vmm_load, addr); + } else { + load_bytes(vmm_load, addr, ldb_B_offset(0, true)); + } + } else { + uni_vmovups(vmm_load, addr); + } + }; + const bool comp_vpad = vpad != 0 && (brg.req_s8s8_compensation || brg.zp_type_a != brgemm_broadcast_t::none); @@ -2167,52 +2662,14 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, compute_int8_compensation( rd_loop, bd_b, bd_e, bd_block, ld_block2, is_ld_tail, vpad); - bool maybe_load_bytes = (rows_for_rd_tail > 0 || brg.brgattr.wary_tail_read) - && is_rd_tail && rd_tail_size != 0 && (brg.is_bf16 || brg.is_int8); - if (n_bcast_1_load) { - for (int rd = 0; rd < rd_loop; rd += brg.rd_step) { - bool have_to_load_bytes - = maybe_load_bytes && (rd == rd_loop - brg.rd_step); - - auto rows_by_load_bytes = have_to_load_bytes ? rows_for_rd_tail : 0; - for (int bd = bd_b; bd < bd_e && !is_emdbd; bd++) { - const auto bd_by_load_bytes = (bd >= bd_e - rows_by_load_bytes - || brg.brgattr.wary_tail_read); - broadcast(bcst(bd), A_offset(bd, rd), - have_to_load_bytes && bd_by_load_bytes, brg.dt_a); - } - for (int ld = 0; ld < ld_block2; ld++) { - const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; - const Vmm vmm_load - = vmm_mask(load(), is_ld_tail, false, ld_tail_mask); - // Note: Assuming the tails are properly padded/blocked for - // avx2_vnni_2 with xf16 data type, as the B matrix is generally - // at least double-blocked. - if (brg.dt_b == data_type::f16) { - if (brg.isa_impl == avx2_vnni_2) { - if (rd % 2 == 0) - vcvtneeph2ps(vmm_load, addr); - else - vcvtneoph2ps(vmm_load, addr); - } else - vcvtph2psx(vmm_load, addr); - } else if (brg.dt_b == data_type::bf16 - && brg.isa_impl == avx2_vnni_2) { - if (rd % 2 == 0) - vcvtneebf162ps(vmm_load, addr); - else - vcvtneobf162ps(vmm_load, addr); - } else if (is_ld_tail) { - if (is_superset(brg.isa_impl, avx512_core)) { - uni_vmovups(vmm_load, addr); - } else { - load_bytes(vmm_load, addr, - brg.typesize_B * brg.ldb_tail * brg.ld_step); - } - } else { - uni_vmovups(vmm_load, addr); - } - for (int bd = bd_b; bd < bd_e; bd++) { + for (dim_t rd = 0; rd < rd_loop; rd += brg.rd_step) { + if (brg.n_bcast_1_load) { + for (dim_t bd = bd_b; bd < bd_e && !is_emdbd; bd++) + broadcast_A(bcst(bd), bd, rd); + for (dim_t ld = 0; ld < ld_block2; ld++) { + load_B(0, rd, ld); + + for (dim_t bd = bd_b; bd < bd_e; bd++) { auto vmm = accm(ld_block2, bd, ld); if (is_emdbd) uni_vfmadd231ps(vmm, load(), @@ -2221,61 +2678,337 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, dot_product(vmm, load(), bcst(bd)); } } - } - } else { - for (int rd = 0; rd < rd_loop; rd += brg.rd_step) { - int prefetch_count_B = 0; - for (int ld = 0; ld < ld_block2; ld++) { - const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; - const Vmm vmm_load - = vmm_mask(load(ld), is_ld_tail, false, ld_tail_mask); - // Note: Assuming the tails are properly padded/blocked for - // avx2_vnni_2, as the B matrix is generally - // at least double-blocked. - if (brg.dt_b == data_type::f16) { - if (brg.isa_impl == avx2_vnni_2) { - if (rd % 2 == 0) - vcvtneeph2ps(vmm_load, addr); + } else { + if (brg.with_wei_decomp) { + auto reg_local_wei_scales = reg_bdb_loop; + auto reg_local_wei_zp = reg_ldb_loop; + auto reg_ptr = reg_local_wei_zp; + + auto accm_tmp = [&](int ld_block, int bd, int ld) { + int idx = max_effective_vregs - 1 - 2 * (brg.ld_block2 * brg.bd_block) - ld; + return Vmm(idx); + }; + + auto load_zero_points = [&](Vmm vmm_zp, Xbyak::Address addr) { + if (brg.wei_decomp_zero_points_stride == 0) { + switch (brg.wei_decomp_zero_points_dt) { + case data_type::f32: { + uni_vbroadcastss(vmm_zp, addr); + break; + } + case data_type::u8: { + auto xmm_zp = Xmm(vmm_zp.getIdx()); + auto reg_ptr_32 = Reg32(reg_ptr.getIdx()); + movzx(reg_ptr_32, addr); + uni_vmovq(xmm_zp, reg_ptr); + uni_vcvtdq2ps(xmm_zp, xmm_zp); + uni_vbroadcastss(vmm_zp, xmm_zp); + break; + } + default: assert(!"unsupported data type"); + } + } else { + switch (brg.wei_decomp_zero_points_dt) { + case data_type::f32: { + uni_vmovups(vmm_zp, addr); + break; + } + case data_type::u8: { + uni_vpmovzxbd(vmm_zp, addr); + uni_vcvtdq2ps(vmm_zp, vmm_zp); + break; + } + default: assert(!"unsupported data type"); + } + } + }; + + auto load_scales = [&](Vmm vmm_scales, Xbyak::Address addr) { + if (brg.wei_decomp_scales_stride == 0) { + switch (brg.wei_decomp_scales_dt) { + case data_type::f32: { + uni_vbroadcastss(vmm_scales, addr); + break; + } + case data_type::e8m0: { + auto xmm_scales = Xmm(vmm_scales.getIdx()); + auto reg_ptr_32 = Reg32(reg_ptr.getIdx()); + movzx(reg_ptr_32, addr); + uni_vmovq(xmm_scales, reg_ptr); + uni_vpslld(xmm_scales, xmm_scales, 23); + uni_vbroadcastss(vmm_scales, xmm_scales); + break; + } + default: assert(!"unsupported data type"); + } + } else { + switch (brg.wei_decomp_scales_dt) { + case data_type::f32: { + uni_vmovups(vmm_scales, addr); + break; + } + case data_type::e8m0: { + uni_vpmovzxbd(vmm_scales, addr); + uni_vpslld(vmm_scales, vmm_scales, 23); + break; + } + default: assert(!"unsupported data type"); + } + } + }; + + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); + + auto vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 1); + auto vmm_mask8 = Vmm(isa_num_vregs(brg.isa_impl) - 1); + auto vmm_mask7 = Vmm(isa_num_vregs(brg.isa_impl) - 2); + auto vmm_lookup = Vmm(isa_num_vregs(brg.isa_impl) - 1); + auto vmm_lookup_low = Vmm(isa_num_vregs(brg.isa_impl) - 3); + auto vmm_lookup_high = Vmm(isa_num_vregs(brg.isa_impl) - 4); + auto vmm_mask_signed_bit = Vmm(isa_num_vregs(brg.isa_impl) - 2); + if (brg.dt_b == data_type::nf4) { + static const float lookup[16] = { + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0}; + + static const int32_t mask8[16] = { + 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8 + }; + static const int32_t mask7[16] = { + 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7 + }; + + if (brg.isa_impl == avx2) { + mov(reg_ptr, (size_t)lookup); + uni_vmovups(vmm_lookup_low, ptr[reg_ptr]); + mov(reg_ptr, (size_t)lookup); + uni_vmovups(vmm_lookup_high, ptr[reg_ptr + 8 * sizeof(float)]); + mov(reg_ptr, (size_t)mask8); + uni_vmovups(vmm_mask8, ptr[reg_ptr]); + mov(reg_ptr, (size_t)mask7); + uni_vmovups(vmm_mask7, ptr[reg_ptr]); + if (brg.wei_decomp_zero_points_stride == 0) + vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 6); else - vcvtneoph2ps(vmm_load, addr); + vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 5); } else { - vcvtph2psx(vmm_load, addr); + mov(reg_ptr, (size_t)lookup); + uni_vmovups(vmm_lookup, ptr[reg_ptr]); + vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 2); } - } else if (brg.dt_b == data_type::bf16 - && brg.isa_impl == avx2_vnni_2) { - if (rd % 2 == 0) - vcvtneebf162ps(vmm_load, addr); - else - vcvtneobf162ps(vmm_load, addr); - } else if (is_ld_tail) { - if (is_superset(brg.isa_impl, avx512_core)) { - uni_vmovups(vmm_load, addr); + } else if (brg.dt_b == data_type::f4_e2m1) { + static const float lookup[16] = { + 0.0f, 0.5f, + 1.0f, 1.5f, + 2.0f, 3.0f, + 4.0f, 6.0f, + -0.0f, -0.5f, + -1.0f, -1.5f, + -2.0f, -3.0f, + -4.0f, -6.0f + }; + + static const uint32_t mask_signed_bit[8] = { + 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0x80000000, 0x80000000, + }; + + if (brg.isa_impl == avx2) { + mov(reg_ptr, (size_t)lookup); + uni_vmovups(vmm_lookup, ptr[reg_ptr]); + mov(reg_ptr, (size_t)mask_signed_bit); + uni_vmovups(vmm_mask_signed_bit, ptr[reg_ptr]); + vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 3); } else { - load_bytes(vmm_load, addr, - brg.typesize_B * brg.ldb_tail * brg.ld_step); + mov(reg_ptr, (size_t)lookup); + uni_vmovups(vmm_lookup, ptr[reg_ptr]); + vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 2); } - } else { - uni_vmovups(vmm_load, addr); } - } - bool have_to_load_bytes - = maybe_load_bytes && (rd == rd_loop - brg.rd_step); + mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]); + mov(reg_local_wei_zp, ptr[rsp + reg_aux2_wei_zero_points_offs_]); - auto rows_by_load_bytes = have_to_load_bytes ? rows_for_rd_tail : 0; - for (int bd = bd_b; bd < bd_e; bd++) { - if (!is_emdbd) { - const auto bd_by_load_bytes - = (bd >= bd_e - rows_by_load_bytes - || brg.brgattr.wary_tail_read); - broadcast(bcst(), A_offset(bd, rd), - have_to_load_bytes && bd_by_load_bytes, brg.dt_a); + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0) { + load_zero_points(vmm_zero_points, ptr[reg_local_wei_zp]); + } + + for (int rd = 0; rd < rd_loop; rd += brg.rd_step) { + int prefetch_count_B = 0; + for (int ld = 0; ld < ld_block2; ld++) { + const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; + const Vmm vmm_load = vmm_mask(load(ld), is_ld_tail, false, ld_tail_mask); + if (brg.dt_b == data_type::u8) { + uni_vpmovzxbd(vmm_load, addr); + uni_vcvtdq2ps(vmm_load, vmm_load); + } else if (brg.dt_b == data_type::s8) { + uni_vpmovsxbd(vmm_load, addr); + uni_vcvtdq2ps(vmm_load, vmm_load); + } else if (brg.dt_b == data_type::u4) { + uni_vpmovzxbd(vmm_load, addr); + if (rd % 2 == 0) { + uni_vpsrld(vmm_load, vmm_load, 4); + } else { + uni_vpslld(vmm_load, vmm_load, 28); + uni_vpsrld(vmm_load, vmm_load, 28); + } + uni_vcvtdq2ps(vmm_load, vmm_load); + } else if (brg.dt_b == data_type::s4) { + if (rd % 2 == 0) { + uni_vpmovsxbd(vmm_load, addr); + vpsrad(vmm_load, vmm_load, 4); + } else { + uni_vpmovsxbd(vmm_load, addr); + uni_vpslld(vmm_load, vmm_load, 28); + vpsrad(vmm_load, vmm_load, 28); + } + uni_vcvtdq2ps(vmm_load, vmm_load); + } else if (brg.dt_b == data_type::nf4) { + uni_vpmovzxbd(vmm_load, addr); + if (rd % 2 == 0) { + uni_vpsrld(vmm_load, vmm_load, 4); + } else { + uni_vpslld(vmm_load, vmm_load, 28); + uni_vpsrld(vmm_load, vmm_load, 28); + } + + if (brg.isa_impl == avx2) { + auto res = bcst(); + auto mask = Vmm(isa_num_vregs(brg.isa_impl) - 5); + vpcmpgtd(mask, vmm_load, vmm_mask7); + vpermd(res, vmm_load, vmm_lookup_low); + vpsubd(vmm_load, vmm_load, vmm_mask8); + vpermd(vmm_load, vmm_load, vmm_lookup_high); + vblendvps(vmm_load, res, vmm_load, mask); + } else { + vpermd(vmm_load, vmm_load, vmm_lookup); + } + } else if (brg.dt_b == data_type::f4_e2m1) { + if (brg.isa_impl == avx2) { + uni_vpmovsxbd(vmm_load, addr); + if (rd % 2 == 0) { + vpsrad(vmm_load, vmm_load, 4); + } else { + uni_vpslld(vmm_load, vmm_load, 28); + vpsrad(vmm_load, vmm_load, 28); + } + auto mask = bcst(); + uni_vpand(mask, vmm_load, vmm_mask_signed_bit); + vpermd(vmm_load, vmm_load, vmm_lookup); + uni_vorps(vmm_load, vmm_load, mask); + } else { + uni_vpmovzxbd(vmm_load, addr); + if (rd % 2 == 0) { + uni_vpsrld(vmm_load, vmm_load, 4); + } else { + uni_vpslld(vmm_load, vmm_load, 28); + uni_vpsrld(vmm_load, vmm_load, 28); + } + vpermd(vmm_load, vmm_load, vmm_lookup); + } + } else { + assert(!"unsupported combination"); + } + + if (brg.with_wei_decomp_zero_points) { + if (brg.wei_decomp_zero_points_stride == 0) { + uni_vsubps(vmm_load, vmm_load, vmm_zero_points); + } else { + load_zero_points(bcst(), ptr[reg_local_wei_zp + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_zero_points_dt)]); + uni_vsubps(vmm_load, vmm_load, bcst()); + } + } + + if (brg.with_wei_decomp_scales && brg.bd_block != 1) { + if (brg.wei_decomp_scales_stride == 0) { + load_scales(bcst(), ptr[reg_local_wei_scales]); + } else { + load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]); + } + uni_vmulps(vmm_load, vmm_load, bcst()); + } + } + + for (int bd = bd_b; bd < bd_e; bd++) { + if (!is_emdbd) { + if (brg.dt_a == data_type::bf16) { + vpbroadcastw(bcst(), ptr[reg_aux_A + A_offset(bd, rd)]); + uni_vpmovzxwd(bcst(), bcst()); + uni_vpslld(bcst(), bcst(), 16); + } else { + broadcast_A(bcst(bd), bd, rd); + } + } + if (prefetch_count_B < ld_block2) { + prefetcht0(ptr[reg_aux_B + B_offset(prefetch_count_B++, rd) + + brg.LDB * brg.rd_block * brg.typesize_B]); + } + for (int ld = 0; ld < ld_block2; ld++) { + auto vmm = brg.bd_block != 1 ? accm(ld_block2, bd, ld) + : accm_tmp(ld_block2, bd, ld); + if (brg.bd_block == 1 && rd == 0) { + if (is_emdbd) + uni_vmulps(vmm, load(ld), ptr_b[reg_aux_A + A_offset(bd, rd)]); + else + uni_vmulps(vmm, load(ld), bcst()); + } else { + if (is_emdbd) + uni_vfmadd231ps(vmm, load(ld), ptr_b[reg_aux_A + A_offset(bd, rd)]); + else + uni_vfmadd231ps(vmm, load(ld), bcst()); + } + } + } + } + + if (brg.with_wei_decomp_scales && brg.bd_block == 1) { + for (int ld = 0; ld < ld_block2; ld++) { + auto vmm_accm_tmp = accm_tmp(ld_block2, 0, ld); + auto vmm_accm = accm(ld_block2, 0, ld); + if (brg.wei_decomp_scales_stride == 0) { + load_scales(bcst(), ptr[reg_local_wei_scales]); + } else { + load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]); + } + uni_vfmadd231ps(vmm_accm, vmm_accm_tmp, bcst()); + } } + + mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); + mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + + return; + } + + dim_t prefetch_count_B = 0; + for (dim_t ld = 0; ld < ld_block2; ld++) { + load_B(ld, rd, ld); + } + + for (dim_t bd = bd_b; bd < bd_e; bd++) { + if (!is_emdbd) broadcast_A(bcst(), bd, rd); if (prefetch_count_B < ld_block2) { prefetcht0(ptr[reg_aux_B + B_offset(prefetch_count_B++, rd) + brg.LDB * brg.rd_block * brg.typesize_B]); } - for (int ld = 0; ld < ld_block2; ld++) { + for (dim_t ld = 0; ld < ld_block2; ld++) { auto vmm = accm(ld_block2, bd, ld); if (is_emdbd) uni_vfmadd231ps(vmm, load(ld), @@ -2289,21 +3022,118 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, } template -void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, - int ld_block2, int ldb_loop_length, bool is_reg_tail, bool is_ld_tail, - bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail, +void jit_brgemm_kernel_t::ldb_loop(dim_t bd_block2, bool is_bdb_tail, + dim_t ld_block2, dim_t ldb_loop_length, bool is_reg_tail, + bool is_ld_tail, bool first_bdb, bool last_bdb, dim_t rows_for_rd_tail, bool skip_accumulation) { + auto ic_group_shift_generic = [&]() { + if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0)) + || brg.with_src_dyn_quant) { + auto reg_local_ic = reg_aux_D; + auto reg_local_wei_params = reg_bdb_loop; + auto reg_local_ic_group = reg_ldb_loop; + + auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) { + mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); + mov(reg_local_ic_group, group_size); + xor_(rdx, rdx); + idiv(reg_local_ic_group); + imul(reg_local_ic, reg_local_ic, stride); + + mov(reg_local_wei_params, ptr[rsp + src_offs]); + add(reg_local_wei_params, reg_local_ic); + mov(ptr[rsp + dst_offs], reg_local_wei_params); + }; + + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + mov(ptr[rsp + reg_aux2_D_offs_], reg_aux_D); + mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); + mov(ptr[rsp + reg_reg_a_offset_offs_], reg_a_offset); // preserve rdx for idiv + + if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) { + ic_group_shift(reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_, + brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt)); + } + + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { + ic_group_shift(reg_aux_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_, + brg.wei_decomp_zero_points_group_size, brg.wei_decomp_zero_points_stride * types::data_type_size(brg.wei_decomp_zero_points_dt)); + } + + if (brg.with_src_dyn_quant) { + ic_group_shift(reg_aux_src_scales_offs_, reg_aux2_src_scales_offs_, + brg.src_scales_group_size, sizeof(float)); + + if (brg.with_wei_decomp_zero_points) { + ic_group_shift(reg_aux_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_, + brg.src_sum_group_size, sizeof(int32_t)); + } + } + + mov(reg_local_ic, ptr[rsp + reg_aux_ic_offs_]); + add(reg_local_ic, brg.rd_block); + mov(ptr[rsp + reg_aux_ic_offs_], reg_local_ic); + + mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + mov(reg_aux_D, ptr[rsp + reg_aux2_D_offs_]); + mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); + mov(reg_a_offset, ptr[rsp + reg_reg_a_offset_offs_]); + } + }; + + auto ic_group_shift_opt = [&](int rb) { + if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0)) + || brg.with_src_dyn_quant) { + mov(ptr[rsp + reg_bdb_loop_offs_], reg_rdb_loop); + auto reg_ptr = reg_rdb_loop; + + auto ic_group_shift = [&](int src_offs, int dst_offs, int group_size, int stride) { + if ((rb + 1) * brg.rd_block % group_size == 0) { + mov(reg_ptr, ptr[rsp + src_offs]); + add(reg_ptr, stride); + mov(ptr[rsp + dst_offs], reg_ptr); + } + }; + + if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) { + ic_group_shift(reg_aux2_wei_scales_offs_, reg_aux2_wei_scales_offs_, + brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt)); + } + + if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { + ic_group_shift(reg_aux2_wei_zero_points_offs_, reg_aux2_wei_zero_points_offs_, + brg.wei_decomp_zero_points_group_size, brg.wei_decomp_zero_points_stride * types::data_type_size(brg.wei_decomp_zero_points_dt)); + } + + if (brg.with_src_dyn_quant) { + ic_group_shift(reg_aux2_src_scales_offs_, reg_aux2_src_scales_offs_, + brg.src_scales_group_size, sizeof(float)); + + if (brg.with_wei_decomp_zero_points) { + ic_group_shift(reg_aux2_src_grouped_sum_offs_, reg_aux2_src_grouped_sum_offs_, + brg.src_sum_group_size, sizeof(int32_t)); + } + } + + mov(reg_rdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + } + }; Label ldb_loop_label; Label BS_loop_label; copy_post_ops_stack_values_to_aux(is_reg_tail); - auto ld_loop_body = [&](int vpad) { + auto ld_loop_body = [&](dim_t vpad, bool last_bdb) { + if (brg.with_grouped_wei_decomp) { + mov(reg_ic, ptr[rsp + reg_ic_offs_]); + mov(ptr[rsp + reg_aux_ic_offs_], reg_ic); + } + set_A_B_matrices(); - int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; - const auto bd_b = nstl::max(0, vpad); + dim_t bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; + const auto bd_b = nstl::max(dim_t(0), vpad); const auto bd_e = nstl::min(bd_block, bd_block + vpad); const auto is_valid_bd = need_comp_pads && vpad != 0 ? bd_b <= bd_e : bd_b < bd_e; @@ -2311,32 +3141,87 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, if (brg.is_tmm) { const bool is_rd_tail = false; - gemm_microkernel_amx( - bd_block2, is_bdb_tail, ld_block2, is_rd_tail, is_ld_tail); + gemm_microkernel_amx(bd_block2, is_bdb_tail, ld_block2, is_rd_tail, + is_ld_tail, last_bdb); } else { - if (brg.rdb > 0) { - Label rdb_loop_label; - mov(reg_rdb_loop, brg.rdb); - L_aligned(rdb_loop_label, 64); - { - const bool is_rd_tail = false; - gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, - is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); + ic_group_shift_generic(); + + auto rdb_group = brg.rd_block; + auto rd_size = brg.rdb * brg.rd_block + brg.rdb_tail; + if (brg.wei_decomp_scales_group_size < rd_size) + rdb_group = nstl::max(rdb_group, brg.wei_decomp_scales_group_size); + if (brg.wei_decomp_zero_points_group_size < rd_size) + rdb_group = nstl::max(rdb_group, brg.wei_decomp_zero_points_group_size); + if (brg.with_src_dyn_quant) { + rdb_group = nstl::max(rdb_group, brg.src_scales_group_size); + if (brg.with_wei_decomp_zero_points) { + rdb_group = nstl::max(rdb_group, brg.src_sum_group_size); + } + } + rdb_group = rdb_group / brg.rd_block; + auto rbd_blocks = brg.rdb / rdb_group; + auto max_rdb_unroll = 8; + + if (brg.with_wei_decomp && rdb_group <= max_rdb_unroll) { + if (rbd_blocks > 0) { + Label rdb_loop_label; + mov(reg_rdb_loop, rbd_blocks); + L_aligned(rdb_loop_label, 64); + { + for (int rb = 0; rb < rdb_group; rb++) { + gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, false, + is_ld_tail, vpad, rows_for_rd_tail); + + add(reg_aux_A, rdb_A_offset()); + add(reg_aux_B, rdb_B_offset()); + + ic_group_shift_opt(rb); + } + + dec(reg_rdb_loop); + cmp(reg_rdb_loop, 0); + } + jg(rdb_loop_label, T_NEAR); + } + + for (int rb = rbd_blocks * rdb_group; rb < brg.rdb; rb++) { + gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, false, + is_ld_tail, vpad, rows_for_rd_tail); add(reg_aux_A, rdb_A_offset()); add(reg_aux_B, rdb_B_offset()); - dec(reg_rdb_loop); - cmp(reg_rdb_loop, 0); + ic_group_shift_opt(rb); + + mov(reg_rdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + } + } else { + if (brg.rdb > 0) { + Label rdb_loop_label; + mov(reg_rdb_loop, brg.rdb); + L_aligned(rdb_loop_label, 64); + { + const bool is_rd_tail = false; + gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, + is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); + + add(reg_aux_A, rdb_A_offset()); + add(reg_aux_B, rdb_B_offset()); + + ic_group_shift_generic(); + + dec(reg_rdb_loop); + cmp(reg_rdb_loop, 0); + } + jg(rdb_loop_label, T_NEAR); } - jg(rdb_loop_label, T_NEAR); } } if (brg.rdb_tail != 0) { const bool is_rd_tail = true; if (brg.is_tmm) { gemm_microkernel_amx(bd_block2, is_bdb_tail, ld_block2, - is_rd_tail, is_ld_tail); + is_rd_tail, is_ld_tail, last_bdb); } else { gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); @@ -2368,28 +3253,14 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, mov(reg_stride_ldb, brg.rd_step * brg.typesize_B * brg.LDB); } - if (brg.req_s8s8_compensation) { - mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); - mov(reg_s8_input_shift, 128); - uni_vpbroadcastb(vmm_inp_shift(), reg_s8_input_shift.cvt8()); - mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); - } - if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) { - mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); - const auto reg32_scratch = reg_zp_a_input_shift.cvt32(); - mov(reg32_scratch, 0x1010101); - uni_vpbroadcastd(vmm_one_bytes(), reg32_scratch); - mov(reg32_scratch, ptr[rsp + reg_zp_a_val_offs_]); - uni_vpbroadcastd(vmm_zp_a_shift(), reg32_scratch); - mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); - } - if (brg.brgattr.max_bs > 1) mov(reg_BS_loop, reg_BS); L_aligned(BS_loop_label, 64); { - if (check_top_vpad || check_bottom_vpad) { - const auto vpad_first = -brg.brgattr.max_bottom_vpad; - const auto vpad_last = brg.brgattr.max_top_vpad; + if (first_bdb || last_bdb) { + const auto vpad_first + = last_bdb ? (-brg.brgattr.max_bottom_vpad) : 1; + const auto vpad_last + = first_bdb ? brg.brgattr.max_top_vpad : -1; const auto n_vpads = vpad_last - vpad_first + 2; constexpr auto MAX_N_VPADS = 2 * brgemm_desc_t::MAX_VPAD; assert(n_vpads < MAX_N_VPADS); @@ -2414,13 +3285,13 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, } else xor_(reg_aux_A_vpad, reg_aux_A_vpad); - for (int vpad = vpad_first; vpad <= vpad_last; vpad++) { + for (dim_t vpad = vpad_first; vpad <= vpad_last; vpad++) { const auto label_vpad = vpad - vpad_first; L(Vpad_loop_iter_label[label_vpad]); - if (!check_top_vpad && vpad > 0) continue; - if (!check_bottom_vpad && vpad < 0) continue; + if (!first_bdb && vpad > 0) continue; + if (!last_bdb && vpad < 0) continue; auto real_vpad = vpad; - if (check_bottom_vpad && brg.bdb_tail && vpad < 0) { + if (last_bdb && brg.bdb_tail && vpad < 0) { if (!is_bdb_tail) { // for last full block before // bdb_tail && -vpad greater than bdb_tail @@ -2440,14 +3311,14 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, } cmp(reg_aux_A_vpad, vpad); jne(Vpad_loop_iter_label[label_vpad + 1], T_NEAR); - ld_loop_body(real_vpad); + ld_loop_body(real_vpad, last_bdb); jmp(Vpad_loop_end_label, T_NEAR); } L(Vpad_loop_iter_label[n_vpads - 1]); - ld_loop_body(0); + ld_loop_body(0, last_bdb); L(Vpad_loop_end_label); } else { - ld_loop_body(0); + ld_loop_body(0, last_bdb); } if (brg.brgattr.max_bs > 1) { dec(reg_BS_loop); @@ -2483,68 +3354,77 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, template void jit_brgemm_kernel_t::bdb_loop() { - auto do_ldb_loop = [this](int bd_block2, bool is_bdb_tail, - bool check_top_vpad, bool check_bottom_vpad, - int rows_for_rd_tail, bool skip_accumulation) { + auto do_ldb_loop = [this](dim_t bd_block2, bool is_bdb_tail, bool first_bdb, + bool last_bdb, dim_t rows_for_rd_tail, + bool skip_accumulation) { if (brg.ldb2 > 0) { const bool is_ld_reg_tail = false; const bool is_ld_tail = false; ldb_loop(bd_block2, is_bdb_tail, brg.ld_block2, brg.ldb2, - is_ld_reg_tail, is_ld_tail, check_top_vpad, - check_bottom_vpad, rows_for_rd_tail, skip_accumulation); + is_ld_reg_tail, is_ld_tail, first_bdb, last_bdb, + rows_for_rd_tail, skip_accumulation); } if (brg.ldb2_tail > 0) { const bool is_ld_reg_tail = (brg.ldb2 == 0) ? false : true; const bool is_ld_tail = false; ldb_loop(bd_block2, is_bdb_tail, brg.ldb2_tail, 1, is_ld_reg_tail, - is_ld_tail, check_top_vpad, check_bottom_vpad, - rows_for_rd_tail, skip_accumulation); + is_ld_tail, first_bdb, last_bdb, rows_for_rd_tail, + skip_accumulation); } if (brg.ldb_tail > 0) { const bool is_ld_reg_tail = (brg.ldb2 == 0 && brg.ldb2_tail == 0) ? false : true; const bool is_ld_tail = true; ldb_loop(bd_block2, is_bdb_tail, 1, 1, is_ld_reg_tail, is_ld_tail, - check_top_vpad, check_bottom_vpad, rows_for_rd_tail, - skip_accumulation); - } - }; - - auto bdb_loop_body = [this, do_ldb_loop](int bd_block2, bool is_bdb_tail, - bool check_top_vpad, bool check_bottom_vpad, - int rows_for_rd_tail, bool skip_accumulation) { - do_ldb_loop(bd_block2, is_bdb_tail, check_top_vpad, check_bottom_vpad, - rows_for_rd_tail, skip_accumulation); - - if (brg.is_runtime_ldc) { - mov(ptr[rsp + reg_aux_C_bdb_loop_backup_offs_], reg_C); - xor_(reg_C, reg_C); - imul(reg_C, ptr[rsp + reg_C_shift_bytes_offs_], - bdb_C_offset(bd_block2)); - add(reg_C, ptr[rsp + reg_aux_C_bdb_loop_backup_offs_]); - } else { - add(reg_C, bdb_C_offset(bd_block2)); - } - if (brg.is_runtime_ldd) { - mov(ptr[rsp + reg_aux_D_bdb_loop_backup_offs_], reg_D); - xor_(reg_D, reg_D); - imul(reg_D, ptr[rsp + reg_D_shift_bytes_offs_], - bdb_D_offset(bd_block2)); - add(reg_D, ptr[rsp + reg_aux_D_bdb_loop_backup_offs_]); - } else { - add(reg_D, bdb_D_offset(bd_block2)); + first_bdb, last_bdb, rows_for_rd_tail, skip_accumulation); } - add(reg_a_offset, bdb_A_offset(bd_block2)); - - advance_bd_block2_post_op_regs(bd_block2); }; - int rows_for_rd_tail, bd_blocks_for_rd_tail; + auto bdb_loop_body + = [this, do_ldb_loop](dim_t bd_block2, bool is_bdb_tail, + bool first_bdb, bool last_bdb, dim_t rows_for_rd_tail, + bool skip_accumulation) { + do_ldb_loop(bd_block2, is_bdb_tail, first_bdb, last_bdb, + rows_for_rd_tail, skip_accumulation); + + if (brg.is_runtime_ldc) { + mov(ptr[rsp + reg_aux_C_bdb_loop_backup_offs_], reg_C); + xor_(reg_C, reg_C); + imul(reg_C, ptr[rsp + reg_C_shift_bytes_offs_], + bdb_C_offset(bd_block2)); + add(reg_C, ptr[rsp + reg_aux_C_bdb_loop_backup_offs_]); + } else { + add(reg_C, bdb_C_offset(bd_block2)); + } + if (brg.is_runtime_ldd) { + mov(ptr[rsp + reg_aux_D_bdb_loop_backup_offs_], reg_D); + xor_(reg_D, reg_D); + imul(reg_D, ptr[rsp + reg_D_shift_bytes_offs_], + bdb_D_offset(bd_block2)); + add(reg_D, ptr[rsp + reg_aux_D_bdb_loop_backup_offs_]); + } else { + add(reg_D, bdb_D_offset(bd_block2)); + } + add(reg_a_offset, bdb_A_offset(bd_block2)); + + if (brg.with_src_dyn_quant) { + mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]); + add(reg_src_scales, bd_block2 * brg.bd_block * brg.src_scales_stride * sizeof(float)); + mov(ptr[rsp + reg_src_scales_offs_], reg_src_scales); + + mov(reg_src_grouped_sum, ptr[rsp + reg_src_grouped_sum_offs_]); + add(reg_src_grouped_sum, bd_block2 * brg.bd_block * brg.src_grouped_sum_stride * sizeof(int32_t)); + mov(ptr[rsp + reg_src_grouped_sum_offs_], reg_src_grouped_sum); + } + + advance_bd_block2_post_op_regs(bd_block2); + }; + + dim_t rows_for_rd_tail, bd_blocks_for_rd_tail; if (brg.is_tmm) { rows_for_rd_tail = 0; bd_blocks_for_rd_tail = 0; - n_bcast_1_load = false; } else { rows_for_rd_tail = 0; if (brg.rdb_tail != 0 && (brg.is_bf16 || brg.is_int8)) { @@ -2554,23 +3434,10 @@ void jit_brgemm_kernel_t::bdb_loop() { : 0; } bd_blocks_for_rd_tail - = div_up(nstl::max(0, + = div_up(nstl::max(dim_t(0), rows_for_rd_tail - brg.bdb_tail + brg.brgattr.max_bottom_vpad), brg.bd_block); - - auto ld_block2 = (brg.ldb2 > 0) - ? brg.ld_block2 - : ((brg.ldb2_tail > 0) ? brg.ldb2_tail : 1); - const int free_vregs = max_effective_vregs - brg.req_s8s8_compensation; - n_bcast_1_load = brg.is_int8 - && ((brg.bd_block * (ld_block2 + 1) < free_vregs) - && (bd_blocks_for_rd_tail == 0) - && (rows_for_rd_tail == 0)); - if (brg.brgattr.hint_loop_order != brgemm_lo_default) - n_bcast_1_load = (brg.brgattr.hint_loop_order == brgemm_lo_bl_1load) - ? true - : false; } auto bdb_loop_avx512 = [&](bool skip_accumulation) { @@ -2656,26 +3523,60 @@ void jit_brgemm_kernel_t::bdb_loop() { L_aligned(bdb_loop_end_label, 64); }; auto bdb_loop_amx = [&](bool skip_accumulation) { - Label bdb_loop_label; - if (brg.bd_block2 >= 1) { - mov(reg_bdb_loop, brg.bdb2); - mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); - L_aligned(bdb_loop_label, 64); - { - bdb_loop_body(brg.bd_block2, false, false, false, 0, + if (brg.amx_wary_k_tail()) { + Label bdb_loop_label; + auto bdblocks = brg.bdb2; + if (bdblocks > 1) { + mov(reg_bdb_loop, brg.bdb2); + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + L_aligned(bdb_loop_label, 64); + { + bdb_loop_body(brg.bd_block2, false, false, false, 0, + skip_accumulation); + mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + dec(reg_bdb_loop); + cmp(reg_bdb_loop, 1); + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + } + jg(bdb_loop_label, T_NEAR); + bdblocks = 1; + } + if (bdblocks == 1) { + const bool last_bdb = brg.bdb2_tail == 0 && brg.bdb_tail == 0; + bdb_loop_body(brg.bd_block2, false, false, last_bdb, 0, skip_accumulation); - mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); - dec(reg_bdb_loop); - cmp(reg_bdb_loop, 0); + } + + if (brg.bdb2_tail > 0) { + const bool last_bdb = brg.bdb_tail == 0; + bdb_loop_body(brg.bdb2_tail, false, false, last_bdb, 0, + skip_accumulation); + } + if (brg.bdb_tail > 0) + do_ldb_loop(1, true, false, false, 0, skip_accumulation); + + } else { + Label bdb_loop_label; + if (brg.bd_block2 >= 1) { + mov(reg_bdb_loop, brg.bdb2); mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + L_aligned(bdb_loop_label, 64); + { + bdb_loop_body(brg.bd_block2, false, false, false, 0, + skip_accumulation); + mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); + dec(reg_bdb_loop); + cmp(reg_bdb_loop, 0); + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); + } + jg(bdb_loop_label, T_NEAR); } - jg(bdb_loop_label, T_NEAR); + if (brg.bdb2_tail > 0) + bdb_loop_body(brg.bdb2_tail, false, false, false, 0, + skip_accumulation); + if (brg.bdb_tail > 0) + do_ldb_loop(1, true, false, false, 0, skip_accumulation); } - if (brg.bdb2_tail > 0) - bdb_loop_body( - brg.bdb2_tail, false, false, false, 0, skip_accumulation); - if (brg.bdb_tail > 0) - do_ldb_loop(1, true, false, false, 0, skip_accumulation); }; auto bdb_loop_general = [&](bool skip_accumulation) { @@ -2736,7 +3637,26 @@ void jit_brgemm_kernel_t::generate() { if (brg.is_int8 && !brg.has_int8_vnni) { mov(reg_tmp_gpr.cvt16(), 0x1); - vpbroadcastw(int8_ones_words(), reg_tmp_gpr.cvt16()); + + if (is_superset(brg.isa_impl, avx512_core)) + vpbroadcastw(int8_ones_words(), reg_tmp_gpr.cvt16()); + else if (is_superset(brg.isa_impl, avx2)) { + movq(Xmm(int8_ones_words().getIdx()), reg_tmp_gpr); + vpbroadcastw(int8_ones_words(), Xmm(int8_ones_words().getIdx())); + } else + assert(!"unsupported isa"); + } + + if (brg.is_f16_b_non_amx_vnni()) { + mov(reg_tmp_gpr, f16_perm_even_table_); + vmovups(f16_perm_even_vreg(), ptr[reg_tmp_gpr]); + mov(reg_tmp_gpr, f16_perm_odd_table_); + vmovups(f16_perm_odd_vreg(), ptr[reg_tmp_gpr]); + } + + if (brg.is_tmm && brg.amx_wary_k_tail()) { + // save tiles description for later use + brgemm_init_tiles(brg, (char *)(&palette_)); } read_params(); @@ -2748,19 +3668,19 @@ void jit_brgemm_kernel_t::generate() { postamble(); align(32); - const int simd = vreg_traits::vlen / sizeof(float); + const dim_t simd = vreg_traits_t::vlen / sizeof(float); if (!isa_has_masks(brg.isa_impl) && brg.ldb_tail > 0) { L(avx_tail_mask_); - for (int i = 0; i < brg.ldb_tail; ++i) + for (dim_t i = 0; i < brg.ldb_tail; ++i) dd(0xffffffff); - for (int i = brg.ldb_tail; i < simd; ++i) + for (dim_t i = brg.ldb_tail; i < simd; ++i) dd(0); } if (!is_superset(brg.isa_impl, avx512_core) && brg.with_sum && brg.sum_scale != 1.f) { L(sum_zp_scale_data_); - const int scale_int = float2int(brg.sum_scale); - for (int i = 0; i < simd; ++i) + const dim_t scale_int = float2int(brg.sum_scale); + for (dim_t i = 0; i < simd; ++i) dd(scale_int); } @@ -2771,6 +3691,25 @@ void jit_brgemm_kernel_t::generate() { if (brg.with_eltwise) postops_injector_->prepare_table(/* generate = */ true); + + if (brg.is_f16_b_non_amx_vnni()) { + // convert interleaved vnni data with holes to packed. + align(64); + L(f16_perm_even_table_); + for (dim_t i = 0; i < 32; ++i) { + if (i < 16) + dw(uint16_t(2 * i)); + else + dw(uint16_t(0)); + } + align(64); + L(f16_perm_odd_table_); + for (dim_t i = 0; i < 32; ++i) + if (i < 16) + dw(uint16_t(2 * i + 1)); + else + dw(uint16_t(0)); + } } brgemm_attr_t::brgemm_attr_t() @@ -2785,7 +3724,8 @@ brgemm_attr_t::brgemm_attr_t() , hint_innermost_loop(brgemm_ld_loop_innermost) , hint_loop_order(brgemm_kernel_loop_order_t::brgemm_lo_default) , hint_prefetching(brgemm_kernel_prefetching_t::brgemm_prf_default) - , wary_tail_read(true) + , wary_A_k_tail_read(true) + , extendable_k(false) , generate_skip_accumulation(false) , bd_mask_level(0) , use_uker(false) @@ -2814,7 +3754,7 @@ void brgemm_kernel_common_t::operator()( } template -const jit_generator *brgemm_kernel_common_t::get_jit_generator() const { +const jit_generator_t *brgemm_kernel_common_t::get_jit_generator() const { return brgemm_kernel_; } diff --git a/src/cpu/x64/cpu_barrier.cpp b/src/cpu/x64/cpu_barrier.cpp index 24ab6515b02..2ab3bb5c4a5 100644 --- a/src/cpu/x64/cpu_barrier.cpp +++ b/src/cpu/x64/cpu_barrier.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2022 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ namespace x64 { namespace simple_barrier { void generate( - jit_generator &code, Xbyak::Reg64 reg_ctx, Xbyak::Reg64 reg_nthr) { + jit_generator_t &code, Xbyak::Reg64 reg_ctx, Xbyak::Reg64 reg_nthr) { #define BAR_CTR_OFF offsetof(ctx_t, ctr) #define BAR_SENSE_OFF offsetof(ctx_t, sense) using namespace Xbyak; @@ -81,7 +81,7 @@ void generate( } /** jit barrier generator */ -struct jit_t : public jit_generator { +struct jit_t : public jit_generator_t { void generate() override { simple_barrier::generate(*this, abi_param1, abi_param2); @@ -89,7 +89,7 @@ struct jit_t : public jit_generator { } // TODO: Need to check status - jit_t() : jit_generator(jit_name()) { create_kernel(); } + jit_t() : jit_generator_t(jit_name()) { create_kernel(); } DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_t) }; diff --git a/src/cpu/x64/cpu_barrier.hpp b/src/cpu/x64/cpu_barrier.hpp index c76d57911af..f5cd7966ac9 100644 --- a/src/cpu/x64/cpu_barrier.hpp +++ b/src/cpu/x64/cpu_barrier.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2020 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,11 +68,12 @@ void barrier(ctx_t *ctx, int nthr); /** injects actual barrier implementation into another jitted code * @params: - * code -- jit_generator object where the barrier is to be injected + * code -- jit_generator_t object where the barrier is to be injected * reg_ctx -- read-only register with pointer to the barrier context * reg_nnthr -- read-only register with the # of synchronizing threads */ -void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, Xbyak::Reg64 reg_nthr); +void generate( + jit_generator_t &code, Xbyak::Reg64 reg_ctx, Xbyak::Reg64 reg_nthr); } // namespace simple_barrier diff --git a/src/cpu/x64/cpu_isa_traits.cpp b/src/cpu/x64/cpu_isa_traits.cpp index 931f13a8c2b..c9d718e1132 100644 --- a/src/cpu/x64/cpu_isa_traits.cpp +++ b/src/cpu/x64/cpu_isa_traits.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ cpu_isa_t init_max_cpu_isa() { if (!isa_val.empty()) { #define IF_HANDLE_CASE(cpu_isa) \ - if (isa_val.compare(cpu_isa_traits::user_option_env) == 0) \ + if (isa_val.compare(cpu_isa_traits_t::user_option_env) == 0) \ max_cpu_isa_val = cpu_isa #define ELSEIF_HANDLE_CASE(cpu_isa) else IF_HANDLE_CASE(cpu_isa) @@ -206,7 +206,9 @@ status_t set_max_cpu_isa(dnnl_cpu_isa_t isa) { cpu_isa_t isa_to_set = isa_undef; #define HANDLE_CASE(cpu_isa) \ - case cpu_isa_traits::user_option_val: isa_to_set = cpu_isa; break; + case cpu_isa_traits_t::user_option_val: \ + isa_to_set = cpu_isa; \ + break; switch (isa) { HANDLE_CASE(isa_all); HANDLE_CASE(sse41); diff --git a/src/cpu/x64/cpu_isa_traits.hpp b/src/cpu/x64/cpu_isa_traits.hpp index 89233c48d4e..3fda4777f73 100644 --- a/src/cpu/x64/cpu_isa_traits.hpp +++ b/src/cpu/x64/cpu_isa_traits.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,17 +28,25 @@ #include "cpu/platform.hpp" +#if !defined(XBYAK64) #define XBYAK64 +#endif + +#if !defined(XBYAK_NO_OP_NAMES) #define XBYAK_NO_OP_NAMES +#endif + /* in order to make selinux happy memory that would be marked with X-bit should * be obtained with mmap */ +#if !defined(XBYAK_USE_MMAP_ALLOCATOR) #define XBYAK_USE_MMAP_ALLOCATOR +#endif +#ifdef DNNL_XBYAK_NO_EXCEPTION +#if defined(NDEBUG) && !defined(XBYAK_NO_EXCEPTION) #define XBYAK_NO_EXCEPTION -#ifndef NDEBUG -#undef XBYAK_NO_EXCEPTION #endif - +#endif #if defined(_MSC_VER) && !defined(__INTEL_COMPILER) /* turn off `size_t to other-type implicit casting` warning * currently we have a lot of jit-generated instructions that @@ -47,8 +55,8 @@ #pragma warning(disable : 4267) #endif #include "common/compiler_workarounds.hpp" -#include "cpu/x64/xbyak/xbyak.h" -#include "cpu/x64/xbyak/xbyak_util.h" +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" namespace dnnl { namespace impl { @@ -79,6 +87,7 @@ enum cpu_isa_bit_t : unsigned { amx_int8_bit = 1u << 15, amx_bf16_bit = 1u << 16, amx_fp16_bit = 1u << 17, + avx512_vpopcnt_bit = 1u << 18, // Fill in hints from most significant bit to least significant bit prefer_ymm_bit = 1u << (cpu_isa_total_bits - 1), @@ -109,7 +118,11 @@ inline unsigned cvt2mask(dnnl_cpu_isa_hints_t hints) { }; inline bool is_hints_bit_set(cpu_isa_bit_t hint_bit, bool soft) { +#if DNNL_X64 const dnnl_cpu_isa_hints_t hints = get_cpu_isa_hints(soft); +#else + const dnnl_cpu_isa_hints_t hints = dnnl_cpu_isa_no_hints; +#endif const unsigned cur_hints_mask = cpu_isa_hints_utils::cvt2mask(hints); return (cur_hints_mask & hint_bit) == hint_bit; } @@ -136,6 +149,7 @@ enum cpu_isa_t : unsigned { avx512_core_amx = avx10_1_512_amx, avx10_1_512_amx_fp16 = avx10_1_512_amx | amx_fp16, avx512_core_amx_fp16 = avx10_1_512_amx_fp16, + avx512_vpopcnt = avx512_vpopcnt_bit, // NOTES: 1. isa_all by default has no isa specific hints isa_all = ~0u & ~cpu_isa_hints_utils::hints_mask, }; @@ -209,28 +223,28 @@ static inline bool is_superset(cpu_isa_t isa_1, cpu_isa_t isa_2) { } template -struct vreg_traits {}; +struct vreg_traits_t {}; template <> -struct vreg_traits { - typedef Xbyak::Ymm Vmm_lower_t; +struct vreg_traits_t { + using Vmm_lower_t = Xbyak::Ymm; static constexpr size_t vlen = 64; }; template <> -struct vreg_traits { - typedef Xbyak::Xmm Vmm_lower_t; +struct vreg_traits_t { + using Vmm_lower_t = Xbyak::Xmm; static constexpr size_t vlen = 32; }; template <> -struct vreg_traits { - typedef Xbyak::Xmm Vmm_lower_t; +struct vreg_traits_t { + using Vmm_lower_t = Xbyak::Xmm; static constexpr size_t vlen = 16; }; template -struct cpu_isa_traits {}; /* ::vlen -> 32 (for avx2) */ +struct cpu_isa_traits_t {}; /* ::vlen -> 32 (for avx2) */ // pack struct so it can fit into a single 64-byte cache line #pragma pack(push, 1) @@ -244,96 +258,105 @@ struct palette_config_t { #pragma pack(pop) template <> -struct cpu_isa_traits { +struct cpu_isa_traits_t { static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_default; static constexpr const char *user_option_env = "default"; }; template <> -struct cpu_isa_traits { - typedef Xbyak::Xmm Vmm; +struct cpu_isa_traits_t { + using Vmm = Xbyak::Xmm; static constexpr int vlen_shift = 4; - static constexpr int vlen = vreg_traits::vlen; + static constexpr int vlen = vreg_traits_t::vlen; static constexpr int n_vregs = 16; static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_sse41; static constexpr const char *user_option_env = "sse41"; }; template <> -struct cpu_isa_traits { - typedef Xbyak::Ymm Vmm; +struct cpu_isa_traits_t { + using Vmm = Xbyak::Ymm; static constexpr int vlen_shift = 5; - static constexpr int vlen = vreg_traits::vlen; + static constexpr int vlen = vreg_traits_t::vlen; static constexpr int n_vregs = 16; static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx; static constexpr const char *user_option_env = "avx"; }; template <> -struct cpu_isa_traits : public cpu_isa_traits { +struct cpu_isa_traits_t : public cpu_isa_traits_t { static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx2; static constexpr const char *user_option_env = "avx2"; }; template <> -struct cpu_isa_traits : public cpu_isa_traits { +struct cpu_isa_traits_t : public cpu_isa_traits_t { static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx2_vnni; static constexpr const char *user_option_env = "avx2_vnni"; }; template <> -struct cpu_isa_traits : public cpu_isa_traits { +struct cpu_isa_traits_t : public cpu_isa_traits_t { static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx2_vnni_2; static constexpr const char *user_option_env = "avx2_vnni_2"; }; template <> -struct cpu_isa_traits { - typedef Xbyak::Zmm Vmm; +struct cpu_isa_traits_t { + using Vmm = Xbyak::Zmm; static constexpr int vlen_shift = 6; - static constexpr int vlen = vreg_traits::vlen; + static constexpr int vlen = vreg_traits_t::vlen; static constexpr int n_vregs = 32; static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx512_core; static constexpr const char *user_option_env = "avx512_core"; }; template <> -struct cpu_isa_traits : public cpu_isa_traits { +struct cpu_isa_traits_t + : public cpu_isa_traits_t { static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx512_core_vnni; static constexpr const char *user_option_env = "avx512_core_vnni"; }; template <> -struct cpu_isa_traits : public cpu_isa_traits { +struct cpu_isa_traits_t + : public cpu_isa_traits_t { static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx512_core_bf16; static constexpr const char *user_option_env = "avx512_core_bf16"; }; template <> -struct cpu_isa_traits { - typedef Xbyak::Zmm Vmm; - static constexpr int vlen = vreg_traits::vlen; +struct cpu_isa_traits_t { + using Vmm = Xbyak::Zmm; + static constexpr int vlen = vreg_traits_t::vlen; static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx10_1_512_amx; static constexpr const char *user_option_env = "avx10_1_512_amx"; }; template <> -struct cpu_isa_traits : public cpu_isa_traits { +struct cpu_isa_traits_t : public cpu_isa_traits_t { static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx10_1_512; static constexpr const char *user_option_env = "avx10_1_512"; }; template <> -struct cpu_isa_traits { - typedef Xbyak::Zmm Vmm; +struct cpu_isa_traits_t { + using Vmm = Xbyak::Zmm; static constexpr dnnl_cpu_isa_t user_option_val = dnnl_cpu_isa_avx10_1_512_amx_fp16; static constexpr const char *user_option_env = "avx10_1_512_amx_fp16"; }; +template <> +struct cpu_isa_traits_t { + static constexpr dnnl_cpu_isa_t user_option_val + = dnnl_cpu_isa_avx512_vpopcnt; + static constexpr const char *user_option_env = "AVX512_VPOPCNT"; +}; + inline const Xbyak::util::Cpu &cpu() { const static Xbyak::util::Cpu cpu_; return cpu_; @@ -354,11 +377,16 @@ bool DNNL_API is_available(); namespace { -static inline bool mayiuse(const cpu_isa_t cpu_isa, bool soft = false) { +inline bool mayiuse(const cpu_isa_t cpu_isa, bool soft = false) { using namespace Xbyak::util; - - unsigned cpu_isa_mask = x64::get_max_cpu_isa_mask(soft); - unsigned cpu_isa_no_hints = cpu_isa & ~cpu_isa_hints_utils::hints_mask; +#if DNNL_X64 + const unsigned cpu_isa_mask = x64::get_max_cpu_isa_mask(soft); +#elif DNNL_X86 + const unsigned cpu_isa_mask = isa_undef; +#else + const unsigned cpu_isa_mask = isa_all; +#endif + const unsigned cpu_isa_no_hints = cpu_isa & ~cpu_isa_hints_utils::hints_mask; if ((cpu_isa_mask & cpu_isa_no_hints) != cpu_isa_no_hints) return false; @@ -412,29 +440,31 @@ static inline bool mayiuse(const cpu_isa_t cpu_isa, bool soft = false) { case avx512_core_amx_fp16: REG_AMX_ISA(return mayiuse(avx512_core_amx, soft) && mayiuse(amx_fp16, soft)); + case avx512_vpopcnt: + REG_AVX512_ISA(return cpu().has(Cpu::tAVX512_VPOPCNTDQ)); case isa_all: return false; case isa_undef: return true; } return false; } -static inline bool isa_has_int8_vnni(cpu_isa_t isa) { +inline bool isa_has_int8_vnni(cpu_isa_t isa) { return is_superset(isa, avx512_core_vnni) || is_superset(isa, avx2_vnni); } -static inline bool isa_has_s8s8(cpu_isa_t isa) { +inline bool isa_has_s8s8(cpu_isa_t isa) { return is_superset(isa, amx_int8) || is_superset(isa, avx2_vnni_2); } -static inline bool isa_has_bf16(cpu_isa_t isa) { +inline bool isa_has_bf16(cpu_isa_t isa) { return is_superset(isa, avx512_core_bf16); } -static inline bool isa_has_masks(cpu_isa_t isa) { +inline bool isa_has_masks(cpu_isa_t isa) { return is_superset(isa, avx512_core); } -static inline int isa_max_vlen(cpu_isa_t isa) { +inline int isa_max_vlen(cpu_isa_t isa) { const bool is_avx512 = is_superset(isa, avx512_core); const bool is_avx = is_superset(isa, avx); const bool is_sse41 = is_superset(isa, sse41); @@ -443,14 +473,14 @@ static inline int isa_max_vlen(cpu_isa_t isa) { MAYBE_UNUSED(is_sse41); if (is_avx512) - return cpu_isa_traits::vlen; + return cpu_isa_traits_t::vlen; else if (is_avx) - return cpu_isa_traits::vlen; + return cpu_isa_traits_t::vlen; else - return cpu_isa_traits::vlen; + return cpu_isa_traits_t::vlen; } -static inline int isa_num_vregs(cpu_isa_t isa) { +inline int isa_num_vregs(cpu_isa_t isa) { const bool is_avx512 = is_superset(isa, avx512_core); const bool is_avx = is_superset(isa, avx); const bool is_sse41 = is_superset(isa, sse41); @@ -459,11 +489,11 @@ static inline int isa_num_vregs(cpu_isa_t isa) { MAYBE_UNUSED(is_sse41); if (is_avx512) - return cpu_isa_traits::n_vregs; + return cpu_isa_traits_t::n_vregs; else if (is_avx) - return cpu_isa_traits::n_vregs; + return cpu_isa_traits_t::n_vregs; else - return cpu_isa_traits::n_vregs; + return cpu_isa_traits_t::n_vregs; } } // namespace @@ -494,10 +524,12 @@ inline data_type_t get_mac_emu_data_type(const data_type_t data_type, using namespace data_type; if (req_emulation) switch (data_type) { case bf16: - if (isa == avx2_vnni_2) return f32; + if (utils::one_of(isa, avx2, avx2_vnni_2, avx512_core)) + return f32; break; case f16: - if (utils::one_of(isa, avx2_vnni_2, avx512_core_fp16)) + if (utils::one_of(isa, avx2, avx2_vnni_2, avx512_core, + avx512_core_fp16)) return f32; break; case f8_e5m2: @@ -520,7 +552,11 @@ inline size_t data_type_vnni_granularity(const data_type_t data_type) { case f32: case s32: return size_t(1); case f16: - case bf16: return size_t(2); + case bf16: + case s4: + case u4: + case nf4: + case f4_e2m1: return size_t(2); case f8_e5m2: case f8_e4m3: case s8: diff --git a/src/cpu/x64/cpu_reducer.cpp b/src/cpu/x64/cpu_reducer.cpp index a000d8b5fca..3bcb5ed0b5c 100644 --- a/src/cpu/x64/cpu_reducer.cpp +++ b/src/cpu/x64/cpu_reducer.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2023 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -97,12 +97,12 @@ void reduce_balancer_t::balance() { using namespace Xbyak; template -struct reducer_2d_driver_t : public jit_generator { - using data_t = typename prec_traits::type; +struct reducer_2d_driver_t : public jit_generator_t { + using data_t = typename prec_traits_t::type; reducer_2d_driver_t(int n_src, size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst, const char *name) - : jit_generator(name) + : jit_generator_t(name) , n_src_(n_src) , src_ld_(src_ld) , src_step_(src_step) @@ -122,11 +122,11 @@ template struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t { DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t) - using data_t = typename prec_traits::type; + using data_t = typename prec_traits_t::type; void operator()( data_t *dst, const data_t *srcs, size_t ny, size_t nx) override { - jit_generator::operator()(dst, srcs, ny, nx); + jit_generator_t::operator()(dst, srcs, ny, nx); } /* cpu specific part */ @@ -145,9 +145,9 @@ struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t { this->paddd(x1, op); } - const int vlen = cpu_isa_traits::vlen; + const int vlen = cpu_isa_traits_t::vlen; const int typesize - = sizeof(typename dnnl::impl::prec_traits::type); + = sizeof(typename dnnl::impl::prec_traits_t::type); Xbyak::Reg64 reg_dst = abi_param1; Xbyak::Reg64 reg_src = abi_param2; Xbyak::Reg64 reg_ny = abi_param3; @@ -195,17 +195,32 @@ struct reducer_2d_driver_f_s_32_t : public reducer_2d_driver_t { for (int i = 0; i < nloads; ++i) { size_t off = base_off + i * load_len; - if (load_len == typesize) - this->uni_add(Xmm(i), this->ptr[reg_src + off]); - else if (load_len == vlen) - this->uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]); + if (load_len == typesize) { + assert(nloads == 1); + if (off > static_cast(INT_MAX)) { + this->mov(reg_long_offt, off); + this->movd(Xmm(nloads + i), + this->ptr[reg_src + reg_long_offt]); + this->uni_add(Xmm(i), Xmm(nloads + i)); + } else { + this->movd(Xmm(nloads + i), this->ptr[reg_src + off]); + this->uni_add(Xmm(i), Xmm(nloads + i)); + } + } else if (load_len == vlen) + if (off > static_cast(INT_MAX)) { + this->mov(reg_long_offt, off); + this->uni_vadd( + Vmm(i), Vmm(i), vmmword[reg_src + reg_long_offt]); + } else { + this->uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]); + } else assert(!"unsupported"); } } void loop_x() { - const int nloads[] = {cpu_isa_traits::n_vregs, 1, 1}; + const int nloads[] = {cpu_isa_traits_t::n_vregs, 1, 1}; const int nbranches = sizeof(nloads) / sizeof(nloads[0]); const int load_len[nbranches] = {vlen, vlen, typesize}; diff --git a/src/cpu/x64/cpu_reducer.hpp b/src/cpu/x64/cpu_reducer.hpp index 2ecf022b859..d07e7545b7d 100644 --- a/src/cpu/x64/cpu_reducer.hpp +++ b/src/cpu/x64/cpu_reducer.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2020 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -168,7 +168,7 @@ struct reducer_2d_driver_t; */ template struct cpu_reducer_t { - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; struct conf_t { conf_t() = default; @@ -248,7 +248,7 @@ struct cpu_reducer_t { template struct cpu_reducer_2d_t { - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; struct conf_t { conf_t() = default; @@ -333,7 +333,7 @@ struct cpu_reducer_2d_t { /** simple 1d accumulator: y[:] += x[:] */ template struct cpu_accumulator_1d_t { - typedef typename prec_traits::type data_t; + using data_t = typename prec_traits_t::type; cpu_accumulator_1d_t(); ~cpu_accumulator_1d_t(); diff --git a/src/cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.cpp b/src/cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.cpp index 7f9b09824d2..81771dab6a8 100644 --- a/src/cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.cpp +++ b/src/cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ static inline Zmm make_zmm(const Xmm &v) { return Zmm(v.getIdx()); } -void jit_avx512_core_amx_copy_kern::transpose(int s, const Ymm &dst1, +void jit_avx512_core_amx_copy_kern_t::transpose(int s, const Ymm &dst1, const Ymm &dst2, const Ymm &src1, const Ymm &src2) { switch (s) { case 32: @@ -91,8 +91,9 @@ void jit_avx512_core_amx_copy_kern::transpose(int s, const Ymm &dst1, } } -void jit_avx512_core_amx_copy_kern::amxtrans8(const Ymm &dst1, const Ymm &dst2, - const Ymm &src1, const Ymm &src2, const Ymm &src3, const Ymm &src4) { +void jit_avx512_core_amx_copy_kern_t::amxtrans8(const Ymm &dst1, + const Ymm &dst2, const Ymm &src1, const Ymm &src2, const Ymm &src3, + const Ymm &src4) { vpunpcklbw(dst1, src1, src2); vpunpckhbw(dst2, src1, src2); vpunpcklbw(src1, src3, src4); @@ -107,7 +108,7 @@ void jit_avx512_core_amx_copy_kern::amxtrans8(const Ymm &dst1, const Ymm &dst2, vshufi32x4(src4, dst1, dst2, 0x03); } -void jit_avx512_core_amx_copy_kern::amxtrans16( +void jit_avx512_core_amx_copy_kern_t::amxtrans16( const Ymm &dst1, const Ymm &dst2, const Ymm &src1, const Ymm &src2) { vpunpcklwd(dst1, src1, src2); vpunpckhwd(dst2, src1, src2); @@ -117,7 +118,7 @@ void jit_avx512_core_amx_copy_kern::amxtrans16( vshufi32x4(src2, src2, src2, 0xd8); } -void jit_avx512_core_amx_copy_kern::load( +void jit_avx512_core_amx_copy_kern_t::load( const Xmm &dst, const Address &src, bool corner) { if (!corner && isize_ == 1) vmovdqu8(dst, src); @@ -129,14 +130,15 @@ void jit_avx512_core_amx_copy_kern::load( vmovdqu16(dst | k1 | T_z, src); } -void jit_avx512_core_amx_copy_kern::store(const Address &dst, const Xmm &src) { +void jit_avx512_core_amx_copy_kern_t::store( + const Address &dst, const Xmm &src) { if (size_ == 1) vmovdqu8(dst, src); else vmovdqu16(dst, src); } -void jit_avx512_core_amx_copy_kern::kernel_AN( +void jit_avx512_core_amx_copy_kern_t::kernel_AN( int unroll_x, int unroll_y, int step, Reg64 A, Reg64 B, bool corner) { // Transpose data. int u[] = {32, 16, 8, 4}; @@ -170,7 +172,7 @@ void jit_avx512_core_amx_copy_kern::kernel_AN( } } -void jit_avx512_core_amx_copy_kern::kernel_BN( +void jit_avx512_core_amx_copy_kern_t::kernel_BN( int unroll_x, int unroll_y, int step, Reg64 A, Reg64 B, bool corner) { // Store data. for (int i = 0; i < 16; i++) @@ -179,7 +181,7 @@ void jit_avx512_core_amx_copy_kern::kernel_BN( src_[i]); } -void jit_avx512_core_amx_copy_kern::kernel_AT( +void jit_avx512_core_amx_copy_kern_t::kernel_AT( int unroll_x, int unroll_y, int step, Reg64 A, Reg64 B, bool corner) { Ymm v[16]; @@ -258,7 +260,7 @@ void jit_avx512_core_amx_copy_kern::kernel_AT( } } -void jit_avx512_core_amx_copy_kern::kernel_BT( +void jit_avx512_core_amx_copy_kern_t::kernel_BT( int unroll_x, int unroll_y, int step, Reg64 A, Reg64 B, bool corner) { // Transpose data. int u[] = {16, 8, 4, 2, 1}; @@ -297,7 +299,7 @@ void jit_avx512_core_amx_copy_kern::kernel_BT( L(store_end); } -void jit_avx512_core_amx_copy_kern::kernel( +void jit_avx512_core_amx_copy_kern_t::kernel( int unroll_x, int unroll_y, int step, Reg64 A, Reg64 B, bool corner) { // Load matrix. @@ -326,7 +328,7 @@ void jit_avx512_core_amx_copy_kern::kernel( kernel_BT(unroll_x, unroll_y, step, A, B, corner); } -void jit_avx512_core_amx_copy_kern::copy_m(int unroll_m, int unroll_n) { +void jit_avx512_core_amx_copy_kern_t::copy_m(int unroll_m, int unroll_n) { if (is_trans_) { mov(B1_, B_); add(B_, unroll_m * unroll_n * size_); @@ -378,7 +380,7 @@ void jit_avx512_core_amx_copy_kern::copy_m(int unroll_m, int unroll_n) { L_aligned(kernel_tail_end); } -void jit_avx512_core_amx_copy_kern::copy_ns(int unroll_n, Label &epilogue) { +void jit_avx512_core_amx_copy_kern_t::copy_ns(int unroll_n, Label &epilogue) { if (unroll_n > 0) { copy_ns(unroll_n - 1, epilogue); @@ -393,7 +395,7 @@ void jit_avx512_core_amx_copy_kern::copy_ns(int unroll_n, Label &epilogue) { } } -void jit_avx512_core_amx_copy_kern::copy_n(int unroll_n, Label &epilogue) { +void jit_avx512_core_amx_copy_kern_t::copy_n(int unroll_n, Label &epilogue) { Label copy_m_loop, copy_m_end; @@ -422,7 +424,7 @@ void jit_avx512_core_amx_copy_kern::copy_n(int unroll_n, Label &epilogue) { copy_ns(unroll_n - 1, epilogue); } -void jit_avx512_core_amx_copy_kern::generate() { +void jit_avx512_core_amx_copy_kern_t::generate() { // Prologue preamble(); sub(rsp, stack_alloc_size_); @@ -494,9 +496,9 @@ void jit_avx512_core_amx_copy_kern::generate() { postamble(); } -jit_avx512_core_amx_copy_kern::jit_avx512_core_amx_copy_kern( +jit_avx512_core_amx_copy_kern_t::jit_avx512_core_amx_copy_kern_t( bool is_a, bool is_trans, int isize) - : jit_generator(jit_name()) + : jit_generator_t(jit_name()) , is_a_(is_a) , is_trans_(is_trans) , size_(isize) diff --git a/src/cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.hpp b/src/cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.hpp index 76d830f9750..db74267baef 100644 --- a/src/cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.hpp +++ b/src/cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2021 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,10 @@ namespace impl { namespace cpu { namespace x64 { -class jit_avx512_core_amx_copy_kern : public jit_generator { +class jit_avx512_core_amx_copy_kern_t : public jit_generator_t { public: - jit_avx512_core_amx_copy_kern(bool is_a, bool is_trans, int isize); - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_copy_kern); + jit_avx512_core_amx_copy_kern_t(bool is_a, bool is_trans, int isize); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_copy_kern_t); protected: bool is_a_; @@ -66,7 +66,7 @@ class jit_avx512_core_amx_copy_kern : public jit_generator { void copy_n(int unroll_n, Xbyak::Label &epilogue); void copy_ns(int unroll_n, Xbyak::Label &epilogue); - void generate() override ATTRIBUTE_OPTIMIZE; + void generate() override; private: static const int offset_a_ = 0, offset_b_ = 0; diff --git a/src/cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.cpp b/src/cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.cpp index f9005d6ea6e..c92560cb70c 100644 --- a/src/cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.cpp +++ b/src/cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,7 +59,7 @@ namespace x64 { #define TILED(X) dword[rsp + ((X) + 0xc0)] #define TILEQ(X) qword[rsp + ((X) + 0xc0)] -void jit_avx512_core_amx_gemm_kern::generate() { +void jit_avx512_core_amx_gemm_kern_t::generate() { int kerneltype = ((typea << 1) | typeb); @@ -455,9 +455,9 @@ void jit_avx512_core_amx_gemm_kern::generate() { ret(); } -jit_avx512_core_amx_gemm_kern::jit_avx512_core_amx_gemm_kern( +jit_avx512_core_amx_gemm_kern_t::jit_avx512_core_amx_gemm_kern_t( int typea, int typeb, int typec, int betaZero) - : jit_generator(jit_name(), avx512_core_amx) + : jit_generator_t(jit_name(), avx512_core_amx) , typea(typea) , typeb(typeb) , typec(typec) diff --git a/src/cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.hpp b/src/cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.hpp index 08987d8afc1..fab208e61cf 100644 --- a/src/cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.hpp +++ b/src/cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2021 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,14 +24,14 @@ namespace impl { namespace cpu { namespace x64 { -class jit_avx512_core_amx_gemm_kern : public jit_generator { +class jit_avx512_core_amx_gemm_kern_t : public jit_generator_t { public: - jit_avx512_core_amx_gemm_kern( + jit_avx512_core_amx_gemm_kern_t( int typea, int typeb, int typec, int betaZero); - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_gemm_kern); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_gemm_kern_t); protected: - void generate() override ATTRIBUTE_OPTIMIZE; + void generate() override; const int typea; const int typeb; const int typec; diff --git a/src/cpu/x64/gemm/bf16/common_s16.hpp b/src/cpu/x64/gemm/bf16/common_s16.hpp index 28eed475e01..c61e44190eb 100644 --- a/src/cpu/x64/gemm/bf16/common_s16.hpp +++ b/src/cpu/x64/gemm/bf16/common_s16.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,68 +24,68 @@ namespace impl { namespace cpu { namespace x64 { -class jit_avx512_core_s16_48x8_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_48x8_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_s16_48x8_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_48x8_copy_an_kern_t); + void generate() override; public: - jit_avx512_core_s16_48x8_copy_an_kern(); + jit_avx512_core_s16_48x8_copy_an_kern_t(); }; -class jit_avx512_core_s16_48x8_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_48x8_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_s16_48x8_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_48x8_copy_at_kern_t); + void generate() override; public: - jit_avx512_core_s16_48x8_copy_at_kern(); + jit_avx512_core_s16_48x8_copy_at_kern_t(); }; -class jit_avx512_core_s16_48x8_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_48x8_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_s16_48x8_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_48x8_copy_bn_kern_t); + void generate() override; public: - jit_avx512_core_s16_48x8_copy_bn_kern(); + jit_avx512_core_s16_48x8_copy_bn_kern_t(); }; -class jit_avx512_core_s16_48x8_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_48x8_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_s16_48x8_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_48x8_copy_bt_kern_t); + void generate() override; public: - jit_avx512_core_s16_48x8_copy_bt_kern(); + jit_avx512_core_s16_48x8_copy_bt_kern_t(); }; -class jit_avx512_core_s16_24x8_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_24x8_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_s16_24x8_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_24x8_copy_an_kern_t); + void generate() override; public: - jit_avx512_core_s16_24x8_copy_an_kern(); + jit_avx512_core_s16_24x8_copy_an_kern_t(); }; -class jit_avx512_core_s16_24x8_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_24x8_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_s16_24x8_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_24x8_copy_at_kern_t); + void generate() override; public: - jit_avx512_core_s16_24x8_copy_at_kern(); + jit_avx512_core_s16_24x8_copy_at_kern_t(); }; -class jit_avx512_core_s16_24x8_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_24x8_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_s16_24x8_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_24x8_copy_bn_kern_t); + void generate() override; public: - jit_avx512_core_s16_24x8_copy_bn_kern(); + jit_avx512_core_s16_24x8_copy_bn_kern_t(); }; -class jit_avx512_core_s16_24x8_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_24x8_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_s16_24x8_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_s16_24x8_copy_bt_kern_t); + void generate() override; public: - jit_avx512_core_s16_24x8_copy_bt_kern(); + jit_avx512_core_s16_24x8_copy_bt_kern_t(); }; } // namespace x64 diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.cpp index 124f6c441b2..17f1a27c19d 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ static inline Zmm make_zmm(const Xmm &v) { } // Load from or store to C. -void jit_avx512_core_gemm_bf16bf16f32_kern::c_load( +void jit_avx512_core_gemm_bf16bf16f32_kern_t::c_load( const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems) { switch (nelems) { case 1: vmovss(make_xmm(dst), src); break; @@ -60,7 +60,7 @@ void jit_avx512_core_gemm_bf16bf16f32_kern::c_load( } } -void jit_avx512_core_gemm_bf16bf16f32_kern::c_store( +void jit_avx512_core_gemm_bf16bf16f32_kern_t::c_store( const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems) { switch (nelems) { case 1: vmovss(dst, make_xmm(src)); break; @@ -76,7 +76,7 @@ void jit_avx512_core_gemm_bf16bf16f32_kern::c_store( // Perform length-2 dot product accumulations of bfloat16 in parallel. // Use vdpbf16ps if available, otherwise emulate. -void jit_avx512_core_gemm_bf16bf16f32_kern::dot_product( +void jit_avx512_core_gemm_bf16bf16f32_kern_t::dot_product( const Xmm &dst, const Xmm &src1, const Xmm &src2) { if (bfloat16_) vdpbf16ps(dst, src1, src2); @@ -85,7 +85,7 @@ void jit_avx512_core_gemm_bf16bf16f32_kern::dot_product( } // Inner kernel. -void jit_avx512_core_gemm_bf16bf16f32_kern::kernel_loop( +void jit_avx512_core_gemm_bf16bf16f32_kern_t::kernel_loop( int unroll_m, int unroll_n, bool cfetch) { int um_vecs = utils::div_up(unroll_m, c_nelems_); Label label_kernel_loop; @@ -147,7 +147,7 @@ void jit_avx512_core_gemm_bf16bf16f32_kern::kernel_loop( } // k remainder loop for kernel. -void jit_avx512_core_gemm_bf16bf16f32_kern::remainder_kernel( +void jit_avx512_core_gemm_bf16bf16f32_kern_t::remainder_kernel( int unroll_m, int unroll_n, int unroll_k, int bwidth) { int um_vecs = utils::div_up(unroll_m, c_nelems_); @@ -181,7 +181,7 @@ void jit_avx512_core_gemm_bf16bf16f32_kern::remainder_kernel( } // Inner loop. -void jit_avx512_core_gemm_bf16bf16f32_kern::innerloop( +void jit_avx512_core_gemm_bf16bf16f32_kern_t::innerloop( int unroll_m, int unroll_n) { int um_vecs = utils::div_up(unroll_m, c_nelems_); int stage1 = unroll_n, stage2 = unroll_n; @@ -311,7 +311,7 @@ void jit_avx512_core_gemm_bf16bf16f32_kern::innerloop( } // Outer loop. -void jit_avx512_core_gemm_bf16bf16f32_kern::outerloop( +void jit_avx512_core_gemm_bf16bf16f32_kern_t::outerloop( int unroll_x, int unroll_y, Label *&cur_outerloop_label) { Label label_m_loop, label_n_loop, label_n_remainder_loops[6]; @@ -375,7 +375,7 @@ void jit_avx512_core_gemm_bf16bf16f32_kern::outerloop( align(16); } -void jit_avx512_core_gemm_bf16bf16f32_kern::generate() { +void jit_avx512_core_gemm_bf16bf16f32_kern_t::generate() { // Prologue preamble(); sub(rsp, stack_alloc_size_); @@ -423,9 +423,10 @@ void jit_avx512_core_gemm_bf16bf16f32_kern::generate() { postamble(); } -jit_avx512_core_gemm_bf16bf16f32_kern::jit_avx512_core_gemm_bf16bf16f32_kern( - bool beta_zero, bool alpha_one, bool use_zmm) - : jit_generator(jit_name()) +jit_avx512_core_gemm_bf16bf16f32_kern_t:: + jit_avx512_core_gemm_bf16bf16f32_kern_t( + bool beta_zero, bool alpha_one, bool use_zmm) + : jit_generator_t(jit_name()) , beta_zero_(beta_zero) , alpha_one_(alpha_one) , bfloat16_(mayiuse(avx512_core_bf16)) @@ -502,17 +503,14 @@ jit_avx512_core_gemm_bf16bf16f32_kern::jit_avx512_core_gemm_bf16bf16f32_kern( zmm_tmp0_ = zmm6; zmm_tmp1_ = zmm3; - bf16_emu_ = nullptr; if (!bfloat16_ && use_zmm) - bf16_emu_ = new bf16_emulation_t( + bf16_emu_ = utils::make_unique( this, one_, even_, selector_, scratch_, zmm_tmp0_, zmm_tmp1_); } -jit_avx512_core_gemm_bf16bf16f32_kern:: - ~jit_avx512_core_gemm_bf16bf16f32_kern() { - delete bf16_emu_; -} - +jit_avx512_core_gemm_bf16bf16f32_kern_t:: + ~jit_avx512_core_gemm_bf16bf16f32_kern_t() + = default; } // namespace x64 } // namespace cpu } // namespace impl diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.hpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.hpp index 5362409a44d..bc176fa9467 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.hpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2021 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,12 +25,12 @@ namespace impl { namespace cpu { namespace x64 { -class jit_avx512_core_gemm_bf16bf16f32_kern : public jit_generator { +class jit_avx512_core_gemm_bf16bf16f32_kern_t : public jit_generator_t { public: - jit_avx512_core_gemm_bf16bf16f32_kern( + jit_avx512_core_gemm_bf16bf16f32_kern_t( bool beta_zero, bool alpha_one, bool use_zmm); - ~jit_avx512_core_gemm_bf16bf16f32_kern(); - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_bf16bf16f32_kern); + ~jit_avx512_core_gemm_bf16bf16f32_kern_t() override; + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_bf16bf16f32_kern_t); protected: bool beta_zero_; @@ -58,7 +58,7 @@ class jit_avx512_core_gemm_bf16bf16f32_kern : public jit_generator { void innerloop(int unroll_m, int unroll_n); void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label); - void generate() override ATTRIBUTE_OPTIMIZE; + void generate() override; private: static const int UNROLL_N_ = 8; @@ -90,13 +90,15 @@ class jit_avx512_core_gemm_bf16bf16f32_kern : public jit_generator { arg_coffset_r_; // For bfloat16 emulation on avx512 and avx512_vnni ISAs - bf16_emulation_t *bf16_emu_; + std::unique_ptr bf16_emu_; Xbyak::Reg64 scratch_; Xbyak::Zmm one_; Xbyak::Zmm even_; Xbyak::Zmm selector_; Xbyak::Zmm zmm_tmp0_; Xbyak::Zmm zmm_tmp1_; + + DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_core_gemm_bf16bf16f32_kern_t); }; } // namespace x64 diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.cpp index 4d77805f9ff..42b0430e9a1 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ static inline Zmm make_zmm(const Xmm &v) { // Perform length-2 dot product accumulations of bfloat16 in parallel. // Use vdpbf16ps if available, otherwise emulate. -void jit_avx512_core_gemv_bf16bf16f32_kern::dot_product( +void jit_avx512_core_gemv_bf16bf16f32_kern_t::dot_product( const Xmm &dst, const Xmm &src1, const Xmm &src2) { if (bfloat16_) vdpbf16ps(dst, src1, src2); @@ -58,7 +58,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::dot_product( } // Vector load for 16-bit values. -void jit_avx512_core_gemv_bf16bf16f32_kern::v_load( +void jit_avx512_core_gemv_bf16bf16f32_kern_t::v_load( const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems) { if (nelems >= 32) vmovdqu16(dst, src); @@ -82,7 +82,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::v_load( vmovdqu16(make_xmm(dst) | k1 | T_z, src); } -void jit_avx512_core_gemv_bf16bf16f32_kern::y_load( +void jit_avx512_core_gemv_bf16bf16f32_kern_t::y_load( const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems) { if (nelems >= 16) vmovups(dst, src); @@ -102,7 +102,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::y_load( vmovss(make_xmm(dst), src); } -void jit_avx512_core_gemv_bf16bf16f32_kern::y_store( +void jit_avx512_core_gemv_bf16bf16f32_kern_t::y_store( const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems) { if (nelems >= 16) vmovups(dst, src); @@ -122,7 +122,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::y_store( vmovss(dst, make_xmm(src)); } -void jit_avx512_core_gemv_bf16bf16f32_kern::kernel_loop_n( +void jit_avx512_core_gemv_bf16bf16f32_kern_t::kernel_loop_n( int unroll_m, int unroll_n, bool fetch, bool last) { int zmm_vecs = utils::div_up(unroll_m, 32); @@ -203,7 +203,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::kernel_loop_n( } // Inner loop for A non-transposed. -void jit_avx512_core_gemv_bf16bf16f32_kern::innerloop_n(int unroll_n) { +void jit_avx512_core_gemv_bf16bf16f32_kern_t::innerloop_n(int unroll_n) { mov(A1_, A_); if (unroll_n > 4) { lea(A2_, ptr[A1_ + LDA_ * 4]); @@ -283,7 +283,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::innerloop_n(int unroll_n) { L_aligned(label_m_tail_end); } -void jit_avx512_core_gemv_bf16bf16f32_kern::kernel_loop_t( +void jit_avx512_core_gemv_bf16bf16f32_kern_t::kernel_loop_t( int unroll_m, int unroll_n, bool fetch, bool last) { // Load x. @@ -312,7 +312,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::kernel_loop_t( } // Inner loop for A transposed. -void jit_avx512_core_gemv_bf16bf16f32_kern::innerloop_t(int unroll_n) { +void jit_avx512_core_gemv_bf16bf16f32_kern_t::innerloop_t(int unroll_n) { mov(A1_, A_); if (unroll_n > 4) { lea(A2_, ptr[A1_ + LDA_ * 4]); @@ -431,7 +431,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::innerloop_t(int unroll_n) { } // Outer loop. -void jit_avx512_core_gemv_bf16bf16f32_kern::outerloop(int unroll_y, +void jit_avx512_core_gemv_bf16bf16f32_kern_t::outerloop(int unroll_y, Label *&cur_outerloop_label, Label *&outerloop_end_label) { bool is_tail = unroll_y < UNROLL_N_; @@ -464,7 +464,7 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::outerloop(int unroll_y, } } -void jit_avx512_core_gemv_bf16bf16f32_kern::generate() { +void jit_avx512_core_gemv_bf16bf16f32_kern_t::generate() { // Prologue preamble(); @@ -513,9 +513,9 @@ void jit_avx512_core_gemv_bf16bf16f32_kern::generate() { } // Function signature: gemv(*m, *n, *alpha, *a, *lda, *x, *incx, *y, *incy) -jit_avx512_core_gemv_bf16bf16f32_kern::jit_avx512_core_gemv_bf16bf16f32_kern( - bool trans) - : jit_generator(jit_name()) +jit_avx512_core_gemv_bf16bf16f32_kern_t:: + jit_avx512_core_gemv_bf16bf16f32_kern_t(bool trans) + : jit_generator_t(jit_name()) , trans_(trans) , bfloat16_(mayiuse(avx512_core_bf16)) , arg_lda_(0) @@ -605,8 +605,8 @@ jit_avx512_core_gemv_bf16bf16f32_kern::jit_avx512_core_gemv_bf16bf16f32_kern( this, one_, even_, selector_, gpr_, zmm_tmp0_, zmm_tmp1_); } -jit_avx512_core_gemv_bf16bf16f32_kern:: - ~jit_avx512_core_gemv_bf16bf16f32_kern() { +jit_avx512_core_gemv_bf16bf16f32_kern_t:: + ~jit_avx512_core_gemv_bf16bf16f32_kern_t() { delete bf16_emu_; } diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.hpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.hpp index c7418ce7642..c108d6afc83 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.hpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2021 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,11 +25,11 @@ namespace impl { namespace cpu { namespace x64 { -class jit_avx512_core_gemv_bf16bf16f32_kern : public jit_generator { +class jit_avx512_core_gemv_bf16bf16f32_kern_t : public jit_generator_t { public: - jit_avx512_core_gemv_bf16bf16f32_kern(bool trans); - ~jit_avx512_core_gemv_bf16bf16f32_kern(); - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_bf16bf16f32_kern); + jit_avx512_core_gemv_bf16bf16f32_kern_t(bool trans); + ~jit_avx512_core_gemv_bf16bf16f32_kern_t() override; + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_bf16bf16f32_kern_t); protected: bool trans_; @@ -52,7 +52,7 @@ class jit_avx512_core_gemv_bf16bf16f32_kern : public jit_generator { void outerloop(int unroll_y, Xbyak::Label *&cur_outerloop_label, Xbyak::Label *&outerloop_end_label); - void generate() override ATTRIBUTE_OPTIMIZE; + void generate() override; private: static const int UNROLL_M_ = 64; diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_an_kern_autogen.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_an_kern_autogen.cpp index 22f089dc8b0..491a2a51c52 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_an_kern_autogen.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_an_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,11 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_s16_24x8_copy_an_kern::jit_avx512_core_s16_24x8_copy_an_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_s16_24x8_copy_an_kern_t:: + jit_avx512_core_s16_24x8_copy_an_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_s16_24x8_copy_an_kern::generate() { +void jit_avx512_core_s16_24x8_copy_an_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_at_kern_autogen.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_at_kern_autogen.cpp index 9a6032745f7..69f0d00e129 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_at_kern_autogen.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_at_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,11 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_s16_24x8_copy_at_kern::jit_avx512_core_s16_24x8_copy_at_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_s16_24x8_copy_at_kern_t:: + jit_avx512_core_s16_24x8_copy_at_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_s16_24x8_copy_at_kern::generate() { +void jit_avx512_core_s16_24x8_copy_at_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_bn_kern_autogen.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_bn_kern_autogen.cpp index be61df11e29..01db091bf68 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_bn_kern_autogen.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_bn_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,11 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_s16_24x8_copy_bn_kern::jit_avx512_core_s16_24x8_copy_bn_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_s16_24x8_copy_bn_kern_t:: + jit_avx512_core_s16_24x8_copy_bn_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_s16_24x8_copy_bn_kern::generate() { +void jit_avx512_core_s16_24x8_copy_bn_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_bt_kern_autogen.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_bt_kern_autogen.cpp index cd62ed88dbd..5164dff7cb8 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_bt_kern_autogen.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_24x8_copy_bt_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,11 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_s16_24x8_copy_bt_kern::jit_avx512_core_s16_24x8_copy_bt_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_s16_24x8_copy_bt_kern_t:: + jit_avx512_core_s16_24x8_copy_bt_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_s16_24x8_copy_bt_kern::generate() { +void jit_avx512_core_s16_24x8_copy_bt_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_an_kern_autogen.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_an_kern_autogen.cpp index 3a936e6a280..c6d3c901c04 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_an_kern_autogen.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_an_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,11 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_s16_48x8_copy_an_kern::jit_avx512_core_s16_48x8_copy_an_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_s16_48x8_copy_an_kern_t:: + jit_avx512_core_s16_48x8_copy_an_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_s16_48x8_copy_an_kern::generate() { +void jit_avx512_core_s16_48x8_copy_an_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_at_kern_autogen.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_at_kern_autogen.cpp index ced7abdd837..815d72b437b 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_at_kern_autogen.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_at_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,11 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_s16_48x8_copy_at_kern::jit_avx512_core_s16_48x8_copy_at_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_s16_48x8_copy_at_kern_t:: + jit_avx512_core_s16_48x8_copy_at_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_s16_48x8_copy_at_kern::generate() { +void jit_avx512_core_s16_48x8_copy_at_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_bn_kern_autogen.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_bn_kern_autogen.cpp index 196039ad816..da6d516438d 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_bn_kern_autogen.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_bn_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,11 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_s16_48x8_copy_bn_kern::jit_avx512_core_s16_48x8_copy_bn_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_s16_48x8_copy_bn_kern_t:: + jit_avx512_core_s16_48x8_copy_bn_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_s16_48x8_copy_bn_kern::generate() { +void jit_avx512_core_s16_48x8_copy_bn_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_bt_kern_autogen.cpp b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_bt_kern_autogen.cpp index d448a2e121a..2f5918a5748 100644 --- a/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_bt_kern_autogen.cpp +++ b/src/cpu/x64/gemm/bf16/jit_avx512_core_s16_48x8_copy_bt_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,11 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_s16_48x8_copy_bt_kern::jit_avx512_core_s16_48x8_copy_bt_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_s16_48x8_copy_bt_kern_t:: + jit_avx512_core_s16_48x8_copy_bt_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_s16_48x8_copy_bt_kern::generate() { +void jit_avx512_core_s16_48x8_copy_bt_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/common_f32.hpp b/src/cpu/x64/gemm/f32/common_f32.hpp index 953aa9481e1..ed632c06c06 100644 --- a/src/cpu/x64/gemm/f32/common_f32.hpp +++ b/src/cpu/x64/gemm/f32/common_f32.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,176 +24,173 @@ namespace impl { namespace cpu { namespace x64 { -class jit_avx512_core_f32_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_f32_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_copy_an_kern_t); + void generate() override; public: - jit_avx512_core_f32_copy_an_kern(); + jit_avx512_core_f32_copy_an_kern_t(); }; -class jit_avx512_core_f32_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_f32_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_copy_at_kern_t); + void generate() override; void generate_part1(const Xbyak::Label &, const Xbyak::Label &, - const Xbyak::Label &, const Xbyak::Label &) ATTRIBUTE_OPTIMIZE; - void generate_part2(Xbyak::Label, Xbyak::Label, Xbyak::Label, - Xbyak::Label) ATTRIBUTE_OPTIMIZE; + const Xbyak::Label &, const Xbyak::Label &); + void generate_part2(Xbyak::Label, Xbyak::Label, Xbyak::Label, Xbyak::Label); public: - jit_avx512_core_f32_copy_at_kern(); + jit_avx512_core_f32_copy_at_kern_t(); }; -class jit_avx512_core_f32_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_f32_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_copy_bn_kern_t); + void generate() override; public: - jit_avx512_core_f32_copy_bn_kern(); + jit_avx512_core_f32_copy_bn_kern_t(); }; -class jit_avx512_core_f32_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_f32_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_copy_bt_kern_t); + void generate() override; public: - jit_avx512_core_f32_copy_bt_kern(); + jit_avx512_core_f32_copy_bt_kern_t(); }; -class jit_avx2_f32_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_f32_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_f32_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_f32_copy_an_kern_t); + void generate() override; public: - jit_avx2_f32_copy_an_kern(); + jit_avx2_f32_copy_an_kern_t(); }; -class jit_avx2_f32_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_f32_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_f32_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_f32_copy_at_kern_t); + void generate() override; public: - jit_avx2_f32_copy_at_kern(); + jit_avx2_f32_copy_at_kern_t(); }; -class jit_avx2_f32_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_f32_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_f32_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_f32_copy_bn_kern_t); + void generate() override; public: - jit_avx2_f32_copy_bn_kern(); + jit_avx2_f32_copy_bn_kern_t(); }; -class jit_avx2_f32_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_f32_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_f32_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_f32_copy_bt_kern_t); + void generate() override; public: - jit_avx2_f32_copy_bt_kern(); + jit_avx2_f32_copy_bt_kern_t(); }; -class jit_avx_f32_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_f32_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_f32_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_f32_copy_an_kern_t); + void generate() override; public: - jit_avx_f32_copy_an_kern(); + jit_avx_f32_copy_an_kern_t(); }; -class jit_avx_f32_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_f32_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_f32_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_f32_copy_at_kern_t); + void generate() override; public: - jit_avx_f32_copy_at_kern(); + jit_avx_f32_copy_at_kern_t(); }; -class jit_avx_f32_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_f32_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_f32_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_f32_copy_bn_kern_t); + void generate() override; public: - jit_avx_f32_copy_bn_kern(); + jit_avx_f32_copy_bn_kern_t(); }; -class jit_avx_f32_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_f32_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_f32_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_f32_copy_bt_kern_t); + void generate() override; public: - jit_avx_f32_copy_bt_kern(); + jit_avx_f32_copy_bt_kern_t(); }; -class jit_avx_kernel_b0_sgemm_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_sgemm_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_b0_sgemm_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_sgemm_kern_t); + void generate() override; void generate_part1(const Xbyak::Label &, const Xbyak::Label &, - const Xbyak::Label &, const Xbyak::Label &) ATTRIBUTE_OPTIMIZE; - void generate_part2(Xbyak::Label, Xbyak::Label, Xbyak::Label, - Xbyak::Label) ATTRIBUTE_OPTIMIZE; + const Xbyak::Label &, const Xbyak::Label &); + void generate_part2(Xbyak::Label, Xbyak::Label, Xbyak::Label, Xbyak::Label); public: - jit_avx_kernel_b0_sgemm_kern(); + jit_avx_kernel_b0_sgemm_kern_t(); }; -class jit_avx_kernel_sgemm_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_sgemm_kern); - void generate() override ATTRIBUTE_OPTIMIZE; - void generate_part1(const Xbyak::Label &, const Xbyak::Label &, - const Xbyak::Label &) ATTRIBUTE_OPTIMIZE; - void generate_part2( - Xbyak::Label &, Xbyak::Label &, Xbyak::Label &) ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_sgemm_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_sgemm_kern_t); + void generate() override; + void generate_part1( + const Xbyak::Label &, const Xbyak::Label &, const Xbyak::Label &); + void generate_part2(Xbyak::Label &, Xbyak::Label &, Xbyak::Label &); public: - jit_avx_kernel_sgemm_kern(); + jit_avx_kernel_sgemm_kern_t(); }; -class jit_sse41_f32_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_f32_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_f32_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_f32_copy_an_kern_t); + void generate() override; public: - jit_sse41_f32_copy_an_kern(); + jit_sse41_f32_copy_an_kern_t(); }; -class jit_sse41_f32_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_f32_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_f32_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_f32_copy_at_kern_t); + void generate() override; public: - jit_sse41_f32_copy_at_kern(); + jit_sse41_f32_copy_at_kern_t(); }; -class jit_sse41_f32_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_f32_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_f32_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_f32_copy_bn_kern_t); + void generate() override; public: - jit_sse41_f32_copy_bn_kern(); + jit_sse41_f32_copy_bn_kern_t(); }; -class jit_sse41_f32_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_f32_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_f32_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_f32_copy_bt_kern_t); + void generate() override; public: - jit_sse41_f32_copy_bt_kern(); + jit_sse41_f32_copy_bt_kern_t(); }; -class jit_sse41_kernel_b0_sgemm_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_sgemm_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_b0_sgemm_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_sgemm_kern_t); + void generate() override; public: - jit_sse41_kernel_b0_sgemm_kern(); + jit_sse41_kernel_b0_sgemm_kern_t(); }; -class jit_sse41_kernel_sgemm_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_sgemm_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_sgemm_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_sgemm_kern_t); + void generate() override; public: - jit_sse41_kernel_sgemm_kern(); + jit_sse41_kernel_sgemm_kern_t(); }; } // namespace x64 diff --git a/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_an_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_an_kern_autogen.cpp index 3b14fe68440..ba136908bfa 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_an_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_an_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx2_f32_copy_an_kern::jit_avx2_f32_copy_an_kern() - : jit_generator(jit_name()) {} +jit_avx2_f32_copy_an_kern_t::jit_avx2_f32_copy_an_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx2_f32_copy_an_kern::generate() { +void jit_avx2_f32_copy_an_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_at_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_at_kern_autogen.cpp index 8f9205dfca5..daa3ece4b9c 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_at_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_at_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx2_f32_copy_at_kern::jit_avx2_f32_copy_at_kern() - : jit_generator(jit_name()) {} +jit_avx2_f32_copy_at_kern_t::jit_avx2_f32_copy_at_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx2_f32_copy_at_kern::generate() { +void jit_avx2_f32_copy_at_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_bn_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_bn_kern_autogen.cpp index 1b086a5e4de..f3e17a76a87 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_bn_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_bn_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx2_f32_copy_bn_kern::jit_avx2_f32_copy_bn_kern() - : jit_generator(jit_name()) {} +jit_avx2_f32_copy_bn_kern_t::jit_avx2_f32_copy_bn_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx2_f32_copy_bn_kern::generate() { +void jit_avx2_f32_copy_bn_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_bt_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_bt_kern_autogen.cpp index 9fd7218234b..461d24d51e4 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_bt_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_f32_copy_bt_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx2_f32_copy_bt_kern::jit_avx2_f32_copy_bt_kern() - : jit_generator(jit_name()) {} +jit_avx2_f32_copy_bt_kern_t::jit_avx2_f32_copy_bt_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx2_f32_copy_bt_kern::generate() { +void jit_avx2_f32_copy_bt_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp index d0fb52fa6c3..0a8dd0ddbaf 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,14 +27,14 @@ namespace impl { namespace cpu { namespace x64 { -int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const { +int jit_avx2_kernel_sgemm_kern_t::next_acc(int idx, int um, int un) const { while (!(((idx / unroll_n_) < std::max(1, um / nelt_per_vecreg_)) || ((idx % unroll_n_) < un))) idx++; return idx; } -void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload( +void jit_avx2_kernel_sgemm_kern_t::prefetchB_beforeBload( int um, int un, int k_idx, int n_idx) { if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((n_idx == 0) && (k_idx == 0) && (un == unroll_n_) && (um != 16)) { @@ -44,7 +44,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload( } } -void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA( +void jit_avx2_kernel_sgemm_kern_t::prefetchB_beforeFMA( int um, int un, int k_idx, int n_idx, int m_idx) { if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um == 16) || (un < unroll_n_)) { @@ -61,7 +61,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA( } } -void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA( +void jit_avx2_kernel_sgemm_kern_t::prefetchA_afterFMA( int um, int un, int k_idx, int n_idx, int m_idx) { if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if ((um < unroll_m_) && (m_idx == 0)) { @@ -85,7 +85,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA( } } -void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload( +void jit_avx2_kernel_sgemm_kern_t::prefetchA_afterBload( int um, int un, int k_idx, int n_idx) { if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um == unroll_m_) && (un == 2)) { @@ -109,7 +109,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload( } } -void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA( +void jit_avx2_kernel_sgemm_kern_t::prefetchB_afterFMA( int k_idx, int n_idx, int m_idx) { if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if (((m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) * unroll_m_reg_) @@ -124,7 +124,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA( } } -void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA( +void jit_avx2_kernel_sgemm_kern_t::prefetchA_beforeFMA( int um, int un, int k_idx, int n_idx, int m_idx) { if (!(mayiuse(avx512_core) && __BUILD_GEMM_AVX512)) { if ((um == unroll_m_) && (un == unroll_n_)) { @@ -158,7 +158,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA( } } -void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload( +void jit_avx2_kernel_sgemm_kern_t::prefetchC_afterBload( int um, int un, int k_idx, int n_idx) { if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if (um == unroll_m_) { @@ -172,7 +172,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload( } } -void jit_avx2_kernel_sgemm_kern::prefetchC_beforeKloop(int um) { +void jit_avx2_kernel_sgemm_kern_t::prefetchC_beforeKloop(int um) { if (mayiuse(avx512_core) && __BUILD_GEMM_AVX512) { if (um < unroll_m_) { prefetchw(ptr[CO2_ + elt_size_ * 0]); @@ -199,7 +199,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchC_beforeKloop(int um) { } } -void jit_avx2_kernel_sgemm_kern::generate() { +void jit_avx2_kernel_sgemm_kern_t::generate() { int i, unroll_x, unroll_y, uy_bin, ux_bin; int C_off = is_windows ? 56 : 8; @@ -435,8 +435,8 @@ void jit_avx2_kernel_sgemm_kern::generate() { postamble(); } -jit_avx2_kernel_sgemm_kern::jit_avx2_kernel_sgemm_kern(bool beta_zero) - : jit_generator(jit_name()), beta_zero_(beta_zero) {} +jit_avx2_kernel_sgemm_kern_t::jit_avx2_kernel_sgemm_kern_t(bool beta_zero) + : jit_generator_t(jit_name()), beta_zero_(beta_zero) {} } // namespace x64 } // namespace cpu diff --git a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp index c51d429c3e8..60b97371367 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,9 +29,9 @@ namespace impl { namespace cpu { namespace x64 { -class jit_avx2_kernel_sgemm_kern : public jit_generator { +class jit_avx2_kernel_sgemm_kern_t : public jit_generator_t { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_kernel_sgemm_kern); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_kernel_sgemm_kern_t); const int elt_size_ = 4; const int elt_size_bin_ = 2; int nelt_per_vecreg_ = mayiuse(avx512_core) && __BUILD_GEMM_AVX512 ? 16 : 8; @@ -79,7 +79,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { void prefetchA_beforeFMA(int um, int un, int k_idx, int n_idx, int m_idx); void prefetchC_afterBload(int um, int un, int k_idx, int n_idx); void prefetchC_beforeKloop(int um); - void generate() override ATTRIBUTE_OPTIMIZE; + void generate() override; template void loadA_betweenFMAs(int um, int un, int k_idx, int n_idx, int m_idx, @@ -701,7 +701,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } public: - jit_avx2_kernel_sgemm_kern(bool beta_zero); + jit_avx2_kernel_sgemm_kern_t(bool beta_zero); }; } // namespace x64 } // namespace cpu diff --git a/src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp b/src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp index 007bc74fc2e..85e9e4aec69 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2024 Intel Corporation +* Copyright 2017-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,18 +59,18 @@ namespace x64 { namespace avx512_common_gemm_f32 { using namespace gemm_utils; -struct xbyak_gemm_t : public jit_generator { +struct xbyak_gemm_t : public jit_generator_t { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm) xbyak_gemm_t(char isTransA, char isTransB, float beta, bool hasBias = false) - : jit_generator(jit_name()) + : jit_generator_t(jit_name()) , isTransA(isTransA) , isTransB(isTransB) , beta(beta) , hasBias(hasBias) , STACK_K_CAPACITY((STACK_CAPACITY - 256) / (SIZE * UNROLL_M)) {} - void generate() override ATTRIBUTE_OPTIMIZE { + void generate() override { using namespace Xbyak; bool isBeta0 = (beta == 0.0); bool isBetaN = (!isBeta0 && beta != 1.0); diff --git a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_an_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_an_kern_autogen.cpp index bca29715498..75b38090dcb 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_an_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_an_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_f32_copy_an_kern::jit_avx512_core_f32_copy_an_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_f32_copy_an_kern_t::jit_avx512_core_f32_copy_an_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_f32_copy_an_kern::generate() { +void jit_avx512_core_f32_copy_an_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part1_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part1_autogen.cpp index 63bb212c563..d7230690c63 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part1_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part1_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_f32_copy_at_kern::jit_avx512_core_f32_copy_at_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_f32_copy_at_kern_t::jit_avx512_core_f32_copy_at_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_f32_copy_at_kern::generate() { +void jit_avx512_core_f32_copy_at_kern_t::generate() { Xbyak::Label l1f80; Xbyak::Label l22b8; Xbyak::Label l2a5c; @@ -48,9 +48,9 @@ void jit_avx512_core_f32_copy_at_kern::generate() { postamble(); } -void jit_avx512_core_f32_copy_at_kern::generate_part1(const Xbyak::Label &l4000, - const Xbyak::Label &l2a5c, const Xbyak::Label &l22b8, - const Xbyak::Label &l1f80) { +void jit_avx512_core_f32_copy_at_kern_t::generate_part1( + const Xbyak::Label &l4000, const Xbyak::Label &l2a5c, + const Xbyak::Label &l22b8, const Xbyak::Label &l1f80) { Xbyak::Label l1d30; Xbyak::Label l1d0c; Xbyak::Label l1cfc; diff --git a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part2_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part2_autogen.cpp index 51c776f1989..379a632bb1a 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part2_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part2_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2021 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ namespace impl { namespace cpu { namespace x64 { -void jit_avx512_core_f32_copy_at_kern::generate_part2(Xbyak::Label l4000, +void jit_avx512_core_f32_copy_at_kern_t::generate_part2(Xbyak::Label l4000, Xbyak::Label l2a5c, Xbyak::Label l22b8, Xbyak::Label l1f80) { std::vector labels(62); L(l1f80); diff --git a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_bn_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_bn_kern_autogen.cpp index c49dbb2f743..ab581f6a2ad 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_bn_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_bn_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_f32_copy_bn_kern::jit_avx512_core_f32_copy_bn_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_f32_copy_bn_kern_t::jit_avx512_core_f32_copy_bn_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_f32_copy_bn_kern::generate() { +void jit_avx512_core_f32_copy_bn_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_bt_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_bt_kern_autogen.cpp index 24d3145349f..99e101a7525 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_bt_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_bt_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx512_core_f32_copy_bt_kern::jit_avx512_core_f32_copy_bt_kern() - : jit_generator(jit_name()) {} +jit_avx512_core_f32_copy_bt_kern_t::jit_avx512_core_f32_copy_bt_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx512_core_f32_copy_bt_kern::generate() { +void jit_avx512_core_f32_copy_bt_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.cpp b/src/cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.cpp index 6c675d189d3..3787430d2bf 100644 --- a/src/cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,13 +43,13 @@ static inline Xbyak::Ymm make_ymm(const Xbyak::Zmm &v) { namespace avx512_core_gemm_smalln_tn_f32 { -struct xbyak_gemm_smalln_tn_t : public jit_generator { +struct xbyak_gemm_smalln_tn_t : public jit_generator_t { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_smalln_tn_xbyak_gemm) xbyak_gemm_smalln_tn_t(int N, float beta, float alpha) - : jit_generator(jit_name()), N(N), beta(beta), alpha(alpha) {} + : jit_generator_t(jit_name()), N(N), beta(beta), alpha(alpha) {} - void generate() override ATTRIBUTE_OPTIMIZE { + void generate() override { using namespace Xbyak; /** * numN = 1 : 16 rows of A, 1x16 accumulators diff --git a/src/cpu/x64/gemm/f32/jit_avx_f32_copy_an_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx_f32_copy_an_kern_autogen.cpp index 117de225946..4354e22db58 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_f32_copy_an_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_f32_copy_an_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx_f32_copy_an_kern::jit_avx_f32_copy_an_kern() - : jit_generator(jit_name()) {} +jit_avx_f32_copy_an_kern_t::jit_avx_f32_copy_an_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx_f32_copy_an_kern::generate() { +void jit_avx_f32_copy_an_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx_f32_copy_at_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx_f32_copy_at_kern_autogen.cpp index 20e8c67d6be..700ff542285 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_f32_copy_at_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_f32_copy_at_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx_f32_copy_at_kern::jit_avx_f32_copy_at_kern() - : jit_generator(jit_name()) {} +jit_avx_f32_copy_at_kern_t::jit_avx_f32_copy_at_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx_f32_copy_at_kern::generate() { +void jit_avx_f32_copy_at_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx_f32_copy_bn_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx_f32_copy_bn_kern_autogen.cpp index 277144c5fbd..ed0494c469b 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_f32_copy_bn_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_f32_copy_bn_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx_f32_copy_bn_kern::jit_avx_f32_copy_bn_kern() - : jit_generator(jit_name()) {} +jit_avx_f32_copy_bn_kern_t::jit_avx_f32_copy_bn_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx_f32_copy_bn_kern::generate() { +void jit_avx_f32_copy_bn_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx_f32_copy_bt_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx_f32_copy_bt_kern_autogen.cpp index a7d9fe4fa04..e59bb0a5d8b 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_f32_copy_bt_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_f32_copy_bt_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx_f32_copy_bt_kern::jit_avx_f32_copy_bt_kern() - : jit_generator(jit_name()) {} +jit_avx_f32_copy_bt_kern_t::jit_avx_f32_copy_bt_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx_f32_copy_bt_kern::generate() { +void jit_avx_f32_copy_bt_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_avx_gemm_f32.cpp b/src/cpu/x64/gemm/f32/jit_avx_gemm_f32.cpp index 8740d81d8c0..38a01ce662e 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_gemm_f32.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_gemm_f32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2024 Intel Corporation +* Copyright 2016-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,11 +58,11 @@ namespace avx_gemm_f32 { using namespace gemm_utils; using namespace Xbyak; -struct xbyak_gemm_t : public jit_generator { +struct xbyak_gemm_t : public jit_generator_t { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm) xbyak_gemm_t(char isTransA, char isTransB, float beta, bool hasBias = false) - : jit_generator(jit_name()) + : jit_generator_t(jit_name()) , isTransA(isTransA) , isTransB(isTransB) , hasBias(hasBias) @@ -1966,7 +1966,7 @@ struct xbyak_gemm_t : public jit_generator { if (hasBias) { add(BIAS, unroll_m * SIZE); } } - void generate() override ATTRIBUTE_OPTIMIZE { + void generate() override { assert(IMPLICATION(!is_avx2, mayiuse(avx))); preamble(); diff --git a/src/cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.cpp b/src/cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.cpp index d85f65fb581..394eb40f2e7 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ static inline Xmm make_xmm(const Xmm &v) { } // Load vector register data for x, y or A. -void jit_avx_gemv_t_f32_kern::v_load( +void jit_avx_gemv_t_f32_kern_t::v_load( const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems) { switch (nelems) { case 1: vmovss(make_xmm(dst), src); break; @@ -52,7 +52,7 @@ void jit_avx_gemv_t_f32_kern::v_load( } // Store vector register data for x, y or A. -void jit_avx_gemv_t_f32_kern::v_store( +void jit_avx_gemv_t_f32_kern_t::v_store( const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems) { switch (nelems) { case 1: vmovss(dst, make_xmm(src)); break; @@ -67,7 +67,7 @@ void jit_avx_gemv_t_f32_kern::v_store( // Perform Hadamard product of 2 vectors and accumulate. // Use FMA instruction, otherwise emulate. -void jit_avx_gemv_t_f32_kern::dot_product( +void jit_avx_gemv_t_f32_kern_t::dot_product( const Xmm &dst, const Xmm &src1, const Xmm &src2) { if (is_avx2_) vfmadd231ps(dst, src1, src2); @@ -78,7 +78,7 @@ void jit_avx_gemv_t_f32_kern::dot_product( } // Inner loop. -void jit_avx_gemv_t_f32_kern::innerloop(int unroll_m, int unroll_n) { +void jit_avx_gemv_t_f32_kern_t::innerloop(int unroll_m, int unroll_n) { if ((unroll_m > M_UNROLL_) || (unroll_n > N_UNROLL_) || (unroll_m < 0) || (unroll_n < 0)) return; @@ -119,7 +119,7 @@ void jit_avx_gemv_t_f32_kern::innerloop(int unroll_m, int unroll_n) { } // Outer loop. -void jit_avx_gemv_t_f32_kern::outerloop( +void jit_avx_gemv_t_f32_kern_t::outerloop( int unroll_x, int unroll_y, Label *&cur_outerloop_label) { if ((unroll_x > M_UNROLL_) || (unroll_y > N_UNROLL_) || (unroll_y < 0) || (unroll_x < 0)) @@ -259,7 +259,7 @@ void jit_avx_gemv_t_f32_kern::outerloop( align(16); } -void jit_avx_gemv_t_f32_kern::generate() { +void jit_avx_gemv_t_f32_kern_t::generate() { // Prologue preamble(); @@ -301,8 +301,8 @@ void jit_avx_gemv_t_f32_kern::generate() { } // Function signature: gemv(*m, *n, *alpha, *a, *lda, *x, *incx, *y, *incy) -jit_avx_gemv_t_f32_kern::jit_avx_gemv_t_f32_kern() - : jit_generator(jit_name()) +jit_avx_gemv_t_f32_kern_t::jit_avx_gemv_t_f32_kern_t() + : jit_generator_t(jit_name()) , is_avx2_(mayiuse(avx2)) , LDA_(is_windows ? rdi : r8) , X_(is_windows ? rsi : r9) diff --git a/src/cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.hpp b/src/cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.hpp index 1ed21b708ff..d4b07183ed5 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.hpp +++ b/src/cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2021 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,10 @@ namespace impl { namespace cpu { namespace x64 { -class jit_avx_gemv_t_f32_kern : public jit_generator { +class jit_avx_gemv_t_f32_kern_t : public jit_generator_t { public: - jit_avx_gemv_t_f32_kern(void); - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemv_t_f32_kern); + jit_avx_gemv_t_f32_kern_t(void); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemv_t_f32_kern_t); protected: bool is_avx2_; @@ -40,7 +40,7 @@ class jit_avx_gemv_t_f32_kern : public jit_generator { void innerloop(int unroll_m, int unroll_n); void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label); - void generate() override ATTRIBUTE_OPTIMIZE; + void generate() override; private: static const int M_UNROLL_ = 16; diff --git a/src/cpu/x64/gemm/f32/jit_avx_kernel_b0_sgemm_kern_part1_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx_kernel_b0_sgemm_kern_part1_autogen.cpp index 52fccd21619..32a2f5860dd 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_kernel_b0_sgemm_kern_part1_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_kernel_b0_sgemm_kern_part1_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx_kernel_b0_sgemm_kern::jit_avx_kernel_b0_sgemm_kern() - : jit_generator(jit_name()) {} +jit_avx_kernel_b0_sgemm_kern_t::jit_avx_kernel_b0_sgemm_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx_kernel_b0_sgemm_kern::generate() { +void jit_avx_kernel_b0_sgemm_kern_t::generate() { Xbyak::Label l259c; Xbyak::Label l2774; Xbyak::Label l2834; @@ -52,7 +52,7 @@ void jit_avx_kernel_b0_sgemm_kern::generate() { postamble(); } -void jit_avx_kernel_b0_sgemm_kern::generate_part1(const Xbyak::Label &l2cf4, +void jit_avx_kernel_b0_sgemm_kern_t::generate_part1(const Xbyak::Label &l2cf4, const Xbyak::Label &l2834, const Xbyak::Label &l2774, const Xbyak::Label &l259c) { std::vector labels(55); diff --git a/src/cpu/x64/gemm/f32/jit_avx_kernel_b0_sgemm_kern_part2_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx_kernel_b0_sgemm_kern_part2_autogen.cpp index 74d2c82cbc4..35a9ea2f626 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_kernel_b0_sgemm_kern_part2_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_kernel_b0_sgemm_kern_part2_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2021 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ namespace impl { namespace cpu { namespace x64 { -void jit_avx_kernel_b0_sgemm_kern::generate_part2(Xbyak::Label l2cf4, +void jit_avx_kernel_b0_sgemm_kern_t::generate_part2(Xbyak::Label l2cf4, Xbyak::Label l2834, Xbyak::Label l2774, Xbyak::Label l259c) { std::vector labels(57); L(labels[56]); diff --git a/src/cpu/x64/gemm/f32/jit_avx_kernel_sgemm_kern_part1_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx_kernel_sgemm_kern_part1_autogen.cpp index 8ea5bd9a729..daeba0781ea 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_kernel_sgemm_kern_part1_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_kernel_sgemm_kern_part1_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_avx_kernel_sgemm_kern::jit_avx_kernel_sgemm_kern() - : jit_generator(jit_name()) {} +jit_avx_kernel_sgemm_kern_t::jit_avx_kernel_sgemm_kern_t() + : jit_generator_t(jit_name()) {} -void jit_avx_kernel_sgemm_kern::generate() { +void jit_avx_kernel_sgemm_kern_t::generate() { Xbyak::Label l1efc; Xbyak::Label l1f44; Xbyak::Label l1f48; @@ -40,13 +40,13 @@ void jit_avx_kernel_sgemm_kern::generate() { mov(C, ptr[OLD_C]); mov(LDC, ptr[OLD_LDC]); - jit_avx_kernel_sgemm_kern::generate_part1(l1efc, l1f44, l1f48); - jit_avx_kernel_sgemm_kern::generate_part2(l1efc, l1f44, l1f48); + jit_avx_kernel_sgemm_kern_t::generate_part1(l1efc, l1f44, l1f48); + jit_avx_kernel_sgemm_kern_t::generate_part2(l1efc, l1f44, l1f48); postamble(); } -void jit_avx_kernel_sgemm_kern::generate_part1(const Xbyak::Label &l1efc, +void jit_avx_kernel_sgemm_kern_t::generate_part1(const Xbyak::Label &l1efc, const Xbyak::Label &l1f44, const Xbyak::Label &l1f48) { std::vector labels(44); diff --git a/src/cpu/x64/gemm/f32/jit_avx_kernel_sgemm_kern_part2_autogen.cpp b/src/cpu/x64/gemm/f32/jit_avx_kernel_sgemm_kern_part2_autogen.cpp index a8154c7a1c8..e1ff79875f0 100644 --- a/src/cpu/x64/gemm/f32/jit_avx_kernel_sgemm_kern_part2_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx_kernel_sgemm_kern_part2_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2021 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ namespace impl { namespace cpu { namespace x64 { -void jit_avx_kernel_sgemm_kern::generate_part2( +void jit_avx_kernel_sgemm_kern_t::generate_part2( Xbyak::Label &l1efc, Xbyak::Label &l1f44, Xbyak::Label &l1f48) { std::vector labels(69); L(l1efc); diff --git a/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_an_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_an_kern_autogen.cpp index 57039cba5b0..9fe1c3a386b 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_an_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_an_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_sse41_f32_copy_an_kern::jit_sse41_f32_copy_an_kern() - : jit_generator(jit_name()) {} +jit_sse41_f32_copy_an_kern_t::jit_sse41_f32_copy_an_kern_t() + : jit_generator_t(jit_name()) {} -void jit_sse41_f32_copy_an_kern::generate() { +void jit_sse41_f32_copy_an_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_at_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_at_kern_autogen.cpp index b1381469d1b..d52a86a0726 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_at_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_at_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_sse41_f32_copy_at_kern::jit_sse41_f32_copy_at_kern() - : jit_generator(jit_name()) {} +jit_sse41_f32_copy_at_kern_t::jit_sse41_f32_copy_at_kern_t() + : jit_generator_t(jit_name()) {} -void jit_sse41_f32_copy_at_kern::generate() { +void jit_sse41_f32_copy_at_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_bn_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_bn_kern_autogen.cpp index f095bf750e9..36c56697f43 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_bn_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_bn_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_sse41_f32_copy_bn_kern::jit_sse41_f32_copy_bn_kern() - : jit_generator(jit_name()) {} +jit_sse41_f32_copy_bn_kern_t::jit_sse41_f32_copy_bn_kern_t() + : jit_generator_t(jit_name()) {} -void jit_sse41_f32_copy_bn_kern::generate() { +void jit_sse41_f32_copy_bn_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_bt_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_bt_kern_autogen.cpp index 3f509e5dcef..b985134391e 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_bt_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_f32_copy_bt_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_sse41_f32_copy_bt_kern::jit_sse41_f32_copy_bt_kern() - : jit_generator(jit_name()) {} +jit_sse41_f32_copy_bt_kern_t::jit_sse41_f32_copy_bt_kern_t() + : jit_generator_t(jit_name()) {} -void jit_sse41_f32_copy_bt_kern::generate() { +void jit_sse41_f32_copy_bt_kern_t::generate() { #ifndef _WIN32 #define M rdi diff --git a/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.cpp b/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.cpp index cb195f55006..83b171466aa 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ static inline int log2_of_pow2(int n) { } // Load vector register data for x, y or A. -void jit_sse41_gemv_n_f32_kern::v_load( +void jit_sse41_gemv_n_f32_kern_t::v_load( const Xmm &dst, const Address &src, int nelems) { if (nelems >= v_nelems_) { uni_vmovups(dst, src); @@ -82,7 +82,7 @@ void jit_sse41_gemv_n_f32_kern::v_load( } // Store vector register data for x, y or A. -void jit_sse41_gemv_n_f32_kern::v_store( +void jit_sse41_gemv_n_f32_kern_t::v_store( const Address &dst, const Xmm &src, int nelems) { if (nelems >= v_nelems_) { uni_vmovups(dst, src); @@ -107,7 +107,7 @@ void jit_sse41_gemv_n_f32_kern::v_store( // Perform Hadamard product of 2 vectors and accumulate. // Use FMA instruction, otherwise emulate. -void jit_sse41_gemv_n_f32_kern::dot_product( +void jit_sse41_gemv_n_f32_kern_t::dot_product( const Xmm &dst, const Xmm &src1, const Xmm &src2) { if (has_avx2_) vfmadd231ps(dst, src1, src2); @@ -120,7 +120,7 @@ void jit_sse41_gemv_n_f32_kern::dot_product( } } -void jit_sse41_gemv_n_f32_kern::kernel_loop( +void jit_sse41_gemv_n_f32_kern_t::kernel_loop( int unroll_m, int unroll_n, bool fetch, bool last) { int um_vecs = utils::div_up(unroll_m, v_nelems_); @@ -168,7 +168,7 @@ void jit_sse41_gemv_n_f32_kern::kernel_loop( } // Inner loop for A non-transposed. -void jit_sse41_gemv_n_f32_kern::innerloop(int unroll_m, int unroll_n) { +void jit_sse41_gemv_n_f32_kern_t::innerloop(int unroll_m, int unroll_n) { mov(Y1_, Y_); // Load x and scale by alpha. @@ -237,7 +237,7 @@ void jit_sse41_gemv_n_f32_kern::innerloop(int unroll_m, int unroll_n) { L_aligned(label_m_loop_end); } -void jit_sse41_gemv_n_f32_kern::outerloop(int unroll_x, int unroll_y, +void jit_sse41_gemv_n_f32_kern_t::outerloop(int unroll_x, int unroll_y, Label *&cur_outerloop_label, Label *&outerloop_end_label) { bool is_tail = unroll_y < unroll_n_; @@ -270,7 +270,7 @@ void jit_sse41_gemv_n_f32_kern::outerloop(int unroll_x, int unroll_y, } } -void jit_sse41_gemv_n_f32_kern::generate() { +void jit_sse41_gemv_n_f32_kern_t::generate() { // Prologue preamble(); @@ -313,8 +313,8 @@ void jit_sse41_gemv_n_f32_kern::generate() { } // Function signature: gemv(*m, *n, *alpha, *a, *lda, *x, *incx, *y, *incy) -jit_sse41_gemv_n_f32_kern::jit_sse41_gemv_n_f32_kern(void) - : jit_generator(jit_name()) +jit_sse41_gemv_n_f32_kern_t::jit_sse41_gemv_n_f32_kern_t(void) + : jit_generator_t(jit_name()) , has_avx512_(mayiuse(avx512_core) && __BUILD_GEMM_AVX512) , has_avx2_(mayiuse(avx2) && __BUILD_GEMM_AVX2) , has_avx_(mayiuse(avx) && __BUILD_GEMM_AVX2) diff --git a/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.hpp b/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.hpp index 89886aed939..8058122ea18 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.hpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,10 @@ namespace impl { namespace cpu { namespace x64 { -class jit_sse41_gemv_n_f32_kern : public jit_generator { +class jit_sse41_gemv_n_f32_kern_t : public jit_generator_t { public: - jit_sse41_gemv_n_f32_kern(); - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_gemv_n_f32_kern); + jit_sse41_gemv_n_f32_kern_t(); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_gemv_n_f32_kern_t); protected: bool has_avx512_; @@ -61,7 +61,7 @@ class jit_sse41_gemv_n_f32_kern : public jit_generator { Xbyak::Label *&cur_outerloop_label, Xbyak::Label *&outerloop_end_label); - void generate() override ATTRIBUTE_OPTIMIZE; + void generate() override; private: static const int max_um_vecs_ = 16; diff --git a/src/cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.cpp b/src/cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.cpp index b3b578975fc..da67ca31e78 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ namespace x64 { using namespace Xbyak; // Load vector register data for x, y or A. -void jit_sse41_gemv_t_f32_kern::v_load( +void jit_sse41_gemv_t_f32_kern_t::v_load( const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems) { switch (nelems) { case 1: movss(dst, src); break; @@ -43,7 +43,7 @@ void jit_sse41_gemv_t_f32_kern::v_load( } // Store vector register data for x, y or A. -void jit_sse41_gemv_t_f32_kern::v_store( +void jit_sse41_gemv_t_f32_kern_t::v_store( const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems) { switch (nelems) { case 1: movss(dst, src); break; @@ -56,14 +56,14 @@ void jit_sse41_gemv_t_f32_kern::v_store( } // Perform Hadamard product of 2 vectors and accumulate. -void jit_sse41_gemv_t_f32_kern::dot_product( +void jit_sse41_gemv_t_f32_kern_t::dot_product( const Xmm &dst, const Xmm &src1, const Xmm &src2) { mulps(src2, src1); addps(dst, src2); } // Inner loop. -void jit_sse41_gemv_t_f32_kern::innerloop(int unroll_m, int unroll_n) { +void jit_sse41_gemv_t_f32_kern_t::innerloop(int unroll_m, int unroll_n) { if ((unroll_m > M_UNROLL_) || (unroll_n > N_UNROLL_) || (unroll_m < 0) || (unroll_n < 0)) return; @@ -104,7 +104,7 @@ void jit_sse41_gemv_t_f32_kern::innerloop(int unroll_m, int unroll_n) { } // Outer loop. -void jit_sse41_gemv_t_f32_kern::outerloop( +void jit_sse41_gemv_t_f32_kern_t::outerloop( int unroll_x, int unroll_y, Label *&cur_outerloop_label) { if ((unroll_x > M_UNROLL_) || (unroll_y > N_UNROLL_) || (unroll_y < 0) || unroll_x < 0) @@ -230,7 +230,7 @@ void jit_sse41_gemv_t_f32_kern::outerloop( align(16); } -void jit_sse41_gemv_t_f32_kern::generate() { +void jit_sse41_gemv_t_f32_kern_t::generate() { // Prologue preamble(); @@ -272,8 +272,8 @@ void jit_sse41_gemv_t_f32_kern::generate() { } // Function signature: gemv(*m, *n, *alpha, *a, *lda, *x, *incx, *y, *incy) -jit_sse41_gemv_t_f32_kern::jit_sse41_gemv_t_f32_kern() - : jit_generator(jit_name()) +jit_sse41_gemv_t_f32_kern_t::jit_sse41_gemv_t_f32_kern_t() + : jit_generator_t(jit_name()) , LDA_(is_windows ? rdi : r8) , X_(is_windows ? rsi : r9) , INCY_(r10) diff --git a/src/cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.hpp b/src/cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.hpp index 8a32fb4beff..9f79643ab8e 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.hpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2021 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,10 @@ namespace impl { namespace cpu { namespace x64 { -class jit_sse41_gemv_t_f32_kern : public jit_generator { +class jit_sse41_gemv_t_f32_kern_t : public jit_generator_t { public: - jit_sse41_gemv_t_f32_kern(void); - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_gemv_t_f32_kern); + jit_sse41_gemv_t_f32_kern_t(void); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_gemv_t_f32_kern_t); protected: void v_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems); @@ -38,7 +38,7 @@ class jit_sse41_gemv_t_f32_kern : public jit_generator { void innerloop(int unroll_m, int unroll_n); void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label); - void generate() override ATTRIBUTE_OPTIMIZE; + void generate() override; private: static const int M_UNROLL_ = 8; diff --git a/src/cpu/x64/gemm/f32/jit_sse41_kernel_b0_sgemm_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_sse41_kernel_b0_sgemm_kern_autogen.cpp index ae734d720b7..a2d2934f144 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_kernel_b0_sgemm_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_kernel_b0_sgemm_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_sse41_kernel_b0_sgemm_kern::jit_sse41_kernel_b0_sgemm_kern() - : jit_generator(jit_name()) {} +jit_sse41_kernel_b0_sgemm_kern_t::jit_sse41_kernel_b0_sgemm_kern_t() + : jit_generator_t(jit_name()) {} -void jit_sse41_kernel_b0_sgemm_kern::generate() { +void jit_sse41_kernel_b0_sgemm_kern_t::generate() { #ifndef _WIN32 diff --git a/src/cpu/x64/gemm/f32/jit_sse41_kernel_sgemm_kern_autogen.cpp b/src/cpu/x64/gemm/f32/jit_sse41_kernel_sgemm_kern_autogen.cpp index ba6a36882ed..6d900e70fb4 100644 --- a/src/cpu/x64/gemm/f32/jit_sse41_kernel_sgemm_kern_autogen.cpp +++ b/src/cpu/x64/gemm/f32/jit_sse41_kernel_sgemm_kern_autogen.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ namespace impl { namespace cpu { namespace x64 { -jit_sse41_kernel_sgemm_kern::jit_sse41_kernel_sgemm_kern() - : jit_generator(jit_name()) {} +jit_sse41_kernel_sgemm_kern_t::jit_sse41_kernel_sgemm_kern_t() + : jit_generator_t(jit_name()) {} -void jit_sse41_kernel_sgemm_kern::generate() { +void jit_sse41_kernel_sgemm_kern_t::generate() { #ifndef _WIN32 diff --git a/src/cpu/x64/gemm/gemm_driver.cpp b/src/cpu/x64/gemm/gemm_driver.cpp index dae0d417f46..aaacd0931e2 100644 --- a/src/cpu/x64/gemm/gemm_driver.cpp +++ b/src/cpu/x64/gemm/gemm_driver.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,15 +80,15 @@ int get_vector_length() { //dummy if #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { - v_bytes = cpu_isa_traits::vlen; + v_bytes = cpu_isa_traits_t::vlen; #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx)) { - v_bytes = cpu_isa_traits::vlen; + v_bytes = cpu_isa_traits_t::vlen; #endif #if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { - v_bytes = cpu_isa_traits::vlen; + v_bytes = cpu_isa_traits_t::vlen; #endif } else { assert(!"not supposed to be reached."); @@ -115,7 +115,7 @@ static inline void add_results(const dim_t m, const dim_t n, const float alpha, c_type *c_data, const dim_t ldc, const c_type *co, offset_type offsetc) { - constexpr bool is_int8 = data_traits::data_type == data_type::s32; + constexpr bool is_int8 = data_traits_t::data_type == data_type::s32; for (dim_t j = 0; j < n; ++j) { for (dim_t i = 0; i < m; ++i) { @@ -254,7 +254,7 @@ static inline void *align(void *ptr, size_t alignment) { template void scale_matrix( dim_t m, dim_t n, scale_t alpha, mat_t *__restrict p_mat, dim_t ld) { - if (data_traits::data_type == data_type::f32) { + if (data_traits_t::data_type == data_type::f32) { for (dim_t j = 0; j < n; j++) { for (dim_t i = 0; i < m; i++) { p_mat[i + j * ld] = (mat_t)((scale_t)p_mat[i + j * ld] * alpha); @@ -400,8 +400,8 @@ void gemm_kernel(dim_t m, dim_t n, const dim_t k, const float alpha, bool row_req = false; constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - constexpr bool is_f32 = data_traits::data_type == data_type::f32; + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_f32 = data_traits_t::data_type == data_type::f32; bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; dim_t m_stk = col_offset_ws ? 1 : m; @@ -547,8 +547,9 @@ static dnnl_status_t gemm_kernel_driver(int ithr, dim_t m, dim_t n, dim_t k, float alpha = arg->alpha; constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; @@ -826,8 +827,9 @@ static dnnl_status_t kernel_driver_parallel_acopiedbcopy(int ithr, dim_t m, size_t b_buf_nelems = k * n_padd; size_t b_col_sum_nelems = n_padd; constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; @@ -1050,7 +1052,7 @@ template static inline bool nocopy_checker( int nthr, const gemm_info_t *arg) { - if (data_traits::data_type != data_type::f32) return false; + if (data_traits_t::data_type != data_type::f32) return false; if (!(mayiuse(avx) && __BUILD_GEMM_AVX2)) return false; @@ -1089,8 +1091,8 @@ static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, static constexpr dim_t M2D_MIN = 384; constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - bool isSgemm = data_traits::data_type == data_type::f32; + data_traits_t::data_type, data_type::s8, data_type::u8); + bool isSgemm = data_traits_t::data_type == data_type::f32; dim_t m = arg->m; dim_t n = arg->n; @@ -1247,8 +1249,9 @@ static inline void set_thread_opts_pack(int nthrs, bool do_n_blocking = true) { constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; bool do_m_blocking_only = do_m_blocking && !do_n_blocking; @@ -1362,8 +1365,9 @@ static inline int set_thread_opts(int nthrs, int nthrs_spawn, thread_info.thread_m = thread_info.thread_n = thread_info.thread_k = -1; constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; if (nocopy_checker(nthrs, arg)) { thread_info.copy = copy_type::no_copy; @@ -1452,8 +1456,9 @@ static dnnl_status_t parallel_a_copy(const int ithr, const int nthrs, float alpha = arg->alpha; constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_amx = is_int8_amx || is_bf16_amx; @@ -1608,7 +1613,7 @@ static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) { auto veclen = get_vector_length(); const double fp_per_cycle = 2.0 * 2.0 * veclen; - const bool is_f32 = data_traits::data_type == data_type::f32; + const bool is_f32 = data_traits_t::data_type == data_type::f32; const bool is_avx512 = mayiuse(avx512_core) && __BUILD_GEMM_AVX512; const bool is_avx = mayiuse(avx) && __BUILD_GEMM_AVX2; @@ -1729,8 +1734,9 @@ static dnnl_status_t gemm_threading_driver( auto is_a_packed = (arg->transa == packed); auto is_b_packed = (arg->transb == packed); constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; if ((arg->m <= 0) || (arg->n <= 0)) return dnnl_success; @@ -1971,7 +1977,7 @@ static dnnl_status_t gemm_threading_driver( // This route is taken only if we realize we need no-copy // after launching the parallel section, due to less // threads being spawned than expected. - assert(data_traits::data_type + assert(data_traits_t::data_type == data_type::f32); assert(arg->packing == pack_type::none); @@ -2048,13 +2054,13 @@ dnnl_status_t gemm_driver(const char *transA, const char *transB, pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) { constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); + data_traits_t::data_type, data_type::s8, data_type::u8); MAYBE_UNUSED(is_int8); #if __BUILD_GEMM_AVX512 // gemm_driver supports bfloat16 gemm for Intel AVX512 and // Intel AVX512 BF16. - assert(IMPLICATION(data_traits::data_type == data_type::bf16, + assert(IMPLICATION(data_traits_t::data_type == data_type::bf16, mayiuse(avx512_core) && !force_nocopy)); #endif @@ -2067,8 +2073,8 @@ dnnl_status_t gemm_driver(const char *transA, const char *transB, #if __BUILD_GEMM_SSE41 // gemm_driver supports sgemm for Intel AVX512, Intel AVX2, Intel AVX, // and Intel SSE4.1 - assert(IMPLICATION( - data_traits::data_type == data_type::f32, mayiuse(sse41))); + assert(IMPLICATION(data_traits_t::data_type == data_type::f32, + mayiuse(sse41))); #endif // 8-bit integer gemm doesn't support nocopy kernels. diff --git a/src/cpu/x64/gemm/gemm_driver.hpp b/src/cpu/x64/gemm/gemm_driver.hpp index 650d1775a01..163349b1101 100644 --- a/src/cpu/x64/gemm/gemm_driver.hpp +++ b/src/cpu/x64/gemm/gemm_driver.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2020 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ dnnl_status_t gemm_driver(const char *transA, const char *transB, const b_type *b, const dim_t *ldb, const b_type *ob, const float *beta, c_type *c, const dim_t *ldc, const c_type *oc, const bool force_jit_nocopy_gemm, pack_type packing = pack_type::none, - gemm_pack_storage_t *pack_dst = NULL, bool measure_only = false); + gemm_pack_storage_t *pack_dst = nullptr, bool measure_only = false); void prep_ref_gemm_s8u8s32_pack( bool do_a, dim_t rows, dim_t cols, gemm_pack_storage_t *pack_dst); diff --git a/src/cpu/x64/gemm/gemm_info.cpp b/src/cpu/x64/gemm/gemm_info.cpp index cd227f30306..05837b0293e 100644 --- a/src/cpu/x64/gemm/gemm_info.cpp +++ b/src/cpu/x64/gemm/gemm_info.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "common/bfloat16.hpp" #include "common/dnnl_traits.hpp" +#include "common/dnnl_sel_build.hpp" #include "cpu/gemm/gemm.hpp" @@ -139,7 +140,7 @@ gemm_info_t::gemm_info_t(const char *transA, const char *transB, } constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); + data_traits_t::data_type, data_type::s8, data_type::u8); if (is_int8) this->ao = oa ? *oa : a_t(0); prepare_bo(this->bo, ob); @@ -155,7 +156,7 @@ gemm_info_t::gemm_info_t(const char *transA, const char *transB, this->co = oc; } - bool is_sgemm = data_traits::data_type == data_type::f32; + bool is_sgemm = data_traits_t::data_type == data_type::f32; bool is_gemv = this->m == 1 || this->n == 1; // Copy-based sgemm doesn't support force-nocopy for ISAs older @@ -213,7 +214,8 @@ void gemm_info_t::jit_init(void) { // TODO: Add dispatching for 1-fma SKUs with support to bf16 // instructions for AMX kernel. { - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; const bool max_isa_supports_bf16_ymm = mayiuse(avx512_core_bf16_ymm) && __BUILD_GEMM_AVX512 && !(mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX); @@ -221,7 +223,7 @@ void gemm_info_t::jit_init(void) { use_bf16_ymm = is_bf16 && max_isa_supports_bf16_ymm; } - switch (data_traits::data_type) { + switch (data_traits_t::data_type) { case data_type::s8: if (false) { // dummy if @@ -391,145 +393,158 @@ void gemm_info_t::jit_init(void) { static std::once_flag initialized; static std::atomic st(dnnl_success); std::call_once(initialized, [&, um] { - const bool b_is_s8 = data_traits::data_type == data_type::s8; + const bool b_is_s8 = data_traits_t::data_type == data_type::s8; UNUSED(b_is_s8); constexpr bool is_int8 = utils::one_of( - data_traits::data_type, data_type::s8, data_type::u8); - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + data_traits_t::data_type, data_type::s8, data_type::u8); + constexpr bool is_bf16 + = data_traits_t::data_type == data_type::bf16; bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx) && __BUILD_GEMM_AMX; bool is_amx = is_int8_amx || is_bf16_amx; - static maybe_unique_ptr copy_a[2][2] = {{nullptr}}; - static maybe_unique_ptr copy_b[2][2] = {{nullptr}}; + static maybe_unique_ptr copy_a[2][2] = {{nullptr}}; + static maybe_unique_ptr copy_b[2][2] = {{nullptr}}; - switch (data_traits::data_type) { + switch (data_traits_t::data_type) { case data_type::s8: if (false) { // dummy if #if __BUILD_GEMM_AMX } else if (mayiuse(amx_int8)) { + DNNL_CSCOPE(jit_init_copy_kern_s8_amx_int8) { for (int isTrans : {no_trans, do_trans}) { copy_a[isTrans][no_sum].reset( - new jit_avx512_core_amx_copy_kern( + new jit_avx512_core_amx_copy_kern_t( true, !isTrans, sizeof(a_t))); copy_b[isTrans][no_sum].reset( - new jit_avx512_core_amx_copy_kern( + new jit_avx512_core_amx_copy_kern_t( false, isTrans, sizeof(b_t))); + } } #endif #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { + DNNL_CSCOPE(jit_init_copy_kern_s8_avx512_core) { copy_a[no_trans][no_sum].reset( - new jit_avx512_core_u8_copy_an_kern()); + new jit_avx512_core_u8_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx512_core_u8_copy_at_kern()); + new jit_avx512_core_u8_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx512_core_u8_copy_bn_kern(b_is_s8)); + new jit_avx512_core_u8_copy_bn_kern_t(b_is_s8)); copy_b[do_trans][no_sum].reset( - new jit_avx512_core_u8_copy_bt_kern(b_is_s8)); + new jit_avx512_core_u8_copy_bt_kern_t(b_is_s8)); copy_a[no_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_an_kern()); + new jit_avx512_core_u8_copy_sum_an_kern_t()); copy_a[do_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_at_kern()); + new jit_avx512_core_u8_copy_sum_at_kern_t()); copy_b[no_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_bn_kern(b_is_s8)); + new jit_avx512_core_u8_copy_sum_bn_kern_t(b_is_s8)); copy_b[do_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_bt_kern(b_is_s8)); + new jit_avx512_core_u8_copy_sum_bt_kern_t(b_is_s8)); + } #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx2_vnni)) { + DNNL_CSCOPE(jit_init_copy_kern_s8_avx2_vnni) { copy_a[no_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_an_kern()); + new jit_avx2_vnni_u8_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_at_kern()); + new jit_avx2_vnni_u8_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_bn_kern()); + new jit_avx2_vnni_u8_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_bt_kern()); + new jit_avx2_vnni_u8_copy_bt_kern_t()); copy_a[no_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_an_kern()); + new jit_avx2_vnni_u8_copy_sum_an_kern_t()); copy_a[do_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_at_kern()); + new jit_avx2_vnni_u8_copy_sum_at_kern_t()); copy_b[no_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_bn_kern()); + new jit_avx2_vnni_u8_copy_sum_bn_kern_t()); copy_b[do_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_bt_kern()); + new jit_avx2_vnni_u8_copy_sum_bt_kern_t()); + } #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx2)) { + DNNL_CSCOPE(jit_init_copy_kern_s8_avx2) { copy_a[no_trans][no_sum].reset( - new jit_avx2_u8_copy_an_kern()); + new jit_avx2_u8_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx2_u8_copy_at_kern()); + new jit_avx2_u8_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx2_u8_copy_bn_kern()); + new jit_avx2_u8_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_avx2_u8_copy_bt_kern()); + new jit_avx2_u8_copy_bt_kern_t()); copy_a[no_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_an_kern()); + new jit_avx2_u8_copy_sum_an_kern_t()); copy_a[do_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_at_kern()); + new jit_avx2_u8_copy_sum_at_kern_t()); copy_b[no_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_bn_kern()); + new jit_avx2_u8_copy_sum_bn_kern_t()); copy_b[do_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_bt_kern()); + new jit_avx2_u8_copy_sum_bt_kern_t()); + } #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx)) { + DNNL_CSCOPE(jit_init_copy_kern_s8_avx) { copy_a[no_trans][no_sum].reset( - new jit_avx_u8_copy_an_kern()); + new jit_avx_u8_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx_u8_copy_at_kern()); + new jit_avx_u8_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx_u8_copy_bn_kern()); + new jit_avx_u8_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_avx_u8_copy_bt_kern()); + new jit_avx_u8_copy_bt_kern_t()); copy_a[no_trans][do_sum].reset( - new jit_avx_u8_copy_sum_an_kern()); + new jit_avx_u8_copy_sum_an_kern_t()); copy_a[do_trans][do_sum].reset( - new jit_avx_u8_copy_sum_at_kern()); + new jit_avx_u8_copy_sum_at_kern_t()); copy_b[no_trans][do_sum].reset( - new jit_avx_u8_copy_sum_bn_kern()); + new jit_avx_u8_copy_sum_bn_kern_t()); copy_b[do_trans][do_sum].reset( - new jit_avx_u8_copy_sum_bt_kern()); + new jit_avx_u8_copy_sum_bt_kern_t()); + } #endif #if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { + DNNL_CSCOPE(jit_init_copy_kern_s8_sse41) { copy_a[no_trans][no_sum].reset( - new jit_sse41_u8_copy_an_kern()); + new jit_sse41_u8_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_sse41_u8_copy_at_kern()); + new jit_sse41_u8_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_sse41_u8_copy_bn_kern()); + new jit_sse41_u8_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_sse41_u8_copy_bt_kern()); + new jit_sse41_u8_copy_bt_kern_t()); copy_a[no_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_an_kern()); + new jit_sse41_u8_copy_sum_an_kern_t()); copy_a[do_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_at_kern()); + new jit_sse41_u8_copy_sum_at_kern_t()); copy_b[no_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_bn_kern()); + new jit_sse41_u8_copy_sum_bn_kern_t()); copy_b[do_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_bt_kern()); + new jit_sse41_u8_copy_sum_bt_kern_t()); + } #endif } break; @@ -539,39 +554,45 @@ void gemm_info_t::jit_init(void) { // dummy if #if __BUILD_GEMM_AMX } else if (mayiuse(amx_bf16)) { + DNNL_CSCOPE(jit_init_copy_kern_bf16_amx_bf16) { for (int isTrans : {no_trans, do_trans}) { copy_a[isTrans][no_sum].reset( - new jit_avx512_core_amx_copy_kern( + new jit_avx512_core_amx_copy_kern_t( true, !isTrans, sizeof(a_t))); copy_b[isTrans][no_sum].reset( - new jit_avx512_core_amx_copy_kern( + new jit_avx512_core_amx_copy_kern_t( false, isTrans, sizeof(b_t))); + } } #endif #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core) && !use_bf16_ymm) { + DNNL_CSCOPE(jit_init_copy_kern_bf16_avx512_core_not_use_bf16_ymm) { copy_a[no_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_an_kern()); + new jit_avx512_core_s16_48x8_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_at_kern()); + new jit_avx512_core_s16_48x8_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_bn_kern()); + new jit_avx512_core_s16_48x8_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_bt_kern()); + new jit_avx512_core_s16_48x8_copy_bt_kern_t()); + } #endif #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core) && use_bf16_ymm) { + DNNL_CSCOPE(jit_init_copy_kern_bf16_avx512_core_use_bf16_ymm) { copy_a[no_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_an_kern()); + new jit_avx512_core_s16_24x8_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_at_kern()); + new jit_avx512_core_s16_24x8_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_bn_kern()); + new jit_avx512_core_s16_24x8_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_bt_kern()); + new jit_avx512_core_s16_24x8_copy_bt_kern_t()); + } #endif } break; @@ -581,51 +602,59 @@ void gemm_info_t::jit_init(void) { // dummy if #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { + DNNL_CSCOPE(jit_init_copy_kern_f32_avx512_core) { copy_a[no_trans][no_sum].reset( - new jit_avx512_core_f32_copy_an_kern()); + new jit_avx512_core_f32_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx512_core_f32_copy_at_kern()); + new jit_avx512_core_f32_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx512_core_f32_copy_bn_kern()); + new jit_avx512_core_f32_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_avx512_core_f32_copy_bt_kern()); + new jit_avx512_core_f32_copy_bt_kern_t()); + } #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx2)) { + DNNL_CSCOPE(jit_init_copy_kern_f32_avx2) { copy_a[no_trans][no_sum].reset( - new jit_avx2_f32_copy_an_kern()); + new jit_avx2_f32_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx2_f32_copy_at_kern()); + new jit_avx2_f32_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx2_f32_copy_bn_kern()); + new jit_avx2_f32_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_avx2_f32_copy_bt_kern()); + new jit_avx2_f32_copy_bt_kern_t()); + } #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx)) { + DNNL_CSCOPE(jit_init_copy_kern_f32_avx) { copy_a[no_trans][no_sum].reset( - new jit_avx_f32_copy_an_kern()); + new jit_avx_f32_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_avx_f32_copy_at_kern()); + new jit_avx_f32_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_avx_f32_copy_bn_kern()); + new jit_avx_f32_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_avx_f32_copy_bt_kern()); + new jit_avx_f32_copy_bt_kern_t()); + } #endif #if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { + DNNL_CSCOPE(jit_init_copy_kern_f32_sse41) { copy_a[no_trans][no_sum].reset( - new jit_sse41_f32_copy_an_kern()); + new jit_sse41_f32_copy_an_kern_t()); copy_a[do_trans][no_sum].reset( - new jit_sse41_f32_copy_at_kern()); + new jit_sse41_f32_copy_at_kern_t()); copy_b[no_trans][no_sum].reset( - new jit_sse41_f32_copy_bn_kern()); + new jit_sse41_f32_copy_bn_kern_t()); copy_b[do_trans][no_sum].reset( - new jit_sse41_f32_copy_bt_kern()); + new jit_sse41_f32_copy_bt_kern_t()); + } #endif } break; @@ -633,87 +662,98 @@ void gemm_info_t::jit_init(void) { default: break; } - constexpr bool is_a_s8 = data_traits::data_type == data_type::s8; - constexpr bool is_b_s8 = data_traits::data_type == data_type::s8; - constexpr bool is_c_s32 = data_traits::data_type == data_type::s32; + constexpr bool is_a_s8 = data_traits_t::data_type == data_type::s8; + constexpr bool is_b_s8 = data_traits_t::data_type == data_type::s8; + constexpr bool is_c_s32 + = data_traits_t::data_type == data_type::s32; UNUSED(is_a_s8); UNUSED(is_b_s8); UNUSED(is_c_s32); - static maybe_unique_ptr kernel[2][2][2][2] + static maybe_unique_ptr kernel[2][2][2][2] = {{{{nullptr}}}}; - switch (data_traits::data_type) { + switch (data_traits_t::data_type) { case data_type::s8: if (false) { // dummy if #if __BUILD_GEMM_AMX } else if (mayiuse(avx512_core_amx)) { + DNNL_CSCOPE(jit_init_gemm_kern_s8_avx512_core_bf16_amx_int8) { for (int isBeta0 : {no_beta0, do_beta0}) { kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx512_core_amx_gemm_kern( + new jit_avx512_core_amx_gemm_kern_t( is_a_s8, is_b_s8, is_c_s32, isBeta0)); } + } #endif #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { + DNNL_CSCOPE(jit_init_gemm_kern_s8_avx512_core) { for (int isBeta0 : {no_beta0, do_beta0}) for (int doColSum : {no_sum, do_sum}) for (int doRowSum : {no_sum, do_sum}) { kernel[isBeta0][do_alpha1][doColSum][doRowSum].reset( - new jit_avx512_core_gemm_s8u8s32_kern( + new jit_avx512_core_gemm_s8u8s32_kern_t( isBeta0, doColSum, doRowSum)); } + } #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx2)) { + DNNL_CSCOPE(jit_init_gemm_kern_s8_avx2) { for (int isBeta0 : {no_beta0, do_beta0}) for (int doColSum : {no_sum, do_sum}) for (int doRowSum : {no_sum, do_sum}) { kernel[isBeta0][do_alpha1][doColSum][doRowSum] - .reset(new jit_avx2_gemm_s8u8s32_kern( + .reset(new jit_avx2_gemm_s8u8s32_kern_t( isBeta0, doColSum, doRowSum, um)); } + } #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx)) { + DNNL_CSCOPE(jit_init_gemm_kern_s8_avx) { kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_gemm_s8u8s32_kern()); + new jit_avx_kernel_gemm_s8u8s32_kern_t()); kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_avx_kernel_c_gemm_s8u8s32_kern()); + new jit_avx_kernel_c_gemm_s8u8s32_kern_t()); kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_avx_kernel_r_gemm_s8u8s32_kern()); + new jit_avx_kernel_r_gemm_s8u8s32_kern_t()); kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_avx_kernel_b_gemm_s8u8s32_kern()); + new jit_avx_kernel_b_gemm_s8u8s32_kern_t()); kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_b0_gemm_s8u8s32_kern()); + new jit_avx_kernel_b0_gemm_s8u8s32_kern_t()); kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_avx_kernel_b0_c_gemm_s8u8s32_kern()); + new jit_avx_kernel_b0_c_gemm_s8u8s32_kern_t()); kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_avx_kernel_b0_r_gemm_s8u8s32_kern()); + new jit_avx_kernel_b0_r_gemm_s8u8s32_kern_t()); kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_avx_kernel_b0_b_gemm_s8u8s32_kern()); + new jit_avx_kernel_b0_b_gemm_s8u8s32_kern_t()); + } #endif #if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { + DNNL_CSCOPE(jit_init_gemm_kern_s8_sse41) { kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_gemm_s8u8s32_kern()); + new jit_sse41_kernel_gemm_s8u8s32_kern_t()); kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_sse41_kernel_c_gemm_s8u8s32_kern()); + new jit_sse41_kernel_c_gemm_s8u8s32_kern_t()); kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_sse41_kernel_r_gemm_s8u8s32_kern()); + new jit_sse41_kernel_r_gemm_s8u8s32_kern_t()); kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_sse41_kernel_b_gemm_s8u8s32_kern()); + new jit_sse41_kernel_b_gemm_s8u8s32_kern_t()); kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_b0_gemm_s8u8s32_kern()); + new jit_sse41_kernel_b0_gemm_s8u8s32_kern_t()); kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_sse41_kernel_b0_c_gemm_s8u8s32_kern()); + new jit_sse41_kernel_b0_c_gemm_s8u8s32_kern_t()); kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_sse41_kernel_b0_r_gemm_s8u8s32_kern()); + new jit_sse41_kernel_b0_r_gemm_s8u8s32_kern_t()); kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_sse41_kernel_b0_b_gemm_s8u8s32_kern()); + new jit_sse41_kernel_b0_b_gemm_s8u8s32_kern_t()); + } #endif } break; @@ -723,20 +763,24 @@ void gemm_info_t::jit_init(void) { // dummy if #if __BUILD_GEMM_AMX } else if (mayiuse(avx512_core_amx)) { + DNNL_CSCOPE(jit_init_gemm_kern_bf16_avx512_core_bf16_amx_bf16) { for (int isBeta0 : {no_beta0, do_beta0}) { kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx512_core_amx_gemm_kern( + new jit_avx512_core_amx_gemm_kern_t( is_a_s8, is_b_s8, is_c_s32, isBeta0)); } + } #endif #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { + DNNL_CSCOPE(jit_init_gemm_kern_bf16_avx512_core) { for (int isBeta0 : {no_beta0, do_beta0}) for (int isAlpha1 : {no_alpha1, do_alpha1}) { kernel[isBeta0][isAlpha1][no_sum][no_sum].reset( - new jit_avx512_core_gemm_bf16bf16f32_kern( + new jit_avx512_core_gemm_bf16bf16f32_kern_t( isBeta0, isAlpha1, !use_bf16_ymm)); } + } #endif } break; @@ -746,24 +790,30 @@ void gemm_info_t::jit_init(void) { // dummy if #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx2)) { + DNNL_CSCOPE(jit_init_gemm_kern_f32_avx2) { for (int isBeta0 : {no_beta0, do_beta0}) { kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx2_kernel_sgemm_kern(isBeta0)); + new jit_avx2_kernel_sgemm_kern_t(isBeta0)); + } } #endif #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx)) { + DNNL_CSCOPE(jit_init_gemm_kern_f32_avx) { kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_sgemm_kern()); + new jit_avx_kernel_sgemm_kern_t()); kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_b0_sgemm_kern()); + new jit_avx_kernel_b0_sgemm_kern_t()); + } #endif #if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { + DNNL_CSCOPE(jit_init_gemm_kern_f32_sse41) { kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_sgemm_kern()); + new jit_sse41_kernel_sgemm_kern_t()); kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_b0_sgemm_kern()); + new jit_sse41_kernel_b0_sgemm_kern_t()); + } #endif } break; @@ -771,22 +821,27 @@ void gemm_info_t::jit_init(void) { default: break; } - static maybe_unique_ptr gemv_kernel[2] = {nullptr}; - static maybe_unique_ptr gemv_s8s8s32_kernel = nullptr; - static maybe_unique_ptr gemv_s8u8s32_kernel = nullptr; - static maybe_unique_ptr gemv_u8s8s32_kernel = nullptr; - switch (data_traits::data_type) { + static maybe_unique_ptr gemv_kernel[2] = {nullptr}; + static maybe_unique_ptr gemv_s8s8s32_kernel = nullptr; + static maybe_unique_ptr gemv_s8u8s32_kernel = nullptr; + static maybe_unique_ptr gemv_u8s8s32_kernel = nullptr; + switch (data_traits_t::data_type) { case data_type::s8: if (false) { // dummy if #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { + DNNL_CSCOPE(jit_init_gemv_kern_s8_avx512_core) { gemv_s8s8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8s8)); + new jit_avx512_core_gemv_s8x8s32_kern_t( + ver_t::s8s8)); gemv_s8u8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8u8)); + new jit_avx512_core_gemv_s8x8s32_kern_t( + ver_t::s8u8)); gemv_u8s8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::u8s8)); + new jit_avx512_core_gemv_s8x8s32_kern_t( + ver_t::u8s8)); + } #endif } break; @@ -796,10 +851,12 @@ void gemm_info_t::jit_init(void) { // dummy if #if __BUILD_GEMM_AVX512 } else if (mayiuse(avx512_core)) { + DNNL_CSCOPE(jit_init_gemv_kern_bf16_avx512_core) { for (int isTrans : {no_trans, do_trans}) gemv_kernel[isTrans].reset( - new jit_avx512_core_gemv_bf16bf16f32_kern( + new jit_avx512_core_gemv_bf16bf16f32_kern_t( isTrans)); + } #endif } break; @@ -809,16 +866,21 @@ void gemm_info_t::jit_init(void) { // dummy if #if __BUILD_GEMM_AVX2 } else if (mayiuse(avx)) { + DNNL_CSCOPE(jit_init_gemv_kern_f32_avx) { gemv_kernel[no_trans].reset( - new jit_sse41_gemv_n_f32_kern()); - gemv_kernel[do_trans].reset(new jit_avx_gemv_t_f32_kern()); + new jit_sse41_gemv_n_f32_kern_t()); + gemv_kernel[do_trans].reset( + new jit_avx_gemv_t_f32_kern_t()); + } #endif #if __BUILD_GEMM_SSE41 } else if (mayiuse(sse41)) { + DNNL_CSCOPE(jit_init_gemv_kern_f32_sse41) { gemv_kernel[no_trans].reset( - new jit_sse41_gemv_n_f32_kern()); + new jit_sse41_gemv_n_f32_kern_t()); gemv_kernel[do_trans].reset( - new jit_sse41_gemv_t_f32_kern()); + new jit_sse41_gemv_t_f32_kern_t()); + } #endif } break; @@ -882,7 +944,7 @@ void gemm_info_t::jit_init(void) { } // Set gemv floating point kernels - if (utils::one_of(data_traits::data_type, data_type::f32, + if (utils::one_of(data_traits_t::data_type, data_type::f32, data_type::bf16)) { for (int isTrans : {no_trans, do_trans}) { auto *p_gemv_kernel = gemv_kernel[isTrans].get(); @@ -895,7 +957,7 @@ void gemm_info_t::jit_init(void) { } // Set gemv integer gemm kernels - if (data_traits::data_type == data_type::s8) { + if (data_traits_t::data_type == data_type::s8) { if (gemv_s8s8s32_kernel != nullptr) { auto *kern = gemv_s8s8s32_kernel.get(); st = kern->create_kernel(); @@ -927,7 +989,7 @@ void gemm_info_t::jit_init(void) { int copy_trans_a = (this->transa == do_trans) ? do_trans : no_trans; int copy_trans_b = (this->transb == do_trans) ? do_trans : no_trans; - constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; + constexpr bool is_bf16 = data_traits_t::data_type == data_type::bf16; bool doAlpha1 = this->alpha != 1.0f && is_bf16 ? no_alpha1 : do_alpha1; { @@ -950,7 +1012,7 @@ void gemm_info_t::jit_init(void) { this->gemv_s8s8s32_kernel = nullptr; this->gemv_s8u8s32_kernel = nullptr; this->gemv_u8s8s32_kernel = nullptr; - if (data_traits::data_type == data_type::s8) { + if (data_traits_t::data_type == data_type::s8) { this->gemv_s8s8s32_kernel = gemv_s8s8s32_kern; this->gemv_s8u8s32_kernel = gemv_s8u8s32_kern; this->gemv_u8s8s32_kernel = gemv_u8s8s32_kern; @@ -965,7 +1027,7 @@ void gemm_info_t::jit_init(void) { template bool gemm_info_t::hasKernels(void) { - switch (data_traits::data_type) { + switch (data_traits_t::data_type) { case data_type::s8: if (mayiuse(sse41)) { for (int isBeta0 : {no_beta0, do_beta0}) diff --git a/src/cpu/x64/gemm/gemm_pack.cpp b/src/cpu/x64/gemm/gemm_pack.cpp index 091a4a69c60..1b98e6e26c5 100644 --- a/src/cpu/x64/gemm/gemm_pack.cpp +++ b/src/cpu/x64/gemm/gemm_pack.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -79,8 +79,8 @@ static inline CBLAS_OFFSET cblas_offset(const char *offset) { template static inline bool use_reference_igemm(void) { constexpr bool is_s8u8 = true - && data_traits::data_type == data_type::s8 - && data_traits::data_type == data_type::u8; + && data_traits_t::data_type == data_type::s8 + && data_traits_t::data_type == data_type::u8; if (is_s8u8) return !mayiuse(sse41); else @@ -241,8 +241,8 @@ dnnl_status_t gemm_x8x8s32_pack_get_size(const char *identifier, #if USE_MKL_PACKED_GEMM constexpr bool is_s8u8 = true - && data_traits::data_type == data_type::s8 - && data_traits::data_type == data_type::u8; + && data_traits_t::data_type == data_type::s8 + && data_traits_t::data_type == data_type::u8; if (is_s8u8) { *size = cblas_gemm_s8u8s32_pack_get_size( @@ -356,8 +356,8 @@ dnnl_status_t gemm_x8x8s32_pack(const char *identifier, const char *transa, #if USE_MKL_PACKED_GEMM constexpr bool is_s8u8 = true - && data_traits::data_type == data_type::s8 - && data_traits::data_type == data_type::u8; + && data_traits_t::data_type == data_type::s8 + && data_traits_t::data_type == data_type::u8; if (is_s8u8) { auto cblas_id = cblas_identifier(identifier); @@ -459,8 +459,8 @@ dnnl_status_t gemm_x8x8s32_compute(const char *transa, const char *transb, #if USE_MKL_PACKED_GEMM constexpr bool is_s8u8 = true - && data_traits::data_type == data_type::s8 - && data_traits::data_type == data_type::u8; + && data_traits_t::data_type == data_type::s8 + && data_traits_t::data_type == data_type::u8; if (is_s8u8) { if (utils::any_null(transa, transb, offsetc, M, N, K, alpha, A, lda, ao, diff --git a/src/cpu/x64/gemm/gemm_pack_storage.hpp b/src/cpu/x64/gemm/gemm_pack_storage.hpp index 2f92e445c0a..73111f73c7d 100644 --- a/src/cpu/x64/gemm/gemm_pack_storage.hpp +++ b/src/cpu/x64/gemm/gemm_pack_storage.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -108,14 +108,14 @@ struct gemm_pack_storage_t { template data_type *row_sums(int ithr, dim_t r0, dim_t cblock) const { - if (!has_row_sums()) return NULL; + if (!has_row_sums()) return nullptr; auto id = thread_to_slice(ithr); return get_block(sums_header->slice[id], r0, cblock); } template data_type *col_sums(int ithr, dim_t rblock, dim_t c0) const { - if (!has_col_sums()) return NULL; + if (!has_col_sums()) return nullptr; auto id = thread_to_slice(ithr); return get_block(sums_header->slice[id], rblock, c0); } diff --git a/src/cpu/x64/gemm/gemm_threading.hpp b/src/cpu/x64/gemm/gemm_threading.hpp index 3915dd54f12..b0af2760095 100644 --- a/src/cpu/x64/gemm/gemm_threading.hpp +++ b/src/cpu/x64/gemm/gemm_threading.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2022 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ struct gemm_slice_t { }; struct gemm_threading_t { - gemm_threading_t() {}; + gemm_threading_t() = default; int nthrs_m, nthrs_n, nthrs_k; dim_t block_m, block_n, block_k; // Blocking sizes (-1 = default) diff --git a/src/cpu/x64/gemm/gemm_utils.hpp b/src/cpu/x64/gemm/gemm_utils.hpp index 76462b5b150..739ed73d812 100644 --- a/src/cpu/x64/gemm/gemm_utils.hpp +++ b/src/cpu/x64/gemm/gemm_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -171,7 +171,7 @@ dnnl_status_t pack_no_copy(const T *src, dim_t ld_src, dim_t nrows, dim_t ncols, dim_t nrows_dst, ncols_dst; dim_t ld_dst, td_dst; - constexpr bool is_f32 = data_traits::data_type == data_type::f32; + constexpr bool is_f32 = data_traits_t::data_type == data_type::f32; if (!dst_pack->get_nocopy(0, trans_dst, ld_dst, td_dst)) return dnnl_invalid_arguments; diff --git a/src/cpu/x64/gemm/gemv_driver.cpp b/src/cpu/x64/gemm/gemv_driver.cpp index 7b6ab72945f..241b84ca802 100644 --- a/src/cpu/x64/gemm/gemv_driver.cpp +++ b/src/cpu/x64/gemm/gemv_driver.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -231,7 +231,7 @@ template static inline int thread_checker( int nthr, const dim_t m, const dim_t n, int trans) { constexpr bool is_f32 - = utils::one_of(data_traits::data_type, data_type::f32); + = utils::one_of(data_traits_t::data_type, data_type::f32); if (is_f32) { // Threshold based on performance measurement with warm and cold cache @@ -317,7 +317,7 @@ template static inline void part_1d(const dim_t m, const int ithr, const int nthr, T *addr, dim_t &off, dim_t &size) { constexpr bool is_f32 - = utils::one_of(data_traits::data_type, data_type::f32); + = utils::one_of(data_traits_t::data_type, data_type::f32); if (ithr >= nthr) { size = 0; @@ -397,9 +397,9 @@ static inline void gemv_threading_driver(const int trans, const dim_t m, const b_t *x, const dim_t incx, const float beta, c_t *y, const dim_t incy, const gemm_info_t *arg) { constexpr bool is_f32 - = utils::one_of(data_traits::data_type, data_type::f32); + = utils::one_of(data_traits_t::data_type, data_type::f32); constexpr bool is_bf16 - = utils::one_of(data_traits::data_type, data_type::bf16); + = utils::one_of(data_traits_t::data_type, data_type::bf16); // Quick return if possible. if (m <= 0 || n <= 0) return; diff --git a/src/cpu/x64/gemm/s8x8s32/common_u8.hpp b/src/cpu/x64/gemm/s8x8s32/common_u8.hpp index 386821cfb4e..575f73a52a9 100644 --- a/src/cpu/x64/gemm/s8x8s32/common_u8.hpp +++ b/src/cpu/x64/gemm/s8x8s32/common_u8.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2024 Intel Corporation +* Copyright 2019-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,464 +22,464 @@ #include "cpu/x64/jit_generator.hpp" #define PADD_BYTESIZE_ONPAGE(x, size) \ - (((x) * (size) + PAGE_4K - 1) / PAGE_4K) * PAGE_4K -#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size + ((((x) * (size) + PAGE_4K - 1) / PAGE_4K) * PAGE_4K) +#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, (size)) / (size)) namespace dnnl { namespace impl { namespace cpu { namespace x64 { -class jit_avx512_core_u8_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_u8_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern_t); + void generate() override; public: - jit_avx512_core_u8_copy_an_kern(); + jit_avx512_core_u8_copy_an_kern_t(); }; -class jit_avx512_core_u8_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_u8_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern_t); + void generate() override; public: - jit_avx512_core_u8_copy_at_kern(); + jit_avx512_core_u8_copy_at_kern_t(); }; -class jit_avx512_core_u8_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_u8_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern_t); + void generate() override; bool s8_case; public: - jit_avx512_core_u8_copy_bn_kern(bool s8 = false); + jit_avx512_core_u8_copy_bn_kern_t(bool s8 = false); }; -class jit_avx512_core_u8_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_u8_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern_t); + void generate() override; bool s8_case; public: - jit_avx512_core_u8_copy_bt_kern(bool s8 = false); + jit_avx512_core_u8_copy_bt_kern_t(bool s8 = false); }; -class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_u8_copy_sum_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern_t); + void generate() override; public: - jit_avx512_core_u8_copy_sum_an_kern(); + jit_avx512_core_u8_copy_sum_an_kern_t(); }; -class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_u8_copy_sum_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern_t); + void generate() override; public: - jit_avx512_core_u8_copy_sum_at_kern(); + jit_avx512_core_u8_copy_sum_at_kern_t(); }; -class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_u8_copy_sum_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern_t); + void generate() override; bool s8_case; public: - jit_avx512_core_u8_copy_sum_bn_kern(bool s8 = false); + jit_avx512_core_u8_copy_sum_bn_kern_t(bool s8 = false); }; -class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx512_core_u8_copy_sum_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern_t); + void generate() override; bool s8_case; public: - jit_avx512_core_u8_copy_sum_bt_kern(bool s8 = false); + jit_avx512_core_u8_copy_sum_bt_kern_t(bool s8 = false); }; -class jit_avx2_vnni_u8_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_vnni_u8_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_an_kern_t); + void generate() override; public: - jit_avx2_vnni_u8_copy_an_kern(); + jit_avx2_vnni_u8_copy_an_kern_t(); }; -class jit_avx2_vnni_u8_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_vnni_u8_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_at_kern_t); + void generate() override; public: - jit_avx2_vnni_u8_copy_at_kern(); + jit_avx2_vnni_u8_copy_at_kern_t(); }; -class jit_avx2_vnni_u8_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_vnni_u8_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_bn_kern_t); + void generate() override; public: - jit_avx2_vnni_u8_copy_bn_kern(); + jit_avx2_vnni_u8_copy_bn_kern_t(); }; -class jit_avx2_vnni_u8_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_vnni_u8_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_bt_kern_t); + void generate() override; public: - jit_avx2_vnni_u8_copy_bt_kern(); + jit_avx2_vnni_u8_copy_bt_kern_t(); }; -class jit_avx2_vnni_u8_copy_sum_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_sum_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_vnni_u8_copy_sum_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_sum_an_kern_t); + void generate() override; public: - jit_avx2_vnni_u8_copy_sum_an_kern(); + jit_avx2_vnni_u8_copy_sum_an_kern_t(); }; -class jit_avx2_vnni_u8_copy_sum_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_sum_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_vnni_u8_copy_sum_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_sum_at_kern_t); + void generate() override; public: - jit_avx2_vnni_u8_copy_sum_at_kern(); + jit_avx2_vnni_u8_copy_sum_at_kern_t(); }; -class jit_avx2_vnni_u8_copy_sum_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_sum_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_vnni_u8_copy_sum_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_sum_bn_kern_t); + void generate() override; public: - jit_avx2_vnni_u8_copy_sum_bn_kern(); + jit_avx2_vnni_u8_copy_sum_bn_kern_t(); }; -class jit_avx2_vnni_u8_copy_sum_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_sum_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_vnni_u8_copy_sum_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_vnni_u8_copy_sum_bt_kern_t); + void generate() override; public: - jit_avx2_vnni_u8_copy_sum_bt_kern(); + jit_avx2_vnni_u8_copy_sum_bt_kern_t(); }; -class jit_avx2_u8_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_u8_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_an_kern_t); + void generate() override; public: - jit_avx2_u8_copy_an_kern(); + jit_avx2_u8_copy_an_kern_t(); }; -class jit_avx2_u8_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_u8_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_at_kern_t); + void generate() override; public: - jit_avx2_u8_copy_at_kern(); + jit_avx2_u8_copy_at_kern_t(); }; -class jit_avx2_u8_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_u8_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_bn_kern_t); + void generate() override; public: - jit_avx2_u8_copy_bn_kern(); + jit_avx2_u8_copy_bn_kern_t(); }; -class jit_avx2_u8_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_u8_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_bt_kern_t); + void generate() override; public: - jit_avx2_u8_copy_bt_kern(); + jit_avx2_u8_copy_bt_kern_t(); }; -class jit_avx2_u8_copy_sum_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_sum_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_u8_copy_sum_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_sum_an_kern_t); + void generate() override; public: - jit_avx2_u8_copy_sum_an_kern(); + jit_avx2_u8_copy_sum_an_kern_t(); }; -class jit_avx2_u8_copy_sum_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_sum_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_u8_copy_sum_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_sum_at_kern_t); + void generate() override; public: - jit_avx2_u8_copy_sum_at_kern(); + jit_avx2_u8_copy_sum_at_kern_t(); }; -class jit_avx2_u8_copy_sum_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_sum_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_u8_copy_sum_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_sum_bn_kern_t); + void generate() override; public: - jit_avx2_u8_copy_sum_bn_kern(); + jit_avx2_u8_copy_sum_bn_kern_t(); }; -class jit_avx2_u8_copy_sum_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_sum_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx2_u8_copy_sum_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_u8_copy_sum_bt_kern_t); + void generate() override; public: - jit_avx2_u8_copy_sum_bt_kern(); + jit_avx2_u8_copy_sum_bt_kern_t(); }; -class jit_avx_u8_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_u8_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_an_kern_t); + void generate() override; public: - jit_avx_u8_copy_an_kern(); + jit_avx_u8_copy_an_kern_t(); }; -class jit_avx_u8_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_u8_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_at_kern_t); + void generate() override; public: - jit_avx_u8_copy_at_kern(); + jit_avx_u8_copy_at_kern_t(); }; -class jit_avx_u8_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_u8_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_bn_kern_t); + void generate() override; public: - jit_avx_u8_copy_bn_kern(); + jit_avx_u8_copy_bn_kern_t(); }; -class jit_avx_u8_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_u8_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_bt_kern_t); + void generate() override; public: - jit_avx_u8_copy_bt_kern(); + jit_avx_u8_copy_bt_kern_t(); }; -class jit_avx_u8_copy_sum_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_sum_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_u8_copy_sum_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_sum_an_kern_t); + void generate() override; public: - jit_avx_u8_copy_sum_an_kern(); + jit_avx_u8_copy_sum_an_kern_t(); }; -class jit_avx_u8_copy_sum_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_sum_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_u8_copy_sum_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_sum_at_kern_t); + void generate() override; public: - jit_avx_u8_copy_sum_at_kern(); + jit_avx_u8_copy_sum_at_kern_t(); }; -class jit_avx_u8_copy_sum_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_sum_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_u8_copy_sum_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_sum_bn_kern_t); + void generate() override; public: - jit_avx_u8_copy_sum_bn_kern(); + jit_avx_u8_copy_sum_bn_kern_t(); }; -class jit_avx_u8_copy_sum_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_sum_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_u8_copy_sum_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_u8_copy_sum_bt_kern_t); + void generate() override; public: - jit_avx_u8_copy_sum_bt_kern(); + jit_avx_u8_copy_sum_bt_kern_t(); }; -class jit_avx_kernel_b0_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_b0_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_avx_kernel_b0_gemm_s8u8s32_kern(); + jit_avx_kernel_b0_gemm_s8u8s32_kern_t(); }; -class jit_avx_kernel_b0_b_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_b_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_b0_b_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_b_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_avx_kernel_b0_b_gemm_s8u8s32_kern(); + jit_avx_kernel_b0_b_gemm_s8u8s32_kern_t(); }; -class jit_avx_kernel_b0_r_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_r_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_b0_r_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_r_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_avx_kernel_b0_r_gemm_s8u8s32_kern(); + jit_avx_kernel_b0_r_gemm_s8u8s32_kern_t(); }; -class jit_avx_kernel_b0_c_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_c_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_b0_c_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b0_c_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_avx_kernel_b0_c_gemm_s8u8s32_kern(); + jit_avx_kernel_b0_c_gemm_s8u8s32_kern_t(); }; -class jit_avx_kernel_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_avx_kernel_gemm_s8u8s32_kern(); + jit_avx_kernel_gemm_s8u8s32_kern_t(); }; -class jit_avx_kernel_b_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_b_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_b_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_avx_kernel_b_gemm_s8u8s32_kern(); + jit_avx_kernel_b_gemm_s8u8s32_kern_t(); }; -class jit_avx_kernel_r_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_r_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_r_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_r_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_avx_kernel_r_gemm_s8u8s32_kern(); + jit_avx_kernel_r_gemm_s8u8s32_kern_t(); }; -class jit_avx_kernel_c_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_c_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_avx_kernel_c_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_kernel_c_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_avx_kernel_c_gemm_s8u8s32_kern(); + jit_avx_kernel_c_gemm_s8u8s32_kern_t(); }; -class jit_sse41_u8_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_u8_copy_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_an_kern_t); + void generate() override; public: - jit_sse41_u8_copy_an_kern(); + jit_sse41_u8_copy_an_kern_t(); }; -class jit_sse41_u8_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_u8_copy_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_at_kern_t); + void generate() override; public: - jit_sse41_u8_copy_at_kern(); + jit_sse41_u8_copy_at_kern_t(); }; -class jit_sse41_u8_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_u8_copy_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_bn_kern_t); + void generate() override; public: - jit_sse41_u8_copy_bn_kern(); + jit_sse41_u8_copy_bn_kern_t(); }; -class jit_sse41_u8_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_u8_copy_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_bt_kern_t); + void generate() override; public: - jit_sse41_u8_copy_bt_kern(); + jit_sse41_u8_copy_bt_kern_t(); }; -class jit_sse41_u8_copy_sum_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_sum_an_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_u8_copy_sum_an_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_sum_an_kern_t); + void generate() override; public: - jit_sse41_u8_copy_sum_an_kern(); + jit_sse41_u8_copy_sum_an_kern_t(); }; -class jit_sse41_u8_copy_sum_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_sum_at_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_u8_copy_sum_at_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_sum_at_kern_t); + void generate() override; public: - jit_sse41_u8_copy_sum_at_kern(); + jit_sse41_u8_copy_sum_at_kern_t(); }; -class jit_sse41_u8_copy_sum_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_sum_bn_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_u8_copy_sum_bn_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_sum_bn_kern_t); + void generate() override; public: - jit_sse41_u8_copy_sum_bn_kern(); + jit_sse41_u8_copy_sum_bn_kern_t(); }; -class jit_sse41_u8_copy_sum_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_sum_bt_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_u8_copy_sum_bt_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_u8_copy_sum_bt_kern_t); + void generate() override; public: - jit_sse41_u8_copy_sum_bt_kern(); + jit_sse41_u8_copy_sum_bt_kern_t(); }; -class jit_sse41_kernel_b0_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_b0_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_sse41_kernel_b0_gemm_s8u8s32_kern(); + jit_sse41_kernel_b0_gemm_s8u8s32_kern_t(); }; -class jit_sse41_kernel_b0_b_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_b_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_b0_b_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_b_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_sse41_kernel_b0_b_gemm_s8u8s32_kern(); + jit_sse41_kernel_b0_b_gemm_s8u8s32_kern_t(); }; -class jit_sse41_kernel_b0_r_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_r_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_b0_r_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_r_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_sse41_kernel_b0_r_gemm_s8u8s32_kern(); + jit_sse41_kernel_b0_r_gemm_s8u8s32_kern_t(); }; -class jit_sse41_kernel_b0_c_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_c_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_b0_c_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b0_c_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_sse41_kernel_b0_c_gemm_s8u8s32_kern(); + jit_sse41_kernel_b0_c_gemm_s8u8s32_kern_t(); }; -class jit_sse41_kernel_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_sse41_kernel_gemm_s8u8s32_kern(); + jit_sse41_kernel_gemm_s8u8s32_kern_t(); }; -class jit_sse41_kernel_b_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_b_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_b_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_sse41_kernel_b_gemm_s8u8s32_kern(); + jit_sse41_kernel_b_gemm_s8u8s32_kern_t(); }; -class jit_sse41_kernel_r_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_r_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_r_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_r_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_sse41_kernel_r_gemm_s8u8s32_kern(); + jit_sse41_kernel_r_gemm_s8u8s32_kern_t(); }; -class jit_sse41_kernel_c_gemm_s8u8s32_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_c_gemm_s8u8s32_kern); - void generate() override ATTRIBUTE_OPTIMIZE; +class jit_sse41_kernel_c_gemm_s8u8s32_kern_t : public jit_generator_t { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_kernel_c_gemm_s8u8s32_kern_t); + void generate() override; public: - jit_sse41_kernel_c_gemm_s8u8s32_kern(); + jit_sse41_kernel_c_gemm_s8u8s32_kern_t(); }; } // namespace x64 diff --git a/src/cpu/x64/gemm/s8x8s32/jit_avx2_gemm_s8u8s32_kern.cpp b/src/cpu/x64/gemm/s8x8s32/jit_avx2_gemm_s8u8s32_kern.cpp index d15e3cc71b6..a214221860a 100644 --- a/src/cpu/x64/gemm/s8x8s32/jit_avx2_gemm_s8u8s32_kern.cpp +++ b/src/cpu/x64/gemm/s8x8s32/jit_avx2_gemm_s8u8s32_kern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2024 Intel Corporation +* Copyright 2018-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ static inline Xmm make_xmm(const Xmm &v) { } // Load from or store to C. -void jit_avx2_gemm_s8u8s32_kern::c_load( +void jit_avx2_gemm_s8u8s32_kern_t::c_load( const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems) { switch (nelems) { case 1: vmovss(make_xmm(dst), src); break; @@ -51,7 +51,7 @@ void jit_avx2_gemm_s8u8s32_kern::c_load( } } -void jit_avx2_gemm_s8u8s32_kern::c_store( +void jit_avx2_gemm_s8u8s32_kern_t::c_store( const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems) { switch (nelems) { case 1: vmovss(dst, make_xmm(src)); break; @@ -67,7 +67,7 @@ void jit_avx2_gemm_s8u8s32_kern::c_store( // Perform length-4 dot product accumulations of unsigned and signed bytes // in parallel. // Use VEX vpdpbusd if avx2-vnni available, otherwise emulate. -void jit_avx2_gemm_s8u8s32_kern::dot_product( +void jit_avx2_gemm_s8u8s32_kern_t::dot_product( const Xmm &dst, const Xmm &src1, const Xmm &src2) { if (vnni_) { vpdpbusd(dst, src1, src2, VexEncoding); @@ -79,7 +79,7 @@ void jit_avx2_gemm_s8u8s32_kern::dot_product( } // Inner kernel. -void jit_avx2_gemm_s8u8s32_kern::kernel_loop( +void jit_avx2_gemm_s8u8s32_kern_t::kernel_loop( int unroll_m, int unroll_n, bool cfetch) { int um_vecs = (unroll_m + 7) >> 3; Label label_kernel_loop; @@ -137,7 +137,7 @@ void jit_avx2_gemm_s8u8s32_kern::kernel_loop( } // k remainder loop for kernel. -void jit_avx2_gemm_s8u8s32_kern::remainder_kernel( +void jit_avx2_gemm_s8u8s32_kern_t::remainder_kernel( int unroll_m, int unroll_n, int unroll_k, int bwidth) { Ymm b = b_regs_[0]; @@ -165,7 +165,7 @@ void jit_avx2_gemm_s8u8s32_kern::remainder_kernel( } // Inner loop. -void jit_avx2_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n) { +void jit_avx2_gemm_s8u8s32_kern_t::innerloop(int unroll_m, int unroll_n) { int um_vecs = (unroll_m + 7) >> 3; int stage1 = unroll_n, stage2 = mayiuse(avx2_vnni) ? 32 : 16; @@ -308,7 +308,7 @@ void jit_avx2_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n) { } // Outer loop. -void jit_avx2_gemm_s8u8s32_kern::outerloop( +void jit_avx2_gemm_s8u8s32_kern_t::outerloop( int unroll_x, int unroll_y, Label *&cur_outerloop_label) { Label label_m_loop, label_n_loop; std::vector