diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index bb712da992884..487336428ad3b 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -40,7 +40,7 @@ jobs: - name: Create Pull Request uses: peter-evans/create-pull-request@v3 with: - branch: gh-pages-pr + branch: gh-pages-pr-c-docs base: gh-pages title: '[Automated]: Update C/C++ API docs' commit-message: 'Update C/C++ API docs to commit ${{ steps.vars.outputs.sha_short }}' diff --git a/.github/workflows/publish-rust-apidocs.yml b/.github/workflows/publish-rust-apidocs.yml new file mode 100644 index 0000000000000..4e48a52e4a517 --- /dev/null +++ b/.github/workflows/publish-rust-apidocs.yml @@ -0,0 +1,43 @@ +name: Update Rust API Docs +on: + push: + branches: + - main + paths: + - rust + workflow_dispatch: + + +jobs: + publish: + name: Generate Rust docs + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install tools + run: | + sudo apt-get update + sudo apt-get install rustc + - name: Run rust docs generation + run: | + cd rust + cargo doc --no-deps + - name: Set commit ID + id: vars + run: echo "::set-output name=sha_short::$(git rev-parse --short HEAD)" + - uses: actions/checkout@v2 + with: + ref: gh-pages + clean: false + - name: Move API docs into target area + run: | + rm -rf docs/api/rust + mv rust/target/doc docs/api/rust + - name: Create Pull Request + uses: peter-evans/create-pull-request@v3 + with: + branch: gh-pages-rustdocs-pr + base: gh-pages + title: '[Automated]: Update Rust API docs' + commit-message: 'Update Rust API docs to commit ${{ steps.vars.outputs.sha_short }}' + add-paths: docs/api/rust diff --git a/.gitignore b/.gitignore index 26620d1bd5214..739ec17ca2fce 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,7 @@ onnxruntime/python/version_info.py # clangd .cache/ compile_commands.json +# Rust specific +rust/**/target +rust/**/Cargo.lock +rust/onnxruntime/synset.txt diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 03366fa133b0d..c2fc62ea936e4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,6 +54,8 @@ The html docs are generated from markdown using Jekyll and published using GitHu To update the docs, create a Pull Request against the [gh-pages](https://github.com/microsoft/onnxruntime/tree/gh-pages) branch of the [ONNX Runtime repo](https://github.com/microsoft/onnxruntime). +To preview your changes, you can push to the gh-pages branch in your fork and this will publish a staged version of your changes to .github.io/onnxruntime/docs. + Once your PR is approved and merged, your changes will be automatically published to https://onnxruntime.ai/docs. Note: technical reference docs for developers of ONNX Runtime source code can be found [here](https://github.com/microsoft/onnxruntime/docs) diff --git a/README.md b/README.md index c0505428caeae..68850f4be8ec1 100644 --- a/README.md +++ b/README.md @@ -23,14 +23,15 @@ ## Build Pipeline Status -|System|CPU|GPU|EPs| -|---|---|---|---| -|Windows|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline?label=Windows+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CI%20Pipeline?label=Windows+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=10)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20TensorRT%20CI%20Pipeline?label=Windows+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)| -|Linux|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline?label=Linux+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20Minimal%20Build%20E2E%20CI%20Pipeline?label=Linux+CPU+Minimal+Build)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20x64%20NoContribops%20CI%20Pipeline?label=Linux+CPU+x64+No+Contrib+Ops)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=110)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/centos7_cpu?label=Linux+CentOS7)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=78)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-ci-pipeline?label=Linux+CPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline?label=Linux+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20TensorRT%20CI%20Pipeline?label=Linux+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-distributed?label=Distributed+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=140)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-gpu-ci-pipeline?label=Linux+GPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20OpenVINO%20CI%20Pipeline?label=Linux+OpenVINO)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)| -|Mac|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/MacOS%20CI%20Pipeline?label=MacOS+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/MacOS%20NoContribops%20CI%20Pipeline?label=MacOS+NoContribops)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=65)||| -|Android|||[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Android%20CI%20Pipeline?label=Android)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)| -|iOS|||[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/iOS%20CI%20Pipeline?label=iOS)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)| -|WebAssembly|||[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20WebAssembly%20CI%20Pipeline?label=WASM)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)| +|System|Inference|Training| +|---|---|---| +|Windows|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline?label=Windows+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CI%20Pipeline?label=Windows+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=10)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20TensorRT%20CI%20Pipeline?label=Windows+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)|| +|Linux|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline?label=Linux+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20Minimal%20Build%20E2E%20CI%20Pipeline?label=Linux+CPU+Minimal+Build)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline?label=Linux+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20TensorRT%20CI%20Pipeline?label=Linux+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20OpenVINO%20CI%20Pipeline?label=Linux+OpenVINO)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-ci-pipeline?label=Linux+CPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-gpu-ci-pipeline?label=Linux+GPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining/orttraining-ortmodule-distributed?label=Training+Distributed)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=148)| +|Mac|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/MacOS%20CI%20Pipeline?label=MacOS+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)|| +|Android|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Android%20CI%20Pipeline?label=Android)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)|| +|iOS|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/iOS%20CI%20Pipeline?label=iOS)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)|| +|Web|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/ONNX%20Runtime%20Web%20CI%20Pipeline?label=Web)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)|| +|Other|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-binary-size-checks-ci-pipeline?repoName=microsoft%2Fonnxruntime&label=Binary+Size+Check)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-python-checks-ci-pipeline?label=Python+Checks)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=164)|| ## Data/Telemetry diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index d1aeed4f51a16..e925f75090a46 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -5239,34 +5239,6 @@ PERFORMANCE OF THIS SOFTWARE. _____ -microsoft/vcpkg, https://github.com/microsoft/vcpkg - -Copyright (c) Microsoft Corporation - -All rights reserved. - -MIT License - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -_____ - openssl/openssl, https://github.com/openssl/openssl Apache License diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 567fe2255df46..378647f273ab9 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -282,7 +282,7 @@ "component": { "type": "git", "git": { - "commitHash": "28cf67e5b64c704cad993c71f29a24e781bee544", + "commitHash": "f412df7a2b64421e1f1d61fde6055a6ea288e8f5", "repositoryUrl": "https://github.com/microsoft/mimalloc.git" }, "comments": "mimalloc" @@ -408,16 +408,6 @@ "comments": "cutlass" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6f7ffeb18f99796233b958aaaf14ec7bd4fb64b2", - "repositoryUrl": "https://github.com/microsoft/vcpkg.git" - }, - "comments": "vcpkg" - } - }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5c088aa8cddc4..e24046fb2b8d5 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -245,7 +245,7 @@ if (onnxruntime_USE_ROCM) endif() if (NOT CMAKE_HIP_ARCHITECTURES) - set(CMAKE_HIP_ARCHITECTURES "gfx906;gfx908;gfx90a") + set(CMAKE_HIP_ARCHITECTURES "gfx906;gfx908;gfx90a;gfx1030") endif() file(GLOB rocm_cmake_components ${onnxruntime_ROCM_HOME}/lib/cmake/*) @@ -603,6 +603,7 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() + endif() if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) @@ -1338,97 +1339,105 @@ if (onnxruntime_ENABLE_TRAINING) add_compile_definitions(ENABLE_STRIDED_TENSORS) add_compile_definitions(ENABLE_TRAINING) - if (UNIX) - if (EXISTS "${onnxruntime_MPI_HOME}") - set(MPI_HOME "${onnxruntime_MPI_HOME}") - elseif (EXISTS "/bert_ort/openmpi") - set(MPI_HOME "/bert_ort/openmpi") - endif() + add_subdirectory(tensorboard EXCLUDE_FROM_ALL) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard) +endif() - find_package(MPI) +if (UNIX AND onnxruntime_USE_MPI) + if (EXISTS "${onnxruntime_MPI_HOME}") + set(MPI_HOME "${onnxruntime_MPI_HOME}") + elseif (EXISTS "/bert_ort/openmpi") + set(MPI_HOME "/bert_ort/openmpi") + endif() - if (MPI_CXX_FOUND) - message( STATUS "MPI Version: ${MPI_CXX_VERSION}") - message( STATUS "MPI (include: ${MPI_CXX_INCLUDE_DIRS}, library: ${MPI_CXX_LIBRARIES})" ) - mark_as_advanced(MPI_CXX_INCLUDE_DIRS MPI_CXX_LIBRARIES) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${MPI_CXX_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) - else () - set(onnxruntime_USE_NCCL OFF) - set(onnxruntime_USE_MPI OFF) - message( WARNING "MPI is not found. Please define onnxruntime_MPI_HOME to specify the path of MPI. Otherwise, NCCL will be disabled." ) - endif() + find_package(MPI) - # Find NCCL and MPI - if (onnxruntime_USE_NCCL AND MPI_CXX_FOUND) - if (onnxruntime_USE_CUDA) - set(NCCL_LIBNAME "nccl") - elseif (onnxruntime_USE_ROCM) - set(NCCL_LIBNAME "rccl") + if (MPI_CXX_FOUND) + message( STATUS "MPI Version: ${MPI_CXX_VERSION}") + message( STATUS "MPI (include: ${MPI_CXX_INCLUDE_DIRS}, library: ${MPI_CXX_LIBRARIES})" ) + mark_as_advanced(MPI_CXX_INCLUDE_DIRS MPI_CXX_LIBRARIES) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${MPI_CXX_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) + else () + message( + FATAL_ERROR + "MPI is not found. Please define onnxruntime_MPI_HOME to specify the path of MPI. Otherwise, NCCL will be disabled." + ) + endif() + + # Find NCCL and MPI + if (onnxruntime_USE_NCCL AND MPI_CXX_FOUND) + if (onnxruntime_USE_CUDA) + set(NCCL_LIBNAME "nccl") + elseif (onnxruntime_USE_ROCM) + set(NCCL_LIBNAME "rccl") + endif() + find_path(NCCL_INCLUDE_DIR + NAMES ${NCCL_LIBNAME}.h + HINTS + ${onnxruntime_NCCL_HOME}/include + $ENV{CUDA_ROOT}/include) + + find_library(NCCL_LIBRARY + NAMES ${NCCL_LIBNAME} + HINTS + ${onnxruntime_NCCL_HOME}/lib/x86_64-linux-gnu + ${onnxruntime_NCCL_HOME}/lib + $ENV{CUDA_ROOT}/lib64) + + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY) + + if (NCCL_FOUND) + set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIR}/${NCCL_LIBNAME}.h") + message( STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}" ) + file (STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED + REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1) + if (NCCL_MAJOR_VERSION_DEFINED) + string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" "" + NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED}) + message( STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}" ) + endif() + file (STRINGS ${NCCL_HEADER_FILE} NCCL_MINOR_VERSION_DEFINED + REGEX "^[ \t]*#define[ \t]+NCCL_MINOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1) + if (NCCL_MINOR_VERSION_DEFINED) + string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MINOR[ \t]+" "" + NCCL_MINOR_VERSION ${NCCL_MINOR_VERSION_DEFINED}) + message(STATUS "NCCL_MINOR_VERSION: ${NCCL_MINOR_VERSION}") endif() - find_path(NCCL_INCLUDE_DIR - NAMES ${NCCL_LIBNAME}.h - HINTS - ${onnxruntime_NCCL_HOME}/include - $ENV{CUDA_ROOT}/include) - - find_library(NCCL_LIBRARY - NAMES ${NCCL_LIBNAME} - HINTS - ${onnxruntime_NCCL_HOME}/lib/x86_64-linux-gnu - ${onnxruntime_NCCL_HOME}/lib - $ENV{CUDA_ROOT}/lib64) - - include(FindPackageHandleStandardArgs) - find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY) - - if (NCCL_FOUND) - set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIR}/${NCCL_LIBNAME}.h") - message( STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}" ) - file (STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED - REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1) - if (NCCL_MAJOR_VERSION_DEFINED) - string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" "" - NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED}) - message( STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}" ) - endif() - file (STRINGS ${NCCL_HEADER_FILE} NCCL_MINOR_VERSION_DEFINED - REGEX "^[ \t]*#define[ \t]+NCCL_MINOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1) - if (NCCL_MINOR_VERSION_DEFINED) - string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MINOR[ \t]+" "" - NCCL_MINOR_VERSION ${NCCL_MINOR_VERSION_DEFINED}) - message(STATUS "NCCL_MINOR_VERSION: ${NCCL_MINOR_VERSION}") + if (NCCL_MAJOR_VERSION_DEFINED AND NCCL_MINOR_VERSION_DEFINED) + if ("${NCCL_MAJOR_VERSION}.${NCCL_MINOR_VERSION}" VERSION_GREATER_EQUAL "2.7") + add_definitions(-DUSE_NCCL_P2P=1) + message( STATUS "NCCL P2P is enabled for supporting ncclSend and ncclRecv." ) endif() + endif() - if (NCCL_MAJOR_VERSION_DEFINED AND NCCL_MINOR_VERSION_DEFINED) - if ("${NCCL_MAJOR_VERSION}.${NCCL_MINOR_VERSION}" VERSION_GREATER_EQUAL "2.7") - add_definitions(-DUSE_NCCL_P2P=1) - message( STATUS "NCCL P2P is enabled for supporting ncclSend and ncclRecv." ) - endif() - endif() + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + message( STATUS "NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})" ) + mark_as_advanced(NCCL_INCLUDE_DIRS NCCL_LIBRARIES) - set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) - set(NCCL_LIBRARIES ${NCCL_LIBRARY}) - message( STATUS "NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})" ) - mark_as_advanced(NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${NCCL_LIBRARIES}) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${NCCL_LIBRARIES}) - add_definitions(-DORT_USE_NCCL=1) - message( STATUS "NCCL is enabled in Linux GPU Build." ) - else () - set(onnxruntime_USE_NCCL OFF) - message( WARNING "NCCL is not found. Please use --nccl_home to specify the path of NCCL. Otherwise, NCCL is disabled." ) - endif() + add_definitions(-DORT_USE_NCCL=1) + message( STATUS "NCCL is enabled in Linux GPU Build." ) + else () + set(onnxruntime_USE_NCCL OFF) + message( + FATAL_ERROR + "NCCL is not found. Please use --nccl_home to specify the path of NCCL. Otherwise, NCCL is disabled." + ) endif() endif() +else() + set(onnxruntime_USE_NCCL OFF) + set(onnxruntime_USE_MPI OFF) +message( WARNING "MPI and NCCL disabled on Win build." ) +endif() - if (onnxruntime_USE_MPI AND MPI_CXX_FOUND) - add_definitions(-DUSE_MPI=1) - endif() - - add_subdirectory(tensorboard EXCLUDE_FROM_ALL) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard) +if (onnxruntime_USE_MPI) + add_definitions(-DUSE_MPI=1) endif() # Default version parts for Microsoft.AI.MachineLearning.dll, onnxruntime.dll, onnxruntime_providers_openvino.dll and onnxruntime_providers_shared.dll in non-ADO pipeline local builds diff --git a/cmake/deps.txt b/cmake/deps.txt index 3a1a691985ea1..d16245ba833cb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -21,7 +21,7 @@ googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc98499 json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/5f4caba4e7a9017816e47becdd918fcc872039ba.zip;fd119887d0d17c37adf1fc227b054befa28158ad -mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.0.9.zip;9d4205c93805b5525de57c6c7ed7f60e770ffdac +mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.0.3.zip;e4f37b93b2da78a5816c2495603a4188d316214b mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.79.0.zip;c8f04e378535ededbe5af52c8f969d2dedbe73d5 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.13.0.zip;8dda5079cdb5a134b08b0c73f4592a6404fc2dc6 #use the commit where it's several commits after 8.5-GA branch (https://github.com/onnx/onnx-tensorrt/commit/369d6676423c2a6dbf4a5665c4b5010240d99d3c) @@ -36,7 +36,6 @@ safeint;https://github.com/dcleblanc/SafeInt/archive/ff15c6ada150a5018c5ef217240 tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v2.11.0.zip;be70c559f07251ba7f33c789dba98872b444c10f # below are deps introduced by triton client, might remove after 1.14 release -vcpkg;https://github.com/microsoft/vcpkg/archive/refs/tags/2022.11.14.zip;3f983141351af5db2d6c3ca965959845f27d5d51 openssl;https://github.com/openssl/openssl/archive/refs/tags/openssl-3.0.7.zip;dda8fc81308555410505eb4a9eab3e1da0436a1d rapidjson;https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.zip;0fe7b4f7b83df4b3d517f4a202f3a383af7a0818 boost;https://github.com/boostorg/boost/archive/refs/tags/boost-1.81.0.zip;f6ab0da855f825b4eb1abd949967d01a4c5e4e1b diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 0c419457787b6..3e713b69671e7 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -462,7 +462,7 @@ if (onnxruntime_USE_CUDA) endif() if(onnxruntime_USE_SNPE) - include(find_snpe.cmake) + include(external/find_snpe.cmake) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${SNPE_NN_LIBS}) endif() diff --git a/cmake/external/tvm.cmake b/cmake/external/tvm.cmake index 1e224a2dad4af..93049c8b85853 100644 --- a/cmake/external/tvm.cmake +++ b/cmake/external/tvm.cmake @@ -21,4 +21,4 @@ if (onnxruntime_USE_TVM) set(tvm_INCLUDE_DIRS ${tvm_SOURCE_DIR}/include) -endif() \ No newline at end of file +endif() diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index 6a33100d23bde..5c947a52b7838 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -75,9 +75,9 @@ if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(onnxruntime_framework Python::Module) target_include_directories(onnxruntime_framework PRIVATE ${dlpack_SOURCE_DIR}/include) endif() - if (onnxruntime_USE_NCCL OR onnxruntime_USE_MPI) - target_include_directories(onnxruntime_framework PUBLIC ${MPI_CXX_INCLUDE_DIRS}) - endif() +endif() +if (onnxruntime_USE_MPI) + target_include_directories(onnxruntime_framework PUBLIC ${MPI_CXX_INCLUDE_DIRS}) endif() if (onnxruntime_ENABLE_ATEN) diff --git a/cmake/onnxruntime_kernel_explorer.cmake b/cmake/onnxruntime_kernel_explorer.cmake index f30b518bf9a1d..d4ae88a1f65df 100644 --- a/cmake/onnxruntime_kernel_explorer.cmake +++ b/cmake/onnxruntime_kernel_explorer.cmake @@ -51,6 +51,7 @@ if (onnxruntime_USE_CUDA) "${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cuh" ) target_sources(kernel_explorer PRIVATE ${kernel_explorer_cuda_kernel_srcs}) + target_include_directories(kernel_explorer PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) elseif (onnxruntime_USE_ROCM) file(GLOB kernel_explorer_rocm_kernel_srcs CONFIGURE_DEPENDS "${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cc" diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 1f9b7129943e6..80a65c6787eb9 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -19,6 +19,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/platform.cpp ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp + ${MLAS_SRC_DIR}/halfgemm.cpp ${MLAS_SRC_DIR}/qgemm.cpp ${MLAS_SRC_DIR}/qdwconv.cpp ${MLAS_SRC_DIR}/convolve.cpp @@ -59,6 +60,7 @@ function(setup_mlas_source_for_windows) if(onnxruntime_target_platform STREQUAL "ARM64") target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp @@ -73,6 +75,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm + ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm @@ -305,6 +308,7 @@ else() ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S @@ -314,10 +318,13 @@ else() ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index fe9e83db6b27c..0b9faf8849e06 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -383,6 +383,11 @@ if (onnxruntime_USE_CUDA) "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc" ) endif() + if (NOT onnxruntime_USE_NCCL) + list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc" + ) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) @@ -507,14 +512,15 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_providers_cuda PRIVATE ${ORTTRAINING_ROOT} ${MPI_CXX_INCLUDE_DIRS}) - if(onnxruntime_USE_MPI) - target_link_libraries(onnxruntime_providers_cuda PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) - endif() + endif() - if (onnxruntime_USE_NCCL) - target_include_directories(onnxruntime_providers_cuda PRIVATE ${NCCL_INCLUDE_DIRS}) - target_link_libraries(onnxruntime_providers_cuda PRIVATE ${NCCL_LIBRARIES}) - endif() + if(onnxruntime_USE_MPI) + target_link_libraries(onnxruntime_providers_cuda PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) + endif() + + if (onnxruntime_USE_NCCL) + target_include_directories(onnxruntime_providers_cuda PRIVATE ${NCCL_INCLUDE_DIRS}) + target_link_libraries(onnxruntime_providers_cuda PRIVATE ${NCCL_LIBRARIES}) endif() if (WIN32) @@ -683,10 +689,11 @@ if (onnxruntime_USE_TENSORRT) target_compile_options(nvonnxparser_static PRIVATE /FIio.h /wd4100) target_compile_options(nvonnxparser PRIVATE /FIio.h /wd4100) endif() - include_directories(${TENSORRT_INCLUDE_DIR}) set(onnxparser_link_libs nvonnxparser_static) endif() + include_directories(${TENSORRT_INCLUDE_DIR}) + set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS @@ -704,11 +711,10 @@ if (onnxruntime_USE_TENSORRT) add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS}) - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) else() target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS}) - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) endif() + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) if(onnxruntime_CUDNN_HOME) target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) endif() @@ -1352,12 +1358,6 @@ if (onnxruntime_USE_ROCM) # disable contrib ops conditionally if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - if (NOT onnxruntime_ENABLE_ATEN) - list(REMOVE_ITEM onnxruntime_rocm_contrib_ops_cc_srcs - "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/aten_ops/aten_op.cc" - ) - endif() - hipify("onnxruntime/contrib_ops" contrib_ops_excluded_files onnxruntime_rocm_generated_contrib_ops_cc_srcs onnxruntime_rocm_generated_contrib_ops_cu_srcs) # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio @@ -1460,6 +1460,7 @@ if (onnxruntime_USE_ROCM) device_gemm_add_fastgelu_instance device_gemm_fastgelu_instance device_batched_gemm_instance + device_softmax_instance ) target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_COMPOSABLE_KERNEL) endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 809a076443609..c24b6b9be548a 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -467,12 +467,21 @@ file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DE file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py" ) +file(GLOB onnxruntime_python_transformers_models_bart_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bart/*.py" +) +file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bert/*.py" +) file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py" ) file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py" ) +file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py" +) file(GLOB onnxruntime_python_transformers_models_t5_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/t5/*.py" ) @@ -526,8 +535,11 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/ort_format_model/ort_flatbuffers_py COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/bart + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/bert COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/gpt2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/operators @@ -606,12 +618,21 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_src} $/onnxruntime/transformers/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_bart_src} + $/onnxruntime/transformers/models/bart/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_bert_src} + $/onnxruntime/transformers/models/bert/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_gpt2_src} $/onnxruntime/transformers/models/gpt2/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_longformer_src} $/onnxruntime/transformers/models/longformer/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_stable_diffusion_src} + $/onnxruntime/transformers/models/stable_diffusion/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_t5_src} $/onnxruntime/transformers/models/t5/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index d3b8f5ebfcc26..2c13b5cbb56eb 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -11,14 +11,14 @@ set(contrib_ops_excluded_files "bert/attention_softmax.h" "bert/multihead_attention.cc" "bert/multihead_attention.h" - "bert/embed_layer_norm.cc" - "bert/embed_layer_norm.h" - "bert/embed_layer_norm_impl.cu" - "bert/embed_layer_norm_impl.h" "bert/fast_gelu_impl.cu" "bert/fast_gelu_impl.h" "bert/fast_gelu.cc" "bert/fast_gelu.h" + "bert/relative_attn_bias.cc" + "bert/relative_attn_bias.h" + "bert/relative_attn_bias_impl.cu" + "bert/relative_attn_bias_impl.h" "bert/skip_layer_norm.cc" "bert/skip_layer_norm.h" "bert/skip_layer_norm_impl.cu" @@ -27,6 +27,15 @@ set(contrib_ops_excluded_files "bert/tensorrt_fused_multihead_attention/*" "bert/transformer_common.h" "bert/transformer_common.cc" + "diffusion/group_norm.h" + "diffusion/group_norm.cc" + "diffusion/group_norm_impl.cu" + "diffusion/group_norm_impl.h" + "diffusion/bias_split_gelu_impl.h" + "diffusion/bias_split_gelu_impl.cu" + "diffusion/bias_split_gelu.h" + "diffusion/bias_split_gelu.cc" + "diffusion/nhwc_conv.cc" "math/complex_mul.cc" "math/complex_mul.h" "math/complex_mul_impl.cu" @@ -76,17 +85,8 @@ set(contrib_ops_excluded_files "tensor/image_scaler_impl.h" "transformers/beam_search.cc" "transformers/beam_search.h" - "transformers/generation_device_helper.cc" - "transformers/generation_device_helper.h" - "transformers/generation_cuda_impl.cu" - "transformers/generation_cuda_impl.h" "transformers/greedy_search.cc" "transformers/greedy_search.h" - "transformers/sampling.cc" - "transformers/sampling.h" - "transformers/sampling_cuda_helper.h" - "transformers/dump_cuda_tensor.cc" - "transformers/dump_cuda_tensor.h" "conv_transpose_with_dynamic_pads.cc" "conv_transpose_with_dynamic_pads.h" "cuda_contrib_kernels.cc" @@ -95,6 +95,13 @@ set(contrib_ops_excluded_files "fused_conv.cc" ) +if (NOT onnxruntime_ENABLE_ATEN) + list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc") +endif() +if (NOT onnxruntime_USE_NCCL) + list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc") +endif() + set(provider_excluded_files "atomic/common.cuh" "controlflow/if.cc" @@ -115,7 +122,9 @@ set(provider_excluded_files "math/softmax_impl.cu" "math/softmax_warpwise_impl.cuh" "math/softmax_common.cc" + "math/softmax_common.h" "math/softmax.cc" + "math/softmax.h" "nn/conv.cc" "nn/conv.h" "nn/conv_transpose.cc" diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.shared.cs index 4e18b22f2ad13..7cab9e89b67c7 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.shared.cs @@ -44,7 +44,11 @@ public enum OrtLanguageProjection } /// - /// This class initializes the process-global ONNX Runtime environment instance (OrtEnv) + /// This class initializes the process-global ONNX Runtime environment instance (OrtEnv). + /// The singleton class OrtEnv contains the process-global ONNX Runtime environment. + /// It sets up logging, creates system wide thread-pools (if Thread Pool options are provided) + /// and other necessary things for OnnxRuntime to function. Create or access OrtEnv by calling + /// the Instance() method. Call this method before doing anything else in your application. /// public sealed class OrtEnv : SafeHandle { diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1e6d46963cd21..c1d12a1d5cba6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -6,9 +6,11 @@ Do not modify directly.* * com.microsoft.Attention * com.microsoft.AttnLSTM * com.microsoft.BeamSearch + * com.microsoft.BiasAdd * com.microsoft.BiasDropout * com.microsoft.BiasGelu * com.microsoft.BiasSoftmax + * com.microsoft.BiasSplitGelu * com.microsoft.BifurcationDetector * com.microsoft.BitmaskBiasDropout * com.microsoft.BitmaskDropout @@ -29,11 +31,13 @@ Do not modify directly.* * com.microsoft.FusedConv * com.microsoft.FusedGemm * com.microsoft.FusedMatMul + * com.microsoft.GatedRelativePositionBias * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu * com.microsoft.GreedySearch * com.microsoft.GridSample + * com.microsoft.GroupNorm * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention @@ -150,7 +154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
-
extra_add (optional) : T
+
relative_position_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
@@ -465,6 +469,40 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.BiasAdd** + + Add input with bias, then add residual inputs. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
X : T
+
Input tensor. Dimensions are (N, S, C), where N is the batch size, S is image size H*W, and C is number of channels
+
bias : T
+
Bias tensor. Dimensions are (C)
+
skip : T
+
Residual tensor. Dimensions are (N, S, C)
+
+ +#### Outputs + +
+
Y : T
+
The output tensor with dimensions (N, S, C)
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input and output types to float tensors.
+
+ + ### **com.microsoft.BiasDropout** output, dropout_mask = Dropout(data + bias, ratio) + residual, Intended to specialize the dropout pattern commonly found in transformer models. @@ -590,6 +628,39 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.BiasSplitGelu** + + A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left + tensor multiplies the Gelu activation result of right tensor. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
X : T
+
Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension
+
bias : T
+
Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor
+
+ +#### Outputs + +
+
Y : T
+
The output tensor with dimensions (N, S, D/2)
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X and output Y types to float tensors.
+
+ + ### **com.microsoft.BifurcationDetector** Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens, @@ -1573,6 +1644,58 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GatedRelativePositionBias** + + query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) + gate_u, gate_r = torch.sigmoid( + self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0 + rel_pos_bias = gate_u_1 * rel_pos + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
num_heads : int (required)
+
Number of attention heads
+
+ +#### Inputs + +
+
query_layer : T
+
tensor with shape (batch_size, seq_len, num_heads x head_size)
+
query_bias : T
+
1-d tensor with shape (num_heads x head_size)
+
rel_pos : T
+
tensor with shape (1, num_head, seq_len, seq_len)
+
weight : T
+
gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2
+
bias : T
+
bias for the gated_ur_linear, shape (D)
+
eco_a : T
+
tensor of shape (1, num_heads, 1, 1)
+
+ +#### Outputs + +
+
output : T
+
output tensor with shape (batch_size, num_heads, seq_len, seq_len)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
+ + ### **com.microsoft.GatherND** Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather @@ -1811,6 +1934,61 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GroupNorm** + + Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494). + + This operator transforms input according to + y = gamma * (x - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for Swish
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
+ +#### Outputs + +
+
Y : T
+
The output tensor of the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X and output Y types to float tensors.
+
M : tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.Inverse** #### Version @@ -2132,19 +2310,21 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (4 - 5) +#### Inputs (1 - 6)
query : T
-
Query with shape (batch_size, sequence_length, hidden_size)
-
key : T
-
Key with shape (batch_size, kv_sequence_length, hidden_size)
-
value : T
+
Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)
+
key (optional) : T
+
Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, v_hidden_size)
-
bias : T
+
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
+
relative_position_bias (optional) : T
+
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
#### Outputs @@ -3131,7 +3311,7 @@ This version of the operator has been available since version 1 of the 'com.micr left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. - Current version does not support past/present, extra_add and qkv_hidden_sizes. + Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. #### Version @@ -3196,7 +3376,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : Q
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
-
extra_add (optional) : S
+
relative_position_bias (optional) : S
additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2dc4fbfb790b2..08178f206568e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -417,7 +417,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| @@ -768,9 +768,9 @@ Do not modify directly.* |||1+|**T** = tensor(double), tensor(float), tensor(float16)| |Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| -|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||10|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 9]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -785,11 +785,13 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| +|BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BitmaskBiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)| |BitmaskDropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)| |ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -802,16 +804,19 @@ Do not modify directly.* |FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| +|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| +|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* extra_add:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| +|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLongformerAttention|*in* input:**Q**
*in* scale_input:**S**
*in* weight:**Q**
*in* scale_weight:**S**
*in* bias:**S**
*in* scale_bias:**S**
*in* scale_qkv_gemm:**S**
*in* mask:**F**
*in* global_weight:**Q**
*in* scale_global_weight:**S**
*in* global_bias:**S**
*in* scale_global_gemm:**S**
*in* global:**G**
*in* scale_output:**S**
*out* output:**Q**|1+|**F** = tensor(float16)
**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| @@ -1087,7 +1092,7 @@ Do not modify directly.* |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1156,7 +1161,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 09146ccda5e1e..4cc7b144332f0 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -30,6 +30,7 @@ class Node; #include "core/framework/func_api.h" #include "core/framework/provider_options.h" #include "core/framework/stream_handles.h" +#include "core/framework/tuning_context.h" namespace onnxruntime { @@ -77,7 +78,7 @@ class IExecutionProvider { /** * Get an allocator with specified device id and MemType. Return nullptr if it doesn't exist */ - virtual AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const; + virtual AllocatorPtr GetAllocator(OrtMemType mem_type) const; /** * Returns a data transfer object that implements methods to copy to and @@ -300,6 +301,13 @@ class IExecutionProvider { */ virtual bool ConcurrentRunSupported() const { return true; } + /** + * Return the tuning context which holds all TunableOp state. + */ + virtual ITuningContext* GetTuningContext() const { + return nullptr; + } + private: const std::string type_; diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index e7e4ab9464d9d..392057cf03b19 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -121,7 +121,7 @@ class OpKernel { return Status::OK(); } - const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const; + const OrtMemoryInfo& Allocator(OrtMemType mem_type) const; const OpKernelInfo& Info() const { return *op_kernel_info_; } diff --git a/include/onnxruntime/core/framework/op_kernel_info.h b/include/onnxruntime/core/framework/op_kernel_info.h index d30f8fead55c5..4695c0a1c9284 100644 --- a/include/onnxruntime/core/framework/op_kernel_info.h +++ b/include/onnxruntime/core/framework/op_kernel_info.h @@ -31,9 +31,9 @@ class OpKernelInfo : public OpNodeProtoHelper { OpKernelInfo(const OpKernelInfo& other); - const OrtMemoryInfo& GetMemoryInfo(int device_id, OrtMemType mem_type) const; + const OrtMemoryInfo& GetMemoryInfo(OrtMemType mem_type) const; - AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const; + AllocatorPtr GetAllocator(OrtMemType mem_type) const; const KernelDef& GetKernelDef() const; diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index e5a84e7aa79f3..5444c825d7991 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -27,10 +27,6 @@ struct OrtRunOptions { // So it is possible that only some of the nodes are executed. bool only_execute_path_to_fetches = false; - // Set to 'true' to synchronize execution providers with CPU at the end of session run. - // Taking CUDA EP as an example, it will trigger cudaStreamSynchronize on the compute stream. - bool synchronize_execution_providers = true; - #ifdef ENABLE_TRAINING // Used by onnxruntime::training::TrainingSession. This class is now deprecated. // Delete training_mode when TrainingSession is deleted. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index ed97c512c0c7f..c0bac21aa74d7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -87,7 +87,7 @@ struct Global { template #ifdef ORT_API_MANUAL_INIT const OrtApi* Global::api_{}; -inline void InitApi() { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } +inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is // required by C++ APIs. @@ -103,7 +103,7 @@ inline void InitApi() { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VER // // ... // } // -inline void InitApi(const OrtApi* api) { Global::api_ = api; } +inline void InitApi(const OrtApi* api) noexcept { Global::api_ = api; } #else #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) @@ -118,7 +118,7 @@ const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); #endif /// This returns a reference to the OrtApi interface in use -inline const OrtApi& GetApi() { return *Global::api_; } +inline const OrtApi& GetApi() noexcept { return *Global::api_; } /// /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and @@ -580,7 +580,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl - SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); + SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index 49b46ca077b75..1f5fcd50e185c 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -25,3 +25,8 @@ // Example usage: "cpu:0;gpu:0" (or) "gpu:0" // By default, the value for this key is empty (i.e.) no memory arenas are shrunk static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage"; + +// Set to '1' to not synchronize execution providers with CPU at the end of session run. +// Per default it will be set to '0' +// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. +static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 72fbdd07bffc6..3ae7d9a814c5b 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -3018,9 +3018,9 @@ } }, "node_modules/http-cache-semantics": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.0.tgz", - "integrity": "sha512-carPklcUh7ROWRK7Cv27RPtdhYhUsela/ue5/jKzjegVvXDqM2ILE9Q2BGn9JZJh1g87cp56su/FgQSzcWS8cQ==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz", + "integrity": "sha512-er295DKPVsV82j5kw1Gjt+ADA/XYHsajl82cGNQG2eyoPkvgUhX+nDIyelzhIWbbsXP39EHcI6l5tYs2FYqYXQ==", "dev": true }, "node_modules/http-errors": { @@ -3654,9 +3654,9 @@ } }, "node_modules/jszip": { - "version": "3.7.1", - "resolved": "https://registry.npmjs.org/jszip/-/jszip-3.7.1.tgz", - "integrity": "sha512-ghL0tz1XG9ZEmRMcEN2vt7xabrDdqHHeykgARpmZ0BiIctWxM47Vt63ZO2dnp4QYt/xJVLLy5Zv1l/xRdh2byg==", + "version": "3.8.0", + "resolved": "https://registry.npmjs.org/jszip/-/jszip-3.8.0.tgz", + "integrity": "sha512-cnpQrXvFSLdsR9KR5/x7zdf6c3m8IhZfZzSblFEHSqBaVwD2nvJ4CuCKLyvKvwBgZm08CgfSoiTBQLm5WW9hGw==", "dev": true, "dependencies": { "lie": "~3.3.0", @@ -9501,9 +9501,9 @@ } }, "http-cache-semantics": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.0.tgz", - "integrity": "sha512-carPklcUh7ROWRK7Cv27RPtdhYhUsela/ue5/jKzjegVvXDqM2ILE9Q2BGn9JZJh1g87cp56su/FgQSzcWS8cQ==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz", + "integrity": "sha512-er295DKPVsV82j5kw1Gjt+ADA/XYHsajl82cGNQG2eyoPkvgUhX+nDIyelzhIWbbsXP39EHcI6l5tYs2FYqYXQ==", "dev": true }, "http-errors": { @@ -9947,9 +9947,9 @@ } }, "jszip": { - "version": "3.7.1", - "resolved": "https://registry.npmjs.org/jszip/-/jszip-3.7.1.tgz", - "integrity": "sha512-ghL0tz1XG9ZEmRMcEN2vt7xabrDdqHHeykgARpmZ0BiIctWxM47Vt63ZO2dnp4QYt/xJVLLy5Zv1l/xRdh2byg==", + "version": "3.8.0", + "resolved": "https://registry.npmjs.org/jszip/-/jszip-3.8.0.tgz", + "integrity": "sha512-cnpQrXvFSLdsR9KR5/x7zdf6c3m8IhZfZzSblFEHSqBaVwD2nvJ4CuCKLyvKvwBgZm08CgfSoiTBQLm5WW9hGw==", "dev": true, "requires": { "lie": "~3.3.0", diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 47db3fe558ce8..6aa0e726afe1b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -198,7 +198,7 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_); @@ -208,7 +208,7 @@ Status Attention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, ¶meters)); const int batch_size = parameters.batch_size; @@ -331,7 +331,7 @@ Status Attention::Compute(OpKernelContext* context) const { return ApplyAttention(Q, K, V, mask_index, past, output, batch_size, sequence_length, parameters.head_size, parameters.v_head_size, parameters.v_hidden_size, - extra_add_qk, context); + relative_position_bias, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index affe7cab1d858..e75f68ea53c7c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -12,7 +12,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const Tensor* past_seq_len) const { // Abbreviation and Meanings: @@ -37,7 +37,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // extra_add_qk : (B, N, S, T) or NULL + // relative_position_bias : (B, N, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) @@ -49,9 +49,9 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger // than hidden dimension of Q, K and V. - if (past != nullptr && extra_add_qk != nullptr) { - // past is used on GPT-2 model with past state, we don't have a case for extra add qk yet - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and extra_add_qk"); + if (past != nullptr && relative_position_bias != nullptr) { + // past is used on GPT-2 model with past state, we don't have a case for relative position bias yet + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and relative_position_bias"); } const auto& dims = input_shape.GetDims(); @@ -191,34 +191,34 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } } - if (extra_add_qk != nullptr) { - const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims(); + if (relative_position_bias != nullptr) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - if (extra_add_qk_dims.size() != 4) { + if (relative_position_bias_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' is expected to have 4 dimensions, got ", - extra_add_qk_dims.size()); + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); } - if (extra_add_qk_dims[0] != batch_size) { + if (relative_position_bias_dims[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ", - extra_add_qk_dims[0]); + "Input 'relative_position_bias' dimension 0 should be same as batch_size, got ", + relative_position_bias_dims[0]); } - if (extra_add_qk_dims[1] != num_heads_) { + if (relative_position_bias_dims[1] != num_heads_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ", - extra_add_qk_dims[1]); + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); } - if (extra_add_qk_dims[2] != sequence_length) { + if (relative_position_bias_dims[2] != sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ", - extra_add_qk_dims[2]); + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); } - if (extra_add_qk_dims[3] != total_sequence_length) { + if (relative_position_bias_dims[3] != total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 3 should be same as total_sequence_length, got ", - extra_add_qk_dims[3]); + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); } } @@ -320,7 +320,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { @@ -328,7 +328,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, extra_add_qk, parameters, past_seq_len); + return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, relative_position_bias, parameters, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 2c49f196d52d8..2e077da2853d0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -18,7 +18,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, // for CUDA const Tensor* past_seq_len = nullptr) const; @@ -61,7 +61,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const Tensor* past_seq_len = nullptr) const; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 0185fa9ea09a0..70d71ffb6ee40 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -19,18 +19,18 @@ class AttentionCPUBase : public AttentionBase { : AttentionBase(info, require_same_hidden_size) {} template - Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH - const T* K, // K data with shape BxNxSxH - const T* V, // V value with size BxNxSxH_v - const Tensor* mask_index, // mask index. nullptr if no mask or its size is B - const Tensor* past, // past state - Tensor* output, // output tensor - int batch_size, // batch size (B) - int sequence_length, // sequence length (S) - int qk_head_size, // head size of Q or K (H) - int v_head_size, // head size of V (H_v) - int v_hidden_size, // hidden size of V (D_v) - const Tensor* extra_add_qk, // extra add in QK. Its size is BxNxSxT + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxSxH + const T* V, // V value with size BxNxSxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past, // past state + Tensor* output, // output tensor + int batch_size, // batch size (B) + int sequence_length, // sequence length (S) + int qk_head_size, // head size of Q or K (H) + int v_head_size, // head size of V (H_v) + int v_hidden_size, // hidden size of V (D_v) + const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT OpKernelContext* context) const { const int kv_sequence_length = sequence_length; @@ -67,16 +67,16 @@ class AttentionCPUBase : public AttentionBase { const T* past_data = past != nullptr ? past->Data() : nullptr; T* present_data = present != nullptr ? present->MutableData() : nullptr; - const T* extra_add_qk_data = nullptr; - if (extra_add_qk != nullptr) { - extra_add_qk_data = extra_add_qk->Data(); + const T* relative_position_bias_data = nullptr; + if (relative_position_bias != nullptr) { + relative_position_bias_data = relative_position_bias->Data(); } ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims, static_cast(mask_data), has_unidirectional, batch_size, sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, - past_data, present_data, tp, extra_add_qk_data); + past_data, present_data, tp, relative_position_bias_data); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -112,7 +112,7 @@ class AttentionCPUBase : public AttentionBase { const T* past, // past state T* present, // present state ThreadPool* tp, // thread pool - const T* extra_add_qk_data // extra add matrix with shape BxNxSxT + const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT ) const { const int total_sequence_length = past_sequence_length + sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H @@ -175,9 +175,9 @@ class AttentionCPUBase : public AttentionBase { } } - if (extra_add_qk_data != nullptr) { + if (relative_position_bias_data != nullptr) { for (int j = 0; j < sequence_length * total_sequence_length; j++) { - output[j] += extra_add_qk_data[output_offset + j]; + output[j] += relative_position_bias_data[output_offset + j]; } } } diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index ce109a83720b9..34a615a880594 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -17,41 +17,108 @@ Status CheckInputs(const T* query, const T* value, const T* bias, const T* key_padding_mask, + const T* relative_position_bias, void* parameters, int num_heads, float mask_filter_value, int max_threads_per_block) { - // query (Q) : (B, S, D) - // key (K) : (B, L, D) - // value (V) : (B, L, D_v) - // bias (Q/K/V) : (D + D + D_v) - // key_padding_mask (K/V) : (B, L) or (L) + // key_padding_mask (K/V) : (B) or (B, L) or None + // relative_position_bias : (B, 1, S, L) + // When no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // bias (Q/K/V) : (D + D + D_v) + // When packed kv is used: + // query (Q) : (B, S, D) + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) + // key (K) : None + // value (V) : None + // bias (Q/K/V) : None const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + if (query_dims.size() != 3 && query_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ", query_dims.size()); } - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int hidden_size = query_dims.size() == 3 ? static_cast(query_dims[2]) : (num_heads * static_cast(query_dims[4])); + int head_size = static_cast(hidden_size) / num_heads; + int kv_sequence_length = sequence_length; + + if (key != nullptr) { + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ", + query_dims.size()); + } + + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3 && key_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } + + if (key_dims.size() == 3) { + if (key_dims[2] != query_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + } + } else // if (key_dims.size() == 5) + { + if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); + } + } + + kv_sequence_length = static_cast(key_dims[1]); + } else { // packed QKV + if (query_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 5 dimensions when key is empty, got ", + query_dims.size()); + } + if (static_cast(query_dims[2]) != num_heads || static_cast(query_dims[3]) != 3) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv"); + } } - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ", - bias_dims.size()); + if (bias != nullptr) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ", + bias_dims.size()); + } + + // Currently, bias is not allowed for packed KV. This constraint can be removed later. + // Here we assume that fusion tool will not include bias for packed KV. + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed qkv or kv. "); + } } AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { mask_type = AttentionMaskType::MASK_UNKNOWN; const auto& mask_dims = key_padding_mask->Shape().GetDims(); - if (mask_dims.size() == 1 && mask_dims[0] == key_dims[0]) { + if (mask_dims.size() == 1 && mask_dims[0] == static_cast(batch_size)) { mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - } else if (mask_dims.size() == 2 && mask_dims[0] == key_dims[0] && mask_dims[1] == key_dims[1]) { + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; } @@ -61,47 +128,69 @@ Status CheckInputs(const T* query, } } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } + int v_hidden_size = hidden_size; + if (value != nullptr) { + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } - int64_t batch_size = query_dims[0]; - int64_t sequence_length = query_dims[1]; - int64_t kv_sequence_length = key_dims[1]; - int64_t q_hidden_size = query_dims[2]; - int64_t v_hidden_size = 0; + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); + if (static_cast(kv_sequence_length) != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)"); + } + v_hidden_size = static_cast(value_dims[2]); } - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } + if (relative_position_bias != nullptr) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - if (key_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have same same dim 1 (sequence_length)"); + if (relative_position_bias_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); + } + if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", + relative_position_bias_dims[0]); + } + if (relative_position_bias_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); + } + if (relative_position_bias_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); + } + if (relative_position_bias_dims[3] != kv_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); + } } - v_hidden_size = value_dims[2]; if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); - output_parameters->batch_size = static_cast(batch_size); - output_parameters->sequence_length = static_cast(sequence_length); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; output_parameters->past_sequence_length = 0; - output_parameters->kv_sequence_length = static_cast(kv_sequence_length); - output_parameters->total_sequence_length = static_cast(kv_sequence_length); + output_parameters->kv_sequence_length = kv_sequence_length; + output_parameters->total_sequence_length = kv_sequence_length; output_parameters->max_sequence_length = 0; output_parameters->input_hidden_size = 0; - output_parameters->hidden_size = static_cast(q_hidden_size); - output_parameters->v_hidden_size = static_cast(v_hidden_size); - output_parameters->head_size = static_cast(q_hidden_size) / num_heads; - output_parameters->v_head_size = static_cast(v_hidden_size) / num_heads; + output_parameters->hidden_size = hidden_size; + output_parameters->v_hidden_size = v_hidden_size; + output_parameters->head_size = hidden_size / num_heads; + output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; output_parameters->is_unidirectional = false; output_parameters->past_present_share_buffer = false; diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index 7c4ee548ccaca..c16aaca5e71ea 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -144,7 +144,7 @@ Status ReorderOutput::Compute(OpKernelContext* context) const { if (channels_last_) { MlasReorderOutputNhwc(Y_shape.data(), x_data, y_data); } else { - MlasReorderOutputNchw(Y_shape.data(), x_data, y_data); + MlasReorderOutputNchw(Y_shape.data(), x_data, y_data, context->GetOperatorThreadPool()); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 64c17b7767e4f..e7df84c1b0066 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -160,7 +160,7 @@ Status QAttention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past_tensor, - nullptr, // extra_add_qk + nullptr, // relative_position_bias nullptr // parameters )); diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_global_average_pool.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_global_average_pool.cc index 7eab6986930e3..e9924bf616eb5 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_global_average_pool.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_global_average_pool.cc @@ -55,6 +55,38 @@ Status ComputeQLinearGlobalAvgPool( return Status::OK(); } +// GCC's unexplained behavior: +// GCC wouldn't generate corresponding symbols versus function instances below when "--disable-exceptions" +// and "--minimal-build" are combined on linux build. +// But this two symbols are required by qlinear_pool.cc. +// The other compilers wouldn't hit it and works fine, and we also didn't see it in the other platforms, such as Android. +// So we are doing explicit instantiation here for every compilers/platforms happy. +template Status ComputeQLinearGlobalAvgPool( + const int8_t* x, + float x_scale, + int8_t x_zero_point, + int8_t* y, + float y_scale, + int8_t y_zero_point, + int64_t N, + int64_t C, + int64_t image_size, + bool channels_last, + concurrency::ThreadPool* tp); + +template Status ComputeQLinearGlobalAvgPool( + const uint8_t* x, + float x_scale, + uint8_t x_zero_point, + uint8_t* y, + float y_scale, + uint8_t y_zero_point, + int64_t N, + int64_t C, + int64_t image_size, + bool channels_last, + concurrency::ThreadPool* tp); + Status QLinearGlobalAveragePool::Compute(OpKernelContext* context) const { const auto tensor_x_scale = context->Input(1); const auto tensor_x_zero_point = context->Input(2); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index 950ddf6d27309..c6e267d26e6df 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -88,7 +88,7 @@ class GenerateBase { device_copy_func_(device_copy_func) { cpu_allocator_ = decoder_session_state.GetExecutionProviders() .Get(onnxruntime::kCpuExecutionProvider) - ->GetAllocator(0, OrtMemTypeDefault); + ->GetAllocator(OrtMemTypeDefault); } virtual ~GenerateBase() = default; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cf1d99688546a..630c533c47323 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -9,10 +9,6 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -#ifndef NDEBUG -//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -#endif - namespace onnxruntime { namespace concurrency { @@ -57,14 +53,14 @@ struct IBeamSearchCpuState { template struct IGreedySearchState { - gsl::span sequences_space; // shape (2, batch_size, max_length) - gsl::span sequence_lengths; // shape (batch_size) - gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. - gsl::span eos_meet; // shape (batch_size) - gsl::span next_token_scores; // shape (batch_size, vocab_size) - gsl::span next_tokens; // shape (batch_size) - gsl::span temp_topk_scores_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1 (GPU only) - gsl::span temp_topk_tokens_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1(GPU only) + gsl::span sequences_space; // shape (2, batch_size, max_length) + gsl::span sequence_lengths; // shape (batch_size) + gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. + gsl::span eos_meet; // shape (batch_size) + gsl::span next_token_scores; // shape (batch_size, vocab_size) + gsl::span next_tokens; // shape (batch_size) + gsl::span temp_topk_scores_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1 (GPU only) + gsl::span temp_topk_tokens_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1(GPU only) gsl::span topk_scores_buffer; // shape (batch_size), output buffer for topk stage 2 (GPU only) gsl::span topk_tokens_buffer; // shape (batch_size), output buffer for topk stage 2 (GPU only) }; @@ -167,6 +163,26 @@ struct IGenerationParameters { bool custom_sampling = false; }; +// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) +#ifdef DEBUG_GENERATION +#define DUMP_TENSOR_LEVEL 2 +#else +#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation. +#endif + +#if DUMP_TENSOR_LEVEL > 0 +#define DUMP_TENSOR_INIT() transformers::CudaTensorConsoleDumper dumper +#define DUMP_TENSOR(...) dumper.Print(__VA_ARGS__) +#else +#define DUMP_TENSOR_INIT() +#define DUMP_TENSOR(...) +#endif +#if DUMP_TENSOR_LEVEL > 1 +#define DUMP_TENSOR_D(...) dumper.Print(__VA_ARGS__) +#else +#define DUMP_TENSOR_D(...) +#endif + class IConsoleDumper { public: IConsoleDumper() : is_enabled_(true) {} diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index c7a2b8f0c0fc1..c8be36a41e944 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -116,7 +116,9 @@ const IExecutionProvider* Subgraph::GetProvider() const { const ExecutionProviders& providers = session_state_->GetExecutionProviders(); const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider); const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider); - const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider; + const IExecutionProvider* rocm_provider = providers.Get(onnxruntime::kRocmExecutionProvider); + const IExecutionProvider* gpu_provider = cuda_provider ? cuda_provider : rocm_provider; + const IExecutionProvider* provider = gpu_provider ? gpu_provider : cpu_provider; return provider; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc index 085968209ef5c..dc026ef71b6c8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc @@ -46,7 +46,7 @@ Status GptSubgraph::CreateInitialFeeds( AllocatorPtr cpu_allocator = session_state_->GetAllocator(input_ids.Location()); // Store allocator, which will be used in remaining feeds - auto default_allocator = provider->GetAllocator(0, OrtMemTypeDefault); + auto default_allocator = provider->GetAllocator(OrtMemTypeDefault); allocator_ = default_allocator; // The ordering is the same as used in Setup diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 3c76ad2e5797e..79532f79f4ef4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -116,7 +116,7 @@ Status T5EncoderSubgraph::CreateInitialFeeds( AllocatorPtr cpu_allocator = session_state_->GetAllocator(original_encoder_input_ids.Location()); if (cpu_allocator == nullptr) { const IExecutionProvider* provider = GetProvider(); - cpu_allocator = provider->GetAllocator(0, OrtMemTypeDefault); + cpu_allocator = provider->GetAllocator(OrtMemTypeDefault); } ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr"); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index b7eebb9d48785..8f271ecfcbfa8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -366,6 +366,39 @@ __global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* outp } } +template +__global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output) { + // Format 4 to unpack TRT packed input format for memory efficient attention. + // Input: BxSxNxMxH + // Output: MxBxSxNxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * head_size; + const int NHS = NH * sequence_length; + + int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M; + const int out_offset = n * head_size + s * NH + b * NHS + m * NHS * batch_size; + + const int h = threadIdx.x; + if (h < head_size) { + if (biases != nullptr) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + } else { + output[out_offset + h] = input[in_offset + h]; + } + } +} + template __global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) { // Format 3 for cutlass memory efficient attention @@ -481,12 +514,12 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const } } - template void InvokeAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count) { + assert(num_heads <= max_threads_per_block); const dim3 grid(sequence_length, batch_size, num_matrices); if (qk_head_size * num_heads <= max_threads_per_block) { const dim3 block(qk_head_size, num_heads, 1); @@ -506,11 +539,13 @@ void InvokeAddBiasTranspose( ORT_ENFORCE(total_matrix_count == 3); AddBiasTransposeCutlass<<>>(input, biases, output, v_head_size); } - } else { // format == 0 + } else if (format == 4) { // format == 4 + AddBiasUnpack<<>>(total_matrix_count, input, biases, output); + } else { // format == 0 AddBiasTranspose<<>>(input, biases, output); } } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); if (format == 2) { AddBiasTransposeTrtLarge<<>>(qk_head_size, input, biases, output); } else if (format == 1) { @@ -528,6 +563,8 @@ void InvokeAddBiasTranspose( } else { ORT_THROW("AddBiasTranspose (format 3) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size"); } + } else if (format == 4) { // format == 4 + ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block"); } else { // format 0 AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } @@ -541,7 +578,7 @@ void LaunchAddBiasTranspose( const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size, half* qkv_add_bias, int total_matrix_count) { total_matrix_count = std::max(num_matrices, total_matrix_count); - if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) { + if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) { const int H = qk_head_size / 4; const int H_v = v_head_size / 4; const Half4* input2 = reinterpret_cast(input); @@ -551,7 +588,7 @@ void LaunchAddBiasTranspose( InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, H, input2, biases2, output2, qkv_add_bias2, H_v, total_matrix_count); - } else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) { + } else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) { const int H = qk_head_size / 2; const int H_v = v_head_size / 2; const half2* input2 = reinterpret_cast(input); @@ -576,7 +613,7 @@ void LaunchAddBiasTranspose( const float* input, const float* biases, float* output, bool /*enable_half4*/, const int v_head_size, float* qkv_add_bias, int total_matrix_count) { total_matrix_count = std::max(num_matrices, total_matrix_count); - if (0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) { + if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) { const int H = qk_head_size / 4; const float4* input2 = reinterpret_cast(input); const float4* biases2 = reinterpret_cast(biases); @@ -586,7 +623,7 @@ void LaunchAddBiasTranspose( stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, H, input2, biases2, output2, qkv_add_bias2, v_head_size / 4, total_matrix_count); - } else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) { + } else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) { const int H = qk_head_size / 2; const float2* input2 = reinterpret_cast(input); const float2* biases2 = reinterpret_cast(biases); @@ -610,7 +647,6 @@ void InvokeAddBiasTransposeTrt( const int batch_size, const int sequence_length, const int num_heads, const int head_size, const T* biases, const T* query, const T* key, const T* value, T* output, bool is_cross_attention, int kv_sequence_length) { - if (!is_cross_attention) { ORT_ENFORCE(sequence_length == kv_sequence_length); constexpr int num_matrices = 3; @@ -619,7 +655,7 @@ void InvokeAddBiasTransposeTrt( const dim3 block(head_size, num_heads, 1); AddBiasTransposeTrt<<>>(query, key, value, biases, output); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtLarge<<>>(head_size, query, key, value, biases, output); } } else { // cross attention @@ -631,7 +667,7 @@ void InvokeAddBiasTransposeTrt( const dim3 block(head_size, num_heads, 1); AddBiasTransposeTrt<<>>(query, biases, output); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtLarge<<>>(head_size, query, biases, output); } } @@ -645,7 +681,7 @@ void InvokeAddBiasTransposeTrt( const dim3 block(head_size, num_heads, 1); AddBiasTransposeTrtKV<<>>(key, value, biases, packed_kv); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtKVLarge<<>>(head_size, key, value, biases, packed_kv); } } @@ -696,52 +732,52 @@ void LaunchAddBiasTransposeTrt( } } - template void InvokeAddBias( cudaStream_t stream, const int max_threads_per_block, const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v) { - constexpr int num_matrices = 1; - // Q - { - const dim3 grid(sequence_length, batch_size, num_matrices); - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(query, biases, q); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); - } + assert(num_heads <= max_threads_per_block); + constexpr int num_matrices = 1; + // Q + { + const dim3 grid(sequence_length, batch_size, num_matrices); + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(query, biases, q); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); } - // K - { - const dim3 grid(kv_sequence_length, batch_size, num_matrices); - const T* biases_k = biases + num_heads * head_size; + } + // K + { + const dim3 grid(kv_sequence_length, batch_size, num_matrices); + const T* biases_k = biases + num_heads * head_size; - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(key, biases_k, k); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k); - } + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(key, biases_k, k); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k); } + } - // V - { - const dim3 grid(kv_sequence_length, batch_size, num_matrices); + // V + { + const dim3 grid(kv_sequence_length, batch_size, num_matrices); - const T* biases_v = biases + 2 * num_heads * head_size; - if (v_head_size * num_heads <= max_threads_per_block) { - const dim3 block(v_head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(value, biases_v, v); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v); - } + const T* biases_v = biases + 2 * num_heads * head_size; + if (v_head_size * num_heads <= max_threads_per_block) { + const dim3 block(v_head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(value, biases_v, v); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v); } + } } template <> @@ -750,7 +786,7 @@ void LaunchAddBias( const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const float* biases, const float* query, const float* key, const float* value, float* q, float* k, float* v) { -if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { + if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { const int H = head_size / 4; const int H_v = v_head_size / 4; const float4* query2 = reinterpret_cast(query); @@ -761,8 +797,8 @@ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { float4* k2 = reinterpret_cast(k); float4* v2 = reinterpret_cast(v); InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, - biases2, query2, key2, value2, q2, k2, v2); + batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, + biases2, query2, key2, value2, q2, k2, v2); } else if (0 == (head_size & 1) && 0 == (v_head_size & 1)) { const int H = head_size / 2; const int H_v = v_head_size / 2; @@ -774,14 +810,13 @@ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { float2* k2 = reinterpret_cast(k); float2* v2 = reinterpret_cast(v); InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, - biases2, query2, key2, value2, q2, k2, v2); + batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, + biases2, query2, key2, value2, q2, k2, v2); } else { InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, head_size, v_head_size, - biases, query, key, value, q, k, v); + batch_size, sequence_length, kv_sequence_length, num_heads, head_size, v_head_size, + biases, query, key, value, q, k, v); } - } template <> @@ -790,8 +825,7 @@ void LaunchAddBias( const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const half* biases, const half* query, const half* key, const half* value, half* q, half* k, half* v) { - - if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { + if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { const int H = head_size / 4; const int H_v = v_head_size / 4; const Half4* query2 = reinterpret_cast(query); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index 8cc36637054e7..a2c3265284a4d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -24,6 +24,10 @@ namespace cuda { // format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (batch_size, sequence_length, num_matrices, num_heads, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) +// format 4: (requires qk_head_size = v_head_size) +// input: (batch_size, sequence_length, num_heads, num_matrices, head_size) +// output: (num_matrices, batch_size, sequence_length, num_heads, head_size) + template void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 4a6d2dc137139..04cac1962f37b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -59,7 +59,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(kPastInputIndex); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); @@ -69,7 +69,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -105,7 +105,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; bool use_causal_fused_runner = !disable_fused_runner_ && (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == extra_add_qk && + nullptr == relative_position_bias && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, @@ -125,7 +125,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { (nullptr == mask_index || is_mask_1d_seq_len) && nullptr == past && nullptr == present && - nullptr == extra_add_qk && + nullptr == relative_position_bias && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, enable_trt_flash_attention_, false); @@ -151,7 +151,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == mask_index && // TODO: support 1D mask nullptr == past && nullptr == present && - nullptr == extra_add_qk && + nullptr == relative_position_bias && (sizeof(T) == 2 || // sequence length threshold is 0 in FP16 parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2); @@ -181,6 +181,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); constexpr size_t element_size = sizeof(T); + constexpr bool use_fused_cross_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -190,6 +191,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_fused_cross_attention, use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -203,13 +205,16 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); - data.extra_add_qk = (nullptr == extra_add_qk) ? nullptr : reinterpret_cast(extra_add_qk->Data()); + data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = nullptr; data.use_memory_efficient_attention = use_memory_efficient_attention; + data.cumulated_sequence_length_q_cache = nullptr; + data.cumulated_sequence_length_kv_cache = nullptr; return QkvToContext(device_prop, cublas, Stream(context), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 187f1bb37edc5..41f19f460e80c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -28,6 +28,7 @@ limitations under the License. #include #include +#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -45,37 +46,47 @@ limitations under the License. using namespace onnxruntime::cuda; using namespace cub; -#define CHECK_CUDA(expr) CUDA_RETURN_IF_ERROR(expr) -#define CUDA_MEMORY_ALIGNMENT 256 - -#define DUMP_ATTENTION_LEVEL 0 -#if DUMP_ATTENTION_LEVEL > 1 -#define DUMP_ATTENTION_INIT() transformers::CudaTensorConsoleDumper dumper -#define DUMP_ATTENTION(...) dumper.Print(__VA_ARGS__) -#define DUMP_ATTENTION_D(...) dumper.Print(__VA_ARGS__) -#elif DUMP_ATTENTION_LEVEL > 0 -#define DUMP_ATTENTION_INIT() transformers::CudaTensorConsoleDumper dumper -#define DUMP_ATTENTION(...) dumper.Print(__VA_ARGS__) -#define DUMP_ATTENTION_D(...) -#else -#define DUMP_ATTENTION_INIT() -#define DUMP_ATTENTION(...) -#define DUMP_ATTENTION_D(...) -#endif - namespace onnxruntime { namespace contrib { namespace cuda { +constexpr size_t kMemoryAlignment = 256; + static size_t AlignTo(size_t a, size_t b) { return CeilDiv(a, b) * b; } size_t AlignSize(size_t bytes) { - const size_t bytesAligned = AlignTo(bytes, CUDA_MEMORY_ALIGNMENT); + const size_t bytesAligned = AlignTo(bytes, kMemoryAlignment); return bytesAligned; } +void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { + if (this->sequence_length != sequence_length) { + ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); + LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, this->max_batch_size, sequence_length, stream); + this->sequence_length = sequence_length; + } +} + +int* GetCumulatedSequenceLength(CumulatedSequenceLengthCache* cache, + const int* mask_index, + int batch_size, + int sequence_length, + cudaStream_t stream, + void* scratch_buffer) { + if (mask_index == nullptr && cache != nullptr) { + if (batch_size <= cache->max_batch_size) { + cache->Initialize(sequence_length, stream); + return reinterpret_cast(cache->buffer.get()); + } + } + + int* sequence_offset = reinterpret_cast(scratch_buffer); + LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, sequence_length, stream); + return sequence_offset; +} + size_t GetAttentionScratchSize( size_t element_size, size_t batch_size, @@ -103,6 +114,7 @@ size_t GetAttentionWorkspaceSize( size_t kv_sequence_length, size_t total_sequence_length, void* fused_runner, + bool use_fused_cross_attention, bool use_memory_efficient_attention) { // Note that q, k and v might need alignment for fused attention kernels. const size_t qkv_bytes = element_size * batch_size * num_heads * @@ -122,8 +134,11 @@ size_t GetAttentionWorkspaceSize( #endif if (fused_runner != nullptr) { - size_t sequence_offset_bytes = GetSequenceOffsetSize(static_cast(batch_size), true); - return qkv_bytes + sequence_offset_bytes; + return qkv_bytes + GetSequenceOffsetSize(static_cast(batch_size), true); + } + + if (use_fused_cross_attention) { + return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast(batch_size), true); } return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, @@ -278,14 +293,15 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, T* qkv = data.workspace; - bool use_fused_kernel = (nullptr != fused_runner && data.bias != nullptr && !parameters.is_unidirectional); + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); // Default format for memory efficient attention. // When there is past state, the format shal be BxNxSxH, so we disable memory efficient attention when there is past. - DUMP_ATTENTION_INIT(); + DUMP_TENSOR_INIT(); if (nullptr != data.gemm_buffer) { if (data.bias == nullptr) { + assert(nullptr == fused_runner); // For quantized attention, bias has been added so only need transpose here. // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH assert(qk_head_size == v_head_size); @@ -317,15 +333,67 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, 3); } - } else { // gemm_buffer == nullptr + } else if (data.key == nullptr) { // gemm_buffer == nullptr and packed qkv + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); + + if (use_memory_efficient_attention) { + // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, qkv, + true, v_head_size, qkv_add_bias, 3); + DUMP_TENSOR_D("k(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (!use_fused_kernel) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + } else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv + // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. + // CheckInputs verified this constraint. + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + + if (use_memory_efficient_attention) { + // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, k, + true, v_head_size, qkv_add_bias, 2); + DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (data.fused_cross_attention_kernel == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } + } else { // gemm_buffer == nullptr and not packed assert(data.query != nullptr && data.key != nullptr && data.value != nullptr && data.bias != nullptr); - DUMP_ATTENTION_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_ATTENTION_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_ATTENTION_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); - DUMP_ATTENTION_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); if (data.fused_cross_attention_kernel != nullptr) { assert(qk_head_size == v_head_size); @@ -347,9 +415,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, num_heads, qk_head_size, v_head_size, data.bias, data.query, data.key, data.value, q, k, v); - DUMP_ATTENTION_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } #endif @@ -362,7 +430,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, batch_size, sequence_length, num_heads, qk_head_size, data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_ATTENTION_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); + DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); qkv_format = AttentionQkvFormat::QKV_BSN3H; } else { // unfused kernel @@ -387,9 +455,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.value, data.bias + 2 * num_heads * qk_head_size, v, true, -1); - DUMP_ATTENTION_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_ATTENTION_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size); - DUMP_ATTENTION_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size); + DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size); qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } } @@ -419,22 +487,28 @@ Status QkvToContext( void* fused_runner = data.fused_runner; // At most one fused kernel is enabled. - assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr) <= 1); + assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); const int batches = batch_size * num_heads; - const int size_per_batch_q = sequence_length * qk_head_size; - const int size_per_batch_k = kv_sequence_length * qk_head_size; - const int size_per_batch_v = kv_sequence_length * v_head_size; - const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); - const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); - const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - - // Q, K and V pointers when fused attention is not used - T* qkv = data.workspace; - T* q = qkv; - T* k = q + elements_q; - T* v = k + elements_k; + + T* qkv = nullptr; + T* q = nullptr; + T* k = nullptr; + T* v = nullptr; + T* scratch1 = data.workspace; + if (data.has_qkv_workspace) { + const int size_per_batch_q = sequence_length * qk_head_size; + const int size_per_batch_k = kv_sequence_length * qk_head_size; + const int size_per_batch_v = kv_sequence_length * v_head_size; + const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); + const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); + const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); + qkv = data.workspace; + q = qkv; + k = q + elements_q; + v = k + elements_k; + scratch1 = v + elements_v; + } bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); @@ -442,8 +516,6 @@ Status QkvToContext( AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - T* scratch1 = qkv + elements_q + elements_k + elements_v; - int present_size_per_batch_k = 0; int present_size_per_batch_v = 0; if (!past_present_share_buffer) { @@ -470,6 +542,7 @@ Status QkvToContext( assert(!use_fused_kernel); assert(data.gemm_buffer != nullptr); assert(!data.use_memory_efficient_attention); + assert(data.has_qkv_workspace); if (data.present != data.past) { // For easy testing. Production should better avoid this path. @@ -481,7 +554,7 @@ Status QkvToContext( ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, - use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer + use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer data.gemm_buffer, data.present)); present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; @@ -491,44 +564,55 @@ Status QkvToContext( } // Q, K and V are ready now - DUMP_ATTENTION_INIT(); + DUMP_TENSOR_INIT(); if (data.fused_cross_attention_kernel != nullptr) { assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - int* q_sequence_offset = reinterpret_cast(scratch1); - LaunchTrtSequenceOffset(q_sequence_offset, nullptr, batch_size, sequence_length, stream); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - DUMP_ATTENTION_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.mask_index == nullptr); + int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + scratch1); + + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); + int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, kv_sequence_length, stream); + kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, + data.mask_index, batch_size, kv_sequence_length, stream, + kv_sequence_offset); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_ATTENTION_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); + DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); + // When there is no bias, we can directly use q and packed kv from inputs. + void const* query = q; + void const* packed_kv = k; + if (data.value == nullptr && data.bias == nullptr) { + query = data.query; + packed_kv = data.key; + } + run_fused_cross_attention( - q, // Q - k, // packed KV - q_sequence_offset, // cumulated sequence length of Q - kv_sequence_offset, // cumulated sequence length of KV - data.output, // output - cross_attention_kernel, // kernels - batch_size, // batch size - num_heads, // number of heads - qk_head_size, // head size of Q/K/V - sequence_length, // sequence length of Q - kv_sequence_length, // sequence length of KV + query, // Q + packed_kv, // packed KV + q_sequence_offset, // cumulated sequence length of Q + kv_sequence_offset, // cumulated sequence length of KV + data.output, // output + cross_attention_kernel, // kernels + batch_size, // batch size + num_heads, // number of heads + qk_head_size, // head size of Q/K/V + sequence_length, // sequence length of Q + kv_sequence_length, // sequence length of KV stream); - DUMP_ATTENTION("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); return Status::OK(); } @@ -538,7 +622,9 @@ Status QkvToContext( if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); } else { - LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + sequence_offset); } CUDA_RETURN_IF_ERROR(cudaGetLastError()); @@ -553,12 +639,19 @@ Status QkvToContext( if (use_fused_kernel) { assert(qkv_format == AttentionQkvFormat::QKV_BSN3H); - fused_fp16_runner->run(qkv, sequence_offset, data.output, stream); - DUMP_ATTENTION("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + + // When there is no bias, we can directly use packed qkv from inputs. + void const* packed_qkv = qkv; + if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { + packed_qkv = data.query; + } + + fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); + DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); } else { assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_ATTENTION("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size); } return Status::OK(); } @@ -570,6 +663,15 @@ Status QkvToContext( assert(data.mask_index == nullptr); assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + const void* query = q; + const void* key = k; + const void* value = v; + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { + assert(data.bias == nullptr); + query = data.query; + } + MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; @@ -582,15 +684,15 @@ Status QkvToContext( p.causal = parameters.is_unidirectional; p.cu_seqlens_q = nullptr; p.cu_seqlens_k = nullptr; - p.query = q; - p.key = k; - p.value = v; + p.query = query; + p.key = key; + p.value = value; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr; p.stream = stream; run_memory_efficient_attention(p); - DUMP_ATTENTION("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); return Status::OK(); } #endif @@ -610,7 +712,7 @@ Status QkvToContext( // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; + : parameters.scale; float alpha = use_raw_attention_mask ? one : scale; cublasSetStream(cublas, stream); @@ -622,7 +724,7 @@ Status QkvToContext( q, qk_head_size, sequence_length * qk_head_size, &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_ATTENTION_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); @@ -639,7 +741,7 @@ Status QkvToContext( T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, data.extra_add_qk, scratch1, scratch2, + mask_index, nullptr, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); @@ -649,14 +751,14 @@ Status QkvToContext( const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, data.extra_add_qk, scratch1, scratch2, parameters.is_unidirectional)); + mask_index, mask_start, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( - ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.extra_add_qk, + ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional)); } - DUMP_ATTENTION_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = qkv; @@ -670,7 +772,7 @@ Status QkvToContext( // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, max_threads_per_block, false, temp_output, data.output); - DUMP_ATTENTION("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size); return result; } @@ -754,15 +856,15 @@ Status DecoderQkvToContext( if (has_layer_state) { if (use_past && static_kv) { - CHECK_CUDA(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); } else { - CHECK_CUDA(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d98a0380c479b..ec7371db4c14d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -6,18 +6,33 @@ #include #include #include "contrib_ops/cpu/bert/attention_common.h" +#include "core/framework/allocator.h" namespace onnxruntime { namespace contrib { namespace cuda { -size_t GetAttentionScratchSize( +constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128; + +struct CumulatedSequenceLengthCache { + onnxruntime::IAllocatorUniquePtr buffer; + int32_t max_batch_size; + int32_t sequence_length; + + CumulatedSequenceLengthCache() : max_batch_size(0), sequence_length(0) {} + void Initialize(int32_t sequence_length, cudaStream_t stream); +}; + +size_t +GetAttentionScratchSize( size_t element_size, size_t batch_size, size_t num_heads, size_t sequence_length, size_t all_sequence_length); +size_t GetSequenceOffsetSize(int batch_size, bool has_padding); + size_t GetAttentionWorkspaceSize( size_t element_size, size_t batchsize, @@ -28,7 +43,8 @@ size_t GetAttentionWorkspaceSize( size_t kv_sequence_length, size_t total_sequence_length, void* fused_runner, - bool use_memory_efficient_attention = false); + bool use_fused_cross_attention, + bool use_memory_efficient_attention); template struct AttentionData { @@ -41,9 +57,11 @@ struct AttentionData { const int* mask_index; gsl::span mask_index_dims; const T* past; - const T* extra_add_qk; + const T* relative_position_bias; + bool has_qkv_workspace; T* workspace; + T* output; T* present; @@ -51,6 +69,9 @@ struct AttentionData { const void* fused_cross_attention_kernel; bool use_memory_efficient_attention; + + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache; + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index 953a45e15b32e..92851c446d48f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -377,11 +377,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, float thread_data = -CUDART_INF_F; if (threadIdx.x < all_sequence_length) { - if (add_before_softmax == nullptr) { - thread_data = float(input[index]) * rsqrt_head_size; - } else { - thread_data = float(input[index] + add_before_softmax[index]) * rsqrt_head_size; - } + thread_data = float(input[index]) * rsqrt_head_size; const int sequence_index = blockIdx.x % sequence_length; if (is_unidirectional) { @@ -412,6 +408,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, thread_data = -CUDART_INF_F; } } + + if (add_before_softmax != nullptr) { + thread_data += float(add_before_softmax[index]); + } } if (skip_softmax) { @@ -714,12 +714,12 @@ Status ComputeSoftmaxWithRawMask(cudaStream_t stream, } if (use_persistent_softmax) { - dispatch_warpwise_softmax_forward(stream, - output, - persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, - batch_size * num_heads * sequence_length); + return dispatch_warpwise_softmax_forward(stream, + output, + persistent_softmax_workspace, + all_sequence_length, + all_sequence_length, + batch_size * num_heads * sequence_length); } return CUDA_CALL(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 299c21df94ff1..d48716109eaaa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -363,7 +363,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { return LaunchDecoderAttentionKernel( device_prop, #ifdef USE_ROCM - IsTunableOpEnabled(), + GetTuningContext(), #endif stream, cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index c7e5d34e1691b..321c2a1df0df2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -31,9 +31,11 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) template - MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) - : CudaKernel(info), fused_fp16_cross_attention_kernel_(nullptr) { + : CudaKernel(info), + fused_fp16_cross_attention_kernel_(nullptr), + cumulated_sequence_length_q_cache_(), + cumulated_sequence_length_kv_cache_() { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); @@ -52,7 +54,15 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) disable_memory_efficient_attention_ = true; #endif - disable_fused_cross_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false); + disable_fused_cross_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false); + + // Allocate cache buffers + constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast(kCumulatedSequenceLengthCacheMaxBatchSize) + 1); + cumulated_sequence_length_q_cache_.buffer = GetTransientScratchBuffer(cache_bytes); + cumulated_sequence_length_q_cache_.max_batch_size = kCumulatedSequenceLengthCacheMaxBatchSize; + cumulated_sequence_length_kv_cache_.buffer = GetTransientScratchBuffer(cache_bytes); + cumulated_sequence_length_kv_cache_.max_batch_size = kCumulatedSequenceLengthCacheMaxBatchSize; } template @@ -62,6 +72,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* bias = context->Input(3); const Tensor* key_padding_mask = context->Input(4); + const Tensor* relative_position_bias = context->Input(5); auto& device_prop = GetDeviceProp(); AttentionParameters parameters; @@ -70,6 +81,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { value, bias, key_padding_mask, + relative_position_bias, ¶meters, num_heads_, mask_filter_value_, @@ -94,6 +106,9 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && + nullptr == relative_position_bias && + key != nullptr && + (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); @@ -111,6 +126,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = !disable_fused_runner_ && fused_cross_attention_kernel == nullptr && + nullptr == relative_position_bias && + (value != nullptr || key == nullptr) && (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && @@ -141,41 +158,58 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !disable_memory_efficient_attention_ && is_long_sequence && nullptr == key_padding_mask && // TODO: support 1D mask + nullptr == relative_position_bias && has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; #endif - constexpr size_t element_size = sizeof(T); - size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, - parameters.batch_size, - parameters.num_heads, - parameters.head_size, - parameters.v_head_size, - parameters.sequence_length, - parameters.kv_sequence_length, - parameters.total_sequence_length, - fused_runner, - use_memory_efficient_attention); - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. + bool no_qkv_workspace = nullptr == value && + (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) && + nullptr == key_padding_mask && + nullptr == bias; + + size_t workspace_bytes; + if (no_qkv_workspace) { + workspace_bytes = (parameters.batch_size > kCumulatedSequenceLengthCacheMaxBatchSize) ? 2 * GetSequenceOffsetSize(parameters.batch_size, true) : 0; + } else { + constexpr size_t element_size = sizeof(T); + workspace_bytes = GetAttentionWorkspaceSize(element_size, + parameters.batch_size, + parameters.num_heads, + parameters.head_size, + parameters.v_head_size, + parameters.sequence_length, + parameters.kv_sequence_length, + parameters.total_sequence_length, + fused_runner, + use_fused_cross_attention, + use_memory_efficient_attention); + } + + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = nullptr; - data.bias = reinterpret_cast(bias->Data()); + data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); - data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); + data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past = nullptr; - data.extra_add_qk = nullptr; + data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.has_qkv_workspace = !no_qkv_workspace; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = nullptr; data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_memory_efficient_attention = use_memory_efficient_attention; + data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); + data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index b4ac7f19597ea..928dbd1c4a0f4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -7,6 +7,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" +#include "contrib_ops/cuda/bert/attention_impl.h" namespace onnxruntime { namespace contrib { @@ -29,6 +30,8 @@ class MultiHeadAttention final : public CudaKernel { bool disable_memory_efficient_attention_; mutable std::unique_ptr fused_fp16_runner_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; + mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; + mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index af13efe0e2fbc..111fed04639e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -3,7 +3,15 @@ #include "core/providers/cuda/cuda_common.h" #include "relative_attn_bias.h" +#include "core/common/safeint.h" #include "relative_attn_bias_impl.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + namespace onnxruntime { namespace contrib { @@ -20,7 +28,16 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - RelPosAttnBias); + RelPosAttnBias); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GatedRelativePositionBias, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + GatedRelativePositionBias); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) @@ -69,6 +86,108 @@ Status RelPosAttnBias::ComputeInternal(OpKernelContext* context) const { device_prop.maxThreadsPerBlock); } +template +GatedRelativePositionBias::GatedRelativePositionBias(const OpKernelInfo& info) : CudaKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = SafeInt(num_heads); +} + +template +Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) const { + const Tensor& query_tensor = *context->Input(0); + const Tensor& query_bias_tensor = *context->Input(1); + const Tensor& rel_pos_tensor = *context->Input(2); + const Tensor& weight_tensor = *context->Input(3); + const Tensor& bias_tensor = *context->Input(4); + const Tensor& eco_a_tensor = *context->Input(5); + + const auto& query_dims = query_tensor.Shape().GetDims(); + ORT_ENFORCE(query_dims.size() == 3); + ORT_ENFORCE(query_dims[2] > 0); + ORT_ENFORCE(query_dims[2] % num_heads_ == 0); + const auto batch_size = SafeInt(query_dims[0]); + const auto seq_len = SafeInt(query_dims[1]); + const auto head_size = SafeInt(query_dims[2] / num_heads_); + + ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1); + ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]); + + const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims(); + ORT_ENFORCE(rel_pos_dims.size() == 4); + ORT_ENFORCE(rel_pos_dims[0] == 1); + ORT_ENFORCE(rel_pos_dims[1] == num_heads_); + ORT_ENFORCE(rel_pos_dims[2] == seq_len); + ORT_ENFORCE(rel_pos_dims[3] == seq_len); + + const auto& weight_dims = weight_tensor.Shape().GetDims(); + ORT_ENFORCE(weight_dims.size() == 2); + ORT_ENFORCE(weight_dims[0] == head_size); + ORT_ENFORCE((weight_dims[1] > 0) && (weight_dims[1] % 2 == 0)); + + ORT_ENFORCE(bias_tensor.Shape().NumDimensions() == 1); + ORT_ENFORCE(bias_tensor.Shape()[0] == weight_dims[1]); + + const auto D = SafeInt(weight_dims[1]); + + const auto& eco_a_dims = eco_a_tensor.Shape().GetDims(); + ORT_ENFORCE(eco_a_dims.size() == 4); + ORT_ENFORCE(eco_a_dims[0] == 1); + ORT_ENFORCE(eco_a_dims[1] == num_heads_); + ORT_ENFORCE(eco_a_dims[2] == 1); + ORT_ENFORCE(eco_a_dims[3] == 1); + + Tensor* output = context->Output(0, {batch_size, num_heads_, seq_len, seq_len}); + + auto& device_prop = GetDeviceProp(); + cublasHandle_t cublas = GetCublasHandle(context); + + typedef typename ToCudaType::MappedType CudaT; + const auto BNS = batch_size * num_heads_ * seq_len; + const size_t elements_in_query = (size_t)BNS * (size_t)head_size; + const size_t elements_after_gemm = (size_t)BNS *(size_t)D; + size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0); + auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream()); + + // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH) + constexpr int format = 1; + constexpr int total_maxtrix = 1; + constexpr int num_matrix_to_transpose = 1; + LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock, + batch_size, seq_len, num_heads_, head_size, + reinterpret_cast(query_tensor.template Data()), + reinterpret_cast(query_bias_tensor.template Data()), + reinterpret_cast(workspace.get()), + false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix); + + // reuse output if possible + CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast(workspace.get()) + elements_in_query) + : reinterpret_cast(output->template MutableData()); + int ld_gemm_output = max(seq_len, D); + + const CudaT one = ToCudaType::FromFloat(1.0f); + const CudaT zero = ToCudaType::FromFloat(0.0f); + + // ([b*n*s, h] * [h, D]), CUDA assumes col-major + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + D, BNS, head_size, &one, + reinterpret_cast(weight_tensor.template Data()), (int)D, + reinterpret_cast(workspace.get()), (int)head_size, + &zero, gemm_output, ld_gemm_output, device_prop)); + + auto status = LaunchGatedRelativePositionBiasKernel( + device_prop, Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(rel_pos_tensor.template Data()), + reinterpret_cast(gemm_output), + reinterpret_cast(bias_tensor.template Data()), + reinterpret_cast(eco_a_tensor.template Data()), + batch_size, num_heads_, seq_len, D, ld_gemm_output); + + return status; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h index b9674f6f35091..3bf4e730e29f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h @@ -22,6 +22,18 @@ class RelPosAttnBias final : public CudaKernel { bool is_bidirectional_; }; +template +class GatedRelativePositionBias final : public CudaKernel { + public: + GatedRelativePositionBias(const OpKernelInfo& op_kernel_info); + + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int num_heads_; +}; + + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index e333152cb5bcf..938496b058025 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -36,7 +36,7 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, const bool is_bidirectional, const int max_distance) { const int head_id = blockIdx.x; - for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { + for (int seq_id = blockDim.x * blockIdx.y + threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { int row_id = seq_id / seq_len; int col_id = seq_id % seq_len; @@ -149,6 +149,122 @@ template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream, const bool is_bidirectional, const int max_threads_per_block); +template +__global__ void GatedRelativePositionBiasKernelSmallD( + T* output, // (batch_size, num_heads, seq_len, seq_len) + const T* rel_pos, // (1, num_heads, seq_len, seq_len) + const T* qw, // (batch_size, num_heads, seq_len, D) + const T* bias, // (D) + const T* eco_a, // (1, num_heads, 1, 1) + const int D, + const int ldqw) { + __shared__ float gate[1]; + + const int seq_len = gridDim.x; + const int num_heads = gridDim.y; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + rel_pos += ((int64_t)n * seq_len + s) * seq_len; + output += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * seq_len; + qw += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * ldqw; + + float val = 0.0f; + if (threadIdx.x < D) { + val = (float)qw[threadIdx.x] + (bias ? (float)bias[threadIdx.x] : 0.0f); + } + + float u = (threadIdx.x < D / 2) ? val : 0.0f; +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + u += __shfl_down_sync(0xffffffff, u, offset); + } + + float r = (threadIdx.x >= D / 2) ? val : 0.0f; +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + r += __shfl_down_sync(0xffffffff, r, offset); + } + + if (threadIdx.x == 0) { + u = 1.0f / (1.0f + expf(-u)); + r = 1.0f / (1.0f + expf(-r)); + gate[0] = u * (r * (float)eco_a[n] - 1.0f) + 2.0f; + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < seq_len; idx += blockDim.x) { + output[idx] = (T)(gate[0] * (float)rel_pos[idx]); + } +} + +template +Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + T* output, + const T* rel_pos, + const T* qw, // query * weight + const T* bias, + const T* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw) { + ORT_ENFORCE(D <= 32 && D > 0 && (D % 2 == 0)); + ORT_ENFORCE(ldqw == seq_len || ldqw == D); + + int tpb = std::max(32, std::max(D, seq_len)); + tpb = std::min(tpb, device_prop.maxThreadsPerBlock); + + // round up tpb to power of 2 + --tpb; + tpb |= (tpb >> 1); + tpb |= (tpb >> 2); + tpb |= (tpb >> 4); + tpb |= (tpb >> 8); + tpb |= (tpb >> 16); + tpb++; + + dim3 block(tpb); + dim3 grid(seq_len, num_heads, batch_size); + + GatedRelativePositionBiasKernelSmallD<<>>( + output, rel_pos, qw, bias, eco_a, D, ldqw); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + float* output, + const float* rel_pos, + const float* qw, + const float* bias, + const float* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + +template Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + half* output, + const half* rel_pos, + const half* qw, + const half* bias, + const half* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h index 5a1a229ab6077..5c7c98f55f3f5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h @@ -22,6 +22,21 @@ Status LaunchRelPosAttnBiasKernel( const int max_threads_per_block ); +template +Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + T* output, + const T* rel_pos, + const T* qw, // from query * weight + const T* bias, + const T* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h index 45a17a03d82f2..23bab06fe46ca 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h @@ -186,7 +186,7 @@ static Fused_multihead_attention_params_mhca getMHCAParams( void const* q_packed_d, void const* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d) { Fused_multihead_attention_params_mhca params{}; - int32_t const d_padded = static_cast(std::pow(2, std::ceil(std::log(d) / std::log(2)))); + int32_t const d_padded = d <= 64 ? 64 : static_cast(std::pow(2, std::ceil(std::log(d) / std::log(2)))); // Set the pointers. params.o_ptr = o_packed_d; @@ -269,11 +269,11 @@ using FusedMHACrossKernelFactory = TSharedCubinKernelFactory min_head_size) && (head_size <= max_head_size) && - (kv_sequence_length <= 128); // TODO: shall we remove this constraint on kv_sequence_length? + (kv_sequence_length <= 128); // TODO: shall we remove this constraint on kv_sequence_length? } inline FusedMultiHeadCrossAttentionKernel const* get_fused_cross_attention_kernels(int32_t sm) { diff --git a/onnxruntime/contrib_ops/cuda/collective/mpi_include.h b/onnxruntime/contrib_ops/cuda/collective/mpi_include.h new file mode 100644 index 0000000000000..ee560bdf4207a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/mpi_include.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_MPI) +#define OMPI_SKIP_MPICXX 1 // See https://github.com/open-mpi/ompi/issues/5157 +#include +#undef OMPI_SKIP_MPICXX + +namespace onnxruntime { + +#if defined(USE_MPI) +#define MPI_CHECK(condition) \ + do { \ + int error = (condition); \ + ORT_ENFORCE( \ + error == MPI_SUCCESS, \ + "MPI Error at: ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ": ", \ + error); \ + } while (0) +#endif +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc new file mode 100644 index 0000000000000..3122f070fd57f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "nccl_kernels.h" +#include "mpi_include.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr)) + +static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) { + if (type == DataTypeImpl::GetType()) { + return ncclUint8; + } else if (type == DataTypeImpl::GetType()) { + return ncclInt8; + } else if (type == DataTypeImpl::GetType()) { + return ncclInt32; + } else if (type == DataTypeImpl::GetType()) { + return ncclInt64; + } else if (type == DataTypeImpl::GetType()) { + return ncclFloat16; + } else if (type == DataTypeImpl::GetType()) { + return ncclFloat32; + } else if (type == DataTypeImpl::GetType()) { + return ncclFloat64; + } else { + ORT_THROW("Tensor type not supported in NCCL."); + } +} + +#ifdef USE_MPI +static Status CreateNcclCommByMPI(int world_size, int rank, ncclComm_t* comm) { + // Create new NCCL communicator + ncclUniqueId nccl_id; + if (rank == 0) { + NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); + } + MPI_CHECK(MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); + NCCL_RETURN_IF_ERROR(ncclCommInitRank(comm, world_size, nccl_id, rank)); + + return Status::OK(); +} +#endif + +NcclContext::NcclContext() { +#ifdef USE_MPI + int is_mpi_initialized = 0; + MPI_Initialized(&is_mpi_initialized); + if (!is_mpi_initialized) { + int mpi_threads_provided = 0; + MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &mpi_threads_provided); + } + + // get world_size and rank from MPI + MPI_Comm_size(MPI_COMM_WORLD, &world_size_); + + MPI_Comm_rank(MPI_COMM_WORLD, &rank_); + + // Initialize global Parallel Group NCCL Communicator + auto ret = CreateNcclCommByMPI(world_size_, rank_, &comm_); + ORT_ENFORCE(ret.IsOK()); + +#else + ORT_THROW("ORT must be built with MPI to use NCCL."); +#endif +} + +NcclContext::~NcclContext() { + if (comm_ != nullptr) { + ncclCommDestroy(comm_); + } + +#ifdef USE_MPI + int is_mpi_finalized = 0; + MPI_Finalized(&is_mpi_finalized); + if (!is_mpi_finalized) { + MPI_Finalize(); + } +#endif +} + +NcclKernel::NcclKernel(const OpKernelInfo& info) : CudaKernel(info) { + static NcclContext context; + nccl_ = &context; +} + +AllReduce::AllReduce(const OpKernelInfo& info) : NcclKernel(info) { +} + +Status AllReduce::ComputeInternal(OpKernelContext* context) const { + ncclComm_t comm = nccl_->Comm(); + + auto input_tensor = context->Input(0); + const void* input_data = input_tensor->DataRaw(); + const auto in_shape = input_tensor->Shape(); + int64_t input_count = in_shape.Size(); + + void* output_data = context->Output(0, in_shape)->MutableDataRaw(); + + ncclDataType_t dtype = GetNcclDataType(input_tensor->DataType()); +#ifdef ORT_USE_NCCL + NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, input_count, dtype, ncclSum, comm, Stream(context))); +#endif + return Status::OK(); +} + +AllGather::AllGather(const OpKernelInfo& info) : NcclKernel(info) { + info.GetAttrOrDefault("group_size", &group_size_, static_cast(1)); +} + +Status AllGather::ComputeInternal(OpKernelContext* context) const { + ncclComm_t comm = nccl_->Comm(); + + auto input_tensor = context->Input(0); + const void* input_data = input_tensor->DataRaw(); + const auto in_shape = input_tensor->Shape(); + int64_t input_count = in_shape.Size(); + // construct output shape + TensorShape out_shape(in_shape); + out_shape[0] = group_size_ * out_shape[0]; + + void* output_data = context->Output(0, out_shape)->MutableDataRaw(); + + ncclDataType_t dtype = GetNcclDataType(input_tensor->DataType()); +#ifdef ORT_USE_NCCL + NCCL_RETURN_IF_ERROR(ncclAllGather(input_data, output_data, input_count, dtype, comm, Stream(context))); +#endif + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX(AllReduce, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .VariadicAlias(0, 0) // outputs and inputs are mapped one to one + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), + AllReduce); + +ONNX_OPERATOR_KERNEL_EX( + AllGather, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), + AllGather); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h new file mode 100644 index 0000000000000..1576f674106e2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" + +#if defined(ORT_USE_NCCL) +#include +#endif + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// ----------------------------------------------------------------------- +// Defines a new version of nccl classes +// that independent with training::DistributedRunContext, only rely on MPI +// ----------------------------------------------------------------------- +class NcclContext final { + public: + NcclContext(); + ~NcclContext(); + + ncclComm_t Comm() { + return comm_; + } + + int Rank() const { + return rank_; + } + + int Size() const { + return world_size_; + } + + private: + ncclComm_t comm_; + int rank_; + int world_size_; +}; + +class NcclKernel : public ::onnxruntime::cuda::CudaKernel { + public: + explicit NcclKernel(const OpKernelInfo& info); + + protected: + NcclContext* nccl_ = nullptr; +}; + +/* + * Defines new version of Nccl classes that independent with training::DistributedContext + * only rely on MPI + */ +class AllReduce final : public NcclKernel { + public: + explicit AllReduce(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +class AllGather final : public NcclKernel { + public: + explicit AllGather(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t group_size_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 38bcbc298b939..1cefd44844f39 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -19,6 +19,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasAdd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); @@ -30,6 +34,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding); @@ -71,6 +77,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); @@ -125,6 +134,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrd class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); #endif +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -133,124 +147,139 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // These ops were experimental ops in onnx domain which have been removed now. We add them here as - // contrib ops to maintain backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // These ops were experimental ops in onnx domain which have been removed now. We add them here as + // contrib ops to maintain backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // TransposedMatMul is still here for backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // TransposedMatMul is still here for backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN - BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif + +#if defined(USE_MPI) && defined(ORT_USE_NCCL) + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc new file mode 100644 index 0000000000000..5d5183221eda4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/bias_add.h" +#include "contrib_ops/cuda/diffusion/bias_add_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BiasAdd, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + BiasAdd); + +REGISTER_KERNEL_TYPED(MLFloat16); +REGISTER_KERNEL_TYPED(float); + +using namespace ONNX_NAMESPACE; + +template +BiasAdd::BiasAdd(const OpKernelInfo& op_info) : CudaKernel(op_info) { +} + +template +Status BiasAdd::ComputeInternal(OpKernelContext* context) const { + // Input: [batch_size, height*width, channels] + // Bias: [channels] + // Skip: [batch_size, height*width, channels] + // Output: [batch_size, height*width, channels] + + const Tensor* input = context->Input(0); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The input is expected to have 3 dimensions, got ", input_dims.size()); + } + + if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels should be 320, 640 or 1280, got ", input_dims[2]); + } + + const Tensor* bias = context->Input(1); + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The bias is expected to have 1 dimensions, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in the last dimension of input and bias are not the same"); + } + + const Tensor* skip = context->Input(2); + if (skip->Shape() != input->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Shape of input and skip (residual) shall be the same"); + } + + Tensor* output = context->Output(0, input->Shape()); + + typedef typename ToCudaType::MappedType CudaT; + const int32_t grid_size = static_cast(input_dims[0] * input_dims[1]); + LaunchBiasAddKernel(Stream(context), grid_size, static_cast(input_dims[2]), + reinterpret_cast(input->Data()), + reinterpret_cast(bias->Data()), + reinterpret_cast(skip->Data()), + reinterpret_cast(output->MutableData())); + + CUDA_RETURN_IF_ERROR(cudaPeekAtLastError()); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h new file mode 100644 index 0000000000000..6f4904f4c8de9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class BiasAdd final : public CudaKernel { + public: + BiasAdd(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu new file mode 100644 index 0000000000000..2983cc99e30b1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The CUDA kernel is modified from SeqLen2Spatial plugin of TensorRT 8.5. +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/diffusion/bias_add_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, T* output) { + int32_t base_offset = blockIdx.x * C + threadIdx.x; + int32_t bias_offset = threadIdx.x; + +#pragma unroll + for (int32_t i = 0; i < C / TPB; ++i) { + output[base_offset] = input[base_offset] + bias[bias_offset] + residual[base_offset]; + base_offset += TPB; + bias_offset += TPB; + } +} + +template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); +template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); +template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); +template __global__ void BiasAddKernel(half const*, half const*, half const*, half*); +template __global__ void BiasAddKernel(half const*, half const*, half const*, half*); +template __global__ void BiasAddKernel(half const*, half const*, half const*, half*); + +template +void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, + T const* input, T const* bias, T const* residual, T* output) { + constexpr int32_t TPB = 320; // thread per block + switch (num_channels) { + case 320: + (BiasAddKernel)<<>>(input, bias, residual, output); + break; + case 640: + (BiasAddKernel)<<>>(input, bias, residual, output); + break; + case 1280: + (BiasAddKernel)<<>>(input, bias, residual, output); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, + float const* input, float const* bias, float const* residual, float* output); + +template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, + half const* input, half const* bias, half const* residual, half* output); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h new file mode 100644 index 0000000000000..d3397ea035959 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/common/status.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, + T const* input, T const* bias, T const* residual, T* output); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc new file mode 100644 index 0000000000000..2b13cdbd803ef --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/bias_split_gelu.h" +#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BiasSplitGelu, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + BiasSplitGelu); + +REGISTER_KERNEL_TYPED(MLFloat16); +REGISTER_KERNEL_TYPED(float); + +using namespace ONNX_NAMESPACE; + +template +BiasSplitGelu::BiasSplitGelu(const OpKernelInfo& op_info) : CudaKernel(op_info) { +} + +template +Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 dimensions, got ", input_dims.size()); + } + + if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "hidden size should be 2560, 5120 or 10240, got ", input_dims[2]); + } + + const Tensor* bias = context->Input(1); + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimensions, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of input and bias are not the same"); + } + + TensorShapeVector output_shape = input->Shape().AsShapeVector(); + output_shape[2] = input_dims[2] / 2; + Tensor* output = context->Output(0, output_shape); + + typedef typename ToCudaType::MappedType CudaT; + const int32_t grid_size = static_cast(input_dims[0] * input_dims[1]); + const int32_t half_hidden_size = static_cast(input_dims[2] / 2); + LaunchBiasSplitGeluKernel(Stream(context), grid_size, half_hidden_size, + reinterpret_cast(input->Data()), + reinterpret_cast(bias->Data()), + reinterpret_cast(output->MutableData())); + + CUDA_RETURN_IF_ERROR(cudaPeekAtLastError()); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h new file mode 100644 index 0000000000000..feec45600bbce --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class BiasSplitGelu final : public CudaKernel { + public: + BiasSplitGelu(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu new file mode 100644 index 0000000000000..8069cbc0a1e0e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The CUDA kernel is modified from SplitGelu plugin of TensorRT 8.5. +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) { + int32_t index_input = blockIdx.x * HHS * 2 + threadIdx.x; + int32_t index_output = blockIdx.x * HHS + threadIdx.x; + int32_t index_bias = threadIdx.x; + +#pragma unroll + for (int32_t i = 0; i < HHS / TPB; ++i) { + auto value_left = (float)(input[index_input] + bias[index_bias]); + auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]); + + // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0) + float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); + float result = value_left * gelu_right; + output[index_output] = static_cast(result); + index_input += TPB; + index_output += TPB; + index_bias += TPB; + } + return; +} + +template +void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + T const* input, T const* bias, T* output) { + constexpr int32_t TPB = 256; // thread per block + switch (half_hidden_size) { + case 1280: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 2560: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 5120: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); + +template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + float const* input, float const* bias, float* output); + +template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + half const* input, half const* bias, half* output); +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h new file mode 100644 index 0000000000000..a04201bd12e3c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/common/status.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + T const* input, T const* bias, T* output); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc new file mode 100644 index 0000000000000..36a2bd11257d6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/group_norm.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define GROUP_NORM_TYPES float, MLFloat16 + +ONNX_OPERATOR_KERNEL_EX( + GroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + +using namespace ONNX_NAMESPACE; + +namespace { +template +struct DispatchGroupNorm { + Status operator()(cudaStream_t stream, + Tensor* output, + const Tensor* input, + const Tensor* gamma, + const Tensor* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_swish_activation) { + typedef typename ToCudaType::MappedType CudaT; + return LaunchGroupNormKernel( + stream, + reinterpret_cast(output->MutableData()), + reinterpret_cast(input->Data()), + gamma->Data(), + beta->Data(), + workspace, + epsilon, + batch_size, + num_channels, + height, + width, + num_groups, + use_swish_activation); + } +}; + +} // namespace + +GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); + ORT_ENFORCE(epsilon_ >= 0); + + int64_t num_groups; + ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); + ORT_ENFORCE(num_groups >= 0); + num_groups_ = static_cast(num_groups); + + int64_t activation; + ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); + ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish + use_swish_activation_ = (activation == 1); +} + +Status GroupNorm::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* gamma = context->Input(1); + const Tensor* beta = context->Input(2); + Tensor* output = context->Output(0, input->Shape()); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 4 dimensions, got ", input_dims.size()); + } + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != input_dims[3]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in gamma and input does not match"); + } + + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != input_dims[3]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in beta and input does not match"); + } + + // Input and output format is NHWC + int batch_size = static_cast(input_dims[0]); + int num_channels = static_cast(input_dims[3]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + + utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); + return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + epsilon_, + batch_size, + num_channels, + height, + width, + num_groups_, + use_swish_activation_); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h new file mode 100644 index 0000000000000..8578a1642198f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +class GroupNorm final : public CudaKernel { + public: + GroupNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + bool use_swish_activation_; + float epsilon_; + int num_groups_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu new file mode 100644 index 0000000000000..01ba078b4be77 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -0,0 +1,475 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +static inline int32_t divUp(int32_t m, int32_t n) { + return (m + n - 1) / n; +} + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sumSq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +template +struct GroupNormNHWCParams { + // The output buffer. Layout NHWC. + T* dst; + // The input buffer. Layout NHWC. + T const* src; + // The gamma scaling factor. + float const* gamma; + // The beta term to add in GN. + float const* beta; + // The temporary buffer to do the global parallel reduction. Size: + // BLOCKS_PER_BATCH x C x 2. + float* redBuffer; + + // The number of instances in the batch. + int32_t n; + // The height and width of each activation map. + int32_t h; + int32_t w; + // The number of channels. + int32_t c; + // The number of groups. + int32_t groups; + // Do we apply the Swish activation function? + bool withSwish; + + // Precomputed values and parameters to control the execution of the kernels. + + // The number of activations per instance (h * w) and the number of + // activations per block. + int32_t hw; + int32_t hwPerBlock; + // The number of channels per group and blocks per activation in the C + // dimension. + int32_t cPerBlock; + int32_t cPerGroup; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hwc in floats (to compute mean/var). + float invHWC; + // The precomputed number of groups per block. + int32_t groupsPerBlock; +}; + +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); + +template <> +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; +} + +template +__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage tempStorage; + // Allocate shared memory for the groups. We could reduce the amount of shared + // memory reserved. + __shared__ float2 smem[tTHREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + // Iterate over the activations to compute the sums. + if (ci < params.c) { + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; + UpdateSum(params.src, offset, sum, sumSq); + } + } + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = threadIdx.x * 2 / params.cPerGroup; + int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + + // Do the segmented scan. + GroupSums out; + BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced + // stores later). + if (cj == params.cPerGroup - 2) { //2 channels per thread + smem[gi] = make_float2(out.sum, out.sumSq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The global group index. + int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + return; + } + + // The first threads (those storing to global memory, load the values). + float2 sums = smem[threadIdx.x]; + + // Store to global memory. + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); +} + +template +void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { + // Make sure the values are as we expect. + ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); + // Make sure a group does not span multiple blocks. + ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); + + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNHWCSumKernel<<>>(params); + break; + case 480: + groupNormNHWCSumKernel<<>>(params); + break; + case 256: + groupNormNHWCSumKernel<<>>(params); + break; + case 128: + groupNormNHWCSumKernel<<>>(params); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template +__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); + +template <> +__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, + float2& gammaF2, float2& betaF2, bool swish) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (swish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); +} + +template <> +__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, + float2& gammaF2, float2& betaF2, bool swish) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (swish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast(&dst[offset]) = f2; +} + +template +__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + if (ci >= params.c) { + return; + } + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = ci / params.cPerGroup; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sumSq = 0.F; + if (gi < params.groups) { + sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; + sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + } + + // Load gamma/beta. + float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + + // Compute the mean. + float mean = sum * params.invHWC; + // Compute the variance. + float var = sumSq * params.invHWC - (mean * mean); + // Compute the inverse of the stddev. + float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + } +} + +template +void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { + // Make sure the dimensions are aligned with what we expect. + ORT_ENFORCE(params.c % params.cPerBlock == 0); + // Make sure a group does not span multiple blocks. + ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); + + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNHWCScaleKernel<<>>(params); + break; + case 480: + groupNormNHWCScaleKernel<<>>(params); + break; + case 256: + groupNormNHWCScaleKernel<<>>(params); + break; + case 128: + groupNormNHWCScaleKernel<<>>(params); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { + int32_t maxDivisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { + maxDivisor = divisor1; + } + if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { + maxDivisor = divisor2; + } + } + } + return maxDivisor; +} + +template +Status LaunchGroupNormKernel( + cudaStream_t stream, + T* output, + const T* input, + const float* gamma, + const float* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_swish_activation) { + if (batch_size > static_cast(kMaxGroupNormBatchSize)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "only support batch_size <= 32. Got", batch_size); + } + + if (num_groups != static_cast(kGroupNormNumberOfGroups)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "only num_groups=32 is supported. Got", num_groups); + } + + GroupNormNHWCParams params; + int32_t cPerBlock = 320; + int32_t maxBlocksPerHW = 1024; + switch (num_channels) { + case 960: + case 1920: + cPerBlock = 480; + break; + case 512: + case 256: + cPerBlock = 256; + break; + case 128: + cPerBlock = 128; + break; + default: + cPerBlock = 320; + } + + params.withSwish = use_swish_activation; + params.dst = output; + params.src = input; + params.gamma = gamma; + params.beta = beta; + params.redBuffer = reinterpret_cast(workspace); + params.n = batch_size; + params.h = height; + params.w = width; + params.c = num_channels; + params.groups = num_groups; + params.hw = params.h * params.w; + const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); + params.hwPerBlock = divUp(params.hw, blocksPerHW); + params.cPerBlock = cPerBlock; + params.cPerGroup = params.c / params.groups; + params.hwc = params.hw * params.c; + params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); + params.groupsPerBlock = cPerBlock / params.cPerGroup; + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("input", input, batch_size, num_channels, height * width); + DUMP_TENSOR("gamma", gamma, 1, num_channels); + DUMP_TENSOR("beta", beta, 1, num_channels); + cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); + groupNormNHWCSum(params, stream); + DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + groupNormNHWCScale(params, stream); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); +} + +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, + const half* input, const float* gamma, const float* beta, void* workspace, + float epsilon, int batch_size, int num_channels, + int height, int width, int num_groups, bool swish); + +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, + const float* input, const float* gamma, const float* beta, void* workspace, + float epsilon, int batch_size, int num_channels, + int height, int width, int num_groups, bool swish); +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h new file mode 100644 index 0000000000000..c7e9245050ee6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/common/status.h" +#include +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +constexpr size_t kMaxGroupNormBatchSize = 32; +constexpr size_t kGroupNormNumberOfGroups = 32; + +constexpr size_t GetGroupNormWorkspaceSizeInBytes() { + // Two buffers for sum and squared sum + return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; +} + +template +Status LaunchGroupNormKernel( + cudaStream_t stream, + T* output, // normalized output tensor + const T* input, // input tensor + const float* gamma, // gamma (also known as weight or scale) + const float* beta, // beta (also known as bias) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_swish_activation // Whether there is Swish activation after group normalization +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc new file mode 100644 index 0000000000000..79f0a18ba515f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/span_utils.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/nn/conv.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + NhwcConv, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Conv); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index 39c3bb282d912..48881ddca4063 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -9,10 +9,10 @@ namespace contrib { namespace cuda { template -class FusedConv : public onnxruntime::cuda::Conv { +class FusedConv : public onnxruntime::cuda::Conv { public: - using Base = onnxruntime::cuda::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { + using Base = onnxruntime::cuda::Conv; + FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { std::string activation; if (info.GetAttr("activation", &activation) == Status::OK() && MapMode(activation) == Status::OK() && diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index e5ea47a6a2a5b..90ec1a35ac63a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -52,7 +52,7 @@ Status QAttention::CheckInputs(const Tensor* input, auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor, - nullptr, // extra_add_qk + nullptr, // relative_position_bias parameters, device_prop.maxThreadsPerBlock)); @@ -174,6 +174,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present = context->Output(1, present_shape); void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up + bool use_fused_cross_attention = false; bool use_memory_efficient_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, @@ -184,6 +185,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_fused_cross_attention, use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -198,13 +200,16 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); - data.extra_add_qk = nullptr; // add_qk is not supported in quantized attention + data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention + data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); data.fused_runner = fused_runner; data.fused_cross_attention_kernel = nullptr; data.use_memory_efficient_attention = use_memory_efficient_attention; + data.cumulated_sequence_length_q_cache = nullptr; + data.cumulated_sequence_length_kv_cache = nullptr; return QkvToContext(GetDeviceProp(), cublas, Stream(context), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 204c786cc2c5d..8122b2de5916b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -212,7 +212,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), merged_weights_shape, merged_bias_shape, mask_index, nullptr, // past - nullptr, // extra_add_qk + nullptr, // relative_position_bias nullptr, // parameters device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h index 5fe62ef127800..5fb31be5fe86f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h @@ -17,4 +17,4 @@ DefineQOrderedAttentionInput(Scale_QK_Softmax, scale_QKT_softmax, 15), DefineQOrderedAttentionInput(Scale_Values_Gemm, scale_values_gemm, 16), DefineQOrderedAttentionInput(Mask_Index, mask_index, 17), DefineQOrderedAttentionInput(Past, past, 18), -DefineQOrderedAttentionInput(Extra_Add, extra_add, 19) +DefineQOrderedAttentionInput(relative_position_bias, relative_position_bias, 19) diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu index 5c54c03a05d1a..dcbc733f2acb2 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu @@ -291,7 +291,7 @@ void LaunchBatchTopKKernel(const T* topk_scores, int32_t num_beams, int32_t k, cudaStream_t stream) { - ORT_ENFORCE(k <= 256, "LaunchBatchTopKKernel doesn't support k >= 256"); + ORT_ENFORCE(k <= 64, "LaunchBatchTopKKernel doesn't support k >= 64"); #define BatchTopKKernelLauncher(K) \ BatchTopKKernel<<>>(topk_scores, \ @@ -311,12 +311,8 @@ void LaunchBatchTopKKernel(const T* topk_scores, BatchTopKKernelLauncher(16); } else if (k <= 32) { BatchTopKKernelLauncher(32); - } else if (k <= 64) { - BatchTopKKernelLauncher(64); - } else if (k <= 128) { - BatchTopKKernelLauncher(128); } else { - BatchTopKKernelLauncher(256); + BatchTopKKernelLauncher(64); } } @@ -330,36 +326,6 @@ template void LaunchBatchTopKKernel(const float* topk_scores, int32_t k, cudaStream_t stream); -template void LaunchBatchTopKKernel(const float* topk_scores, - const int64_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - float* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - -template void LaunchBatchTopKKernel(const half* topk_scores, - const int32_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - half* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - -template void LaunchBatchTopKKernel(const half* topk_scores, - const int64_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - half* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - template void BeamSearchTopK( const T* input, @@ -426,21 +392,6 @@ template void BeamSearchTopK( int32_t* output_indices, cudaStream_t stream); -template void BeamSearchTopK( - const half* input, - int32_t batch_size, - int32_t num_beams, - int32_t vocab_size, - int32_t k, - half* tmp_values_1st_stage, - int32_t* tmp_indices_1st_stage, - half* tmp_values_2st_stage, - int32_t* tmp_indices_2st_stage, - half* output_values, - int32_t* output_tokens, - int32_t* output_indices, - cudaStream_t stream); - } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h index 5e338b417e8a5..096448c002e36 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h @@ -11,18 +11,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template -void LaunchBatchTopKKernel( - const T* topk_scores, - const I* topk_indices, - int32_t* next_indices, - int32_t* next_tokens, - T* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - template void BeamSearchTopK( const T* input, diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index 6c0f7f69c58a1..3046a58040635 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -11,18 +11,18 @@ namespace contrib { namespace cuda { namespace transformers { -#ifdef DEBUG_GENERATION +#if DUMP_TENSOR_LEVEL > 0 template class PinnedHostBuffer { public: PinnedHostBuffer(size_t length) : buffer_(nullptr) { - cudaHostAlloc(&buffer_, length * sizeof(T), cudaHostAllocDefault); + CUDA_CALL_THROW(cudaHostAlloc((void**)&buffer_, length * sizeof(T), cudaHostAllocDefault)); } virtual ~PinnedHostBuffer() { if (buffer_) { - cudaFreeHost(buffer_); + CUDA_CALL_THROW(cudaFreeHost(buffer_)); } } @@ -46,8 +46,9 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool i // In that case, we copy tensor data as well. It is not needed, but it keeps code simple. int num_items = dim0 * dim1; auto data = std::make_shared>(num_items); - cudaDeviceSynchronize(); - cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost); + CUDA_CALL_THROW(cudaDeviceSynchronize()); + CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost)); + if (nullptr != name) { std::cout << std::string(name) << std::endl; @@ -64,8 +65,8 @@ template void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, bool is_gpu_tensor) { int num_items = dim0 * dim1 * dim2; auto data = std::make_shared>(num_items); - cudaDeviceSynchronize(); - cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost); + CUDA_CALL_THROW(cudaDeviceSynchronize()); + CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost)); if (nullptr != name) { std::cout << std::string(name) << std::endl; @@ -82,8 +83,8 @@ template void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, int dim3, bool is_gpu_tensor) { int num_items = dim0 * dim1 * dim2 * dim3; auto data = std::make_shared>(num_items); - cudaDeviceSynchronize(); - cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost); + CUDA_CALL_THROW(cudaDeviceSynchronize()); + CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost)); if (nullptr != name) { std::cout << std::string(name) << std::endl; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 523603a550be9..90c91228204b6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -320,33 +320,33 @@ void GetTempStorageSize(const T* d_keys_in, bool is_descending, size_t& temp_storage_bytes) { if (is_descending) { - cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, - temp_storage_bytes, - d_keys_in, - (T*)nullptr, - d_values_in, - (int*)nullptr, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, + temp_storage_bytes, + d_keys_in, + (T*)nullptr, + d_values_in, + (int*)nullptr, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream)); } else { - cub::DeviceSegmentedRadixSort::SortPairs(nullptr, - temp_storage_bytes, - d_keys_in, - (T*)nullptr, - d_values_in, - (int*)nullptr, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(nullptr, + temp_storage_bytes, + d_keys_in, + (T*)nullptr, + d_values_in, + (int*)nullptr, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream)); } } @@ -412,33 +412,33 @@ void LaunchSortPairs(void* d_temp_storage, cudaStream_t stream, bool is_descending) { if (is_descending) { - cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_keys_out, - d_values_in, - d_values_out, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream)); } else { - cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_keys_out, - d_values_in, - d_values_out, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream)); } } @@ -721,9 +721,9 @@ void TorchMultinomialKernelLauncher(float* d_input, cudaStream_t stream) { // Store the props in class variables int device; - cudaGetDevice(&device); + CUDA_CALL_THROW(cudaGetDevice(&device)); cudaDeviceProp props; - cudaGetDeviceProperties(&props, device); + CUDA_CALL_THROW(cudaGetDeviceProperties(&props, device)); int numSM = props.multiProcessorCount; int maxThreads = props.maxThreadsPerBlock; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 1a5a9ac5d97b2..e4846e86d1eb2 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -17,14 +17,23 @@ #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" #include "contrib_ops/cuda/transformers/greedy_search_top_one.h" + +// the includes would be dummy for ROCm, we will ignore them for now +#ifdef ENABLE_NVTX_PROFILE #include "core/providers/cuda/nvtx_profile.h" #include "core/providers/cuda/nvtx_profile_context.h" +#endif + #include "sampling_cuda_helper.h" #ifdef DEBUG_GENERATION #include #endif +using onnxruntime::cuda::ToCudaType; +using onnxruntime::cuda::TArray; +using onnxruntime::cuda::TopKImpl; + namespace onnxruntime { namespace concurrency { class ThreadPool; @@ -131,7 +140,7 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, ORT_ENFORCE(total_bytes > 0); - AllocatorPtr pinned_allocator = provider->GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU); + AllocatorPtr pinned_allocator = provider->GetAllocator(OrtMemTypeCPU); cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; auto pinned_buffer = IAllocator::MakeUniquePtr(pinned_allocator, total_bytes); char* pinned_data = static_cast(pinned_buffer.get()); @@ -168,7 +177,7 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, stream)); CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone)); // TODO(tianleiwu): allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call. - const OrtMemoryInfo& location = provider->GetAllocator(0, OrtMemTypeDefault)->Info(); + const OrtMemoryInfo& location = provider->GetAllocator(OrtMemTypeDefault)->Info(); for (auto& input : inputs) { if (input.IsAllocated()) { const Tensor& tensor = input.Get(); @@ -203,12 +212,13 @@ void InitBeamState(transformers::IBeamSearchState* beam_state, // TODO(tianleiwu): we can use another stream to avoid blocking subgraph execution. cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->next_scores.data(), 0, beam_state->next_scores.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->topk_buffer.data(), 0, beam_state->topk_buffer.size_bytes(), cuda_stream); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_scores.data(), 0, beam_state->next_scores.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->topk_buffer.data(), 0, beam_state->topk_buffer.size_bytes(), cuda_stream)); + // Initialize score of first beam of each group with 0 and the rest with -1e9. cuda::LaunchInitKernel(beam_state->beam_scores.data(), batch_size, num_beams, cuda_stream); @@ -216,8 +226,8 @@ void InitBeamState(transformers::IBeamSearchState* beam_state, // copy sequence lengths to GPU // since next_positions is only needed to update feeds after subgraph execution, so it is fine to use Async here. if (!beam_state->next_positions.empty()) { // next_positions is empty for T5 - cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), - cudaMemcpyHostToDevice, cuda_stream); + CUDA_CALL_THROW(cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), + cudaMemcpyHostToDevice, cuda_stream)); } #ifdef ENABLE_NVTX_PROFILE @@ -234,12 +244,12 @@ void InitGreedyState(transformers::IGreedySearchState* greedy_state, initStateRange.Begin(); #endif - cudaStream_t cuda_stream = ort_stream ? reinterpret_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(greedy_state->next_token_scores.data(), 0, greedy_state->next_token_scores.size_bytes(), cuda_stream); - cudaMemsetAsync(greedy_state->next_positions.data(), 0, greedy_state->next_positions.size_bytes(), cuda_stream); + cudaStream_t cuda_stream = ort_stream ? reinterpret_cast(ort_stream->GetHandle()) : nullptr; + CUDA_CALL_THROW(cudaMemsetAsync(greedy_state->next_token_scores.data(), 0, greedy_state->next_token_scores.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(greedy_state->next_positions.data(), 0, greedy_state->next_positions.size_bytes(), cuda_stream)); - cudaMemcpyAsync(greedy_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), - cudaMemcpyHostToDevice, cuda_stream); + CUDA_CALL_THROW(cudaMemcpyAsync(greedy_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), + cudaMemcpyHostToDevice, cuda_stream)); #ifdef ENABLE_NVTX_PROFILE initStateRange.End(); @@ -337,11 +347,11 @@ Status ProcessLogits(const OrtValue& logits, // const CudaT* X_data = is_reuse_logits_buffer ? logits_data : reinterpret_cast(next_token_logits.data()); - dispatch_blockwise_softmax_forward( + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward( cuda_stream, Y_data, X_data, vocab_size, is_reuse_logits_buffer ? padded_vocab_size : vocab_size, vocab_size, - batch_size * num_beams); + batch_size * num_beams))); #ifdef DEBUG_GENERATION dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); @@ -430,12 +440,16 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams); dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams); #endif + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), + beam_state->next_scores.data(), + beam_state->next_scores.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); } else { // Apply top-k selection like the following: // next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) // next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True) - // int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size}; - int64_t next_token_scores_dims[] = {batch_size * num_beams, vocab_size}; + int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size}; TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); auto element_type = DataTypeImpl::GetType(); @@ -450,31 +464,36 @@ Status ProcessLogits(const OrtValue& logits, // constexpr bool sorted = true; // results returned in sorted order. std::unique_ptr topk_scores = Tensor::CreateDefault(); - std::unique_ptr topk_tokens = Tensor::CreateDefault(); + std::unique_ptr topk_indices = Tensor::CreateDefault(); ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, ort_stream, thread_pool, - *topk_scores, *topk_tokens)); + *topk_scores, *topk_indices)); #ifdef DEBUG_GENERATION dumper->Print("topk_scores", *(topk_scores.get())); - dumper->Print("topk_tokens", *(topk_tokens.get())); + dumper->Print("topk_indices", *(topk_indices.get())); #endif - cuda::LaunchBatchTopKKernel(topk_scores->Data(), - topk_tokens->Data(), - beam_state->next_indices.data(), - beam_state->next_tokens.data(), - beam_state->next_scores.data(), - batch_size, - num_beams, - 2 * num_beams, - cuda_stream); + // Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following: + // next_indices = (next_tokens / vocab_size).long() + // next_tokens = next_tokens % vocab_size + const int64_t* next_token_indices = topk_indices->Data(); + cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(), + batch_size, top_k, vocab_size, cuda_stream); + + const float* data = topk_scores->Data(); +#ifdef DEBUG_GENERATION + dumper->Print("next_scores before scorer", data, batch_size, top_k); + dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k); + dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k); +#endif + + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), + data, + topk_scores->SizeInBytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); } - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), - beam_state->next_scores.data(), - beam_state->next_scores.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream)); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(), beam_state->next_tokens.data(), beam_state->next_tokens.size_bytes(), diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index d82648890f94f..2a5875aba5fa1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -11,6 +11,9 @@ #include #endif +using onnxruntime::cuda::ToCudaType; +using onnxruntime::cuda::dispatch_blockwise_softmax_forward; + namespace onnxruntime { namespace contrib { namespace SamplingCudaHelper { @@ -88,14 +91,14 @@ Status Sample(AllocatorPtr& allocator, #endif gsl::span& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream, - d_sorted_softmaxed_score.data(), - reinterpret_cast(d_sorted_score.data()), - parameters->vocab_size, - parameters->vocab_size, - parameters->vocab_size, - parameters->batch_size); - + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(cuda_stream, + d_sorted_softmaxed_score.data(), + reinterpret_cast(d_sorted_score.data()), + parameters->vocab_size, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size))); + #ifdef DEBUG_GENERATION dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), @@ -122,13 +125,13 @@ Status Sample(AllocatorPtr& allocator, #endif gsl::span& d_softmaxed_score = sampling_state->d_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream, - d_softmaxed_score.data(), - reinterpret_cast(next_token_scores.data()), - parameters->vocab_size, - parameters->vocab_size, - parameters->vocab_size, - parameters->batch_size); + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(cuda_stream, + d_softmaxed_score.data(), + reinterpret_cast(next_token_scores.data()), + parameters->vocab_size, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size))); #ifdef DEBUG_GENERATION dumper->Print("d_softmaxed_score_buffer", diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc index 756919834aef8..1210442580e8b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cc +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc @@ -15,6 +15,10 @@ namespace onnxruntime { namespace contrib { namespace rocm { +constexpr int kPastSequenceLengthInputIndex = 6; +constexpr int kPastInputIndex = 4; +constexpr int kPresentOutputIndex = 1; + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Attention, \ @@ -22,8 +26,10 @@ namespace rocm { 1, \ T, \ kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ Attention); REGISTER_KERNEL_TYPED(float) @@ -39,48 +45,42 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); + const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); + AttentionParameters parameters; ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past, - extra_add_qk, - nullptr, - device_prop.maxThreadsPerBlock)); - - // input shape (batch_size, sequence_length, input_hidden_size) - const auto& shape = input->Shape(); - int batch_size = static_cast(shape[0]); - int sequence_length = static_cast(shape[1]); - int input_hidden_size = static_cast(shape[2]); - - // Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in ROCM EP - // bias shape (3 * hidden_size) - const auto& bias_shape = bias->Shape(); - int hidden_size = static_cast(bias_shape[0]) / 3; - - int head_size = hidden_size / num_heads_; + relative_position_bias, + ¶meters, + device_prop.maxThreadsPerBlock, + past_seq_len)); + ORT_ENFORCE(parameters.sequence_length == parameters.kv_sequence_length); // self attention TensorShapeVector output_shape(3); - output_shape[0] = shape[0]; - output_shape[1] = shape[1]; - output_shape[2] = static_cast(hidden_size); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(parameters.sequence_length); + output_shape[2] = static_cast(parameters.v_hidden_size); Tensor* output = context->Output(0, output_shape); - int past_sequence_length = 0; - Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); + std::vector present_dims{ + 2, parameters.batch_size, parameters.num_heads, + parameters.past_present_share_buffer ? parameters.max_sequence_length : parameters.total_sequence_length, + parameters.head_size}; + TensorShape present_shape(present_dims); + Tensor* present = context->Output(kPresentOutputIndex, present_shape); rocblas_handle rocblas = GetRocblasHandle(context); constexpr size_t element_size = sizeof(T); - // Use GEMM for fully connection. - int m = batch_size * sequence_length; - int n = 3 * hidden_size; - int k = input_hidden_size; - auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size, context->GetComputeStream()); + int m = parameters.batch_size * parameters.sequence_length; + int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); + int k = parameters.input_hidden_size; + auto gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); typedef typename ToHipType::MappedType HipT; namespace blas = rocm::tunable::blas; @@ -88,7 +88,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B. // TODO: use custom kernel of expand to improve the performance. ORT_RETURN_IF_ERROR(blas::column_major::Gemm( - IsTunableOpEnabled(), Stream(context), rocblas, + GetTuningContext(), Stream(context), rocblas, blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, n, m, 1, /*alpha=*/1.0f, @@ -99,7 +99,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // result(N, M) = 1 * weights x input + 1 x B. ORT_RETURN_IF_ERROR(blas::column_major::Gemm( - IsTunableOpEnabled(), Stream(context), rocblas, + GetTuningContext(), Stream(context), rocblas, blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, n, m, k, /*alpha=*/1.0f, @@ -108,28 +108,32 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { /*beta=*/1.0f, reinterpret_cast(gemm_buffer.get()), n)); - size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, - sequence_length, past_sequence_length); + size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, + parameters.batch_size, + parameters.num_heads, + parameters.head_size, + parameters.sequence_length, + parameters.past_sequence_length); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); return LaunchAttentionKernel( device_prop, - IsTunableOpEnabled(), + GetTuningContext(), Stream(context), rocblas, element_size, - batch_size, - sequence_length, - num_heads_, - head_size, - past_sequence_length, - is_unidirectional_, + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.past_sequence_length, + parameters.is_unidirectional, reinterpret_cast(gemm_buffer.get()), nullptr == mask_index ? nullptr : mask_index->Data(), nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), - mask_filter_value_, + parameters.mask_filter_value, nullptr == past ? nullptr : past->Data(), - nullptr == extra_add_qk ? nullptr : extra_add_qk->Data(), + nullptr == relative_position_bias ? nullptr : relative_position_bias->Data(), work_space.get(), output->MutableData(), nullptr == present ? nullptr : present->MutableData()); diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 954a129be1c65..e42fb2b2eb9dd 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -72,7 +72,7 @@ size_t GetAttentionWorkspaceSize( template Status QkvToContext( const hipDeviceProp_t& prop, - bool tuning, + RocmTuningContext* tuning_ctx, rocblas_handle& rocblas, hipStream_t stream, const int batch_size, @@ -89,7 +89,7 @@ Status QkvToContext( bool is_unidirectional, int past_sequence_length, const T* past, - const T* extra_add_qk, + const T* relative_position_bias, T* present, bool use_persistent_softmax) { const int all_sequence_length = past_sequence_length + sequence_length; @@ -139,7 +139,7 @@ Status QkvToContext( const int temp_matrix_size = sequence_length * all_sequence_length; ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning, stream, rocblas, + tuning_ctx, stream, rocblas, blas::BlasOp::Trans, blas::BlasOp::NonTrans, all_sequence_length, sequence_length, head_size, // For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation. @@ -158,7 +158,7 @@ Status QkvToContext( T* persistent_softmax_workspace = scratch1; // replace Q*K' in place if persistent softmax is selected. ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, extra_add_qk, scratch1, scratch2, + mask_index, nullptr, relative_position_bias, scratch1, scratch2, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index @@ -166,15 +166,15 @@ Status QkvToContext( // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)); + mask_index, mask_start, relative_position_bias, scratch1, scratch2, is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, - extra_add_qk, scratch1, scratch2, is_unidirectional)); + relative_position_bias, scratch1, scratch2, is_unidirectional)); } // compute P*V (as V*P), and store in scratch3: BxNxSxH ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning, stream, rocblas, + tuning_ctx, stream, rocblas, blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, head_size, sequence_length, all_sequence_length, /*alpha=*/1.0f, @@ -191,7 +191,7 @@ Status QkvToContext( Status LaunchAttentionKernel( const hipDeviceProp_t& prop, - bool tuning, + RocmTuningContext* tuning_ctx, hipStream_t stream, rocblas_handle& rocblas, const size_t element_size, @@ -206,7 +206,7 @@ Status LaunchAttentionKernel( gsl::span mask_index_dims, const float mask_filter_value, const void* past, - const void* extra_add_qk, + const void* relative_position_bias, void* workspace, void* output, void* present) { @@ -215,7 +215,7 @@ Status LaunchAttentionKernel( bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); if (element_size == 2) { return QkvToContext( - prop, tuning, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size, + prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size, reinterpret_cast(input), reinterpret_cast<__half*>(output), reinterpret_cast<__half*>(workspace), @@ -225,12 +225,12 @@ Status LaunchAttentionKernel( is_unidirectional, past_sequence_length, reinterpret_cast(past), - reinterpret_cast(extra_add_qk), + reinterpret_cast(relative_position_bias), reinterpret_cast<__half*>(present), use_persistent_softmax); } else { return QkvToContext( - prop, tuning, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size, + prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size, reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), @@ -240,7 +240,7 @@ Status LaunchAttentionKernel( is_unidirectional, past_sequence_length, reinterpret_cast(past), - reinterpret_cast(extra_add_qk), + reinterpret_cast(relative_position_bias), reinterpret_cast(present), use_persistent_softmax); } @@ -249,7 +249,7 @@ Status LaunchAttentionKernel( template Status DecoderQkvToContext( const hipDeviceProp_t& prop, - bool tuning, + RocmTuningContext* tuning_ctx, hipStream_t stream, rocblas_handle& rocblas, const size_t element_size, @@ -352,7 +352,7 @@ Status DecoderQkvToContext( const int strideB = sequence_length * head_size; if (use_past && static_kv) { ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning, stream, rocblas, + tuning_ctx, stream, rocblas, blas::BlasOp::Trans, blas::BlasOp::NonTrans, kv_sequence_length, sequence_length, head_size, /*alpha=*/rsqrt_head_size, @@ -363,7 +363,7 @@ Status DecoderQkvToContext( BN)); } else { ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning, stream, rocblas, + tuning_ctx, stream, rocblas, blas::BlasOp::Trans, blas::BlasOp::NonTrans, kv_sequence_length, sequence_length, head_size, /*alpha=*/rsqrt_head_size, @@ -386,7 +386,7 @@ Status DecoderQkvToContext( // compute P*V (as V*P), and store in scratch3: BxNxSxH if (use_past && static_kv) { ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning, stream, rocblas, + tuning_ctx, stream, rocblas, blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, head_size, sequence_length, kv_sequence_length, /*alpha=*/1.0f, @@ -397,7 +397,7 @@ Status DecoderQkvToContext( BN)); } else { ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning, stream, rocblas, + tuning_ctx, stream, rocblas, blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, head_size, sequence_length, kv_sequence_length, /*alpha=*/1.0f, @@ -415,7 +415,7 @@ Status DecoderQkvToContext( Status LaunchDecoderAttentionKernel( const hipDeviceProp_t& prop, - bool tuning, + RocmTuningContext* tuning_ctx, hipStream_t stream, rocblas_handle& rocblas, const size_t element_size, @@ -442,7 +442,7 @@ Status LaunchDecoderAttentionKernel( if (element_size == 2) { return DecoderQkvToContext( prop, - tuning, + tuning_ctx, stream, rocblas, element_size, @@ -469,7 +469,7 @@ Status LaunchDecoderAttentionKernel( } else { return DecoderQkvToContext( prop, - tuning, + tuning_ctx, stream, rocblas, element_size, diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 7db692083f5e5..3fcfeb51752c3 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -6,6 +6,7 @@ #include #include #include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" namespace onnxruntime { namespace contrib { @@ -27,7 +28,7 @@ size_t GetAttentionWorkspaceSize( Status LaunchAttentionKernel( const hipDeviceProp_t& prop, // Device Properties - bool tuning, // Whether to enable tuning + RocmTuningContext* tuning_ctx, // context for tuning hipStream_t stream, // Hip stream rocblas_handle& rocblas, // Rocblas handle const size_t element_size, // Element size of input tensor @@ -42,7 +43,7 @@ Status LaunchAttentionKernel( gsl::span mask_index_dims, // Mask index shape const float mask_filter_value, // Mask value for filtered out positions const void* past, // Past state input - const void* extra_add_qk, // Additional Add + const void* relative_position_bias, // Additional Add void* workspace, // Temporary buffer void* output, // Output tensor void* present // Present state output @@ -50,7 +51,7 @@ Status LaunchAttentionKernel( Status LaunchDecoderAttentionKernel( const hipDeviceProp_t& prop, // Device Properties - bool tuning, // Whether to enable tuning + RocmTuningContext* tuning_ctx, // context for tuning hipStream_t stream, // Hip stream rocblas_handle& rocblas, // Rocblas handle const size_t element_size, // Element size of input tensor diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h index 27ecdf253ecdb..7c99fc05ec9ee 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -513,12 +513,12 @@ Status ComputeSoftmaxWithRawMask(hipStream_t stream, } if (use_persistent_softmax) { - dispatch_warpwise_softmax_forward(stream, - output, - persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, - batch_size * num_heads * sequence_length); + return dispatch_warpwise_softmax_forward(stream, + output, + persistent_softmax_workspace, + all_sequence_length, + all_sequence_length, + batch_size * num_heads * sequence_length); } return HIP_CALL(hipPeekAtLastError()); diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc index 8942db99be5b8..b81c511124c0d 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc @@ -48,7 +48,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToHipType::MappedType HipT; - return LaunchFastGeluKernel(IsTunableOpEnabled(), + return LaunchFastGeluKernel(GetTuningContext(), Stream(context), static_cast(input_length), static_cast(bias_length), diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu index 009ea9e0faef8..f0ec070663d64 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu @@ -40,25 +40,27 @@ namespace contrib { namespace rocm { template -Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, int input_length, int bias_length, const T* input, const T* bias, T* output) { - FastGeluParams params(stream, input, bias, output, input_length, bias_length); - if (tuning) { + FastGeluParams params(tuning_ctx, stream, input, bias, output, input_length, bias_length); + if (tuning_ctx->IsTunableOpEnabled()) { static FastGeluTunableOp op; - op.EnableTuning(); return op(¶ms); } return FastGeluStaticSelection(¶ms); } -template Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length, +template Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, + int input_length, int bias_length, const float* input, const float* bias, float* output); -template Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length, +template Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, + int input_length, int bias_length, const BFloat16* input, const BFloat16* bias, BFloat16* output); -template Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length, +template Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, + int input_length, int bias_length, const half* input, const half* bias, half* output); } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h index 7f6475e3d6d31..f120c0559f9b0 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h @@ -6,14 +6,16 @@ // Licensed under the MIT License. #pragma once + #include "core/common/common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" namespace onnxruntime { namespace contrib { namespace rocm { template -Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, int input_length, int bias_length, const T* input, const T* bias, T* output); } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h index 0498927fed65e..72c3ee01038c7 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h @@ -19,9 +19,9 @@ namespace contrib { namespace rocm { template -struct FastGeluParams : onnxruntime::rocm::tunable::OpParams { - FastGeluParams(hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) : - OpParams(stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {} +struct FastGeluParams : OpParams { + FastGeluParams(RocmTuningContext* tuning_ctx, hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) : + OpParams(tuning_ctx, stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {} std::string Signature() const override { std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length); @@ -119,7 +119,7 @@ Status FastGeluStaticSelection(const FastGeluParams* params) { this->RegisterOp(FastGeluOp{}); template -class FastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { +class FastGeluTunableOp : public TunableOp> { public: FastGeluTunableOp() { this->RegisterOp(FastGeluStaticSelection); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc index 453c82a2ed6f6..8b0cc98964581 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc @@ -58,7 +58,7 @@ Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { using onnxruntime::rocm::tunable::blas::BlasOp; return blas::row_major::GemmFastGelu( - IsTunableOpEnabled(), + GetTuningContext(), Stream(ctx), GetRocblasHandle(ctx), transa ? BlasOp::Trans : BlasOp::NonTrans, transb ? BlasOp::Trans : BlasOp::NonTrans, diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh index 3cafa7d0dfa2d..f1fa3a7f5c848 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -19,7 +19,6 @@ #include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" using onnxruntime::rocm::ToHipType; -using onnxruntime::rocm::tunable::Op; namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h index 9bb095110174e..dd98b76153cc2 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h @@ -19,7 +19,7 @@ namespace rocm { namespace blas { template -struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams { +struct GemmFastGeluParams : OpParams { std::string Signature() const override { bool has_bias = (nullptr != bias) ? 0 : 1; return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); @@ -39,7 +39,6 @@ struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams { T beta; T* c; int64_t ldc; - bool tuning{false}; }; } // namespace blas diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu index 039573e585c7d..294e7be91e883 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu @@ -20,6 +20,7 @@ namespace row_major { template inline GEMMFASTGELU(T, ScalarT) { GemmFastGeluParams params; + params.tuning_ctx = tuning_ctx; params.stream = stream; params.handle = handle; @@ -46,23 +47,18 @@ inline GEMMFASTGELU(T, ScalarT) { params.c = c; params.ldc = ldc; - if (tunable) { - params.tuning = true; + if (tuning_ctx->IsTunableOpEnabled()) { if (opa == BlasOp::N && opb == BlasOp::N) { static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - gemm_fast_gelu.EnableTuning(); return gemm_fast_gelu(¶ms); } else if (opa == BlasOp::T && opb == BlasOp::N) { static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - gemm_fast_gelu.EnableTuning(); return gemm_fast_gelu(¶ms); } else if (opa == BlasOp::N && opb == BlasOp::T) { static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - gemm_fast_gelu.EnableTuning(); return gemm_fast_gelu(¶ms); } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - gemm_fast_gelu.EnableTuning(); return gemm_fast_gelu(¶ms); } } @@ -71,7 +67,7 @@ inline GEMMFASTGELU(T, ScalarT) { } #define CALL_GEMMFASTGELU(T, ScalarT) \ - GemmFastGelu(tunable, stream, handle, \ + GemmFastGelu(tuning_ctx, stream, handle, \ opa, opb, \ m, n, k, \ alpha, a, lda, b, ldb, bias, \ diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h index 3daeea07b62da..637572e51b315 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h @@ -14,7 +14,7 @@ namespace blas { #define GEMMFASTGELU(T, ScalarT) \ common::Status GemmFastGelu( \ - bool tunable, hipStream_t stream, rocblas_handle handle, \ + RocmTuningContext* tuning_ctx, hipStream_t stream, rocblas_handle handle, \ BlasOp opa, BlasOp opb, \ std::int64_t m, std::int64_t n, std::int64_t k, \ ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh index 230951dc45f7c..24058f80db19b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh @@ -21,7 +21,7 @@ namespace internal { template Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { namespace column_major = onnxruntime::rocm::tunable::blas::column_major; - ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning, params->stream, params->handle, + ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning_ctx, params->stream, params->handle, params->opb, params->opa, params->n, params->m, params->k, params->alpha, params->b, params->ldb, params->a, params->lda, @@ -41,7 +41,7 @@ Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { // // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp // to protect original input value. - return onnxruntime::contrib::rocm::LaunchFastGeluKernel(params->tuning, + return onnxruntime::contrib::rocm::LaunchFastGeluKernel(params->tuning_ctx, params->stream, static_cast(fast_gelu_input_length), static_cast(bias_length), @@ -51,7 +51,7 @@ Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { } template -class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { +class GemmFastGeluTunableOp : public TunableOp> { public: GemmFastGeluTunableOp() { this->RegisterOp(GemmFastGeluUnfused); diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc index c1b99eb7f5b1d..a254a8c04fbcd 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc @@ -43,6 +43,10 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { Tensor* output = ctx->Output(0, input->Shape()); + // For inferencing, we support one more optional output which is the sum + // of the input and skip tensors + Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + if (input->Shape() != skip->Shape()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "skip is expected to have same shape as input"); @@ -98,9 +102,10 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { typedef typename ToHipType::MappedType HipT; return LaunchSkipLayerNormKernel( - IsTunableOpEnabled(), + GetTuningContext(), Stream(ctx), reinterpret_cast(output->MutableData()), + skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, reinterpret_cast(input->Data()), reinterpret_cast(skip->Data()), reinterpret_cast(gamma->Data()), diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu index 13baf5ec1e5d5..bf33f940b3936 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu @@ -41,31 +41,32 @@ namespace rocm { template Status LaunchSkipLayerNormKernel( - bool tuning, hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, int ld, int element_count) { + RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, + const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, int ld, int element_count) { // this must be true because element_count is the total size of the tensor assert(element_count % ld == 0); - SkipLayerNormParams params(stream, output, input, skip, gamma, beta, bias, epsilon, ld, element_count); + SkipLayerNormParams params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, gamma, beta, bias, epsilon, ld, element_count); - if (tuning) { + if (tuning_ctx->IsTunableOpEnabled()) { static SkipLayerNormTunableOp op; - op.EnableTuning(); return op(¶ms); } return SkipLayerNormStaticSelection(¶ms); } -template Status LaunchSkipLayerNormKernel(bool tuning, hipStream_t stream, float* output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel(bool tuning, hipStream_t stream, half* output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, float* skip_input_bias_add_output, const float* input, + const float* skip, const float* gamma, const float* beta, + const float* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, half* skip_input_bias_add_output, const half* input, + const half* skip, const half* gamma, const half* beta, + const half* bias, float epsilon, int ld, + int element_count); } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h index c32b6c48a8441..911164af92292 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h @@ -2,7 +2,9 @@ // Licensed under the MIT License. #pragma once + #include "core/common/common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" namespace onnxruntime { namespace contrib { @@ -10,9 +12,10 @@ namespace rocm { template Status LaunchSkipLayerNormKernel( - bool tuning, + RocmTuningContext* tuning, hipStream_t stream, T* output, // output tensor + T* skip_input_bias_add_output, // optional output tensor const T* input, // input tensor const T* skip, // skip tensor const T* gamma, // Layer normalization gamma tensor diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h index ee38b1c7e70cf..bcef54871a837 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h @@ -26,7 +26,7 @@ half maybe2half(float x) { template __global__ void SkipLayerNormKernel( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output) { + const T epsilon, T* output, T* skip_input_bias_add_output) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; @@ -39,6 +39,11 @@ __global__ void SkipLayerNormKernel( const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i]; const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); + + if (skip_input_bias_add_output != nullptr) { + skip_input_bias_add_output[idx] = val; + } + output[idx] = val; } @@ -49,7 +54,8 @@ __global__ void SkipLayerNormKernel( template __global__ void SkipLayerNormKernelVec( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, bool hasBias) { + const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, + bool hasBias, bool hasSkipInputBiasAdditionOutput) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; @@ -58,7 +64,7 @@ __global__ void SkipLayerNormKernelVec( hipcub::KeyValuePair thread_data(0, 0); using VecT = aligned_vector; - T input_v[ILP], skip_v[ILP], bias_v[ILP]; + T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];; if (threadIdx.x * ILP < ld) { VecT* input_val = reinterpret_cast(&input_v); VecT* skip_val = reinterpret_cast(&skip_v); @@ -76,9 +82,19 @@ __global__ void SkipLayerNormKernelVec( #pragma unroll for (int k = 0; k < ILP; k++) { input_v[k] += hasBias ? skip_v[k] + bias_v[k] : skip_v[k]; + + if (hasSkipInputBiasAdditionOutput) { + skip_input_bias_add_output_v[i] = input_v[i]; + } + const T rldval = reverse_ld * input_v[k]; thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * input_v[k])); } + + if (hasSkipInputBiasAdditionOutput) { + *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + } + *(reinterpret_cast(&output[idx])) = *reinterpret_cast(&input_v[0]); } } @@ -90,12 +106,13 @@ __global__ void SkipLayerNormKernelVec( template __global__ void SkipLayerNormKernelSmall( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, bool hasBias) { + const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, + bool hasBias, bool hasSkipInputBiasAdditionOutput) { const T rld = T(1.f / ld); const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld using VecT = aligned_vector; - T input_v[ILP], skip_v[ILP], bias_v[ILP]; + T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP]; hipcub::KeyValuePair thread_data(T(0.f), T(0.f)); @@ -116,10 +133,20 @@ __global__ void SkipLayerNormKernelSmall( #pragma unroll for (int i = 0; i < ILP; i++) { input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; + + if (hasSkipInputBiasAdditionOutput) { + skip_input_bias_add_output_v[i] = input_v[i]; + } + const T rldval = rld * input_v[i]; rldval_sum += rldval; rldvalsq_sum += rldval * input_v[i]; } + + if (hasSkipInputBiasAdditionOutput) { + *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + } + thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); } LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h index f7ae2bc905dfa..b8d0dfee74f9e 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h @@ -19,12 +19,12 @@ namespace contrib { namespace rocm { template -struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams { - SkipLayerNormParams(hipStream_t stream, T* output, const T* input, +struct SkipLayerNormParams : OpParams { + SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, int ld, int element_count) - : OpParams(stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), - epsilon(epsilon), ld(ld), element_count(element_count) {} + : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), + gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} std::string Signature() const override { std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); @@ -32,6 +32,7 @@ struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams { } T* output; + T* skip_input_bias_add_output; const T* input; const T* skip; const T* gamma; @@ -51,8 +52,8 @@ Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { dim3(ThreadsPerBlock), 0, params->stream>>>( params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, - (params->bias == nullptr) ? false : true); + params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, + (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); return HIP_CALL(hipGetLastError()); } @@ -66,51 +67,52 @@ Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { dim3(ThreadsPerBlock), 0, params->stream>>>( params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, - (params->bias == nullptr) ? false : true); + params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, + (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); return HIP_CALL(hipGetLastError()); } template Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { bool hasBias = (params->bias == nullptr) ? false : true; + bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; if (0 == (params->ld % 4)) { const int grid_size = params->element_count / params->ld; if (params->ld <= 32) { constexpr int block_size = 32; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 64) { constexpr int block_size = 64 / 2; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 128) { constexpr int block_size = 128 / 4; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 384) { constexpr int block_size = 384 / 4; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 768) { constexpr int block_size = 768 / 4; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 1024) { constexpr int block_size = 1024 / 4; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else { constexpr int block_size = 256; SkipLayerNormKernel<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output); } } else { const int grid_size = params->element_count / params->ld; @@ -118,27 +120,27 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { constexpr int block_size = 32; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 64) { constexpr int block_size = 64; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 128) { constexpr int block_size = 128; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld == 384) { constexpr int block_size = 384; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else { constexpr int block_size = 256; SkipLayerNormKernel<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output); } } return HIP_CALL(hipPeekAtLastError()); @@ -160,7 +162,7 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 384) template -class SkipLayerNormTunableOp : public onnxruntime::rocm::tunable::TunableOp> { +class SkipLayerNormTunableOp : public TunableOp> { public: SkipLayerNormTunableOp() { this->RegisterOp(SkipLayerNormStaticSelection); diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc index 1b98142f1fe4b..4049d6dcb1600 100644 --- a/onnxruntime/contrib_ops/rocm/fused_conv.cc +++ b/onnxruntime/contrib_ops/rocm/fused_conv.cc @@ -74,7 +74,10 @@ struct FNVHash { void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { int spatial_dim = 1; - // Current MIOpen doesn't provide API to probe the dimension of a +#if ROCM_VERSION >= 50500 + miopenGetConvolutionDescriptorSize(cdesc, &spatial_dim); +#else + // Previous versions of MIOpen doesn't provide API to probe the dimension of a // miopenConvolutionDescriptor_t, so we have to guess. // This algorithm is based on a specific behavior of miopenGetConvolutionNdDescriptor, // which fails when requestedSpatialDim > the convolution's spatial dimension @@ -113,6 +116,7 @@ struct FNVHash { "miopenGetConvolutionNdDescriptor is supposed to fail before spatial_dim gets to ", spatial_dim); } +#endif } private: uint32_t value_ = BASIS; diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index b92efc3a6109a..252f943c43df3 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -69,6 +69,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); @@ -108,6 +109,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); #endif +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -116,115 +122,120 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, //default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, - // These ops were experimental ops in onnx domain which have been removed now. We add them here as - // contrib ops to maintain backward compatibility - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // These ops were experimental ops in onnx domain which have been removed now. We add them here as + // contrib ops to maintain backward compatibility + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo - BuildKernelCreateInfo, - // TransposedMatMul is still here for backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // TransposedMatMul is still here for backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN - BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif +#if defined(USE_MPI) && defined(ORT_USE_NCCL) + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif - }; + }; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index b950e4e734fa5..03460c9def5bd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -143,6 +143,7 @@ void CPUIDInfo::ArmLinuxInit() { if (pytorch_cpuinfo_init_) { is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); is_armv8_narrow_ld_.resize(core_cnt, false); @@ -165,6 +166,7 @@ void CPUIDInfo::ArmLinuxInit() { } } else { has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + has_fp16_ |= has_arm_neon_dot_; } } @@ -220,9 +222,45 @@ void CPUIDInfo::ArmWindowsInit() { lastUarch = uarch; } } + + switch (lastUarch) { + case cpuinfo_uarch_cortex_a55: + case cpuinfo_uarch_cortex_a55r0: + case cpuinfo_uarch_cortex_a76: + case cpuinfo_uarch_neoverse_n1: + case cpuinfo_uarch_cortex_a77: + case cpuinfo_uarch_exynos_m4: + case cpuinfo_uarch_exynos_m5: + has_fp16_ = true; + break; + default: + break; + } + if (!has_fp16_) { + /* + * Detecting fp16 support. Different cores should have the same instruction set. + * So we just check the first ID_AA64PFR0_EL1 + * Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000), + */ + uint64_t ID_AA64PFR0_EL1; + unsigned long valsize = sizeof(uint64_t); + auto retCode = ::RegGetValueA( + HKEY_LOCAL_MACHINE, + "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", + "CP 4020", RRF_RT_REG_QWORD, nullptr, + &ID_AA64PFR0_EL1, &valsize); + if (retCode == ERROR_SUCCESS) { + // AdvSIMD, bits [23:20] + auto advSimd = ID_AA64PFR0_EL1 >> 20; + if ((advSimd & 0xfULL) == 1) { + has_fp16_ = true; + } + } + } #endif /* Application Family or OneCore Family */ has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); + has_fp16_ |= has_arm_neon_dot_; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 858f8595b8220..c413e0ca7ed5f 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -21,7 +21,7 @@ class CPUIDInfo { bool HasAVX512f() const { return has_avx512f_; } bool HasAVX512_BF16() const {return has_avx512_bf16_;} bool HasAVX512Skylake() const { return has_avx512_skylake_; } - bool HasF16C() const { return has_f16c_; } + bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/ bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } @@ -85,6 +85,9 @@ class CPUIDInfo { return is_armv8_narrow_ld_[coreIdx]; } + bool HasFp16VectorAcceleration() const { + return has_fp16_; + } private: CPUIDInfo() { @@ -118,6 +121,7 @@ class CPUIDInfo { std::vector is_armv8_narrow_ld_; bool has_arm_neon_dot_{false}; + bool has_fp16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index ff1079a22f57d..0af5924987458 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -718,7 +718,7 @@ class PlannerImpl { if (!is_implicit_input) { OrtMemType mem_type = p_kernel_def->InputMemoryType(arg_idx); - plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(0, mem_type)->Info()); + plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(mem_type)->Info()); set_node_arg_has_explicit_consumer.insert(index); } else { // implicit input // Only process an implicit input if there are explicit consumers at this graph level @@ -790,16 +790,16 @@ class PlannerImpl { if (already_seen_ep_for_node_arg == map_implicitly_consumed_node_arg_to_ep.end()) { // First time we are encountering this implicitly consumed input at this graph level (or) - plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(exec_provider->GetDeviceId(), OrtMemType::OrtMemTypeDefault)->Info()); + plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(OrtMemType::OrtMemTypeDefault)->Info()); map_implicitly_consumed_node_arg_to_ep.insert({index, exec_provider}); } else if (already_seen_ep_for_node_arg->second == exec_provider) { // The EP that we previously seen for this implicit input is the same one as the current EP // we have seen - plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(exec_provider->GetDeviceId(), OrtMemType::OrtMemTypeDefault)->Info()); + plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(OrtMemType::OrtMemTypeDefault)->Info()); } else { // Default the location to CPU plan_.SetLocation(static_cast(index), - execution_providers_.Get(CPU)->GetAllocator(exec_provider->GetDeviceId(), OrtMemType::OrtMemTypeDefault)->Info()); + execution_providers_.Get(CPU)->GetAllocator(OrtMemType::OrtMemTypeDefault)->Info()); set_implicitly_consumed_node_arg_has_heterogenous_ep_consumers.insert(index); } } @@ -822,8 +822,7 @@ class PlannerImpl { if (!node_output->Exists()) continue; OrtValueIndex index = Index(node_output->Name()); ProcessDef(index, node_output); - int device_id = p_kernel_def->IsOutputOnCpu(i) ? 0 : exec_provider->GetDeviceId(); - auto allocator = exec_provider->GetAllocator(device_id, p_kernel_def->OutputMemoryType(i)); + auto allocator = exec_provider->GetAllocator(p_kernel_def->OutputMemoryType(i)); ORT_ENFORCE(allocator); plan_.SetLocation(static_cast(index), allocator->Info()); @@ -844,7 +843,7 @@ class PlannerImpl { if (utils::IsInputOnCpu(node, &kernel_create_info, input_index)) // weights are not output from any node, so it's OK to put its location on CPU provider return execution_providers_.GetDefaultCpuMemoryInfo(); - return p_provider->GetAllocator(p_provider->GetDeviceId(), OrtMemTypeDefault)->Info(); + return p_provider->GetAllocator(OrtMemTypeDefault)->Info(); } void GeneratePlanForWeightsHelper(const GraphViewer& graph_viewer, @@ -1741,7 +1740,7 @@ class PlannerImpl { onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); ORT_ENFORCE(ep); - auto& node_device_mem_location = ep->GetAllocator(ep->GetDeviceId(), OrtMemType::OrtMemTypeDefault)->Info(); + auto& node_device_mem_location = ep->GetAllocator(OrtMemType::OrtMemTypeDefault)->Info(); execution_plan.emplace_back(std::make_unique(node_device_mem_location.device)); // 2. add steps to the execution plan for (auto node_index : stream_nodes_[0]) { @@ -1781,7 +1780,7 @@ class PlannerImpl { onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); ORT_ENFORCE(ep); - auto& node_device_mem_location = ep->GetAllocator(ep->GetDeviceId(), OrtMemType::OrtMemTypeDefault)->Info(); + auto& node_device_mem_location = ep->GetAllocator(OrtMemType::OrtMemTypeDefault)->Info(); execution_plan.emplace_back(std::make_unique(node_device_mem_location.device)); } else { execution_plan.emplace_back(nullptr); @@ -1809,7 +1808,7 @@ class PlannerImpl { for (size_t i = 0; i < num_logic_streams_; ++i) { for (auto node_index : stream_nodes_[i]) { auto* node = graph_viewer_.GetNode(node_index); - // Neither trigger ActivateNotification/WaitOnEPStep for Shape op (whose output is ready for all the EPs), nor + // Neither trigger ActivateNotification/WaitOnEPStep for Shape op (whose output is ready for all the EPs), nor // upstream is on CPU device (As currently we never invoke RegisterWaitFn(CPU, ...) for all kinds of EP, thus no wait_handle can be retrieved for this case) if (node->OpType() != "Shape" && execution_plan[i]->device_.Type() != OrtDevice::CPU) { for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) { @@ -1850,11 +1849,11 @@ class PlannerImpl { auto* node = graph_viewer_.GetNode(node_index); onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); - auto& node_device_mem_location = ep->GetAllocator(ep->GetDeviceId(), OrtMemType::OrtMemTypeDefault)->Info(); + auto& node_device_mem_location = ep->GetAllocator(OrtMemType::OrtMemTypeDefault)->Info(); ORT_ENFORCE(execution_plan[node_stream_map_[node_index]]->device_.Type() == node_device_mem_location.device.Type()); } } - + // 4. add commands to logic queue for (size_t i = 0; i < num_logic_streams_; ++i) { for (size_t j = 0; j < stream_nodes_[i].size(); ++j) { @@ -2256,7 +2255,7 @@ Status DeviceBasedPartitioner::PartitionGraph(const onnxruntime::GraphViewer& gr const auto& op_type = node->OpType(); const auto& node_name = node->Name(); auto* ep = execution_providers.Get(*node); - auto& device_mem_location = ep->GetAllocator(ep->GetDeviceId(), OrtMemType::OrtMemTypeDefault)->Info(); + auto& device_mem_location = ep->GetAllocator(OrtMemType::OrtMemTypeDefault)->Info(); auto device_type = device_mem_location.device.Type(); // log the device @@ -2396,4 +2395,4 @@ std::unique_ptr IGraphPartitioner::CreateGraphPartitioner(con #endif -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index d081d01ee53e6..ec9f4888c7c52 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -335,7 +335,7 @@ void DumpTensor( if (tensor_location.device.Type() == OrtDevice::GPU) { const auto& execution_providers = session_state.GetExecutionProviders(); const auto* cpu_execution_provider = execution_providers.Get(onnxruntime::kCpuExecutionProvider); - auto cpu_allocator = cpu_execution_provider->GetAllocator(0, OrtMemTypeDefault); + auto cpu_allocator = cpu_execution_provider->GetAllocator(OrtMemTypeDefault); Tensor cpu_tensor{data_type, tensor.Shape(), cpu_allocator}; const auto& data_transfer_mgr = session_state.GetDataTransferMgr(); auto status = data_transfer_mgr.CopyTensor(tensor, cpu_tensor); diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index 5bc5dcdbd7696..e9e13dbbbda84 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -18,8 +18,12 @@ inline int MakeKey(int id, OrtMemType mem_type) { } } // namespace -AllocatorPtr IExecutionProvider::GetAllocator(int device_id, OrtMemType mem_type) const { - // TODO(leca): ORT_ENFORCE(mem_type != OrtMemType::OrtMemTypeDefault || device_id == GetDeviceId(), "Rule out the case that search on OrtMemTypeDefault but device_id doesn't match GetDeviceId()"); +AllocatorPtr IExecutionProvider::GetAllocator(OrtMemType mem_type) const { + // if mem_type is OrtMemType::OrtMemTypeDefault, it will allocate memory from the current device + // otherwise (mem_type is OrtMemTypeCpu...) it will allocate memory from Cpu as input/output, thus set the device_id + // to 0 as there is only 1 CPU in each machine. + int device_id = GetDeviceId(); + if (mem_type != OrtMemType::OrtMemTypeDefault) device_id = 0; auto iter = allocators_.find(MakeKey(device_id, mem_type)); if (iter != allocators_.end()) { return iter->second; diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 944a983b56639..2b602703ff75e 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -74,7 +74,7 @@ class ExecutionProviders { const_iterator end() const noexcept { return exec_providers_.cend(); } const AllocatorPtr GetDefaultCpuAllocator() const { - return Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault); + return Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(OrtMemTypeDefault); } OrtMemoryInfo GetDefaultCpuMemoryInfo() const { diff --git a/onnxruntime/core/framework/func_kernel.h b/onnxruntime/core/framework/func_kernel.h index 9bf7a5bff6cd9..c0e13babf068a 100644 --- a/onnxruntime/core/framework/func_kernel.h +++ b/onnxruntime/core/framework/func_kernel.h @@ -27,7 +27,7 @@ class FunctionKernel : public OpKernel { if (compute->create_state_func) { //TODO: we are only provide host allocate method in compute context. //Do we need to hold the ref-counting here? - funckernel->host_allocator_ = info.GetAllocator(0, OrtMemType::OrtMemTypeDefault); + funckernel->host_allocator_ = info.GetAllocator(OrtMemType::OrtMemTypeDefault); ComputeContext context = {allocate_helper_func, release_helper_func, funckernel->host_allocator_.get(), info.node().Name().c_str()}; int ret = funckernel->compute_info_->create_state_func(&context, &funckernel->func_state_); diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index fd795ef8fbcd4..b1a6ab5291e97 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -21,8 +21,8 @@ const onnxruntime::KernelDef& OpKernel::KernelDef() const { return op_kernel_info_->GetKernelDef(); } -const OrtMemoryInfo& OpKernel::Allocator(int id, OrtMemType mem_type) const { - return op_kernel_info_->GetMemoryInfo(id, mem_type); +const OrtMemoryInfo& OpKernel::Allocator(OrtMemType mem_type) const { + return op_kernel_info_->GetMemoryInfo(mem_type); } OpKernelContext::OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel, @@ -93,7 +93,7 @@ int OpKernelContext::NumVariadicInputs(size_t arg_num) const { } Status OpKernelContext::GetTempSpaceAllocator(AllocatorPtr* output) const { - *output = execution_frame_->GetAllocator(kernel_->Allocator(0, OrtMemTypeDefault)); + *output = execution_frame_->GetAllocator(kernel_->Allocator(OrtMemTypeDefault)); if (!*output) return Status(common::ONNXRUNTIME, common::FAIL, "TempSpace allocator not found"); return Status::OK(); diff --git a/onnxruntime/core/framework/op_kernel_info.cc b/onnxruntime/core/framework/op_kernel_info.cc index 2c150486885b0..e73b64784b162 100644 --- a/onnxruntime/core/framework/op_kernel_info.cc +++ b/onnxruntime/core/framework/op_kernel_info.cc @@ -28,14 +28,14 @@ OpKernelInfo::OpKernelInfo(const OpKernelInfo& other) : OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.constant_initialized_tensors_, other.ort_value_name_idx_map_, other.data_transfer_mgr_) {} -const OrtMemoryInfo& OpKernelInfo::GetMemoryInfo(int device_id, OrtMemType mem_type) const { - AllocatorPtr alloc = GetAllocator(device_id, mem_type); +const OrtMemoryInfo& OpKernelInfo::GetMemoryInfo(OrtMemType mem_type) const { + AllocatorPtr alloc = GetAllocator(mem_type); if (alloc == nullptr) ORT_THROW("cannot find allocator"); return alloc->Info(); } -AllocatorPtr OpKernelInfo::GetAllocator(int device_id, OrtMemType mem_type) const { - return execution_provider_->GetAllocator(device_id, mem_type); +AllocatorPtr OpKernelInfo::GetAllocator(OrtMemType mem_type) const { + return execution_provider_->GetAllocator(mem_type); } const KernelDef& OpKernelInfo::GetKernelDef() const { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index facce93cde798..7204683912820 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -100,8 +100,8 @@ void SessionState::SetupAllocators() { } else { // slightly weird indirection to go back to the provider to get the allocator each time it's needed // in order to support scenarios such as the CUDA EP's per-thread allocator. - allocators_[memory_info] = [&provider](int id, OrtMemType mem_type) { - return provider->GetAllocator(id, mem_type); + allocators_[memory_info] = [&provider](OrtMemType mem_type) { + return provider->GetAllocator(mem_type); }; } } @@ -112,7 +112,7 @@ AllocatorPtr SessionState::GetAllocator(const OrtMemoryInfo& location) const noe AllocatorPtr result; auto entry = allocators_.find(location); if (entry != allocators_.cend()) { - result = entry->second(location.id, location.mem_type); + result = entry->second(location.mem_type); } return result; @@ -121,7 +121,7 @@ AllocatorPtr SessionState::GetAllocator(const OrtMemoryInfo& location) const noe AllocatorPtr SessionState::GetAllocator(OrtDevice device) const noexcept { for (const auto& iter : allocators_) { if (iter.first.device == device) { - return iter.second(device.Id(), iter.first.mem_type); + return iter.second(iter.first.mem_type); } } return nullptr; @@ -451,7 +451,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapInfo().GetAllocator(0, OrtMemType::OrtMemTypeDefault); + AllocatorPtr session_cpu_alloc = kernel->Info().GetAllocator(OrtMemType::OrtMemTypeDefault); ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_cpu_alloc, // use allocator tied to this session is_packed, @@ -1004,7 +1004,7 @@ Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { const auto& ep = node.GetExecutionProviderType(); - if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { + if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index f2d2eb1d93978..6028c19d3c3ba 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -447,7 +447,7 @@ class SessionState { // for internal allocations by CUDAExecutionProvider::GetScratchBuffer, but could access the per-thread allocator // directly instead of going through CUDAExecutionProvider::GetAllocator. // If that can be validated we could simply store the AllocatorPtr here and get rid of the delegate. - std::map, + std::map, OrtMemoryInfoLessThanIgnoreNameAndAllocType> allocators_; diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index f1bb78bf771ab..50442268e797e 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -23,25 +23,31 @@ #ifndef SHARED_PROVIDER #include "core/common/logging/logging.h" #endif +#include "core/framework/execution_provider.h" +#include "core/framework/tuning_context.h" namespace onnxruntime { -namespace tunable { -template +template struct OpParams { - OpParams() : stream{} {} - explicit OpParams(StreamT stream) : stream(stream) {} + OpParams() : tuning_ctx{nullptr}, stream{} {} + OpParams(TuningContextT* tuning_ctx, StreamT stream) : tuning_ctx(tuning_ctx), stream(stream) {} virtual ~OpParams() = default; virtual std::string Signature() const = 0; virtual StreamT Stream() const { return stream; } + virtual TuningContextT* TuningContext() const { return tuning_ctx; } + + // NOTE: the reason of TuningContext does not contains the Stream is that ORT now supports multiple stream and the + // stream may change from call to call. + TuningContextT* tuning_ctx; StreamT stream; }; template -class Timer { +class ITimer { public: - explicit Timer(StreamT stream) : stream_{stream} {} - virtual ~Timer() = default; + explicit ITimer(StreamT stream) : stream_{stream} {} + virtual ~ITimer() = default; virtual void Start() = 0; virtual void End() = 0; @@ -126,37 +132,30 @@ class TunableOp { TunableOp(TunableOp&&) = default; Status operator()(const ParamsT* params) { - int id; - if (tuning_) { - if (kernel_map_.find(params->Signature()) == kernel_map_.end()) { - auto maybe_proxy_params = this->PreTuning(params); + int id = default_id_; + ITuningContext* ctx = params->TuningContext(); + if (ctx->IsTunableOpEnabled()) { + auto& mgr = ctx->GetTuningResultsManager(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + id = mgr.Lookup(op_sig, params_sig); + if (id > static_cast(ops_.size())) { + LOGS_DEFAULT(ERROR) << "Invalid TunableOp kernel id for " << op_sig + << ", id:" << id << ", registered op:" << ops_.size(); + mgr.Delete(op_sig, params_sig); + id = -1; + } + if (id < 0) { + auto maybe_proxy_params = PreTuning(params); id = FindFastest(maybe_proxy_params); PostTuning(maybe_proxy_params); - kernel_map_.insert({params->Signature(), id}); - } else { - id = kernel_map_[params->Signature()]; + mgr.Add(op_sig, params_sig, id); } - } else { - id = default_id_; } ORT_RETURN_IF_ERROR(ops_[id](params)); return Status::OK(); } - void EnableTuning() { - tuning_ = true; - for (auto nested_op_ptr : nested_tunable_ops_) { - nested_op_ptr->EnableTuning(); - } - } - - void DisableTuning() { - tuning_ = false; - for (auto nested_op_ptr : nested_tunable_ops_) { - nested_op_ptr->DisableTuning(); - } - } - // We might want to do some tricks to the `params`, e.g., some op will use a buffer for input and output at the same // time, so it will do inplace update to it. If we blindly tune over the `params`, there will be accumulated update // to that buffer during FindFastest, which is an undesired side effect. In this case, we must prepare a new (proxy) @@ -184,11 +183,6 @@ class TunableOp { void RegisterNestedTunableOp(TunableOp* op_ptr) { nested_tunable_ops_.insert(op_ptr); - if (tuning_) { - op_ptr->EnableTuning(); - } else { - op_ptr->DisableTuning(); - } // Add an op for this tunable op as well. RegisterOp([op_ptr](const ParamsT* params) { @@ -196,6 +190,10 @@ class TunableOp { }); } + std::string Signature() const { + return signature_; + } + private: static void WarmUp(Op& op, const ParamsT* param) { constexpr const int num_iter = 4; @@ -224,30 +222,13 @@ class TunableOp { return true; } - std::string OpSignature() const { -#ifdef ORT_NO_RTTI - ORT_THROW("TunableOp must be built with RTTI enabled"); -#else -#ifndef _WIN32 - const auto* name = typeid(*this).name(); - char buf[256]; - size_t buf_len = 256; - abi::__cxa_demangle(name, buf, &buf_len, nullptr); - buf[255] = '\0'; - return buf; -#else - return typeid(*this).name(); -#endif -#endif - } - protected: virtual int FindFastest(const ParamsT* params) { return FindFastestImpl(params, ops_); } int FindFastestImpl(const ParamsT* params, const std::vector>& candidates) { - auto op_sig = OpSignature(); + auto op_sig = Signature(); auto param_sig = params->Signature(); LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ')'; auto min_time = std::numeric_limits::infinity(); @@ -256,8 +237,7 @@ class TunableOp { for (size_t i = 0; i < candidates.size(); i++) { auto& candidate = const_cast&>(candidates[i]); if (!IsSupported(candidate, params)) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl found unsupported " << op_sig - << '(' << param_sig << ") id=" << i; + LOGS_DEFAULT(VERBOSE) << "FindFastestImpl found unsupported " << op_sig << '(' << param_sig << ") id=" << i; continue; } @@ -275,11 +255,27 @@ class TunableOp { } private: - // mapping from Signature to best impl - std::unordered_map kernel_map_; + std::string CreateSignature() { +#ifdef ORT_NO_RTTI + ORT_THROW("TunableOp must be built with RTTI enabled"); +#else +#ifndef _WIN32 + const auto* name = typeid(*this).name(); + char buf[256]; + size_t buf_len = 256; + abi::__cxa_demangle(name, buf, &buf_len, nullptr); + buf[255] = '\0'; + return buf; +#else + return typeid(*this).name(); +#endif +#endif + } + + std::string signature_{CreateSignature()}; + // the default impl to use when tuning is disabled int default_id_{0}; - bool tuning_{false}; std::vector> ops_; @@ -287,5 +283,4 @@ class TunableOp { std::unordered_set*> nested_tunable_ops_; }; -} // namespace tunable } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h new file mode 100644 index 0000000000000..6cd61931b8aaf --- /dev/null +++ b/onnxruntime/core/framework/tuning_context.h @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/common.h" +#include "core/platform/ort_mutex.h" +#include "core/framework/tuning_results.h" + +namespace onnxruntime { + +class IExecutionProvider; +class TuningResultsManager; +class TuningResultsValidator; + +class ITuningContext { + public: + explicit ITuningContext(IExecutionProvider* ep) : ep_(ep) {} + virtual ~ITuningContext() = default; + + virtual void EnableTunableOp() = 0; + virtual void DisableTunableOp() = 0; + virtual bool IsTunableOpEnabled() const = 0; + + virtual TuningResultsManager& GetTuningResultsManager() = 0; + virtual const TuningResultsManager& GetTuningResultsManager() const = 0; + + virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0; + + virtual TuningResults GetTuningResults() const; + virtual Status LoadTuningResults(const TuningResults& tr); + + protected: + IExecutionProvider* ep_; +}; + +class TuningResultsManager { + public: + TuningResultsManager() = default; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TuningResultsManager); + + KernelMap Lookup(const std::string& op_signature) const; + int Lookup(const std::string& op_signature, const std::string& params_signature) const; + + void Add(const std::string& op_signature, const std::string& params_signature, int best_id); + void Delete(const std::string& op_signature, const std::string& params_signature); + + void Load(const std::unordered_map& results_to_load); + std::unordered_map Dump() const; + + void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map); + + // Mainly for testing purpose + void Clear(); + + private: + mutable OrtMutex lock_; + std::unordered_map results_; +}; + +class TuningResultsValidator { + public: + using GetFunc = std::function; + using ValidateFunc = std::function; + using GetValidateFuncs = std::unordered_map>; + + TuningResultsValidator(); + virtual ~TuningResultsValidator() = default; + + std::unordered_map GetAllValidators() const; + Status ValidateAll(const std::unordered_map& to_validate) const; + + protected: + void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); + + virtual std::string GetOrtVersion() const; + virtual Status ValidateOrtVersion(const std::string& value) const; + + virtual std::string GetOrtGitCommit() const; + virtual Status ValidateOrtGitCommit(const std::string& value) const; + + virtual std::string GetOrtBuildConfig() const; + virtual Status ValidateOrtBuildConfig(const std::string& value) const; + + public: + static constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"}; + + private: + GetValidateFuncs validators_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/tuning_context_impl.h b/onnxruntime/core/framework/tuning_context_impl.h new file mode 100644 index 0000000000000..c8b0583e3ea5f --- /dev/null +++ b/onnxruntime/core/framework/tuning_context_impl.h @@ -0,0 +1,285 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file contains the implementation of TuningContext. At the moment, there is no necessity to expose these +// methods as OrtApis. This will cause missing symbols when loading provider dynamic libraries, because the libraries +// are not whole-archive linked and these symbols are not referenced at framework level. To circumvent this problem, +// the EP must has and only has one translation unit include this file. +#ifndef TUNING_CONTEXT_IMPL +#error define TUNING_CONTEXT_IMPL to use this header (impl) file +#endif + +#pragma once + +#include +#include +#include + +#include "core/framework/tunable.h" +#include "core/framework/tuning_context.h" +#include "core/framework/tuning_results.h" + +namespace onnxruntime { + +TuningResults ITuningContext::GetTuningResults() const { + TuningResults tr; + tr.ep = ep_->Type(); + tr.validators = GetTuningResultsValidator().GetAllValidators(); + tr.results = GetTuningResultsManager().Dump(); + return tr; +} + +Status ITuningContext::LoadTuningResults(const TuningResults& tr) { + ORT_RETURN_IF(tr.ep != ep_->Type(), "EP mismatch"); + LOGS_DEFAULT(VERBOSE) << "Loading tuning results for " << tr.ep; + ORT_RETURN_IF_ERROR(GetTuningResultsValidator().ValidateAll(tr.validators)); + GetTuningResultsManager().Load(tr.results); + return Status::OK(); +} + +KernelMap TuningResultsManager::Lookup(const std::string& op_signature) const { + std::scoped_lock l{lock_}; + auto it = results_.find(op_signature); + if (it == results_.cend()) { + return {}; + } + return it->second; // copied +} + +// NOLINTNEXTLINE(bugprone-easily-swappable-parameters) +int TuningResultsManager::Lookup(const std::string& op_signature, const std::string& params_signature) const { + std::scoped_lock l{lock_}; + auto kernel_map_it = results_.find(op_signature); + if (kernel_map_it == results_.cend()) { + return -1; + } + + const auto& km = kernel_map_it->second; + auto it = km.find(params_signature); + if (it == km.cend()) { + return -1; + } + return it->second; +} + +inline void AddImpl(const std::string& op_signature, + const std::string& params_signature, + int best_id, + KernelMap& kernel_map) { + auto it = kernel_map.find(params_signature); + if (it != kernel_map.end()) { + if (it->second != best_id) { + LOGS_DEFAULT(WARNING) << op_signature << "(" << params_signature << ") already has a best kernel " + << "id=" << it->second << " selected, want to add a different best kernel id=" << best_id + << ", the new kernel id will be ignored."; + } + return; + } + + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ") -> " << best_id; + kernel_map[params_signature] = best_id; +} + +void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, int best_id) { + std::scoped_lock l{lock_}; + + auto it = results_.find(op_signature); + if (it == results_.end()) { + it = results_.insert({op_signature, {}}).first; + } + + AddImpl(op_signature, params_signature, best_id, it->second); +} + +// NOLINTNEXTLINE(bugprone-easily-swappable-parameters) +void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { + std::scoped_lock l{lock_}; + + auto it = results_.find(op_signature); + if (it == results_.end()) { + return; + } + + auto it2 = it->second.find(params_signature); + if (it2 == it->second.end()) { + return; + } + + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ")"; + it->second.erase(it2); +} + +std::unordered_map TuningResultsManager::Dump() const { + std::scoped_lock l{lock_}; + return results_; +} + +void DisjointMergeImpl( + const std::string& op_signature, + const KernelMap& kernel_map, + /*out*/ std::unordered_map& results) { + auto it = results.find(op_signature); + if (it == results.end()) { + for(const auto& [param_sig, kernel_id] : kernel_map) { + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << param_sig << ") -> " << kernel_id; + } + results[op_signature] = kernel_map; + return; + } + + for (const auto& [params_signature, best_id] : kernel_map) { + AddImpl(op_signature, params_signature, best_id, it->second); + } +} + +void TuningResultsManager::Load(const std::unordered_map& results_to_load) { + std::scoped_lock l{lock_}; + for (const auto& [op_signature, kernel_map] : results_to_load) { + DisjointMergeImpl(op_signature, kernel_map, results_); + } +} + +void TuningResultsManager::DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map) { + std::scoped_lock l{lock_}; + DisjointMergeImpl(op_signature, kernel_map, results_); +} + +void TuningResultsManager::Clear() { + results_ = {}; +} + +static Status CheckMandatoryKeys( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + bool passed = true; + std::ostringstream oss; + for (const auto& k : TuningResultsValidator::mandatory_keys) { + if (gv_funcs.find(k) == gv_funcs.end()) { + passed = false; + oss << "key=\"" << k << "\" is not registered for Get and Validate. "; + } + + if (to_check.find(k) == to_check.end()) { + passed = false; + oss << "key=\"" << k << "\" is not provided for validation. "; + } + } + ORT_RETURN_IF(!passed, oss.str()); + return Status::OK(); +} + +static Status CheckKeysMatching( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + auto get_keys = [](const auto& it) -> std::string { return it.first; }; + std::vector required_keys; + std::vector provided_keys; + std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys); + std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys); + std::sort(required_keys.begin(), required_keys.end()); + std::sort(provided_keys.begin(), provided_keys.end()); + + std::unordered_set intersection; + std::set_intersection(required_keys.cbegin(), required_keys.cend(), + provided_keys.cbegin(), provided_keys.cend(), + std::inserter(intersection, intersection.end())); + bool matched = true; + std::ostringstream oss; + if (intersection.size() != required_keys.size()) { + matched = false; + for (const auto& k : required_keys) { + if (intersection.find(k) == intersection.end()) { + oss << "Unmatched validator: \"" << k << "\" is required, but the tuning results does not provide it. "; + } + } + } + if (intersection.size() != provided_keys.size()) { + matched = false; + for (const auto& k : provided_keys) { + if (intersection.find(k) == intersection.end()) { + oss << "Unmatched validator: \"" << k << "\" is provided, but onnxruntime is unable to consume it. "; + } + } + } + ORT_RETURN_IF(!matched, oss.str()); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtVersion() const { + return ORT_VERSION; +} + +Status TuningResultsValidator::ValidateOrtVersion(const std::string& value) const { + ORT_RETURN_IF(value != ORT_VERSION, "onnxruntime version mismatch"); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtGitCommit() const { + // TODO: + return ""; +} + +Status TuningResultsValidator::ValidateOrtGitCommit(const std::string& value) const { + // TODO: + ORT_UNUSED_PARAMETER(value); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtBuildConfig() const { + return ""; +} + +Status TuningResultsValidator::ValidateOrtBuildConfig(const std::string& value) const { + auto current = GetOrtBuildConfig(); + ORT_RETURN_IF(current != value, + "onnxruntime building configuration mismatch: tuning results produced with library \"", + value, "\", current library built with \"", current, "\""); + return Status::OK(); +} + +TuningResultsValidator::TuningResultsValidator() { + RegisterValidator( + "ORT_VERSION", + [this]() { return GetOrtVersion(); }, + [this](auto&& k) { return ValidateOrtVersion(std::forward(k)); }); + + RegisterValidator( + "ORT_GIT_COMMIT", + [this]() { return GetOrtGitCommit(); }, + [this](auto&& k) { return ValidateOrtGitCommit(std::forward(k)); }); + + RegisterValidator( + "ORT_BUILD_CONFIG", + [this]() { return GetOrtBuildConfig(); }, + [this](auto&& k) { return ValidateOrtBuildConfig(std::forward(k)); }); +} + +Status TuningResultsValidator::ValidateAll(const std::unordered_map& to_validate) const { + ORT_RETURN_IF_ERROR(CheckMandatoryKeys(validators_, to_validate)); + ORT_RETURN_IF_ERROR(CheckKeysMatching(validators_, to_validate)); + + for (const auto& [key, value] : to_validate) { + const auto& it = validators_.find(key); + ORT_ENFORCE(it != validators_.cend()); + const ValidateFunc& validator = it->second.second; + ORT_RETURN_IF_ERROR(validator(value)); + } + + return Status::OK(); +} + +std::unordered_map TuningResultsValidator::GetAllValidators() const { + std::unordered_map ret; + for (const auto& [key, get_validate_func_pair] : validators_) { + const GetFunc& getter = get_validate_func_pair.first; + ret[key] = getter(); + } + return ret; +} + +void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) { + ORT_ENFORCE(validators_.find(key) == validators_.end()); + validators_[key] = std::make_pair(gf, vf); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/tuning_results.h b/onnxruntime/core/framework/tuning_results.h new file mode 100644 index 0000000000000..1a32d81f29908 --- /dev/null +++ b/onnxruntime/core/framework/tuning_results.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { + +// Mapping from params signature to kernel id +using KernelMap = std::unordered_map; + +struct TuningResults { + std::string ep; + + // Validates if these results are compatible with the libraries, the validation process is EP defined + std::unordered_map validators; + + // Mapping from Op signature to Op's tuning result + std::unordered_map results; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 79691d7b516dc..f88d098454479 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -20,6 +20,8 @@ #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas.h" #include "core/framework/TensorSeq.h" +#include "core/framework/run_options.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #ifdef USE_AZURE #include "core/framework/cloud_executor.h" #endif @@ -793,13 +795,14 @@ common::Status ExecuteGraph(const SessionState& session_state, logger); } #endif + bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; return ExecuteGraph(session_state, feeds_fetches_manager, feeds, fetches, execution_mode, run_options.terminate, logger, - run_options.synchronize_execution_providers, + synchronize_execution_providers, run_options.only_execute_path_to_fetches); } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index b4ad4d64e7ddb..580a0a993454a 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -126,33 +126,65 @@ void RestorePaddingTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) } void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { - // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) - // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) // Output 0 has shape (batch_size, sequence_length, v_hidden_size) + // Q, K and V without packing: + // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) + // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + + // Packed KV: + // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + // Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size) + // Input 2 nullptr + + // Packed QKV: + // Input 0 (batch_size, sequence_length, num_heads, 3, head_size) + // Input 1 nullptr + // Input 2 nullptr + // Type inference ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); // Shape inference - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + if (hasInputShape(ctx, 0)) { auto& query_shape = getInputShape(ctx, 0); auto& query_dims = query_shape.dim(); - if (query_dims.size() != 3) { - fail_shape_inference("Inputs 0 (query) shall be 3 dimensions"); + + if (query_dims.size() != 3 && query_dims.size() != 5) { + fail_shape_inference("Inputs 0 (query) shall be 3 or 5 dimensions"); } - auto& value_shape = getInputShape(ctx, 2); - auto& value_dims = value_shape.dim(); - if (value_dims.size() != 3) { - fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + if (query_dims.size() == 5) { // packed QKV + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + *output_shape.add_dim() = query_dims[2] * query_dims[4]; + updateOutputShape(ctx, 0, output_shape); + return; } - ONNX_NAMESPACE::TensorShapeProto output_shape; - *output_shape.add_dim() = query_dims[0]; - *output_shape.add_dim() = query_dims[1]; - *output_shape.add_dim() = value_dims[2]; - updateOutputShape(ctx, 0, output_shape); + if (hasInputShape(ctx, 2)) { + auto& value_shape = getInputShape(ctx, 2); + auto& value_dims = value_shape.dim(); + if (value_dims.size() != 3) { + fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + *output_shape.add_dim() = value_dims[2]; + updateOutputShape(ctx, 0, output_shape); + return; + } + + if (hasInputShape(ctx, 1)) { + auto& key_shape = getInputShape(ctx, 1); + if (key_shape.dim().size() == 5) { // packed KV + ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx); + } + } } } @@ -234,7 +266,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .Input(5, - "extra_add", + "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", "T", OpSchema::Optional) @@ -283,25 +315,34 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::FLOAT, OPTIONAL_VALUE) .Input(0, "query", - "Query with shape (batch_size, sequence_length, hidden_size)", + "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)", "T") .Input(1, "key", - "Key with shape (batch_size, kv_sequence_length, hidden_size)", - "T") + "Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)", + "T", + OpSchema::Optional) .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, v_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "bias", "Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection", - "T") + "T", + OpSchema::Optional) .Input(4, "key_padding_mask", "Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)", "M", OpSchema::Optional) + .Input(5, + "relative_position_bias", + "relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)" + " or (1, num_heads, sequence_length, total_sequence_length)", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, v_hidden_size)", @@ -657,5 +698,41 @@ ONNX_MS_OPERATOR_SET_SCHEMA( RestorePaddingTypeAndShapeInference(ctx); })); +constexpr const char* GatedRelativePositionBias_ver1_doc = R"DOC( + query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) + gate_u, gate_r = torch.sigmoid( + self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0 + rel_pos_bias = gate_u_1 * rel_pos +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GatedRelativePositionBias, 1, + OpSchema() + .SetDoc(GatedRelativePositionBias_ver1_doc) + .Attr("num_heads", "Number of attention heads", AttributeProto::INT) + .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T") + .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T") + .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T") + .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T") + .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T") + .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T") + .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + int64_t num_heads = getAttribute(ctx, "num_heads", -1L); + if (hasInputShape(ctx, 0)) { + auto& query_layer_shape = getInputShape(ctx, 0); + TensorShapeProto output_shape; + *output_shape.add_dim() = query_layer_shape.dim(0); + output_shape.add_dim()->set_dim_value(num_heads); + *output_shape.add_dim() = query_layer_shape.dim(1); + *output_shape.add_dim() = query_layer_shape.dim(1); + updateOutputShape(ctx, 0, output_shape); + } + })); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc new file mode 100644 index 0000000000000..167b80238a3d6 --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/constants.h" + +namespace onnxruntime { +namespace contrib { + +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::InferenceContext; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::OPTIONAL_VALUE; +using ONNX_NAMESPACE::TypeProto; + +void RegisterCollectiveOps() { + ONNX_CONTRIB_OPERATOR_SCHEMA(AllReduce) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "input", "tensors to be reduced", "T", OpSchema::Variadic) + .Output(0, "output", "reduced tensors", "T", OpSchema::Variadic) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain to float, float16 and double tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(AllGather) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("group_size", + "total size in the group that need to be gathered.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "tensors to be sent", "T", OpSchema::Variadic) + .Output(0, "output", "gathered tensors", "T", OpSchema::Variadic) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain to float, float16 and double tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto group_size = getAttribute(ctx, "group_size", 1); + assert(group_size >= static_cast(1)); + // propagate type for output + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // propagate shape for output. + // output shape is [group_size * input_shape[0], ...] + auto output_type = ctx.getOutputType(0); + auto input_type = ctx.getInputType(0); + if (hasShape(*input_type)) { + auto shape = input_type->tensor_type().shape(); + auto dim = shape.dim(0) * group_size; + *shape.mutable_dim(0) = dim; + *output_type->mutable_tensor_type()->mutable_shape() = shape; + } + }); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a8c870d1442cf..e80841ef63543 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2708,6 +2708,10 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t RegisterNchwcSchemas(); } #endif + +#ifdef USE_MPI + RegisterCollectiveOps(); +#endif } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.h b/onnxruntime/core/graph/contrib_ops/contrib_defs.h index 7d70c708a9c7b..4c24b284c6ddb 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.h +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.h @@ -13,7 +13,7 @@ #define ONNX_MS_OPERATOR_SET_SCHEMA(name, ver, impl) \ ONNX_OPERATOR_SET_SCHEMA_EX(name, Microsoft, ::onnxruntime::kMSDomain, ver, true, impl) -//They are in ONNX domain but they are in our source code +// They are in ONNX domain but they are in our source code #define ONNX_CONTRIB_OPERATOR_SET_SCHEMA(name, ver, impl) \ ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ::ONNX_NAMESPACE::ONNX_DOMAIN, ver, true, impl) @@ -29,7 +29,7 @@ inline bool HasRawData(const ONNX_NAMESPACE::TensorProto& ten_proto) { return ten_proto.data_type() != ONNX_NAMESPACE::TensorProto::UNDEFINED && ten_proto.has_raw_data(); // XXX: Figure out how to do in proto3 } -} +} // namespace utils #define ONNX_CONTRIB_OPERATOR_SCHEMA(name) \ ONNX_CONTRIB_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name) @@ -53,6 +53,10 @@ void RegisterContribSchemas(); void RegisterNchwcSchemas(); void RegisterQuantizationSchemas(); +#if defined(USE_MPI) +void RegisterCollectiveOps(); +#endif + constexpr const float kDefaultSkipLayerNormEpsilon = 1e-12f; constexpr const float kDefaultEmbedLayerNormEpsilon = 1e-12f; } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc new file mode 100644 index 0000000000000..c6d3db7fbe6da --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/constants.h" +#include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/contrib_ops/onnx_function_util.h" +#include "core/graph/contrib_ops/shape_inference_functions.h" + +// Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from +// ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build +#if defined(_WIN32) && !defined(NDEBUG) +#pragma warning(disable : 26426) +#endif + +namespace onnxruntime { +namespace contrib { +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::TensorShapeProto; +#ifndef NDEBUG +using ONNX_NAMESPACE::DbgOperatorSetTracker; +#endif + +constexpr const char* GroupNorm_ver1_doc = R"DOC( +Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494). + +This operator transforms input according to + y = gamma * (x - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GroupNorm, 1, + OpSchema() + .SetDoc(GroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for Swish", + AttributeProto::INT) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") + .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + +constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( +A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left +tensor multiplies the Gelu activation result of right tensor. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + BiasSplitGelu, 1, + OpSchema() + .SetDoc(BiasSplitGelu_ver1_doc) + .Input(0, + "X", + "Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension", + "T") + .Input(1, + "bias", + "Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor", + "T") + .Output(0, + "Y", + "The output tensor with dimensions (N, S, D/2)", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { + auto& input_shape = getInputShape(ctx, 0); + if (input_shape.dim().size() != 3) { + fail_shape_inference("input shall be 3 dimensions"); + } + + auto& bias_shape = getInputShape(ctx, 1); + if (bias_shape.dim().size() != 1) { + fail_shape_inference("bias shall be 1 dimension"); + } + + TensorShapeProto output_shape; + *output_shape.add_dim() = input_shape.dim(0); + *output_shape.add_dim() = input_shape.dim(1); + if (bias_shape.dim(0).has_dim_value()) { + output_shape.add_dim()->set_dim_value(bias_shape.dim(0).dim_value() / 2); + } else { + output_shape.add_dim(); + } + + updateOutputShape(ctx, 0, output_shape); + } + })); + +constexpr const char* BiasAdd_ver1_doc = R"DOC( +Add input with bias, then add residual inputs. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + BiasAdd, 1, + OpSchema() + .SetDoc(BiasAdd_ver1_doc) + .Input(0, + "X", + "Input tensor. Dimensions are (N, S, C), where N is the batch size, S is image size H*W, and C is number of channels", + "T") + .Input(1, + "bias", + "Bias tensor. Dimensions are (C)", + "T") + .Input(2, + "skip", + "Residual tensor. Dimensions are (N, S, C)", + "T") + .Output(0, + "Y", + "The output tensor with dimensions (N, S, C)", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 1f0af31a4bdd0..548f0a7ecc353 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -49,6 +49,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BeamSearch); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasDropout); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BitmaskBiasDropout); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasGelu); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSplitGelu); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasAdd); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSoftmax); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BifurcationDetector); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CDist); @@ -69,6 +71,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupNorm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite); @@ -79,6 +82,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBias); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); @@ -135,6 +139,8 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -155,6 +161,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -167,6 +174,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 6111afbd5d817..91e4f5d8ff81a 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -1140,7 +1140,7 @@ where value of each element is the end position, or valid length of actual seque left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. -Current version does not support past/present, extra_add and qkv_hidden_sizes. +Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. )DOC"; @@ -1202,7 +1202,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(18, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "Q", OpSchema::Optional) - .Input(19, "extra_add", + .Input(19, "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "S", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "Q") diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 5b6756e4fb90b..07757917de59d 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -90,7 +90,7 @@ typedef enum { CblasLeft=141, CblasRight=142} CBLAS_SIDE; #endif // -// Forward declare the thread pool implementation class. +// Forward declare the thread pool implementation class and half precision floating point. // // N.B. Avoid including ONNX Runtime headers here to keep the dependencies for // standalone MLAS test executables smaller. @@ -100,10 +100,12 @@ namespace onnxruntime { namespace concurrency { class ThreadPool; }; -}; + struct MLFloat16; +}; // namespace onnxruntime using MLAS_THREADPOOL = onnxruntime::concurrency::ThreadPool; + // // Platform routines. // @@ -613,7 +615,7 @@ MlasGemm( // Currently only supported in ARM64 // #if defined(MLAS_TARGET_ARM64) -constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 15; +constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 30; #else constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 0; #endif @@ -1085,7 +1087,8 @@ MLASCALL MlasReorderOutputNchw( const int64_t* OutputShape, const float* S, - float* D + float* D, + MLAS_THREADPOOL* ThreadPool ); void @@ -1366,3 +1369,173 @@ MlasQLinearMul( size_t N, bool IsScalarB ); + +// +// Half precision routines +// + +// Any type with size=2 should work +using MLAS_FP16 = onnxruntime::MLFloat16; + +constexpr size_t FP16_SIZE = sizeof(uint16_t); + + +bool MLASCALL +MlasFp16AccelerationSupported(); + +/** + * @brief Interface for half gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from half precision to single precision, etc. + * + * Half GEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. +*/ +class MLAS_HALF_GEMM_POSTPROCESSOR { +public: + virtual + void + Process( + MLAS_FP16*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_HALF_GEMM_POSTPROCESSOR() {} +}; + +/** + * @brief Convert half gemm result matrix to single precision float matrix +*/ +class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { +public: + MLAS_HALF_GEMM_2FLOAT_PROCESSOR( + float* Output, /**< address of the output matrix, row major */ + size_t RowStride /**< row stride of the output matrix */ + ) : + Output_(Output), + RowStride_(RowStride) + {} + + void + Process( + MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const override; + +private: + float* Output_; + size_t RowStride_; +}; + + +/** + * @brief Data parameters for half precision GEMM routine + * All except C are [in] parameters +*/ +struct MLAS_HALF_GEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */ + MLAS_FP16* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/ +}; + +/** + * @brief Half precision Batched GEMM: C = A * B + Bias + * Either A or B can be fp32 or fp16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return +*/ +void +MLASCALL +MlasHalfGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr + ); + +/** + * @brief For half precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] float2half Whether the input is float that + * needs to be converted to half precision + * @return size of the packing buffer, + * 0 if operation not supported +*/ +size_t +MLASCALL +MlasHalfGemmPackBSize( + size_t N, + size_t K, + bool float2half + ); + +/** + * @brief For half precision GEMM, pack the right hand + * side matrix B + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix +*/ +void +MLASCALL +MlasHalfGemmPackB( + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB + ); + +/** + * @brief For half precision GEMM, convert the float matrix B + * to half precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix +*/ +void +MLASCALL +MlasHalfGemmConvertPackB( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); diff --git a/onnxruntime/core/mlas/inc/mlas_float16.h b/onnxruntime/core/mlas/inc/mlas_float16.h new file mode 100644 index 0000000000000..33227ea90d6be --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_float16.h @@ -0,0 +1,115 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_float16.h + +Abstract: + + Utilities for half precision floating type conversions. Used internally + by MLAS on platforms without half precision support. Provided here as + convenience for tests or other client libraries/apps. + +--*/ + +#pragma once + +#include +#include +#include + + +using _mlas_fp16_ = uint16_t; + +union fp32_bits { + uint32_t u; + float f; +}; + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) + +/*PreFast told us to convert them to constexpr but the compiler says we can't.*/ +#pragma warning(disable : 26497) + +/*Added whole bunch of casts, still can't get rid of these overflow warnings.*/ +#pragma warning(disable : 26450) +#pragma warning(disable : 26451) +#endif + +inline +_mlas_fp16_ +MLAS_Float2Half(float ff) +{ + constexpr fp32_bits f32infty = {255 << 23}; + constexpr fp32_bits f16max = {(127 + 16) << 23}; + constexpr fp32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + constexpr uint32_t sign_mask = 0x80000000u; + + auto val = static_cast(0x0u); + fp32_bits f; + f.f = ff; + + uint32_t sign = f.u & sign_mask; + f.u ^= sign; + + if (f.u >= f16max.u) { + // Inf or NaN (all exponent bits set) + val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { + if (f.u < (113 << 23)) { + // Subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + val = static_cast(f.u - denorm_magic.u); + } else { + uint32_t mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + f.u += ((uint32_t)(15 - 127) << 23) + 0xfff; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + val = static_cast(f.u >> 13); + } + } + + val |= static_cast(sign >> 16); + return val; +} + +inline +float +MLAS_Half2Float(_mlas_fp16_ val) +{ + constexpr fp32_bits magic = {113 << 23}; + constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + fp32_bits o; + + o.u = (val & 0x7fff) << 13; // exponent/mantissa bits + uint32_t exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize + } + + o.u |= (val & 0x8000) << 16; // sign bit + return o.f; +} + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S new file mode 100644 index 0000000000000..036928d21b8ca --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S @@ -0,0 +1,550 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + HalfGemmKernelNeon.s + +Abstract: + + This module implements the kernels for the half precision matrix/matrix + multiply operation (HALF GEMM). + +--*/ + +#include "asmmacro.h" + +// +// Stack frame layout for the half gemm kernel. +// Callee save registers: d8-d15, x19-x30. x18 is reserved by the OS. +// + .equ .LHGemmKernelFrame_SavedRegs, (2 * 8) + .equ .LHGemmKernelFrame_B, 0 + .LHGemmKernelFrame_SavedRegs + .equ .LHGemmKernelFrame_ldb, 8 + .LHGemmKernelFrame_SavedRegs + .equ .LHGemmKernelFrame_ZeroMode, 16 + .LHGemmKernelFrame_SavedRegs + + .text + +/*++ + +Routine Description: + + This routine is an inner kernel to compute 6 rows of GEMM + +Arguments: + + CountM - (x0) the number of rows for matrix A and matrix C. + only process 6 rows + + CountN - (x1) the number of columns from matrix B and matrix C + + CountK - (x2/x0) the number of columns from matrix A and the + number of rows from matrix B. + + C - (x3) the address of matrix C. + + ldc - (x4) - the first dimension of matrix C. + + Bias - (x5) - the address of the Bias vector (optional) + + A - (x6) - the address of matrix A + + lda - (x7) - the first dimension of matrix A + + B - the address of matrix B + + ldb - the first dimension of matrix B + + ZeroMode - true if the output matrix must be zero initialized, else + if the output matrix is accumulated into + +--*/ + + FUNCTION_ENTRY MlasHalfGemmKernelNeon + + str x19,[sp,#-.LHGemmKernelFrame_SavedRegs]! + ldr x9,[sp,#.LHGemmKernelFrame_ldb] + lsl x2,x2,#1 // k *= sizeof(fp16) + cmp x0,2 + add x14,x6,x7,lsl #1 // a1 = a0 + lda + add x10,x3,x4,lsl #1 // c1 = c0 + ldc + ldr x8,[sp,#.LHGemmKernelFrame_B] + csel x14,x6,x14,LO // M < 2 ? a1 = a0 + csel x10,x3,x10,LO // c1 = c0 + add x15,x14,x7,lsl #1 // a2 = a1 + lda + add x11,x10,x4,lsl #1 // c2 = c1 + ldc + csel x15,x14,x15,LS // M <= 2 ? a2 = a1 + csel x11,x10,x11,LS // c2 = c1 + cmp x0,4 + add x16,x15,x7,lsl #1 // a3 = a2 + lda + add x12,x11,x4,lsl #1 // c3 = c2 + ldc + csel x16,x15,x16,LO // M < 4 ? a3 = a2 + csel x12,x11,x12,LO // c3 = c2 + add x17,x16,x7,lsl #1 // a4 = a3 + lda + add x13,x12,x4,lsl #1 // c4 = c3 + ldc + csel x17,x16,x17,LS // M <= 4 ? a4 = a3 + csel x13,x12,x13,LS // c4 = c3 + cmp x0,6 + add x7,x17,x7,lsl #1 // a5 = a4 + lda + add x4,x13,x4,lsl #1 // c5 = c4 + ldc + csel x7,x17,x7,LO // M < 6 ? a5 = a4 + csel x4,x13,x4,LO // c5 = c4 + lsl x9,x9,#1 // ldb *= sizeof(fp16) + ldrb w19,[sp,#.LHGemmKernelFrame_ZeroMode] + sub x9,x9,16 // ldb -= 16 + +/**** +Main loop processes 6x16 tile, depth 4. + B 4x16 + --------------------------------------- + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + A 6x4 --------------------------------------- + ------------------ --------------------------------------- +x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 +x14 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x10 +x15 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x11 +x16 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x12 +x17 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 +x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 + ------------------ --------------------------------------- +****/ + +.LM6N16OutterLoopN: + cbz x5,.LM6N16SkipBias + ldp q20,q21,[x5],32 // Load 16 Bias values + b .LM6N16PopulateAccumulators + +.LM6N16SkipBias: + eor v20.16b,v20.16b,v20.16b // No bias, reset regs + eor v21.16b,v21.16b,v21.16b + +.LM6N16PopulateAccumulators: + mov v22.16b,v20.16b + mov v23.16b,v21.16b + mov v24.16b,v20.16b + mov v25.16b,v21.16b + mov v26.16b,v20.16b + mov v27.16b,v21.16b + mov v28.16b,v20.16b + subs x0,x2,8 // k -= 4 (8 bytes) + mov v29.16b,v21.16b + mov v30.16b,v20.16b + mov v31.16b,v21.16b + b.LO .LM6N16RemainderK123 // remaining k 1~3 + + ldr d0,[x6],8 // A0 + ldr q16,[x8],16 // B0.l + ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row + subs x0,x0,8 // over decement k -= 4 (8 bytes) + ldr d1,[x14],8 // A1 + ldr d2,[x15],8 // A2 + ldr d3,[x16],8 // A3 + b.LO .LM6N16LoopK_Epilogue // need k>=8 for main loop + +.LM6N16InnerLoopK: + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + subs x0,x0,8 // k -= 4 + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + ldr q16,[x8],16 // Load B0.low for next iter + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + ld1 {v17.16b},[x8],x9 // Load B0.high for next iter + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + ldr d0,[x6],8 // Load A0 for next iter + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + ldr d1,[x14],8 // Load A1 for next iter + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + ldr d2,[x15],8 // Load A2 for next iter + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + ldr d3,[x16],8 // Load A3 for next iter + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.hs .LM6N16InnerLoopK // k >= 8 for main loop + +.LM6N16LoopK_Epilogue: + // last block of k >= 4, no pre-load for next iter + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + adds x0,x0,8 // revert k over-decrement + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.NE .LM6N16RemainderK123 // remaining k 1~3 + +.LM6N16OutterLoopNTail: + subs x1,x1,16 // N -= 16 + ldr x8,[sp,#.LHGemmKernelFrame_B] + b.LO .LM6StoreRemainderN // remaining N < 16 + + cbnz x19,.LM6N16SkipAccumulateOutput + ldp q0,q1,[x3] + ldp q2,q3,[x10] + ldp q4,q5,[x11] + ldp q6,q7,[x12] + ldp q16,q17,[x13] + ldp q18,q19,[x4] + fadd v20.8h,v20.8h,v0.8h // !ZeroMode + fadd v21.8h,v21.8h,v1.8h // accumulate into C + fadd v22.8h,v22.8h,v2.8h + fadd v23.8h,v23.8h,v3.8h + fadd v24.8h,v24.8h,v4.8h + fadd v25.8h,v25.8h,v5.8h + fadd v26.8h,v26.8h,v6.8h + fadd v27.8h,v27.8h,v7.8h + fadd v28.8h,v28.8h,v16.8h + fadd v29.8h,v29.8h,v17.8h + fadd v30.8h,v30.8h,v18.8h + fadd v31.8h,v31.8h,v19.8h + +.LM6N16SkipAccumulateOutput: + st1 {v20.16b,v21.16b},[x3],32 + sub x6,x6,x2 // restore a0 + st1 {v22.16b,v23.16b},[x10],32 + sub x14,x14,x2 // restore a1 + st1 {v24.16b,v25.16b},[x11],32 + sub x15,x15,x2 // restore a2 + st1 {v26.16b,v27.16b},[x12],32 + sub x16,x16,x2 // restore a3 + st1 {v28.16b,v29.16b},[x13],32 + sub x17,x17,x2 // restore a4 + add x8,x8,32 // B <- next 16 columns + st1 {v30.16b,v31.16b},[x4],32 + sub x7,x7,x2 // restore a5 + str x8,[sp,#.LHGemmKernelFrame_B] + b.HI .LM6N16OutterLoopN + +.LExitKernel: + ldr x19,[sp],#.LHGemmKernelFrame_SavedRegs + ret + +.LM6N16RemainderK123: + tbz x0,2,.LM6N16RemainderK1 + ldr s0,[x6],4 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr s1,[x14],4 // A1 + ldr s2,[x15],4 // A2 + ldr s3,[x16],4 // A3 + ldr s4,[x17],4 // A4 + ldr s5,[x7],4 // A5 + ldr q18,[x8],16 // B1.low + ld1 {v19.16b},[x8],x9 // B2.high + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + + fmla v20.8h,v18.8h,v0.h[1] + fmla v22.8h,v18.8h,v1.h[1] + fmla v24.8h,v18.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v21.8h,v19.8h,v0.h[1] + fmla v23.8h,v19.8h,v1.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v31.8h,v19.8h,v5.h[1] + tbz x0,1,.LM6N16OutterLoopNTail + +.LM6N16RemainderK1: + ldr h0,[x6],2 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr h1,[x14],2 // A1 + ldr h2,[x15],2 // A2 + ldr h3,[x16],2 // A3 + ldr h4,[x17],2 // A4 + ldr h5,[x7],2 // A5 + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + b .LM6N16OutterLoopNTail + +.LM6StoreRemainderN: + cbnz x19,.LM6StoreRemainderNZeroMode + tbz x1,3,.LM6StoreRemainderN4 + ldr q0,[x3] + ldr q1,[x10] + ldr q2,[x11] + ldr q3,[x12] + ldr q4,[x13] + ldr q5,[x4] + fadd v20.8h,v20.8h,v0.8h + fadd v22.8h,v22.8h,v1.8h + fadd v24.8h,v24.8h,v2.8h + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + fadd v26.8h,v26.8h,v3.8h + fadd v28.8h,v28.8h,v4.8h + fadd v30.8h,v30.8h,v5.8h + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +.LM6StoreRemainderN4: + tbz x1,2,.LM6StoreRemainderN2 + ldr d0,[x3] + ldr d1,[x10] + ldr d2,[x11] + ldr d3,[x12] + ldr d4,[x13] + ldr d5,[x4] + fadd v21.4h,v20.4h,v0.4h + dup d20,v20.d[1] + fadd v23.4h,v22.4h,v1.4h + dup d22,v22.d[1] + fadd v25.4h,v24.4h,v2.4h + dup d24,v24.d[1] + fadd v27.4h,v26.4h,v3.4h + dup d26,v26.d[1] + fadd v29.4h,v28.4h,v4.4h + dup d28,v28.d[1] + fadd v31.4h,v30.4h,v5.4h + dup d30,v30.d[1] + str d21,[x3],8 + str d23,[x10],8 + str d25,[x11],8 + str d27,[x12],8 + str d29,[x13],8 + str d31,[x4],8 + +.LM6StoreRemainderN2: + tbz x1,1,.LM6StoreRemainderN1 + ldr s0,[x3] + ldr s1,[x10] + ldr s2,[x11] + ldr s3,[x12] + ldr s4,[x13] + ldr s5,[x4] + fadd v21.4h,v20.4h,v0.4h + fadd v23.4h,v22.4h,v1.4h + fadd v25.4h,v24.4h,v2.4h + fadd v27.4h,v26.4h,v3.4h + fadd v29.4h,v28.4h,v4.4h + fadd v31.4h,v30.4h,v5.4h + str s21,[x3],4 + str s23,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s25,[x11],4 + str s27,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s29,[x13],4 + str s31,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +.LM6StoreRemainderN1: + tbz x1,0,.LExitKernel + ldr h0,[x3] + ldr h1,[x10] + ldr h2,[x11] + ldr h3,[x12] + ldr h4,[x13] + ldr h5,[x4] + fadd v20.4h,v20.4h,v0.4h + fadd v22.4h,v22.4h,v1.4h + fadd v24.4h,v24.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v28.4h,v28.4h,v4.4h + fadd v30.4h,v30.4h,v5.4h + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b .LExitKernel + +.LM6StoreRemainderNZeroMode: + tbz x1,3,.LM6StoreRemainderN4ZeroMode + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +.LM6StoreRemainderN4ZeroMode: + tbz x1,2,.LM6StoreRemainderN2ZeroMode + str d20,[x3],8 + str d22,[x10],8 + dup d20,v20.d[1] + dup d22,v22.d[1] + str d24,[x11],8 + str d26,[x12],8 + dup d24,v24.d[1] + dup d26,v26.d[1] + str d28,[x13],8 + str d30,[x4],8 + dup d28,v28.d[1] + dup d30,v30.d[1] + +.LM6StoreRemainderN2ZeroMode: + tbz x1,1,.LM6StoreRemainderN1ZeroMode + str s20,[x3],4 + str s22,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s24,[x11],4 + str s26,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s28,[x13],4 + str s30,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +.LM6StoreRemainderN1ZeroMode: + tbz x1,0,.LExitKernel + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b .LExitKernel + + .end diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm new file mode 100644 index 0000000000000..d7b626327780c --- /dev/null +++ b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm @@ -0,0 +1,552 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + HalfGemmKernelNeon.asm + +Abstract: + + This module implements the kernels for the half precision matrix/matrix + multiply operation (HALF GEMM). + +--*/ + +#include "kxarm64.h" + +// +// Stack frame layout for the half gemm kernel. +// Callee save registers: d8-d15, x19-x30. x18 is reserved by the OS. +// + +#define HGemmKernelFrame_SavedRegs (2 * 8) +#define HGemmKernelFrame_B 0 + HGemmKernelFrame_SavedRegs +#define HGemmKernelFrame_ldb 8 + HGemmKernelFrame_SavedRegs +#define HGemmKernelFrame_ZeroMode 16 + HGemmKernelFrame_SavedRegs + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute 6 rows of GEMM + +Arguments: + + CountM - (x0) the number of rows for matrix A and matrix C. + only process 6 rows + + CountN - (x1) the number of columns from matrix B and matrix C + + CountK - (x2/x0) the number of columns from matrix A and the + number of rows from matrix B. + + C - (x3) the address of matrix C. + + ldc - (x4) - the first dimension of matrix C. + + Bias - (x5) - the address of the Bias vector (optional) + + A - (x6) - the address of matrix A + + lda - (x7) - the first dimension of matrix A + + B - the address of matrix B + + ldb - the first dimension of matrix B + + ZeroMode - true if the output matrix must be zero initialized, else + if the output matrix is accumulated into + +--*/ + + LEAF_ENTRY MlasHalfGemmKernelNeon + + PROLOG_SAVE_REG x19,#-HGemmKernelFrame_SavedRegs! + ldr x9,[sp,#HGemmKernelFrame_ldb] + lsl x2,x2,#1 // k *= sizeof(fp16) + cmp x0,2 + add x14,x6,x7,lsl #1 // a1 = a0 + lda + add x10,x3,x4,lsl #1 // c1 = c0 + ldc + ldr x8,[sp,#HGemmKernelFrame_B] + csel x14,x6,x14,LO // M < 2 ? a1 = a0 + csel x10,x3,x10,LO // c1 = c0 + add x15,x14,x7,lsl #1 // a2 = a1 + lda + add x11,x10,x4,lsl #1 // c2 = c1 + ldc + csel x15,x14,x15,LS // M <= 2 ? a2 = a1 + csel x11,x10,x11,LS // c2 = c1 + cmp x0,4 + add x16,x15,x7,lsl #1 // a3 = a2 + lda + add x12,x11,x4,lsl #1 // c3 = c2 + ldc + csel x16,x15,x16,LO // M < 4 ? a3 = a2 + csel x12,x11,x12,LO // c3 = c2 + add x17,x16,x7,lsl #1 // a4 = a3 + lda + add x13,x12,x4,lsl #1 // c4 = c3 + ldc + csel x17,x16,x17,LS // M <= 4 ? a4 = a3 + csel x13,x12,x13,LS // c4 = c3 + cmp x0,6 + add x7,x17,x7,lsl #1 // a5 = a4 + lda + add x4,x13,x4,lsl #1 // c5 = c4 + ldc + csel x7,x17,x7,LO // M < 6 ? a5 = a4 + csel x4,x13,x4,LO // c5 = c4 + lsl x9,x9,#1 // ldb *= sizeof(fp16) + ldrb w19,[sp,#HGemmKernelFrame_ZeroMode] + sub x9,x9,16 // ldb -= 16 + +/**** +Main loop processes 6x16 tile, depth 4. + B 4x16 + --------------------------------------- + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + A 6x4 --------------------------------------- + ------------------ --------------------------------------- +x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 +x14 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x10 +x15 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x11 +x16 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x12 +x17 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 +x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 + ------------------ --------------------------------------- +****/ + +M6N16OutterLoopN + cbz x5,M6N16SkipBias + ldp q20,q21,[x5],32 // Load 16 Bias values + b M6N16PopulateAccumulators + +M6N16SkipBias + eor q20.16b,q20.16b,q20.16b // No bias, reset regs + eor q21.16b,q21.16b,q21.16b + +M6N16PopulateAccumulators + mov v22.16b,v20.16b + mov v23.16b,v21.16b + mov v24.16b,v20.16b + mov v25.16b,v21.16b + mov v26.16b,v20.16b + mov v27.16b,v21.16b + mov v28.16b,v20.16b + subs x0,x2,8 // k -= 4 (8 bytes) + mov v29.16b,v21.16b + mov v30.16b,v20.16b + mov v31.16b,v21.16b + b.LO M6N16RemainderK123 // remaining k 1~3 + + ldr d0,[x6],8 // A0 + ldr q16,[x8],16 // B0.l + ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row + subs x0,x0,8 // over decement k -= 4 (8 bytes) + ldr d1,[x14],8 // A1 + ldr d2,[x15],8 // A2 + ldr d3,[x16],8 // A3 + b.LO M6N16LoopK_Epilogue // need k>=8 for main loop + +M6N16InnerLoopK + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + subs x0,x0,8 // k -= 4 + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + ldr q16,[x8],16 // Load B0.low for next iter + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + ld1 {v17.16b},[x8],x9 // Load B0.high for next iter + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + ldr d0,[x6],8 // Load A0 for next iter + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + ldr d1,[x14],8 // Load A1 for next iter + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + ldr d2,[x15],8 // Load A2 for next iter + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + ldr d3,[x16],8 // Load A3 for next iter + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.hs M6N16InnerLoopK // k >= 8 for main loop + +M6N16LoopK_Epilogue + // last block of k >= 4, no pre-load for next iter + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + adds x0,x0,8 // revert k over-decrement + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.NE M6N16RemainderK123 // remaining k 1~3 + +M6N16OutterLoopNTail + subs x1,x1,16 // N -= 16 + ldr x8,[sp,#HGemmKernelFrame_B] + b.LO M6StoreRemainderN // remaining N < 16 + + cbnz x19,M6N16SkipAccumulateOutput + ldp q0,q1,[x3] + ldp q2,q3,[x10] + ldp q4,q5,[x11] + ldp q6,q7,[x12] + ldp q16,q17,[x13] + ldp q18,q19,[x4] + fadd v20.8h,v20.8h,v0.8h // !ZeroMode + fadd v21.8h,v21.8h,v1.8h // accumulate into C + fadd v22.8h,v22.8h,v2.8h + fadd v23.8h,v23.8h,v3.8h + fadd v24.8h,v24.8h,v4.8h + fadd v25.8h,v25.8h,v5.8h + fadd v26.8h,v26.8h,v6.8h + fadd v27.8h,v27.8h,v7.8h + fadd v28.8h,v28.8h,v16.8h + fadd v29.8h,v29.8h,v17.8h + fadd v30.8h,v30.8h,v18.8h + fadd v31.8h,v31.8h,v19.8h + +M6N16SkipAccumulateOutput + st1 {v20.16b,v21.16b},[x3],32 + sub x6,x6,x2 // restore a0 + st1 {v22.16b,v23.16b},[x10],32 + sub x14,x14,x2 // restore a1 + st1 {v24.16b,v25.16b},[x11],32 + sub x15,x15,x2 // restore a2 + st1 {v26.16b,v27.16b},[x12],32 + sub x16,x16,x2 // restore a3 + st1 {v28.16b,v29.16b},[x13],32 + sub x17,x17,x2 // restore a4 + add x8,x8,32 // B <- next 16 columns + st1 {v30.16b,v31.16b},[x4],32 + sub x7,x7,x2 // restore a5 + str x8,[sp,#HGemmKernelFrame_B] + b.HI M6N16OutterLoopN + +ExitKernel + EPILOG_RESTORE_REG x19,#HGemmKernelFrame_SavedRegs! + EPILOG_RETURN + +M6N16RemainderK123 + tbz x0,2,M6N16RemainderK1 + ldr s0,[x6],4 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr s1,[x14],4 // A1 + ldr s2,[x15],4 // A2 + ldr s3,[x16],4 // A3 + ldr s4,[x17],4 // A4 + ldr s5,[x7],4 // A5 + ldr q18,[x8],16 // B1.low + ld1 {v19.16b},[x8],x9 // B2.high + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + + fmla v20.8h,v18.8h,v0.h[1] + fmla v22.8h,v18.8h,v1.h[1] + fmla v24.8h,v18.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v21.8h,v19.8h,v0.h[1] + fmla v23.8h,v19.8h,v1.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v31.8h,v19.8h,v5.h[1] + tbz x0,1,M6N16OutterLoopNTail + +M6N16RemainderK1 + ldr h0,[x6],2 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr h1,[x14],2 // A1 + ldr h2,[x15],2 // A2 + ldr h3,[x16],2 // A3 + ldr h4,[x17],2 // A4 + ldr h5,[x7],2 // A5 + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + b M6N16OutterLoopNTail + +M6StoreRemainderN + cbnz x19,M6StoreRemainderNZeroMode + tbz x1,3,M6StoreRemainderN4 + ldr q0,[x3] + ldr q1,[x10] + ldr q2,[x11] + ldr q3,[x12] + ldr q4,[x13] + ldr q5,[x4] + fadd v20.8h,v20.8h,v0.8h + fadd v22.8h,v22.8h,v1.8h + fadd v24.8h,v24.8h,v2.8h + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + fadd v26.8h,v26.8h,v3.8h + fadd v28.8h,v28.8h,v4.8h + fadd v30.8h,v30.8h,v5.8h + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +M6StoreRemainderN4 + tbz x1,2,M6StoreRemainderN2 + ldr d0,[x3] + ldr d1,[x10] + ldr d2,[x11] + ldr d3,[x12] + ldr d4,[x13] + ldr d5,[x4] + fadd v21.4h,v20.4h,v0.4h + dup d20,v20.d[1] + fadd v23.4h,v22.4h,v1.4h + dup d22,v22.d[1] + fadd v25.4h,v24.4h,v2.4h + dup d24,v24.d[1] + fadd v27.4h,v26.4h,v3.4h + dup d26,v26.d[1] + fadd v29.4h,v28.4h,v4.4h + dup d28,v28.d[1] + fadd v31.4h,v30.4h,v5.4h + dup d30,v30.d[1] + str d21,[x3],8 + str d23,[x10],8 + str d25,[x11],8 + str d27,[x12],8 + str d29,[x13],8 + str d31,[x4],8 + +M6StoreRemainderN2 + tbz x1,1,M6StoreRemainderN1 + ldr s0,[x3] + ldr s1,[x10] + ldr s2,[x11] + ldr s3,[x12] + ldr s4,[x13] + ldr s5,[x4] + fadd v21.4h,v20.4h,v0.4h + fadd v23.4h,v22.4h,v1.4h + fadd v25.4h,v24.4h,v2.4h + fadd v27.4h,v26.4h,v3.4h + fadd v29.4h,v28.4h,v4.4h + fadd v31.4h,v30.4h,v5.4h + str s21,[x3],4 + str s23,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s25,[x11],4 + str s27,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s29,[x13],4 + str s31,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +M6StoreRemainderN1 + tbz x1,0,ExitKernel + ldr h0,[x3] + ldr h1,[x10] + ldr h2,[x11] + ldr h3,[x12] + ldr h4,[x13] + ldr h5,[x4] + fadd v20.4h,v20.4h,v0.4h + fadd v22.4h,v22.4h,v1.4h + fadd v24.4h,v24.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v28.4h,v28.4h,v4.4h + fadd v30.4h,v30.4h,v5.4h + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b ExitKernel + +M6StoreRemainderNZeroMode + tbz x1,3,M6StoreRemainderN4ZeroMode + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +M6StoreRemainderN4ZeroMode + tbz x1,2,M6StoreRemainderN2ZeroMode + str d20,[x3],8 + str d22,[x10],8 + dup d20,v20.d[1] + dup d22,v22.d[1] + str d24,[x11],8 + str d26,[x12],8 + dup d24,v24.d[1] + dup d26,v26.d[1] + str d28,[x13],8 + str d30,[x4],8 + dup d28,v28.d[1] + dup d30,v30.d[1] + +M6StoreRemainderN2ZeroMode + tbz x1,1,M6StoreRemainderN1ZeroMode + str s20,[x3],4 + str s22,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s24,[x11],4 + str s26,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s28,[x13],4 + str s30,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +M6StoreRemainderN1ZeroMode + tbz x1,0,ExitKernel + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b ExitKernel + + LEAF_END MlasHalfGemmKernelNeon + + END diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp new file mode 100644 index 0000000000000..778db2003d6c6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -0,0 +1,334 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + half gemm.cpp + +Abstract: + + This module implements the half precision (fp16) matrix/matrix multiply + operation (QGEMM). + +--*/ + +#include "mlasi.h" +#include "mlas_float16.h" + +#include "halfgemm.h" + +#include + +bool MLASCALL +MlasFp16AccelerationSupported() +{ + return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); +} + + +void +MLASCALL +MlasHalfGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool + ) +{ + const MLAS_HALFGEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); + MLAS_HALFGEMM_OPERATION* operation = dispatch->Operation; + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + operation(N, K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + const size_t StrideM = dispatch->StrideM; + + size_t nc = N; + if ((size_t)MlasGetMaximumThreadCount(ThreadPool) > BatchN) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min(nc, MlasDivRoundup(nc, max_nc * MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + operation(N, K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} + + +size_t +MLASCALL +MlasHalfGemmPackBSize( + size_t N, + size_t K, + bool float2half + ) +{ + const auto* dispatch = MlasHalfGemmGetDispatch(); + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackededK; + if (!float2half && dispatch->CopyPackBRoutine == nullptr) { + // No packing routine provided + return 0; + } + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t BytesRequired = N * AlignedK * FP16_SIZE + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + return AlignedBytesRequired; +} + +void +MLASCALL +MlasHalfGemmPackB( + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB + ) +{ + const auto* dispatch = MlasHalfGemmGetDispatch(); + dispatch->CopyPackBRoutine((_mlas_fp16_*)PackedB, (const _mlas_fp16_*)B, ldb, N, K); +} + +void +MLASCALL +MlasHalfGemmConvertPackB( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ) +{ + const auto* dispatch = MlasHalfGemmGetDispatch(); + dispatch->ConvertPackBRoutine((_mlas_fp16_*)PackedB, B, ldb, N, K); +} + + +// +// Post Processor Implementations +// + +MLAS_FORCEINLINE +void +CvtHalf2Float( + float* dest, + const _mlas_fp16_* src, + size_t len +) +{ +#ifdef MLAS_TARGET_ARM64 + while (len >= 4) { + const auto* srcPtr = reinterpret_cast(src); + auto* dstPtr = reinterpret_cast(dest); + *dstPtr = vcvt_f32_f16(*srcPtr); + src += 4; + dest += 4; + len -= 4; + } + + if (0 == len) { + return; + } + + float16x4_t buf; + std::memcpy(&buf, src, len * sizeof(_mlas_fp16_)); + float32x4_t res = vcvt_f32_f16(buf); + + if ((len & 2) != 0) { + auto wide = vreinterpretq_f64_f32(res); + vst1q_lane_f64((float64_t*)dest, wide, 0); + res = vreinterpretq_f32_f64(vdupq_laneq_f64(wide, 1)); + dest += 2; + } + if ((len & 1) != 0) { + vst1q_lane_f32(dest, res, 0); + } +#else + for (size_t i = 0; i < len; i++) { + *dest++ = MLAS_Half2Float(*src++); + } +#endif // MLAS_TARGET_ARM64 +} + +void +MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( + MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const +{ + // + // TODO!! use templates to add activations in this impl + // + float* Output = Output_; + const auto* CRow = reinterpret_cast(C); + CRow += StartM * ldc + StartN; + Output += StartM * RowStride_ + StartN; + + while (CountM-- > 0) { + CvtHalf2Float(Output, CRow, CountN); + + CRow += ldc; + Output += RowStride_; + } +} + + +// +// Dummy C++ implementation that runs very slowly +// + +struct MLAS_HALF_GEMM_KERNEL_DEFAULT { + + static constexpr bool PackNeeded = false; + static constexpr size_t KernelMaxM = 128; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 1; + + static constexpr MLAS_HALF_GEMM_STRIDES Strides{8, 16, 32}; +}; + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +) +{ + for (size_t m = 0; m < CountM; m++) { + for (size_t k = 0; k < CountK; k++) { + *D++ = MLAS_Float2Half(*(A + m * lda + k)); + } + } +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + for (size_t k = 0; k < CountK; k++) { + for (size_t n = 0; n < CountN; n++) { + *D++ = MLAS_Float2Half(*(B + k * ldb + n)); + } + } +} + + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmKernel( + size_t CountM, + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + const bool ZeroMode) +{ + for (size_t m = 0; m < CountM; m++) { + for (size_t n = 0; n < CountN; n++) { + const auto* a = A + (m * lda); + const auto* b = B + n; + auto* c = C + (m * ldc) + n; + + float sum = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[n]); + if (!ZeroMode) { + sum += MLAS_Half2Float(*c); + } + + for (size_t k = 0; k < CountK; k++) { + auto down = MLAS_Float2Half(MLAS_Half2Float(*a) * MLAS_Half2Float(*b) + sum); + sum = MLAS_Half2Float(down); + b += ldb; + a += 1; + } + + *c = MLAS_Float2Half(sum); + } + } +} + + +const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { + MlasHalfGemmOperation, + nullptr, + MlasHalfGemmConvertPackB, + MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, + MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM, + 0 +}; diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h new file mode 100644 index 0000000000000..9e781207571a4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -0,0 +1,515 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm.h + +Abstract: + + This module defines the set of template functions to implement half + precision matrix/matrix multiply operation (QGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasHalfGemmCopyPackB + MlasHalfGemmConvertPackA + MlasHalfGemmConvertPackB + MlasHalfGemmPackedBOffset + MlasHalfGemmPackedBLeadingDim + MlasHalfGemmKernel + + MlasHalfGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether fp16 B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + MLAS_HALF_GEMM_STRIDES Strides{128, 128, 128}; +--*/ + +#pragma once + +#include +#include +#include + +#include "mlasi.h" +#include "mlas_float16.h" + + +/** + * @brief Define the default striding parameters for + * the half precision gemm operation + */ +struct MLAS_HALF_GEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Packing function for fp16 B matrix + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack +*/ +template +MLAS_FORCEINLINE +void +MlasHalfGemmCopyPackB( + _mlas_fp16_* D, + const _mlas_fp16_* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + MLAS_UNREFERENCED_PARAMETER(D); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(CountN); + MLAS_UNREFERENCED_PARAMETER(CountK); + // No packing needed by default +} + +/** + * @brief Convert fp32 matrix A to fp16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of the packing buffer + * @param[in] A Address of fp32 matrix A + * @param[in] lda leading dimension of A + * @param[in] CountM # of rows to pack + * @param[in] CountK # of columns to pack +*/ +template +void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +); + +/** + * @brief Convert fp32 matrix B to fp16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] +*/ +template +MLAS_FORCEINLINE +const _mlas_fp16_* +MlasHalfGemmPackedBOffset( + const _mlas_fp16_* PackedB, + size_t DimN, + size_t DimK, + size_t StartN, + size_t StartK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +/*No it can NOT be constexpr!.*/ +#pragma warning(disable : 26497) +#endif + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer +*/ +template +MLAS_FORCEINLINE +size_t +MlasHalfGemmPackedBLeadingDim( + size_t DimN, + size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + +template +void +MlasHalfGemmKernel( + const size_t CountM, + const size_t CountN, + const size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + const size_t lda, + const _mlas_fp16_* B, + const size_t ldb, + const bool ZeroMode +); + + +template +MLAS_FORCEINLINE +void +MlasHalfGemmNoPackOperation( + const size_t N, + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) +{ + // + // Optimize for the special case where no packing is needed. + // Simpler tiling as we are not restricted by packing panel size + // + + const size_t lda = Data->lda; + size_t ldb = Data->ldb; // 0 if prepacked + const size_t ldc = Data->ldc; + + const auto* pa = reinterpret_cast(Data->A) + + RangeStartM * lda; + const _mlas_fp16_* pb; + if (ldb == 0) { + pb = MlasHalfGemmPackedBOffset( + reinterpret_cast(Data->B), + N, + K, + RangeStartN, + 0); + ldb = MlasHalfGemmPackedBLeadingDim(N, K); + } else { + pb = reinterpret_cast(Data->B) + RangeStartN; + } + + const _mlas_fp16_* Bias = (nullptr == Data->Bias) + ? nullptr + : reinterpret_cast(Data->Bias) + RangeStartN; + _mlas_fp16_* c = reinterpret_cast<_mlas_fp16_*>(Data->C) + + RangeStartM * ldc + RangeStartN; + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + MlasHalfGemmKernel( + RowsRemaining, + RangeCountN, + K, + c, + ldc, + Bias, + pa, + lda, + pb, + ldb, + true); + + size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); + + if (Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process( + Data->C, + RangeStartM + RangeCountM - RowsRemaining, + RangeStartN, + RowsHandled, + RangeCountN, + Data->ldc); + } + + c += ldc * RowsHandled; + pa += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } +} + + +template +void +MlasHalfGemmOperation( + const size_t N, + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) +{ + const size_t lda = Data->lda; + const size_t ldb = Data->ldb; + const size_t ldc = Data->ldc; + + if (!Data->AIsfp32 && (ldb == 0 || (!KernelType::PackNeeded && !Data->BIsfp32))) { + // !Data->AIsfp32 => A is fp16, no packing on the left hand side + // ldb == 0 => B is already packed, no packing on the right hand side + // !KernelType::PackNeeded && !Data->BIsfp32 => B is fp16 and the kernel + // does not require packing + // + // So no packing needed on either A or B, use a simpler driver instead + + MlasHalfGemmNoPackOperation( + N, + K, + Data, + RangeStartM, + RangeCountM, + RangeStartN, + RangeCountN); + return; + } + + const auto* Bias = reinterpret_cast(Data->Bias); + _mlas_fp16_* C = reinterpret_cast<_mlas_fp16_*>(Data->C) + + RangeStartM * ldc + RangeStartN; + + // + // Three dimensional tiling due to limited packing panel size + // + constexpr MLAS_HALF_GEMM_STRIDES Strides = KernelType::Strides; + constexpr size_t packASize = UpAlignSize(Strides.M * Strides.K * FP16_SIZE); + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * FP16_SIZE); + MlasThreadedBufAlloc(packASize + packBSize); + + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelA = reinterpret_cast<_mlas_fp16_*>(p); + p += packASize; + auto* PanelB = reinterpret_cast<_mlas_fp16_*>(p); + + // + // Step through each slice of matrix B along the K dimension. + // + + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, Strides.K); + const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; + + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, Strides.N); + + // + // Copy a panel of matrix B to a local packed buffer. + // + size_t ld_pb; + const _mlas_fp16_* pb; + if (ldb == 0) { + // Already packed + pb = MlasHalfGemmPackedBOffset( + reinterpret_cast(Data->B), + N, + K, + RangeStartN + n, + k); + ld_pb = MlasHalfGemmPackedBLeadingDim(N, K); + } else if (Data->BIsfp32) { + // fp32, need conversion and packing + MlasHalfGemmConvertPackB( + PanelB, + reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, + ldb, + CountN, + CountK); + pb = PanelB; + ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); + } else if (KernelType::PackNeeded) { + // fp16, need packing + MlasHalfGemmCopyPackB( + PanelB, + reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, + ldb, + CountN, + CountK); + pb = PanelB; + ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); + } else { + // fp16, and no packing needed + pb = reinterpret_cast(Data->B) + ldb * k + RangeStartN + n; + ld_pb = ldb; + } + + // + // Step through each slice of matrix A along the M dimension. + // + + auto* c = C + n; + const auto* pbias = (nullptr == Bias) ? nullptr : Bias + RangeStartN + n; + size_t CountM; + for (size_t m = 0; m < RangeCountM; m += CountM) { + CountM = std::min(RangeCountM - m, Strides.M); + + // + // Copy a panel of matrix A to a local packed buffer. + // + const _mlas_fp16_* pa; + size_t ld_pa; + if (Data->AIsfp32) { + MlasHalfGemmConvertPackA( + PanelA, + reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k, + lda, + CountM, + CountK); + pa = PanelA; + ld_pa = KernelType::PackedK * PackedCountK; + } else { + pa = reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k; + ld_pa = lda; + } + + size_t RowsRemaining = CountM; + bool ZeroMode = (k == 0); + bool PostProcess = (k + CountK == K); + + while (RowsRemaining > 0) { + MlasHalfGemmKernel( + RowsRemaining, + CountN, + CountK, + c, + ldc, + ZeroMode ? pbias : nullptr, + pa, + ld_pa, + pb, + ld_pb, + ZeroMode); + + size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); + + if (PostProcess && Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process( + Data->C, + RangeStartM + m + CountM - RowsRemaining, + RangeStartN + n, + RowsHandled, + CountN, + Data->ldc); + } + + c += ldc * RowsHandled; + pa += ld_pa * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + } + } +} + + +// +// dispatch structure. +// + +typedef +void +(MLAS_HALFGEMM_OPERATION)( + const size_t N, + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ); + + +typedef +void +(MLAS_HALFGEMM_COPYPACKB_ROUTINE)( + _mlas_fp16_* D, + const _mlas_fp16_* B, + size_t ldb, + size_t CountN, + size_t CountK + ); + +typedef +void +(MLAS_HALFGEMM_CONVERTPACKB_ROUTINE)( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK + ); + +/** + * @brief Hardware dependent dispatch for half precision GEMM +*/ +struct MLAS_HALFGEMM_DISPATCH { + MLAS_HALFGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_HALFGEMM_COPYPACKB_ROUTINE* CopyPackBRoutine; /**< Pack function for B */ + MLAS_HALFGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackededK; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault; + +#if defined(MLAS_TARGET_ARM64) +extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon; +#endif + +MLAS_FORCEINLINE +const MLAS_HALFGEMM_DISPATCH* +MlasHalfGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasHalfGemmDispatchNeon; +#else + return &MlasHalfGemmDispatchDefault; +#endif +} diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..d7f5a90b00589 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -0,0 +1,187 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_neon.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include "mlasi.h" +#include "halfgemm.h" + +#include "arm_neon.h" + +// +// Define the prototypes of the NEON routines written in assembly. +// +// N.B. The kernel has not been ported to build with the Windows ARM32 toolset. +// + +extern "C" { + + size_t + MLASCALL + MlasHalfGemmKernelNeon( + const size_t CountM, + const size_t CountN, + const size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + const size_t lda, + const _mlas_fp16_* B, + const size_t ldb, + const bool ZeroMode + ); + +} + + +struct MLAS_HALF_GEMM_KERNEL_NEON { + static constexpr bool PackNeeded = false; + static constexpr size_t KernelMaxM = 6; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 1; + + static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 512}; +}; + + +MLAS_FORCEINLINE +void +CvtFloat2Half( + _mlas_fp16_* dest, + const float* src, + size_t len +) +{ + while (len >= 4) { + const auto* srcPtr = reinterpret_cast(src); + auto* dstPtr = reinterpret_cast(dest); + *dstPtr = vcvt_f16_f32(*srcPtr); + src += 4; + dest += 4; + len -= 4; + } + + if (0 == len) { + return; + } + + float32x4_t buf; + std::memcpy(&buf, src, len * sizeof(float)); + float16x4_t res = vcvt_f16_f32(buf); + + if ((len & 2) != 0) { + auto wide = vreinterpret_f32_f16(res); + vst1_lane_f32((float32_t*)dest, wide, 0); + res = vreinterpret_f16_f32(vdup_lane_f32(wide, 1)); + dest += 2; + } + if ((len & 1) != 0) { + vst1_lane_u16(dest, vreinterpret_u16_f16(res), 0); + } +} + +/** + * @brief Convert a 2D matrix from float to fp16 +*/ +MLAS_FORCEINLINE +void +CvtFloat2Half2D( + _mlas_fp16_* dest, + const float* src, + size_t stride, + size_t CntRow, + size_t CntCol + ) +{ + if (stride == CntCol) { + const size_t len = CntRow * CntCol; + CvtFloat2Half(dest, src, len); + return; + } + while (CntRow > 0) { + CvtFloat2Half(dest, src, CntCol); + src += stride; + dest += CntCol; + CntRow--; + } +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +) +{ + CvtFloat2Half2D(D, A, lda, CountM, CountK); +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + CvtFloat2Half2D(D, B, ldb, CountK, CountN); +} + + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmKernel( + size_t CountM, + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + const bool ZeroMode) +{ + MlasHalfGemmKernelNeon( + CountM, + CountN, + CountK, + C, + ldc, + Bias, + A, + lda, + B, + ldb, + ZeroMode); +} + + +const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = { + MlasHalfGemmOperation, + nullptr, + MlasHalfGemmConvertPackB, + MLAS_HALF_GEMM_KERNEL_NEON::PackedK, + MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 31999f3294999..21949535cf63b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -107,6 +107,8 @@ Module Name: #include "core/common/cpuid_info.h" using MLAS_CPUIDINFO = onnxruntime::CPUIDInfo; +#include "core/framework/float16.h" + #else // BUILD_MLAS_NO_ONNXRUNTIME class MLASCPUIDInfo @@ -121,6 +123,8 @@ class MLASCPUIDInfo // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + bool HasFp16VectorAcceleration() const { return has_fp16_; } + uint32_t GetCurrentCoreIdx() const { return 0xFFFFFFFF; } int32_t GetCurrentUarch() const { return -1; } @@ -135,6 +139,7 @@ class MLASCPUIDInfo MLASCPUIDInfo(); bool has_arm_neon_dot_{false}; + bool has_fp16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -179,7 +184,49 @@ enum MlasUArch { #endif // MLAS_TARGET_ARM64 -#endif // BUILD_MLAS_NO_ONNXRUNTIME +// +// Define MLAS_FP16 +// +#include "mlas_float16.h" + +namespace onnxruntime +{ +struct MLFloat16 { + uint16_t val{0}; + + MLFloat16() = default; + explicit constexpr MLFloat16(uint16_t x) : val(x) {} + explicit MLFloat16(float ff) : val(MLAS_Float2Half(ff)) {} + + float ToFloat() const { return MLAS_Half2Float(val); } + + operator float() const { return ToFloat(); } + + MLFloat16& operator=(float ff) + { + val = MLAS_Float2Half(ff); + return *this; + } +}; + +inline bool +operator==(const MLFloat16& left, const MLFloat16& right) +{ + return left.val == right.val; +} + +inline bool +operator!=(const MLFloat16& left, const MLFloat16& right) +{ + return left.val != right.val; +} + +} + +#endif // BUILD_MLAS_NO_ONNXRUNTIME + +static_assert(sizeof(MLAS_FP16) == FP16_SIZE); + // // Define the maximum number of threads supported by this implementation. @@ -700,9 +747,9 @@ extern "C" { // thread to perform additional work. // -#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024) -#define MLAS_DGEMM_THREAD_COMPLEXITY (64 * 1024) -#define MLAS_QGEMM_THREAD_COMPLEXITY (64 * 1024) +#define MLAS_SGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#define MLAS_QGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) // // Single-threaded single precision matrix/matrix multiply operation. diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 7d8624a32a218..c52d4f3b0b8c4 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -36,6 +36,9 @@ Module Name: MLASCPUIDInfo::MLASCPUIDInfo() { has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); + + // raw hack! Need CPUIDInfo implementation for more precise detection + has_fp16_ = has_arm_neon_dot_; } #endif @@ -50,7 +53,13 @@ MLASCPUIDInfo::MLASCPUIDInfo() #endif #if defined(BUILD_MLAS_NO_ONNXRUNTIME) -MLASCPUIDInfo::MLASCPUIDInfo() { has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); } +MLASCPUIDInfo::MLASCPUIDInfo() +{ + has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + + // raw hack! Need CPUIDInfo implementation for more precise detection + has_fp16_ = has_arm_neon_dot_; +} #endif #else diff --git a/onnxruntime/core/mlas/lib/reorder.cpp b/onnxruntime/core/mlas/lib/reorder.cpp index 0d7fbd97a4a6f..99c1dbac3b692 100644 --- a/onnxruntime/core/mlas/lib/reorder.cpp +++ b/onnxruntime/core/mlas/lib/reorder.cpp @@ -1,6 +1,7 @@ /*++ Copyright (c) Microsoft Corporation. All rights reserved. +Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. Licensed under the MIT License. @@ -17,6 +18,20 @@ Module Name: #include "mlasi.h" +// +// Define the parameters to execute segments of a NCHW output reordering +// operation on worker threads. +// + +struct MLAS_REORDER_OUTPUT_NCHW_BLOCK { + ptrdiff_t TargetThreadCount; + const float* S; + float* D; + size_t OutputChannels; + size_t OutputSize; + size_t TasksCount; +}; + MLAS_FORCEINLINE void MlasReorderGatherFloat32x4( @@ -396,25 +411,22 @@ Return Value: } void -MLASCALL -MlasReorderOutputNchw( - const int64_t* OutputShape, - const float* S, - float* D +MlasReorderOutputNchwThreaded( + void* Context, + ptrdiff_t Index ) /*++ Routine Description: - This routine reorders an output buffer from NCHWc to NCHW format. + This routine is invoked from a worker thread to execute a segment of a + NCHW output reordering operation. Arguments: - OutputShape - Supplies the shape of the output tensor. - - S - Supplies the address of the source tensor. + Context - Supplies the pointer to the context for the threaded operation. - D - Supplies the address of the destination tensor. + Index - Supplies the current index of the threaded operation. Return Value: @@ -422,77 +434,168 @@ Return Value: --*/ { + const auto* WorkBlock = (MLAS_REORDER_OUTPUT_NCHW_BLOCK*)Context; + + const size_t OutputChannels = WorkBlock->OutputChannels; + const size_t OutputSize = WorkBlock->OutputSize; + const float* S = WorkBlock->S; + float* D = WorkBlock->D; + const size_t BlockSize = MlasNchwcGetBlockSize(); + const size_t TasksPerBatch = size_t(ceil(((float)OutputChannels) / BlockSize)); + const size_t LastTaskInBatchIndex = TasksPerBatch - 1; - const size_t BatchCount = size_t(OutputShape[0]); - const size_t OutputChannels = size_t(OutputShape[1]); - const size_t OutputSize = size_t(OutputShape[2]) * size_t(OutputShape[3]); + // + // Compute the range of task indices to use for this thread. + // + + size_t TaskStart; + size_t TasksRemaining; + MlasPartitionWork(Index, WorkBlock->TargetThreadCount, WorkBlock->TasksCount, + &TaskStart, &TasksRemaining); + + size_t TaskEnd = TaskStart + TasksRemaining; + // - // Transpose NCHWc blocks from the source buffer to the destination buffer. + // Rebase the pointers to the source and destination buffers for this thread. // - for (size_t batch = 0; batch < BatchCount; batch++) { + size_t FirstBatchIndex = TaskStart / TasksPerBatch; + size_t FirstTaskInBatchIndex = TaskStart % TasksPerBatch; + S += BlockSize * OutputSize * (FirstBatchIndex * TasksPerBatch + FirstTaskInBatchIndex); + D += OutputSize * (FirstBatchIndex * OutputChannels + BlockSize * FirstTaskInBatchIndex); - for (size_t o = OutputChannels; o > 0;) { + // + // Transpose NCHWc blocks associated with tasks in the range [TaskStart, TaskEnd) + // from the source buffer to the destination buffer. + // - const size_t OutputChannelsThisIteration = std::min(o, BlockSize); - const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3); - o -= OutputChannelsThisIteration; + for (size_t t = TaskStart; t < TaskEnd; t++) { + size_t TaskInBatchIndex = t % TasksPerBatch; - const float* s = S; - float* d = D; - size_t OutputSizeRemaining = OutputSize; + const size_t OutputChannelsThisIteration = (TaskInBatchIndex < LastTaskInBatchIndex) ? + BlockSize : OutputChannels - BlockSize * LastTaskInBatchIndex; + const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3); - for (; OutputSizeRemaining >= 4; OutputSizeRemaining -= 4) { + const float* s = S; + float* d = D; + size_t OutputSizeRemaining = OutputSize; - const float* ss = s; - float* dd = d; - size_t bc = 0; + for (; OutputSizeRemaining >= 4; OutputSizeRemaining -= 4) { - for (; bc < AlignedOutputChannelsThisIteration; bc += 4) { - MlasReorderTransposeFloat32x4x4(ss, dd, BlockSize, OutputSize); - ss += 4; - dd += 4 * OutputSize; - } + const float* ss = s; + float* dd = d; + size_t bc = 0; - for (; bc < OutputChannelsThisIteration; bc += 1) { - MlasReorderGatherFloat32x4(ss, dd, BlockSize); - ss += 1; - dd += OutputSize; - } + for (; bc < AlignedOutputChannelsThisIteration; bc += 4) { + MlasReorderTransposeFloat32x4x4(ss, dd, BlockSize, OutputSize); + ss += 4; + dd += 4 * OutputSize; + } - s += 4 * BlockSize; - d += 4; + for (; bc < OutputChannelsThisIteration; bc += 1) { + MlasReorderGatherFloat32x4(ss, dd, BlockSize); + ss += 1; + dd += OutputSize; } - for (; OutputSizeRemaining > 0; OutputSizeRemaining--) { + s += 4 * BlockSize; + d += 4; + } - const float* ss = s; - float* dd = d; - size_t bc = 0; + for (; OutputSizeRemaining > 0; OutputSizeRemaining--) { - for (; bc < AlignedOutputChannelsThisIteration; bc += 4) { - MlasReorderScatterFloat32x4(ss, dd, OutputSize); - ss += 4; - dd += 4 * OutputSize; - } + const float* ss = s; + float* dd = d; + size_t bc = 0; - for (; bc < OutputChannelsThisIteration; bc += 1) { - *dd = *ss++; - dd += OutputSize; - } + for (; bc < AlignedOutputChannelsThisIteration; bc += 4) { + MlasReorderScatterFloat32x4(ss, dd, OutputSize); + ss += 4; + dd += 4 * OutputSize; + } - s += BlockSize; - d += 1; + for (; bc < OutputChannelsThisIteration; bc += 1) { + *dd = *ss++; + dd += OutputSize; } - S += BlockSize * OutputSize; - D += OutputChannelsThisIteration * OutputSize; + s += BlockSize; + d += 1; } + + S += BlockSize * OutputSize; + D += OutputChannelsThisIteration * OutputSize; } } + +void +MLASCALL +MlasReorderOutputNchw( + const int64_t* OutputShape, + const float* S, + float* D, + MLAS_THREADPOOL* ThreadPool + ) +/*++ + +Routine Description: + + This routine reorders an output buffer from NCHWc to NCHW format. + +Arguments: + + OutputShape - Supplies the shape of the output tensor. + + S - Supplies the address of the source tensor. + + D - Supplies the address of the destination tensor. + +Return Value: + + None. + +--*/ +{ + MLAS_REORDER_OUTPUT_NCHW_BLOCK WorkBlock; + + // + // Capture the NCHW reorder output operation parameters to the work block. + // + + WorkBlock.S = S; + WorkBlock.D = D; + + WorkBlock.OutputChannels = size_t(OutputShape[1]); + WorkBlock.OutputSize = size_t(OutputShape[2]) * size_t(OutputShape[3]); + + const size_t BlockSize = MlasNchwcGetBlockSize(); + const size_t TasksPerBatch = size_t(ceil(((float)WorkBlock.OutputChannels) / BlockSize)); + const size_t BatchCount = size_t(OutputShape[0]); + const size_t TasksCount = BatchCount * TasksPerBatch; + WorkBlock.TasksCount = TasksCount; + + // + // Schedule the operation across a set of worker threads if the output + // tensor is sufficienly large. Limit the number of threads to at least + // the number of available tasks. + // + + ptrdiff_t TargetThreadCount = 1; + const size_t BufferSize = BatchCount * WorkBlock.OutputChannels * WorkBlock.OutputSize; + if (BufferSize > 1024 && TasksCount > 1) { + TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); + if (size_t(TargetThreadCount) > TasksCount) { + TargetThreadCount = ptrdiff_t(TasksCount); + } + } + WorkBlock.TargetThreadCount = TargetThreadCount; + + MlasExecuteThreaded(MlasReorderOutputNchwThreaded, &WorkBlock, TargetThreadCount, ThreadPool); +} + void MLASCALL MlasReorderOutputNhwc( diff --git a/onnxruntime/core/optimizer/bias_softmax_fusion.cc b/onnxruntime/core/optimizer/bias_softmax_fusion.cc index 80603cdbd3270..7c34449d583cc 100755 --- a/onnxruntime/core/optimizer/bias_softmax_fusion.cc +++ b/onnxruntime/core/optimizer/bias_softmax_fusion.cc @@ -135,6 +135,7 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node new_axis = (int)HandleNegativeAxis(axis, rank); // The axis attribute for Softmax in OpSet-11 and OpSet-13 are different. + // Details in function documentatin. if (is_since_opset_13 && new_axis != rank - 1) return false; int singlebatch_rank = rank - new_axis; diff --git a/onnxruntime/core/optimizer/constant_sharing.cc b/onnxruntime/core/optimizer/constant_sharing.cc index 96c60bfd145d8..fa9a309098c76 100644 --- a/onnxruntime/core/optimizer/constant_sharing.cc +++ b/onnxruntime/core/optimizer/constant_sharing.cc @@ -129,7 +129,9 @@ struct GetOrAddValueInConstantStoreDispatcher { } // namespace Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, - const logging::Logger& /*logger*/) const { + const logging::Logger& logger) const { + int shared_count = 0; + // Accumulated map from type/value/rank to initializer: // > The key is a string representation of initializer's data type, value and rank. // > The value is newly created initializer NodeArg* to be shared. @@ -138,9 +140,11 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve InlinedVector original_initializer_names; original_initializer_names.reserve(initialized_tensor_set.size()); for (const auto& entry : initialized_tensor_set) { - // Ignore if the initializer already handled, or not a constant initializer. + // Ignore if the initializer exists in graph output, already handled, + // or not a constant initializer (implicitly excludes the graph input). if (IsSharedInitializer(entry.first) || !graph_utils::IsConstantInitializer(graph, entry.first) || + graph.IsOutput(graph.GetNodeArg(entry.first)) || excluded_initializers_.find(entry.first) != excluded_initializers_.end()) { continue; } @@ -191,6 +195,8 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve NodeArg& shared_scalar_initializer_node_arg = graph_utils::AddInitializer(graph, constant_tensor_proto_as_replacement); pattern_key_to_shared_arg_map[pattern_key] = &shared_scalar_initializer_node_arg; + } else { + shared_count += 1; } ReplaceInputsToUseSharedInitializer(graph, consumer_node_to_input_ports_map, origin_initializer_node_arg, @@ -199,6 +205,8 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve modified = true; } + LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count; + return Status::OK(); } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index fdee3c19f2e8e..7fbef8a9b4de4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -200,7 +200,6 @@ InlinedVector> GenerateTransformers( // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by // default, CSE will not merge them, because the different initializers are represented by different NodeArg. if (session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableDoubleQDQRemover, "0") == "0"){ - transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); } transformers.emplace_back(std::make_unique()); @@ -217,7 +216,7 @@ InlinedVector> GenerateTransformers( // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. - auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault); + auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); // add __backwardpass attribute to nodes after YieldOp, ROCm-only @@ -335,7 +334,7 @@ InlinedVector> GenerateTransformers( if (MlasNchwcGetBlockSize() > 1) { transformers.emplace_back(std::make_unique()); } - auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault); + auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); // NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things // of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for @@ -408,7 +407,7 @@ InlinedVector> GenerateTransformersForMinimalB if (!saving) { #ifndef DISABLE_CONTRIB_OPS const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; - auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault); + auto cpu_allocator = cpu_execution_provider.GetAllocator(OrtMemTypeDefault); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); #else ORT_UNUSED_PARAMETER(cpu_execution_provider); diff --git a/onnxruntime/core/optimizer/identical_children_consolidation.cc b/onnxruntime/core/optimizer/identical_children_consolidation.cc index 17f01cebcdb6c..07dc25dabde5f 100644 --- a/onnxruntime/core/optimizer/identical_children_consolidation.cc +++ b/onnxruntime/core/optimizer/identical_children_consolidation.cc @@ -117,8 +117,11 @@ string_view IdenticalChildrenConsolidation::IdentityBuilder(const Graph& graph, } else { identity.append(name); } + } else { + return ignore_identity; } + identity.append("####"); } - return {identity.append("####")}; + return {identity}; } } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 9895918dd2653..25feb5b8d702c 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -4,6 +4,7 @@ #include "core/optimizer/layer_norm_fusion.h" #include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" +#include "core/optimizer/transpose_optimizer/optimizer_api.h" #include "float.h" #include @@ -16,12 +17,17 @@ static constexpr std::array supported_data_types{"tensor(fl // Default epsilon static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f; -static bool IsSupportedDataType(const Node& node) { +static bool IsSupportedDataType(const Node& node, int first_n_inputs=-1) { + int input_index = 0; for (const auto& input_arg : node.InputDefs()) { + if (first_n_inputs != -1 && input_index >= first_n_inputs) { + return true; + } if (std::find(supported_data_types.begin(), supported_data_types.end(), *(input_arg->Type())) == supported_data_types.end()) { return false; } + ++input_index; } return true; } @@ -99,11 +105,11 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& reduce_mean_node = *p_reduce_mean; ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13, 18}) || !graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) || (reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) || graph.NodeProducesGraphOutput(reduce_mean_node) || - !IsSupportedDataType(reduce_mean_node)) { + !IsSupportedDataType(reduce_mean_node, 1)) { continue; } nodes_to_remove.push_back(reduce_mean_node); @@ -263,10 +269,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13, 18}) || reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, reduce_mean2_node, 1) || - !IsSupportedDataType(reduce_mean2_node) || + !IsSupportedDataType(reduce_mean2_node, 1) || reduce_mean2_node.GetInputEdgesCount() == 0) { continue; } @@ -333,8 +339,16 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // get axes attributes const onnxruntime::NodeAttributes& attributes = reduce_mean_node.GetAttributes(); std::vector axes_values; + // TODO: modify this codes when opset >= 18 (axes is an input). if (attributes.find("axes") != attributes.end()) { axes_values = RetrieveValues(attributes.at("axes")); + } else if (reduce_mean_node.InputDefs().size() == 2) { + auto axes = reduce_mean_node.InputDefs()[1]; + auto axes_const = graph.GetConstantInitializer(axes->Name(), true); + if (axes_const != nullptr) { + Initializer initializer{*axes_const, graph.ModelPath()}; + axes_values.insert(axes_values.end(), initializer.DataAsSpan().begin(), initializer.DataAsSpan().end()); + } } // Get the inputs for the new LayerNormalization node. @@ -485,9 +499,9 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr continue; } Node& reduce_mean_node = *graph.GetNode(p_reduce_mean->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13, 18}) || reduce_mean_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() || - !optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) || !IsSupportedDataType(reduce_mean_node) || + !optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) || !IsSupportedDataType(reduce_mean_node, 1) || reduce_mean_node.GetInputEdgesCount() == 0) { continue; } @@ -585,6 +599,13 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr std::vector axes_values; if (attributes.find("axes") != attributes.end()) { axes_values = RetrieveValues(attributes.at("axes")); + } else if (reduce_mean_node.InputDefs().size() == 2) { + auto axes = reduce_mean_node.InputDefs()[1]; + auto axes_const = graph.GetConstantInitializer(axes->Name(), true); + if (axes_const != nullptr && axes_const->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) { + Initializer initializer{*axes_const, graph.ModelPath()}; + axes_values.insert(axes_values.end(), initializer.DataAsSpan().begin(), initializer.DataAsSpan().end()); + } } // Get the inputs for the new LayerNormalization node. diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index a9c11604a65f9..e10874f79e394 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -35,7 +35,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const std::function& is_sparse_initializer_func) : execution_provider_(execution_provider), is_sparse_initializer_func_(is_sparse_initializer_func) { - allocator_ptr_ = execution_provider_.GetAllocator(device_id_, mem_type_); + allocator_ptr_ = execution_provider_.GetAllocator(mem_type_); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); ORT_THROW_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::make_unique())); @@ -89,7 +89,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const std::function& is_sparse_initializer_func) : execution_provider_(execution_provider), is_sparse_initializer_func_(is_sparse_initializer_func) { - allocator_ptr_ = execution_provider_.GetAllocator(device_id_, mem_type_); + allocator_ptr_ = execution_provider_.GetAllocator(mem_type_); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); ORT_THROW_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::make_unique())); diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index c9aed249910d3..4f3a1cc62cbf1 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -35,7 +35,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame { ~Info() = default; AllocatorPtr GetAllocator(const OrtMemoryInfo& info) const { - return execution_provider_.GetAllocator(info.id, info.mem_type); + return execution_provider_.GetAllocator(info.mem_type); } const AllocatorPtr& GetAllocator() const { @@ -68,8 +68,6 @@ class OptimizerExecutionFrame final : public IExecutionFrame { } private: - // The optimizer is running on CPU execution provider by default. - const int device_id_{0}; const OrtMemType mem_type_{OrtMemTypeDefault}; AllocatorPtr allocator_ptr_; DataTransferManager data_transfer_mgr_; diff --git a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc index 37409086a0458..ec46893233f05 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc @@ -850,7 +850,7 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider) { // sub graph recurse will be added later - auto api_graph = MakeApiGraph(graph, execution_provider.GetAllocator(0, OrtMemTypeDefault), nullptr); + auto api_graph = MakeApiGraph(graph, execution_provider.GetAllocator(OrtMemTypeDefault), nullptr); const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); for (auto& node : api_graph->Nodes()) { diff --git a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc index 0ac7cbb8fa058..700c91ab85974 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc @@ -1040,7 +1040,7 @@ static bool HandlePad(HandlerArgs& args) { constexpr HandlerInfo pad_handler = {&FirstInput, &HandlePad}; -static bool HandleReduceOp(HandlerArgs& args) { +static bool HandleReduceOpWithArg(HandlerArgs& args) { int64_t keepdims = args.node.GetAttributeIntDefault("keepdims", 1); std::optional> axes = args.node.GetAttributeInts("axes"); @@ -1078,11 +1078,11 @@ static bool HandleReduceOp(HandlerArgs& args) { return true; } -constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOp}; - -static bool HandleReduceSum(HandlerArgs& args) { - if (args.ctx.opset < 13) { - return HandleReduceOp(args); +static bool HandleReduceOps(HandlerArgs& args) { + if ((args.node.OpType() == "ReduceSum" && args.ctx.opset < 13) || + // or all other reduce operators since opset 18 + (args.node.OpType() != "ReduceSum" && args.ctx.opset < 18)) { + return HandleReduceOpWithArg(args); } bool keepdims = args.node.GetAttributeIntDefault("keepdims", 1) != 0; @@ -1147,7 +1147,7 @@ static bool HandleReduceSum(HandlerArgs& args) { return true; } -constexpr HandlerInfo reduce_sum_handler = {&FirstInput, &HandleReduceSum}; +constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; static bool HandleSqueeze(HandlerArgs& args) { std::vector new_axes; @@ -1709,7 +1709,7 @@ static const std::unordered_map handler_ma #if !defined(USE_CUDA) && !defined(USE_ROCM) {"Resize", resize_handler}, #endif - {"ReduceSum", reduce_sum_handler}, + {"ReduceSum", reduce_op_handler}, {"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 08a37c345fde0..87b8af2afcf47 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -95,7 +95,7 @@ class WindowsThread : public EnvThread { } if (custom_create_thread_fn) { - custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, (OrtThreadWorkerFn)CustomThreadMain, local_param.get()); + custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, CustomThreadMain, local_param.get()); if (!custom_thread_handle) { ORT_THROW("custom_create_thread_fn returned invalid handle."); } @@ -217,7 +217,7 @@ class WindowsThread : public EnvThread { } #pragma warning(pop) - static void __stdcall CustomThreadMain(void* param) { + static void CustomThreadMain(void* param) { std::unique_ptr p(static_cast(param)); ORT_TRY { p->start_address(p->index, p->param); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index da9816b97b1db..d4428a4c093b2 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1418,20 +1418,12 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse return Status::OK(); } -AllocatorPtr CANNExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const { - if (mem_type == OrtMemTypeDefault) { - return IExecutionProvider::GetAllocator(info_.device_id, mem_type); - } else { - return IExecutionProvider::GetAllocator(id, mem_type); - } -} - void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manager) { OrtDevice cann_device{OrtDevice::NPU, OrtDevice::MemType::DEFAULT, info_.device_id}; OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; OrtDevice cpu_device{OrtDevice::CPU, OrtDevice::MemType::DEFAULT, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; - auto cann_alloc = IExecutionProvider::GetAllocator(cann_device.Id(), OrtMemTypeDefault); + auto cann_alloc = IExecutionProvider::GetAllocator(OrtMemTypeDefault); if (!cann_alloc) { cann_alloc = allocator_manager.GetAllocator(OrtMemTypeDefault, cann_device); @@ -1453,7 +1445,7 @@ void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage InsertAllocator(cann_alloc); } - auto cann_pinned_alloc = IExecutionProvider::GetAllocator(pinned_device.Id(), OrtMemTypeCPUOutput); + auto cann_pinned_alloc = IExecutionProvider::GetAllocator(OrtMemTypeCPUOutput); if (!cann_pinned_alloc) { cann_pinned_alloc = allocator_manager.GetAllocator(OrtMemTypeCPUOutput, pinned_device); @@ -1471,7 +1463,7 @@ void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage InsertAllocator(cann_pinned_alloc); } - auto cann_cpu_alloc = IExecutionProvider::GetAllocator(cpu_device.Id(), OrtMemTypeCPUInput); + auto cann_cpu_alloc = IExecutionProvider::GetAllocator(OrtMemTypeCPUInput); if (!cann_cpu_alloc) { cann_cpu_alloc = allocator_manager.GetAllocator(OrtMemTypeCPUInput, cpu_device); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index a05feac4b60fa..2fe4024487c49 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -45,14 +45,14 @@ class CANNExecutionProvider : public IExecutionProvider { if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(GetAllocator(info_.device_id, OrtMemTypeDefault), count_or_bytes); + return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeDefault), count_or_bytes); } template IAllocatorUniquePtr GetScratchBufferOnCANNPinned(size_t count_or_bytes) const { if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU), + return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeCPU), count_or_bytes); } @@ -84,7 +84,6 @@ class CANNExecutionProvider : public IExecutionProvider { return CANNExecutionProviderInfo::ToProviderOptions(info_); } - AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override; void RegisterAllocator(AllocatorManager& allocator_manager) override; private: diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index 69301f52aca29..77ac161e01d08 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -319,7 +319,7 @@ common::Status Loop::SetupSubgraphExecutionInfo(const SessionState& session_stat // 'cond' is first output and we need it to be on CPU so we can read the latest value const auto& cpu_allocator_info = session_state.GetExecutionProviders() .Get(onnxruntime::kCpuExecutionProvider) - ->GetAllocator(0, OrtMemTypeDefault) + ->GetAllocator(OrtMemTypeDefault) ->Info(); fetch_locations.push_back(&cpu_allocator_info); @@ -411,7 +411,7 @@ Status LoopImpl::Initialize() { // these need to be on CPU auto cpu_allocator = session_state_.GetExecutionProviders() .Get(onnxruntime::kCpuExecutionProvider) - ->GetAllocator(0, OrtMemTypeDefault); + ->GetAllocator(OrtMemTypeDefault); iter_num_mlvalue_ = MakeScalarMLValue(cpu_allocator, 0, iter_num_rank != 0); condition_mlvalue_ = MakeScalarMLValue(cpu_allocator, condition_, condition_rank != 0); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index caba0090753dc..c866226cf77c1 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -35,7 +35,7 @@ CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info, void CPUExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manager) { OrtDevice cpu_device{OrtDevice::CPU, OrtDevice::MemType::DEFAULT, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; // if EP is used in multiple inference sessions we may already have an allocator. if so use that. - auto cpu_alloc = GetAllocator(cpu_device.Id(), OrtMemTypeDefault); + auto cpu_alloc = GetAllocator(OrtMemTypeDefault); if (!cpu_alloc) { // use shared allocator if available cpu_alloc = allocator_manager.GetAllocator(OrtMemTypeDefault, cpu_device); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index b4a92019992b5..c0a75fc50b07e 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -198,12 +198,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, - extra_add_qk, + relative_position_bias, parameters, max_threads_per_block, past_seq_len); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 2490789dd31a2..f12e080adf30a 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -145,7 +145,7 @@ struct ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) = 0; diff --git a/onnxruntime/core/providers/cpu/nn/conv_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_attributes.h index 51a1e7acafe11..b31030acc52c1 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_attributes.h @@ -73,7 +73,7 @@ struct ConvAttributes { ~ConvAttributes() = default; - Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape) const { + Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape, bool weight_channels_last = false) const { if (kernel_shape_specified) { kernel_shape = kernel_shape_; if (kernel_shape.size() + 2 != weight_shape.NumDimensions()) { @@ -82,15 +82,20 @@ struct ConvAttributes { " W: ", weight_shape.ToString().c_str()); } for (size_t i = 0; i < kernel_shape.size(); ++i) { - if (kernel_shape[i] != weight_shape[i + 2]) { + if (kernel_shape[i] != weight_shape[i + (weight_channels_last ? 1 : 2)]) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", weight_shape.ToString().c_str()); + " W: ", weight_shape.ToString().c_str(), + " channels_last: ", weight_channels_last); } } } else { auto weight_dims = weight_shape.GetDims(); - kernel_shape.assign(weight_dims.begin() + 2, weight_dims.end()); + if (weight_channels_last) { + kernel_shape.assign(weight_dims.begin() + 1, weight_dims.end() - 1); + } else { + kernel_shape.assign(weight_dims.begin() + 2, weight_dims.end()); + } } return Status::OK(); @@ -98,7 +103,8 @@ struct ConvAttributes { Status ValidateInputShape(const TensorShape& input_shape, const TensorShape& weight_shape, - bool channels_last = false) const { + bool input_channels_last = false, + bool weight_channels_last = false) const { if (input_shape.NumDimensions() != weight_shape.NumDimensions()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "X num_dims does not match W num_dims.", " X: ", input_shape.ToString().c_str(), @@ -106,9 +112,9 @@ struct ConvAttributes { } const int64_t M = weight_shape[0]; - const int64_t C = channels_last ? input_shape.GetDims().back() : input_shape[1]; + const int64_t C = input_channels_last ? input_shape.GetDims().back() : input_shape[1]; - if (C != weight_shape[1] * group) { + if (C != (weight_channels_last ? weight_shape.GetDims().back() : weight_shape[1]) * group) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input channels C is not equal to kernel channels * group.", " C: ", C, " kernel channels: ", weight_shape[1], diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h index e5e641e96f7af..59b512def619d 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h @@ -414,7 +414,8 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in static_cast(output_height * 2), [&](std::ptrdiff_t first, std::ptrdiff_t last) { if (output_height == input_height) { - std::copy_n(Xdata_span.begin() + narrow(first * input_width), narrow((last - first) * output_width), + auto workload_in_thread = narrow(last) - narrow(first); + std::copy_n(Xdata_span.begin() + narrow(first * input_width), narrow(workload_in_thread * output_width), Ydata_span.begin() + narrow(first * output_width)); return; } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 7a673a05858ca..50c739021db5a 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -220,7 +220,8 @@ void OverrideTunableOpInfoByEnv(CUDAExecutionProviderInfo& info) { CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kCudaExecutionProvider}, - info_{info} { + info_{info}, + tuning_context_(this, &info_.tunable_op) { CUDA_CALL_THROW(cudaSetDevice(info_.device_id)); // must wait GPU idle, otherwise cudaGetDeviceProperties might fail @@ -271,18 +272,8 @@ CUDAExecutionProvider::~CUDAExecutionProvider() { } } -void CUDAExecutionProvider::EnableTunableOp() { - LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider"; - info_.tunable_op.enabled = true; -} - -void CUDAExecutionProvider::DisableTunableOp() { - LOGS_DEFAULT(INFO) << "Disable TunableOp for CUDA Execution Provider"; - info_.tunable_op.enabled = false; -} - -bool CUDAExecutionProvider::IsTunableOpEnabled() const { - return info_.tunable_op.enabled; +ITuningContext* CUDAExecutionProvider::GetTuningContext() const { + return const_cast(&tuning_context_); } std::unique_ptr CUDAExecutionProvider::GetProfiler() { @@ -358,9 +349,9 @@ void CUDAExecutionProvider::ReleasePerThreadContext() const { per_thread_context_cache->erase(cached_context_it); } -AllocatorPtr CUDAExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const { +AllocatorPtr CUDAExecutionProvider::GetAllocator(OrtMemType mem_type) const { if (mem_type == OrtMemTypeDefault) { - auto cuda_alloc = IExecutionProvider::GetAllocator(id, mem_type); + auto cuda_alloc = IExecutionProvider::GetAllocator(mem_type); if (!cuda_alloc) { // this means the program invoke GetAllocator before RegsiterAllocators, // which only happnens in some UTs. @@ -371,7 +362,7 @@ AllocatorPtr CUDAExecutionProvider::GetAllocator(int id, OrtMemType mem_type) co } } - return IExecutionProvider::GetAllocator(id, mem_type); + return IExecutionProvider::GetAllocator(mem_type); } Status CUDAExecutionProvider::Sync() const { @@ -2446,7 +2437,7 @@ void CUDAExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage // if EP is used in multiple inference sessions we may already have an allocator. if so use that. // NOTE: We call IExecutionProvider::GetAllocator as CUDAExecutionProvider::GetAllocator will return // a per-thread allocator for OrtMemTypeDefault. - auto cuda_alloc = IExecutionProvider::GetAllocator(cuda_device.Id(), OrtMemTypeDefault); + auto cuda_alloc = IExecutionProvider::GetAllocator(OrtMemTypeDefault); if (!cuda_alloc) { // use shared allocator if available cuda_alloc = allocator_manager.GetAllocator(OrtMemTypeDefault, cuda_device); @@ -2464,7 +2455,7 @@ void CUDAExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage // OrtMemTypeCPUOutput -- allocated by cudaMallocHost, used to copy CUDA device memory to CPU // Use pinned memory instead of pageable memory make the data transfer faster // Used by node MemcpyToHost only - auto cuda_pinned_alloc = IExecutionProvider::GetAllocator(pinned_device.Id(), OrtMemTypeCPUOutput); + auto cuda_pinned_alloc = IExecutionProvider::GetAllocator(OrtMemTypeCPUOutput); if (!cuda_pinned_alloc) { cuda_pinned_alloc = allocator_manager.GetAllocator(OrtMemTypeCPUOutput, pinned_device); @@ -2488,7 +2479,7 @@ void CUDAExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage } // OrtMemTypeCPUInput -- op place the input on CPU and will not be accessed by CUDA kernel, no sync issue - auto cuda_cpu_alloc = IExecutionProvider::GetAllocator(cpu_device.Id(), OrtMemTypeCPUInput); + auto cuda_cpu_alloc = IExecutionProvider::GetAllocator(OrtMemTypeCPUInput); if (!cuda_cpu_alloc) { cuda_cpu_alloc = allocator_manager.GetAllocator(OrtMemTypeCPUInput, cpu_device); @@ -2516,25 +2507,15 @@ void CUDAExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const { // This allocator must be the same to the allocator // used in AllocateBufferOnCPUPinned. - auto allocator = GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU); - if (use_ep_level_unified_stream_) - RegisterCudaStreamHandles(stream_handle_registry, - OrtDevice::GPU, - allocator, - !IsGraphCaptureEnabled(), - stream_, - use_ep_level_unified_stream_, - GetPerThreadContext().CudnnHandle(), - GetPerThreadContext().CublasHandle()); - else - RegisterCudaStreamHandles(stream_handle_registry, - OrtDevice::GPU, - allocator, - !IsGraphCaptureEnabled(), - stream_, - use_ep_level_unified_stream_, - GetPerThreadContext().CudnnHandle(), - GetPerThreadContext().CublasHandle()); + auto allocator = GetAllocator(OrtMemTypeCPU); + RegisterCudaStreamHandles(stream_handle_registry, + OrtDevice::GPU, + allocator, + !IsGraphCaptureEnabled(), + stream_, + use_ep_level_unified_stream_, + GetPerThreadContext().CudnnHandle(), + GetPerThreadContext().CublasHandle()); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 0bf9f75b710c1..d95cac03017eb 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -15,6 +15,7 @@ #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cuda/shared_inc/cuda_call.h" +#include "core/providers/cuda/tunable/cuda_tuning_context.h" namespace onnxruntime { @@ -26,7 +27,7 @@ class CUDAExecutionProvider : public IExecutionProvider { explicit CUDAExecutionProvider(const CUDAExecutionProviderInfo& info); virtual ~CUDAExecutionProvider(); - AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override; + AllocatorPtr GetAllocator(OrtMemType mem_type) const override; Status Sync() const override; @@ -61,7 +62,7 @@ class CUDAExecutionProvider : public IExecutionProvider { IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, Stream* stream, WaitNotificationFn wait_fn) const { if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(GetAllocator(info_.device_id, OrtMemTypeDefault), count_or_bytes, false, stream, wait_fn); + return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeDefault), count_or_bytes, false, stream, wait_fn); } template @@ -69,7 +70,7 @@ class CUDAExecutionProvider : public IExecutionProvider { if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(GetAllocator(info_.device_id, OrtMemTypeDefault), count_or_bytes, true); + return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeDefault), count_or_bytes, true); } template @@ -78,7 +79,7 @@ class CUDAExecutionProvider : public IExecutionProvider { // In some CUDA async if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU), + return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeCPU), count_or_bytes); } @@ -104,9 +105,7 @@ class CUDAExecutionProvider : public IExecutionProvider { static AllocatorPtr CreateCudaAllocator(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); - void EnableTunableOp(); - void DisableTunableOp(); - bool IsTunableOpEnabled() const; + ITuningContext* GetTuningContext() const override; std::unique_ptr GetProfiler() override; @@ -126,6 +125,9 @@ class CUDAExecutionProvider : public IExecutionProvider { bool use_ep_level_unified_stream_ = false; + // the tuning context might be altered when calling into a TunableOp + mutable cuda::tunable::CudaTuningContext tuning_context_; + class PerThreadContext final { public: PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index d5cb96b4d0148..f7d7daddeda7b 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -87,7 +87,9 @@ class CudaKernel : public OpKernel { return stream->cublas_handle_; } - bool IsTunableOpEnabled() const { return provider_->IsTunableOpEnabled(); } + tunable::CudaTuningContext* GetTuningContext() const { + return static_cast(provider_->GetTuningContext()); + } // To support cudaMemcpyAsync, the cpu memory should be allocated in pinned memory // and it can only be released after the copy has finished diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index d7858fd1c4e24..d059f93cebf1c 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -28,7 +28,7 @@ using namespace onnxruntime; namespace onnxruntime { -#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) +#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) && defined(ENABLE_TRAINING) namespace cuda { cuda::INcclService& GetINcclService(); } @@ -164,7 +164,7 @@ struct ProviderInfo_CUDA_Impl : ProviderInfo_CUDA { info = CUDAExecutionProviderInfo::FromProviderOptions(options); } -#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) +#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) && defined(ENABLE_TRAINING) cuda::INcclService& GetINcclService() override { return cuda::GetINcclService(); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.h b/onnxruntime/core/providers/cuda/cuda_provider_factory.h index 259fd199120e7..76b1693d62b9b 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.h +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.h @@ -20,7 +20,7 @@ class NvtxRangeCreator; } struct ProviderInfo_CUDA { - virtual ~ProviderInfo_CUDA() {} // This is declared due to a TSA warning, the only instantiation of this class is a global variable of automatic storage. + virtual ~ProviderInfo_CUDA() {} // This is declared due to a TSA warning, the only instantiation of this class is a global variable of automatic storage. virtual OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) = 0; virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; @@ -43,7 +43,7 @@ struct ProviderInfo_CUDA { virtual int cudaGetDeviceCount() = 0; virtual void CUDAExecutionProviderInfo__FromProviderOptions(const onnxruntime::ProviderOptions& options, onnxruntime::CUDAExecutionProviderInfo& info) = 0; -#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) +#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) && defined(ENABLE_TRAINING) virtual onnxruntime::cuda::INcclService& GetINcclService() = 0; #endif diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index b818e9b57a7b2..81d0070f1aeaf 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -102,7 +102,7 @@ struct CpuBuffersInfo { // should contain all values in // deferred_release_buffer_pool_[my_stream] // when release my_stream's buffers. - void** buffers; + std::unique_ptr buffers; // CPU buffer buffers[i]. // Number of buffer points in "buffers". size_t n_buffers; @@ -117,7 +117,6 @@ static void CUDART_CB ReleaseCpuBufferCallback(void* raw_info) { for (size_t i = 0; i < info->n_buffers; ++i) { info->allocator->Free(info->buffers[i]); } - delete[] info->buffers; } Status CudaStream::CleanUpOnRunEnd() { @@ -128,7 +127,7 @@ Status CudaStream::CleanUpOnRunEnd() { if (release_cpu_buffer_on_cuda_stream_ && cpu_allocator_->Info().alloc_type == OrtArenaAllocator) { std::unique_ptr cpu_buffers_info = std::make_unique(); cpu_buffers_info->allocator = cpu_allocator_; - cpu_buffers_info->buffers = new void*[deferred_cpu_buffers_.size()]; + cpu_buffers_info->buffers = std::make_unique(deferred_cpu_buffers_.size()); for (size_t i = 0; i < deferred_cpu_buffers_.size(); ++i) { cpu_buffers_info->buffers[i] = deferred_cpu_buffers_.at(i); } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index d62a651880a85..4c9cbbe605a7a 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -42,6 +42,12 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat return Status::OK(); } +Status CudnnTensor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w) { + ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); + CUDNN_RETURN_IF_ERROR(cudnnSetTensor4dDescriptor(tensor_, format, dataType, n, c, h, w)); + return Status::OK(); +} + Status CudnnTensor::Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); CUDNN_RETURN_IF_ERROR(cudnnDeriveBNTensorDescriptor(tensor_, x_desc, mode)); @@ -113,15 +119,23 @@ Status CudnnFilterDescriptor::Set(gsl::span filter_dims, cudnnDat return Status::OK(); } +Status CudnnFilterDescriptor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int k, int c, int h, int w) { + if (!desc_) + CUDNN_RETURN_IF_ERROR(cudnnCreateFilterDescriptor(&desc_)); + + CUDNN_RETURN_IF_ERROR(cudnnSetFilter4dDescriptor(desc_, dataType, format, k, c, h, w)); + return Status::OK(); +} + template cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN engine currently supports only single/double/half/int8/uint8 precision data types. Got:", - typeid(ElemType).name()); + typeid(ElemType).name()); // Not reachable but GCC complains return CUDNN_DATA_FLOAT; } -template<> +template <> cudnnDataType_t CudnnTensor::GetDataType() { return CUDNN_DATA_FLOAT; } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index f104373b9413a..ba75ab4f2c029 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -18,6 +18,8 @@ class CudnnTensor final { Status Set(gsl::span input_dims, cudnnDataType_t dataType); Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode); + // Set 4D tensor format (for NHWC) + Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w); operator cudnnTensorDescriptor_t() const { return tensor_; } @@ -58,6 +60,9 @@ class CudnnFilterDescriptor final { Status Set(gsl::span filter_dims, cudnnDataType_t data_typ); + // Set 4D filter where k is output channels, c is input channels, h and w is rows and columns per filter. + Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int k, int c, int h, int w); + operator cudnnFilterDescriptor_t() const { return desc_; } private: diff --git a/onnxruntime/core/providers/cuda/math/softmax.cc b/onnxruntime/core/providers/cuda/math/softmax.cc index dc1830a192945..5047a70242a5c 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.cc +++ b/onnxruntime/core/providers/cuda/math/softmax.cc @@ -26,15 +26,12 @@ Status SoftMaxComputeHelper( auto X_data = reinterpret_cast(X); if (D <= 1024 && D * sizeof(T) <= 4096) { - dispatch_warpwise_softmax_forward, is_log_softmax>( + return dispatch_warpwise_softmax_forward, is_log_softmax>( stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); - } else { - dispatch_blockwise_softmax_forward, is_log_softmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), - gsl::narrow_cast(N)); } - - return Status::OK(); + return dispatch_blockwise_softmax_forward, is_log_softmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), + gsl::narrow_cast(N)); } #define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \ diff --git a/onnxruntime/core/providers/cuda/math/softmax.h b/onnxruntime/core/providers/cuda/math/softmax.h index b2528bb0c8855..b66ad32517458 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.h +++ b/onnxruntime/core/providers/cuda/math/softmax.h @@ -18,12 +18,12 @@ Status SoftMaxComputeHelper( int64_t axis); template -void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, - int softmax_elements, int softmax_elements_stride, int batch_count); +Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, + int softmax_elements, int softmax_elements_stride, int batch_count); template -void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, - int softmax_elements, int input_stride, int output_stride, int batch_count); +Status dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, + int softmax_elements, int input_stride, int output_stride, int batch_count); template class Softmax final : public CudaKernel { diff --git a/onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh b/onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh index bb26f5fdccad6..6cb65ea8e739c 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh @@ -1,19 +1,19 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ // The code below is mostly copied from Pytorch SoftMax.cuh @@ -23,7 +23,6 @@ namespace onnxruntime { namespace cuda { -constexpr int ALIGN_BYTES = 16; const int max_threads = 1024; dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { @@ -45,33 +44,28 @@ dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { return dim3(static_cast(block_size)); } - //////////////////////////////////////////////////////////////////////////////// // Regular kernel (fast when dim_size is large; requires inner_size == 1) //////////////////////////////////////////////////////////////////////////////// - template -struct MaxFloat -{ +struct MaxFloat { __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { return ::max(max, (AccumT)v); } }; -template -struct AddFloat -{ +template +struct AddFloat { __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + (AccumT)v; } }; -template -struct SumExpFloat -{ +template +struct SumExpFloat { __device__ __forceinline__ SumExpFloat(AccumT v) - : max_k(v) {} + : max_k(v) {} __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + std::exp((AccumT)v - max_k); @@ -80,12 +74,23 @@ struct SumExpFloat const AccumT max_k; }; -template class Reduction, typename AccumT> -__device__ __forceinline__ AccumT -blockReduce(AccumT* smem, AccumT val, - const Reduction& r, - AccumT defaultVal) -{ +// One block has N(warps_per_block) warps, one warp has M(WARP_SIZE) threads. +// 1. All the threads in one block read data into shared memory. +// 2. Reduce all data to the first warp. Only the threads of warp-0 are used. Each thread in warp-0 reads data from the +// same location of every warp and computes result. For example, thread-0 computes the first data of every warp and +// writes the result into the location of data0. +// Shared memory +// ----------------------------------------------------------------------------------------------------------------------- +// | data0 | data1 | data2 | .... | dataM | ... | dataM*2 | ... | +// ----------------------------------------------------------------------------------------------------------------------- +// | | | | +// -------------------warp-0----------------------------------warp-1----------------------------------warp-2-------------- +// 3. Thread-0 reduces all data in warp-0 and writes the results into the location of data0, then return data0. + +template