diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..884d08b46 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*~ +*.o +*build*/ +*.pyc \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 03c48bb4e..76757d608 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,23 @@ project(FasterTransformer LANGUAGES CXX CUDA) find_package(CUDA 10.1 REQUIRED) -option(BUILD_TRT "Build in TensorRT mode" OFF) option(BUILD_TF "Build in TensorFlow mode" OFF) -option(BUILD_THE "Build in PyTorch eager mode" OFF) -option(BUILD_THS "Build in TorchScript class mode" OFF) +option(BUILD_PYT "Build in PyTorch TorchScript class mode" OFF) +option(BUILD_GPT "Build project with gpt" ON) # TODO Set default to OFF -if(BUILD_THS) +if(BUILD_GPT) + message(STATUS "Add DBUILD_GPT, requires MPI and NCCL") + add_definitions("-DBUILD_GPT") + set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) + find_package(MPI REQUIRED) + find_package(NCCL REQUIRED) + #if(${NCCL_VERSION} LESS 2.7) + # message(FATAL_ERROR "NCCL_VERSION ${NCCL_VERSION} is less than 2.7") + #endif() + set(CMAKE_MODULE_PATH "") # prevent the bugs for pytorch building +endif() + +if(BUILD_PYT) if(DEFINED ENV{NVIDIA_PYTORCH_VERSION}) if($ENV{NVIDIA_PYTORCH_VERSION} VERSION_LESS "20.03") message(FATAL_ERROR "NVIDIA PyTorch image is too old for TorchScript mode.") @@ -32,7 +43,11 @@ if(BUILD_THS) endif() endif() -set(CXX_STD "11" CACHE STRING "C++ standard") +if(BUILD_PYT OR BUILD_GPT) + set(CXX_STD "14" CACHE STRING "C++ standard") +else() + set(CXX_STD "11" CACHE STRING "C++ standard") +endif() set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) @@ -42,12 +57,6 @@ if(BUILD_TF AND NOT TF_PATH) message(FATAL_ERROR "TF_PATH must be set if BUILD_TF(=TensorFlow mode) is on.") endif() -set(TRT_PATH "" CACHE STRING "TensorRT path") - -if(BUILD_TRT AND NOT TRT_PATH) - message(FATAL_ERROR "TRT_PATH must be set if BUILD_TRT(=TensorRT mode) is on.") -endif() - list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64) if (${CUDA_VERSION} GREATER_EQUAL 11.0) @@ -55,44 +64,99 @@ if (${CUDA_VERSION} GREATER_EQUAL 11.0) add_definitions("-DCUDA11_MODE") endif() +# profiling +option(USE_NVTX "Whether or not to use nvtx" OFF) +if(USE_NVTX) + message(STATUS "NVTX is enabled.") + add_definitions("-DUSE_NVTX") +endif() + # setting compiler flags -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall") - -if (SM STREQUAL 80 OR - SM STREQUAL 86 OR - SM STREQUAL 70 OR - SM STREQUAL 75 OR - SM STREQUAL 61 OR - SM STREQUAL 60) -#set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"sm_${SM},compute_${SM}\\\" -rdc=true") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"sm_${SM},compute_${SM}\\\"") - if (SM STREQUAL 70 OR SM STREQUAL 75 OR SM STREQUAL 80 OR SM STREQUAL 86) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall -ldl") + +# if (SM STREQUAL 80 OR +# SM STREQUAL 86 OR +# SM STREQUAL 70 OR +# SM STREQUAL 75 OR +# SM STREQUAL 61 OR +# SM STREQUAL 60) +# #set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"sm_${SM},compute_${SM}\\\" -rdc=true") +# set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"sm_${SM},compute_${SM}\\\"") +# if (SM STREQUAL 70 OR SM STREQUAL 75 OR SM STREQUAL 80 OR SM STREQUAL 86) +# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") +# set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") +# endif() +# message("-- Assign GPU architecture (sm=${SM})") + +# else() +# set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ +# -gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \ +# -gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \ +# ") +# # -rdc=true") +# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") +# set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") +# message("-- Assign GPU architecture (sm=70,75)") +# endif() + +set(SM_SETS 52 60 61 70 75 80) +set(USING_WMMA False) +set(FIND_SM False) + +foreach(SM_NUM IN LISTS SM_SETS) + string(FIND "${SM}" "${SM_NUM}" SM_POS) + if(SM_POS GREATER -1) + if(FIND_SM STREQUAL False) + set(ENV{TORCH_CUDA_ARCH_LIST} "") + endif() + set(FIND_SM True) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM_NUM},code=\\\"sm_${SM_NUM},compute_${SM_NUM}\\\"") + + if (SM_NUM STREQUAL 70 OR SM_NUM STREQUAL 75 OR SM_NUM STREQUAL 80 OR SM_NUM STREQUAL 86) + set(USING_WMMA True) + endif() + + if(BUILD_PYT) + string(SUBSTRING ${SM_NUM} 0 1 SM_MAJOR) + string(SUBSTRING ${SM_NUM} 1 1 SM_MINOR) + set(ENV{TORCH_CUDA_ARCH_LIST} "$ENV{TORCH_CUDA_ARCH_LIST}\;${SM_MAJOR}.${SM_MINOR}") + endif() + + set(CMAKE_CUDA_ARCHITECTURES ${SM_NUM}) + message("-- Assign GPU architecture (sm=${SM_NUM})") endif() -if(BUILD_THE OR BUILD_THS) - string(SUBSTRING ${SM} 0 1 SM_MAJOR) - string(SUBSTRING ${SM} 1 1 SM_MINOR) - set(ENV{TORCH_CUDA_ARCH_LIST} "${SM_MAJOR}.${SM_MINOR}") +endforeach() + +if(USING_WMMA STREQUAL True) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") + message("-- Use WMMA") endif() -message("-- Assign GPU architecture (sm=${SM})") -else() -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ - -gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \ - -gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \ - ") -# -rdc=true") -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") -if(BUILD_THE OR BUILD_THS) - set(ENV{TORCH_CUDA_ARCH_LIST} "7.0;7.5") +if(NOT (FIND_SM STREQUAL True)) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ + -gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \ + -gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \ + -gencode=arch=compute_80,code=\\\"sm_80,compute_80\\\" \ + ") + # -rdc=true") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") + if(BUILD_PYT) + set(ENV{TORCH_CUDA_ARCH_LIST} "7.0;7.5;8.0") + endif() + set(CMAKE_CUDA_ARCHITECTURES 70 75 80) + message("-- Assign GPU architecture (sm=70,75,80)") endif() -message("-- Assign GPU architecture (sm=70,75)") + +if(BUILD_PYT) + set(TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST}) endif() set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0") @@ -128,21 +192,14 @@ if(BUILD_TF) list(APPEND COMMON_LIB_DIRS ${TF_PATH}) endif() -if(BUILD_TRT) - list(APPEND COMMON_HEADER_DIRS ${TRT_PATH}/include) - list(APPEND COMMON_LIB_DIRS ${TRT_PATH}/lib) -endif() - set(PYTHON_PATH "python" CACHE STRING "Python path") -if(BUILD_THS) +if(BUILD_PYT) execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import torch; print(torch.__version__,end='');" RESULT_VARIABLE _PYTHON_SUCCESS OUTPUT_VARIABLE TORCH_VERSION) if (TORCH_VERSION VERSION_LESS "1.5.0") message(FATAL_ERROR "PyTorch >= 1.5.0 is needed for TorchScript mode.") endif() -endif() -if(BUILD_THE OR BUILD_THS) execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import os; import torch; print(os.path.dirname(torch.__file__),end='');" RESULT_VARIABLE _PYTHON_SUCCESS @@ -152,34 +209,25 @@ print(os.path.dirname(torch.__file__),end='');" endif() list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR}) find_package(Torch REQUIRED) - execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; from distutils import sysconfig; -print(sysconfig.get_python_inc()); -print(sysconfig.get_config_var('SO'));" +print(sysconfig.get_python_inc());" RESULT_VARIABLE _PYTHON_SUCCESS - OUTPUT_VARIABLE _PYTHON_VALUES) + OUTPUT_VARIABLE PY_INCLUDE_DIR) if (NOT _PYTHON_SUCCESS MATCHES 0) message(FATAL_ERROR "Python config Error.") endif() - string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES}) - string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES}) - list(GET _PYTHON_VALUES 0 PY_INCLUDE_DIR) - list(GET _PYTHON_VALUES 1 PY_SUFFIX) list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR}) - - execute_process(COMMAND ${PYTHON_PATH} "-c" "from torch.utils import cpp_extension; print(' '.join(cpp_extension._prepare_ldflags([],True,False)),end='');" - RESULT_VARIABLE _PYTHON_SUCCESS - OUTPUT_VARIABLE TORCH_LINK) - if (NOT _PYTHON_SUCCESS MATCHES 0) - message(FATAL_ERROR "PyTorch link config Error.") - endif() endif() +list(APPEND COMMON_HEADER_DIRS ${MPI_INCLUDE_PATH}) include_directories( ${COMMON_HEADER_DIRS} ) +# set path of mpi +list(APPEND COMMON_LIB_DIRS /usr/local/mpi/lib) + link_directories( ${COMMON_LIB_DIRS} ) @@ -196,7 +244,7 @@ if(BUILD_TF) ) endif() -if(BUILD_THE OR BUILD_THS) +if(BUILD_PYT) add_custom_target(copy ALL COMMENT "Copying pytorch test scripts") add_custom_command(TARGET copy POST_BUILD @@ -205,3 +253,110 @@ if(BUILD_THE OR BUILD_THS) COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/utils/translation/test.* ${PROJECT_BINARY_DIR}/pytorch/translation/data/ ) endif() + +######################################## + +if(BUILD_GPT) +# Following feature requires cmake 3.15 +# TODO Remove this part or modify such that we can run it under cmake 3.10 +cmake_minimum_required(VERSION 3.15 FATAL_ERROR) +add_library(transformer-static STATIC + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $) +set_property(TARGET transformer-static PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET transformer-static PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(transformer-static PUBLIC -lcublas -lcudart -lcurand -lnccl -lmpi nvtx_utils) + +add_library(transformer-shared SHARED + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $) +## add_library(transformer-shared SHARED $) +set_target_properties(transformer-shared PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(transformer-shared PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) +set_target_properties(transformer-shared PROPERTIES LINKER_LANGUAGE CXX) +target_link_libraries(transformer-shared PUBLIC ${NCCL_LIBRARIES} ${MPI_LIBRARIES} -lcublas -lcublasLt -lcudart -lcurand ) + +include(GNUInstallDirs) +set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/FasterTransformer) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + ${CMAKE_CURRENT_LIST_DIR}/cmake/FasterTransformerConfig.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerConfig.cmake + INSTALL_DESTINATION ${INSTALL_CONFIGDIR} +) + +install( + FILES + ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerConfig.cmake + DESTINATION ${INSTALL_CONFIGDIR} +) + +install( + TARGETS + transformer-shared + EXPORT + transformer-shared-targets + LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib +) + +install( + EXPORT + transformer-shared-targets + FILE + FasterTransformerTargets.cmake + DESTINATION + ${INSTALL_CONFIGDIR} +) + +file(GLOB_RECURSE HEADER_FILES "*.h" "*.hpp" "*.cuh") +foreach ( file ${HEADER_FILES} ) + file( RELATIVE_PATH rfile ${CMAKE_CURRENT_SOURCE_DIR} ${file} ) + get_filename_component( dir ${rfile} DIRECTORY ) + install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir} ) +endforeach() + + +################################################################################ +add_executable(gpt sample/cpp/gpt_sample.cc ) +target_link_libraries(gpt PUBLIC -lcublas -lcublasLt -lcudart -lcurand -lnccl -lmpi transformer-static) +# target_link_libraries(gpt PUBLIC -lcublas -lcublasLt -lcudart -lcurand -lnccl -lmpi decoder decoding) + +export( + EXPORT + transformer-shared-targets + FILE + ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerTargets.cmake + NAMESPACE + TritonCore:: +) + +export(PACKAGE FasterTransformer) + +endif() # BUILD_GPT diff --git a/README.md b/README.md index df995575d..201b7ef0e 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ This repository provides a script and recipe to run the highly optimized transfo - [FasterTransformer](#fastertransformer) - [Table Of Contents](#table-of-contents) - [Model overview](#model-overview) + - [Architecture matrix](#architecture-matrix) - [Configuration support matrix](#configuration-support-matrix) - [Setup](#setup) - [Requirements](#requirements) @@ -15,53 +16,69 @@ This repository provides a script and recipe to run the highly optimized transfo - [Execute the encoder demos](#execute-the-encoder-demos) - [Execute the decoder/decoding demos](#execute-the-decoderdecoding-demos) - [Translation demos](#translation-demos) + - [GPT demo](#gpt-demo) - [Advanced](#advanced) - [Scripts and sample codes](#scripts-and-sample-codes) - [Command-line options](#command-line-options) - [Inference process](#inference-process) - [Performance](#performance) - [Encoder performance](#encoder-performance) - - [Encoder performances of FasterTransformer new features on cpp](#encoder-performances-of-fastertransformer-new-features-on-cpp) + - [Encoder performances of FasterTransformer new features](#encoder-performances-of-fastertransformer-new-features) - [Encoder performance on TensorFlow](#encoder-performance-on-tensorflow) - [Encoder performance on PyTorch](#encoder-performance-on-pytorch) - [Decoding and Decoder performance](#decoding-and-decoder-performance) - - [Decoder and Decoding performance on TensorFlow](#decoder-and-decoding-performance-on-tensorflow) - - [Decoder and decoding performance on PyTorch](#decoder-and-decoding-performance-on-pytorch) - - [TensorFlow performance on translation](#tensorflow-performance-on-translation) - - [PyTorch performance on translation](#pytorch-performance-on-translation) + - [Decoder and Decoding end-to-end translation performance on TensorFlow](#decoder-and-decoding-end-to-end-translation-performance-on-tensorflow) + - [Decoder and Decoding end-to-end translation performance on PyTorch](#decoder-and-decoding-end-to-end-translation-performance-on-pytorch) + - [GPT performance](#gpt-performance) - [Release notes](#release-notes) - [Changelog](#changelog) - [Known issues](#known-issues) + - [TODO](#todo) ## Model overview - In NLP, encoder and decoder are two important components, with the transformer layer becoming a popular architecture for both components. FasterTransformer implements a highly optimized transformer layer for both the encoder and decoder for inference. On Volta, Turing and Ampere GPUs, the computing power of Tensor Cores are used automatically when the precision of the data and weights are FP16. -In FasterTransformer 1.0, we implemented a highly optimized BERT transformer layer, which is used in the encoder. +FasterTransformer v1.0 provides a highly optimized BERT equivalent Transformer layer for inference, including C++ API, TensorFlow op and TensorRT plugin. The experiments show that FasterTransformer v1 can provide 1.3 ~ 2 times speedup on NVIDIA Tesla T4 and NVIDIA Tesla V100 for inference. + +In FasterTransformer v2.0, we have added a highly optimized decoder and decoding models based on OpenNMT-TF, an open-source library. Here, the decoder is the model that contains some transformer layers. On the other hand, decoding refers to the whole translating process, including the lookup embedding table, position encoding, a decoder and beam search. -In FasterTransformer 2.0, we have added a highly optimized decoder and decoding models based on OpenNMT-TF, an open-source library. Here, the decoder is the model that contains some transformer layers. On the other hand, decoding refers to the whole translating process, including the lookup embedding table, position encoding, a decoder and beam search. +In FasterTransformer v2.1, we add some important features. First one is the supporting on PyTorch. Recently, there are more and more PyTorch users. We hope the users of PyTorch can also use the FasterTransformer in their application and research. The second feature is the supporting of [Effective Transformer](https://github.com/bytedance/effective_transformer). This idea is proposed by ByteDance. We call this feature as Effective FasterTransformer It removes the useless padding of encoder input to reduce the computing cost. Third, in addition to decoding with beam search, we also provide the decoding with sampling module. Finally, we optimize many kernels of encoder, decoder and beam search to improve the speed of FasterTransformer. -In FasterTransformer 2.1, we add some important features. First one is the supporting on PyTorch. Recently, there are more and more PyTorch users. We hope the users of PyTorch can also use the FasterTransformer in their application and research. The second feature is the supporting of [Effective Transformer](https://github.com/bytedance/effective_transformer). This idea is proposed by ByteDance. We call this feature as Effective FasterTransformer It removes the useless padding of encoder input to reduce the computing cost. Third, in addition to decoding with beam search, we also provide the decoding with sampling module. Finally, we optimize many kernels of encoder, decoder and beam search to improve the speed of FasterTransformer. +In FasterTransformer v3.0, we implemented the INT8 quantization for encoder (also supporting Effective FasterTransformer). With INT8 quantization, we can take advantage of the powerful INT8 tensor core in Turing GPU to achieve better inference performance (INT8 quantization in FT 3.0 is only supported on device with SM >= 7.5). We also provide quantization tools of tensorflow. -In FasterTransformer 3.0, we implemented the INT8 quantization for encoder (also supporting Effective FasterTransformer). With INT8 quantization, we can take advantage of the powerful INT8 tensor core in Turing GPU to achieve better inference performance (INT8 quantization in FT 3.0 is only supported on device with SM >= 7.5). We also provide quantization tools of tensorflow. +In FasterTransformer v3.1, we provide following new features and enhancements. First, we optimize the INT8 kernel of encoder to achieve better performance. Compare to FasterTransformer v3.0, the performance of INT8 quantization brings at most 1.75x speedup. Second, we provide a PyTorch tool to let user be able to train a INT8 quantized model on PyTorch. Besides, FasterTransformer also starts to support the INT8 inference with PyTorch op. So, the users of PyTorch can leverage the INT8 inference. Third, we integrate the fused multi-head attention kernel of TensorRT plugin into FasterTransformer to improve the speed of encoder on Turing and new GPUs. This optimization can bring about 10% ~ 20% speedup compare to original implementation. Finally, we add the supporting of GPT-2 model, which is an important and popular model for decoder. -In FasterTransformer 3.1, we provide following new features and enhancements. First, we optimize the INT8 kernel of encoder to achieve better performance. Compare to FasterTransformer 3.0, the performance of INT8 quantization brings at most 1.75x speedup. Second, we provide a PyTorch tool to let user be able to train a INT8 quantized model on PyTorch. Besides, FasterTransformer also starts to support the INT8 inference with PyTorch op. So, the users of PyTorch can leverage the INT8 inference. Third, we integrate the fused multi-head attention kernel of TensorRT plugin into FasterTransformer to improve the speed of encoder on Turing and new GPUs. This optimization can bring about 10% ~ 20% speedup compare to original implementation. Finally, we add the supporting of GPT-2 model, which is an important and popular model for decoder. +In FasterTransformer v4.0, we provide the multi-nodes multi-gpu inference for GPT model. Compare to usual framework to train giant model like Megatron, FasterTransformer provides 1.2x ~ 3x speedup. Besides, integrating the INT8 fused multi-head attention kernel of TensorRT plugin to further improve the performance of FasterTransformer encoder on INT8. We also add supporting of FP16 fused multi-head attention kernel for V100. Finally, we optimize the decoding module. Compare to v3.1, v4.0 provides at most 2x speedup. The following graph demonstrates the model architecture.
Fig. 1 Encoder-Decoding model architecture.
+ FasterTransformer is built on top of CUDA, cuBLAS and cuBLASLt, providing the C++ API and TensorFlow/PyTorch OPs. Users can integrate them into TensorFlow, PyTorch, or other inference service codes that are built in native C++. We also provide some simple sample code to demonstrate how to use the encoder, decoder and to carry out decoding in C++, TensorFlow and PyTorch. -More details are in [`docs/encoder_guide.md`](docs/encoder_guide.md) and [`docs/decoder_guide.md`](docs/decoder_guide.md). Some common questions and the respective answers are put in [`docs/QAList.md`](docs/QAList.md) +More details are in [`docs/encoder_guide.md`](docs/encoder_guide.md), [`docs/decoder_guide.md`](docs/decoder_guide.md) and [`docs/gpt_guide.md`](docs/gpt_guide.md). Some common questions and the respective answers are put in [`docs/QAList.md`](docs/QAList.md) + +### Architecture matrix + +The following matrix shows the architecture differences between the model. + +| Architecure | Encoder | Encoder INT8
quantization | Decoder | Decoding with
beam search | Decoding with
sampling | GPT-2 | GPT-3 | +|-------------|---------|---------------------------------|---------|---------------------------------|------------------------------|-------|-------| +| v1 | Yes | No | No | No | No | No | No | +| v2 | Yes | No | Yes | Yes | No | No | No | +| v2.1 | Yes | No | Yes | Yes | Yes | No | No | +| v3.0 | Yes | Yes | Yes | Yes | Yes | No | No | +| v3.1 | Yes | Yes | Yes | Yes | Yes | Yes | No | +| v4.0 | Yes | Yes | Yes | Yes | Yes | Yes | Yes | ### Configuration support matrix -The following configurations are supported in the FasterTransformer encoder. +The following configurations are supported in the FasterTransformer encoder. - Batch size (B1): smaller or equal to 4096 -- Sequence length (S): smaller or equal to 1024. For INT8 data type, sequence length should be a multiple of 32. +- Sequence length (S): smaller or equal to 1024. - Head number (H) and size per head (N): - 16 heads * 64 per heads - 12 heads * 64 per heads @@ -90,8 +107,7 @@ The following section lists the requirements to use FasterTransformer. - CUDA 10.1 or newer version - Python 3 is recommended because some features are not supported in python 2 - Tensorflow 1.13 or 1.14 or 1.15 -- PyTorch >= 1.4.0 -- TensorRT 5 or newer version +- PyTorch >= 1.5.0 These components are readily available within the NGC TensorFlow/PyTorch Docker image below. @@ -118,25 +134,22 @@ The following section shows how to use FasterTransformer on the NGC container. You can choose the tensorflow version and python version you want. Here, we list some possible images: - - `nvcr.io/nvidia/tensorflow:19.06-py3` contains the TensorFlow 1.13 and python 3.5. - `nvcr.io/nvidia/tensorflow:19.07-py2` contains the TensorFlow 1.14 and python 2.7. - - `nvcr.io/nvidia/tensorflow:20.07-tf1-py3` contains the TensorFlow 1.15 and python 3.6. - - `nvcr.io/nvidia/tensorrt:20.03-py3` contains the TensorRT 7.0.0 and python 3.6. - - `nvcr.io/nvidia/pytorch:20.01-py3` contains the PyTorch 1.4.0 and python 3.6 + - `nvcr.io/nvidia/tensorflow:20.12-tf1-py3` contains the TensorFlow 1.15 and python 3.8. - `nvcr.io/nvidia/pytorch:20.03-py3` contains the PyTorch 1.5.0 and python 3.6 - `nvcr.io/nvidia/pytorch:20.07-py3` contains the PyTorch 1.6.0 and python 3.6 + - `nvcr.io/nvidia/pytorch:20.12-py3` contains the PyTorch 1.8.0 and python 3.8 - For example, running image `nvcr.io/nvidia/tensorflow:19.07-py2` by + To achieve best performance, we recommand to use the latest image. For example, running image `nvcr.io/nvidia/tensorflow:20.12-tf1-py3` by ```bash - nvidia-docker run -ti --rm nvcr.io/nvidia/tensorflow:19.07-py2 bash + nvidia-docker run -ti --rm nvcr.io/nvidia/tensorflow:20.12-tf1-py3 bash ``` 2. Clone the repository. ```bash - git clone https://github.com/NVIDIA/DeepLearningExamples - cd DeepLearningExamples/FasterTransformer/v3.1 + git clone https://github.com/NVIDIA/FasterTransformer.git mkdir -p build cd build ``` @@ -150,92 +163,36 @@ The following section shows how to use FasterTransformer on the NGC container. make ``` - Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100). 3.2 build with TensorFlow - * `nvcr.io/nvidia/tensorflow:19.06-py3` - - First, update the cmake to cmake 3.8 or later version, and then build the project by the following scripts. - - ```bash - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.5/dist-packages/tensorflow .. - make - ``` - - Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). - - * `nvcr.io/nvidia/tensorflow:19.07-py2` - - First, link the `libtensorflow_framework.so`, and then build the project by the following scripts. - - ```bash - ln -s /usr/local/lib/python2.7/dist-packages/tensorflow/libtensorflow_framework.so.1 /usr/local/lib/python2.7/dist-packages/tensorflow/libtensorflow_framework.so - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python2.7/dist-packages/tensorflow .. - make - ``` - - Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). - - * `nvcr.io/nvidia/tensorflow:20.07-tf1-py3` - - First, link the `libtensorflow_framework.so`, and then build the project by the following scripts. + Uses need to set the path of TensorFlow. For example, if we use `nvcr.io/nvidia/tensorflow:20.12-tf1-py3`, then ```bash - ln -s /usr/local/lib/python3.6/dist-packages/tensorflow_core/libtensorflow_framework.so.1 /usr/local/lib/python3.6/dist-packages/tensorflow_core/libtensorflow_framework.so - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.6/dist-packages/tensorflow_core/ .. + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TF=ON -DTF_PATH=/usr/local/lib/python3.8/dist-packages/tensorflow_core/ .. make ``` - Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). - - 3.3 build with TensorRT - - * `nvcr.io/nvidia/tensorrt:20.03-py3` - - ```bash - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_TRT=ON -DTRT_PATH=/opt/tensorrt/ .. - make - ``` - - Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). - - 3.4 build with PyTorch - - * `nvcr.io/nvidia/pytorch:20.01-py3` - - ```bash - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_THE=ON .. - make - ``` + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100). - * `nvcr.io/nvidia/pytorch:20.03-py3` or later + 3.3 build with PyTorch ```bash - cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_THE=ON -DBUILD_THS=ON -DCXX_STD=14 .. + cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON .. make ``` - Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4). (You can ignore this variable.) + Note: `xx` is the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100). - `-DBUILD_THE=ON` is to build the regular PyTorch extension for eager mode. It is not compatible with TorchScript, but it may be compatible with more PyTorch versions. + This will build the TorchScript custom class. Please make sure that the `PyTorch >= 1.5.0`. - `-DBUILD_THS=ON` is to build the TorchScript custom class. If you want to use this custom class, please make sure that the `PyTorch >= 1.5.0`. - - ***You can choose one of them or all. No need to add all options.*** - - For `PyTorch == 1.4.0`, please use C++11, that is, `-DCXX_STD=11` or just ignore this variable. - - For `PyTorch >= 1.5.0`, please use C++14, that is, `-DCXX_STD=14`. - - Note: From `FasterTransformer 3.1`, TorchScript custom op (function type) is deprecated. + Note: From `FasterTransformer 3.1`, TorchScript custom op (function type) is deprecated. From `FasterTransformer 4.0`, Eager mode PyTorch extension is deprecated. ### Execute the encoder demos 1. Run FasterTransformer encoder on C++ - - ```bash ./bin/encoder_gemm ./bin/encoder_sample @@ -261,10 +218,10 @@ The following section shows how to use FasterTransformer on the NGC container. | feature | int8_mode == 1 | int8_mode == 2 | - |:-------:|:-------------:|:-------------:| - | quantize residual | No | Yes | - | int8 output gemm | No | Yes | - | per-channel quantiztion for weights | Yes | No | + |:-------:|:--------------:|:--------------:| + | quantize residual | No | Yes | + | int8 output gemm | No | Yes | + | per-channel quantiztion for weights | Yes | No | ```bash #For int8_mode == 1 @@ -311,6 +268,8 @@ The following section shows how to use FasterTransformer on the NGC container. --allow_gemm_test False ``` + If use sets `--test_time 1`, the program will show the performance of TensorFlow, FasterTransformer and FasterTransformer with removing padding. + 2.2 Run FasterTransformer encoder under FP16 on TensorFlow ```bash @@ -355,56 +314,6 @@ The following section shows how to use FasterTransformer on the NGC container. --allow_gemm_test False ``` - 2.4 Run Effective FasterTransformer under FP32 on TensorFlow - - ```bash - ./bin/encoder_gemm 32 32 12 64 0 0 - python tensorflow/encoder_sample.py \ - --batch_size 32 \ - --max_seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp32 \ - --test_time 1 \ - --remove_padding True \ - --avg_seq_len 16 \ - --allow_gemm_test False - ``` - - 2.5 Run Effective FasterTransformer under INT8 on TensorFlow - ```bash - #For int8_mode == 1 - ./bin/encoder_gemm 32 32 12 64 1 1 - python tensorflow/encoder_sample.py \ - --batch_size 32 \ - --max_seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp16 \ - --test_time 1 \ - --remove_padding True \ - --avg_seq_len 16 \ - --int8_mode 1 \ - --allow_gemm_test False - - #For int8_mode == 2 - ./bin/encoder_gemm 32 32 12 64 1 2 - python tensorflow/encoder_sample.py \ - --batch_size 32 \ - --max_seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp16 \ - --test_time 1 \ - --remove_padding True \ - --avg_seq_len 16 \ - --int8_mode 2 \ - --allow_gemm_test False - ``` - 3. Run FasterTransformer on PyTorch Please install HuggingFace's `transformers` first before run the demos by @@ -438,29 +347,6 @@ The following section shows how to use FasterTransformer on the NGC container. python pytorch/encoder_sample.py 32 12 32 12 64 --int8_mode 2 --time ``` - 3.4 Run Effective FasterTransformer under FP32 on PyTorch - - ```bash - ./bin/encoder_gemm 32 32 12 64 0 0 - python pytorch/encoder_sample.py 32 12 32 12 64 --time --remove_padding - ``` - -4. Run FasterTransformer on TensorRT - - 4.1 Run FasterTransformer under FP32 on TensorRT - - ```bash - ./bin/encoder_gemm 32 32 12 64 0 0 - ./bin/transformer_trt 32 12 32 12 64 fp32 - ``` - - 4.2 Run FasterTransformer under FP16 on TensorRT - - ```bash - ./bin/encoder_gemm 32 32 12 64 1 0 - ./bin/transformer_trt 32 12 32 12 64 fp16 - ``` - ### Execute the decoder/decoding demos 1. Run FasterTransformer decoding on C++ @@ -716,6 +602,54 @@ The following section shows how to use FasterTransformer on the NGC container. python pytorch/run_translation.py --batch_size 128 --beam_size 4 --model_type decoding_ext --data_type fp16 ``` +### GPT demo + +Here, we demonstrate how to run Fastertransformer on Megatron model with C++ and PyTorch api. More details are in [`docs/gpt_guide.md`](docs/gpt_guide.md). + +1. Prepare + +```bash +pip install -r ../requirement.txt +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -P models +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -P models +wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip +mkdir -p models/megatron-models/345m +unzip megatron_lm_345m_v0.0.zip -d models/megatron-models/345m +git clone https://github.com/NVIDIA/Megatron-LM.git +python ../sample/pytorch/utils/megatron_ckpt_convert.py -i ./models/megatron-models/345m/release/ -o ./models/megatron-models/c-model/345m/ -t_g 1 -i_g 1 +``` + +Note that there are different checkpoint version of Megatron. The version of the checkpoint above is 0. If users have trained a model by themselves, the default version of latest Megatron is 3. To convert the checkpoint with version 3, please add `-checkpoint_version 3`. + +2. Run GPT + + 2.1 Run on C++ + + Users can see the details of arguments in `sample/cpp/gpt_config.ini`. It controls the model path, model size, tensor parallelism size, and some hyper-parameters. And then run gpt by following script: + + ```bash + ./bin/gpt_sample + python ../sample/pytorch/utils/convert_gpt_token.py --vocab_file=./models/gpt2-vocab.json --bpe_file=./models/gpt2-merges.txt + ``` + + The following script run multi-gpus (Note that users need to modify the `gpt_config.ini`. For example, set `tensor_para_size` to 8.) + + ```bash + mpirun -n 8 ./bin/gpt_sample + python ../sample/pytorch/utils/convert_gpt_token.py --vocab_file=./models/gpt2-vocab.json --bpe_file=./models/gpt2-merges.txt + ``` + + 2.2 Run on Pytorch + + ```bash + # No parallelism (tensor_para_size=1, layer_para_size=1) + mpirun -n 1 --allow-run-as-root python ./pytorch/gpt_sample.py + + # TP (tensor_para_size=8, layer_para_size=1) + mpirun -n 8 --allow-run-as-root python ./pytorch/gpt_sample.py --tensor_para_size=8 --layer_para_size=1 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/8-gpu" + ``` + + ## Advanced The following sections provide greater details. @@ -729,13 +663,13 @@ The following code lists the directory structure of FasterTransformer: |--/cuda: some CUDA kernels and multi-head attention implementation, both are compiled with cuda/cuBLAS/cuBLASLt. |--/tf_op: custom Tensorflow OP implementation |--/th_op: custom PyTorch OP implementation - |--/trt_plugin: TensorRT plugin implementation + |--/triton_backend: custom triton backend implementation + |--/trt_fused_multihead_attention: fused multihead attention kernels of TensorRT /sample: C++ and tensorflow transformer interface samples |--/cpp: C++ interface samples |--/pytorch: PyTorch OP samples |--/tensorflow: TensorFlow OP samples |--/tensorflow_bert: samples that show of how to integrate our Tensorflow OP into the open source BERT model for sentence (and sentence-pair) classification tasks (GLUE), the samples support both FP16 and FP32, see readme file within this folder more details - |--/tensorRT: both FP16 and FP32 tensorRT plugin samples /tools/gemm_test: loop over all GEMM algorithms to pick the best one /bert-quantization/ |--bert-tf-quantization: TensorFlow quantization tool and sample codes @@ -758,7 +692,7 @@ The `fastertransformer/` folder encapsulates all the source codes of FasterTrans * `open_decoder.h` - Contains the decoder transformer layer * `decoding_beamsearch.h` - Contains the progress of decoding with beam search * `decoding_sampling.h` - Contains the progress of decoding with beam search -* `gpt2.h` - Contains the progress of GPT-2 +* `gpt.h` - Contains the progress of GPT The `tools/` folder contains the tools to generate the GEMM configuration of FasterTransformer for different settings: * `tools/gemm_test/encoder_gemm.cc` - Encoder GEMM config @@ -768,16 +702,13 @@ The `sample/` folder contains useful sample codes for FasterTransformer: * `sample/cpp/encoder_sample.cc` - C encoder sample codes * `sample/cpp/decoding_beamsearch_sample.cc` - C decoding with beam search sample codes * `sample/cpp/decoding_sampling_sample.cc` - C decoding with sampling sample codes -* `sample/cpp/gpt2_sample.cc` - C GPT-2 codes +* `sample/cpp/gpt_sample.cc` - C GPT codes * `sample/tensorflow/encoder_sample.py` - TensorFlow encoder sample codes -* `sample/tensorflow/encoder_sample_int8.py` - TensorFlow encoder sample codes for INT8 * `sample/tensorflow/decoder_sample.py` - TensorFlow decoder sample codes * `sample/tensorflow/decoding_sample.py` - TensorFlow decoding sample codes * `sample/tensorflow/tensorflow_bert/` - TensorFlow using FasterTransformer in BERT sample codes -* `sample/tensorflow/encoder_decoder_sample.py` - TensorFlow `encoder_decoder` sample codes -* `sample/tensorflow/encoder_decoding_sample.py` - TensorFlow `encoder_decoding` sample codes * `sample/tensorflow/translate_sample.py` - TensorFlow translation sample codes -* `sample/tensorflow/gpt2_sample.py` - TensorFlow GPT-2 sample codes +* `sample/tensorflow/gpt_sample.py` - TensorFlow GPT sample codes * `sample/pytorch/encoder_sample.py` - PyTorch encoder sample codes * `sample/pytorch/decoder_sample.py` - PyTorch decoder sample codes * `sample/pytorch/decoding_sample.py` - PyTorch decoding sample codes @@ -791,11 +722,8 @@ To see the full list of available options and their descriptions, use the `-h` o ```bash python tensorflow/encoder_sample.py --help -python tensorflow/encoder_sample_int8.py --help python tensorflow/decoder_sample.py --help python tensorflow/decoding_sample.py --help -python tensorflow/encoder_decoder_sample.py --help -python tensorflow/encoder_decoding_sample.py --help python tensorflow/translate_sample.py --help ``` @@ -805,9 +733,10 @@ This subsection provides the details about how to use the encoder, the decoder a ## Performance -Hardware settings: +Hardware settings: + +* 8xA100-80GBs (with mclk 1593MHz, pclk 1410MHz) with AMD EPYC 7742 64-Core Processor * T4 (with mclk 5000MHz, pclk 1590MHz) with Intel(R) Xeon(R) CPU E5-2670 0 @ 2.60GHz -* V100 (with mclk 877MHz, pclk 1380MHz) with Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz (dgx-1 server) In order to run the following benchmark, we need to install the unix computing tool "bc" by @@ -817,19 +746,13 @@ apt-get install bc ### Encoder performance -We demonstrate the inference time of FasterTransformer in C++, TensorFlow and PyTorch, and compare to the performance of pure TensorFlow and PyTorch on T4 with FP16 and INT8. Besides, we also show the performance of Effective FasterTransformer on T4 with FP16. Note that the total sequence length of Effective FasterTransformer is not fixed, so we use the default gemm configuration to run the benchmark. - -For the benchmark of TensorFlow, we compare the performance of TensorFlow with XLA (TF), the performance of TensorFlow with FasterTransformer OP (FT-OP) and the performance of FasterTransformer on C++ (TF-CPP), and show the speedup of FT-OP and FT-CPP compare to the TensorFlow. +The FP16 results of TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_encoder_performance.sh`. -For the benchmark of PyTorch, we compare the performance of PyTorch, and performance of TorchScript and the performance of PyTorch with FasterTransformer custom extension (CustomExt) and show the speedup of CustomExt compare to the PyTorch and TorchScript. Because CustomExt has no obvious overhead compare to the FasterTransformer on C++, we skip the comparison with the C++ implementation. +The INT8 results of TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_encoder_performance_int8.sh`. -The FP16 results of C++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_encoder_performance.sh` and `sample/tensorflow/scripts/profile_effective_transformer_performance.sh`. +The FP16 results of PyTorch were obtained by running the `sample/pytorch/scripts/profile_encoder.sh`. -The INT8 results of C++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_encoder_performance_int8.sh` and `sample/tensorflow/scripts/profile_effective_transformer_performance_int8.sh`. - -The FP16 results of PyTorch were obtained by running the `sample/pytorch/scripts/profile_encoder.sh` and `sample/pytorch/scripts/profile_encoder_effective_transformer.sh`. - -The INT8 results of PyTorch were obtained by running the `sample/pytorch/scripts/profile_encoder_int8.sh` and `sample/pytorch/scripts/profile_encoder_effective_transformer_int8.sh`. +The INT8 results of PyTorch were obtained by running the `sample/pytorch/scripts/profile_encoder_int8.sh`. In the experiments of encoder, we updated the following parameters: @@ -839,11 +762,11 @@ In the experiments of encoder, we updated the following parameters: More benchmarks are put in [`docs/encoder_guide.md`](docs/encoder_guide.md#encoder-performance). -#### Encoder performances of FasterTransformer new features on cpp +#### Encoder performances of FasterTransformer new features The following figure compares the performances of different features of FasterTransformer and FasterTransformer under FP16 on T4. -For large batch size and sequence length, Effective FasterTransformer brings about 2x speedup and int8v2 brings about 1.75x speedup. Using Effective FasterTransformer and int8v2 at the same time can bring more than 3x speedup compared to FasterTransformer FP16. +For large batch size and sequence length, both EFF-FT and FT-INT8-v2 bring about 2x speedup. Using Effective FasterTransformer and int8v2 at the same time can bring about 3.5x speedup compared to FasterTransformer FP16 for large case.
@@ -851,103 +774,86 @@ For large batch size and sequence length, Effective FasterTransformer brings abo The following figure compares the performances of different features of FasterTransformer and TensorFlow XLA under FP16 on T4. -For small batch size and sequence length, using FasterTransformer CPP can bring about 3x ~ 6.5x speedup. +For small batch size and sequence length, using FasterTransformer can bring about 3x speedup. -For large batch size and sequence length, using Effective FasterTransformer with INT8 quantization can bring about 4x speedup. +For large batch size and sequence length, using Effective FasterTransformer with INT8-v2 quantization can bring about 5x speedup. -
+
#### Encoder performance on PyTorch The following figure compares the performances of different features of FasterTransformer and PyTorch TorchScript under FP16 on T4. -For small batch size and sequence length, using FasterTransformer CustomExt can bring about 4x ~ 6x speedup. +For small batch size and sequence length, using FasterTransformer CustomExt can bring about 4x ~ 6x speedup. -For large batch size and sequence length, using Effective FasterTransformer with INT8 quantization can bring about 4x speedup. +For large batch size and sequence length, using Effective FasterTransformer with INT8-v2 quantization can bring about 5x speedup. -
+
### Decoding and Decoder performance -We demonstrate the inference time of FasterTransformer in C++, TensorFlow and PyTorch, and compare to the performance of pure TensorFlow and PyTorch on T4 with FP16. - -For the benchmark of TensorFlow, we compare the performance of TensorFlow (TF), the performance of FasterTransformer OP Decoder, FasterTransformer OP Decoding and the FasterTransformer CPP Decoding. - -We do not demonstrate the performance of TensorFlow with XLA since we did not find that using XLA has obvious speedup. +The results of TensorFlow were obtained by running the `profile_decoding_beamsearch_performance.sh` and `profile_decoding_sampling_performance.sh` -For the benchmark of PyTorch, we compare the performance of PyTorch, the performance of FasterTransformer OP Decoder and FasterTransformer OP Decoding. Due to the dynamic property, it is hard to trace/script the PyTorch decoder/decoding model, so we only test on plain PyTorch. - -The results of C++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_decoder_performance.sh` and `sample/tensorflow/scripts/profile_decoding_performance.sh`. - -The results of PyTorch were obtained by running the `../sample/pytorch/scripts/profile_decoder_decoding.sh`. +The results of PyTorch were obtained by running the `profile_decoder_decoding.sh`. In the experiments of decoding, we updated the following parameters: * head_num = 8 * size_per_head = 64 -* num_layers = 6 +* num_layers = 6 for both encoder and decoder * vocabulary_size = 30000 for TensorFlow sample codes, 31538 for PyTorch sample codes * memory_hidden_dim = 512 +* max sequenc elength = 128 More benchmarks are put in [`docs/decoder_guide.md`](docs/decoder_guide.md#decoding-performance). -#### Decoder and Decoding performance on TensorFlow - -
- -#### Decoder and decoding performance on PyTorch +#### Decoder and Decoding end-to-end translation performance on TensorFlow -
+The following figure shows the speedup of of FT-Decoder op and FT-Decoding op compared to TensorFlow under FP16 with T4. Here, we use the throughput of translating a test set to prevent the total tokens of each methods may be different. Compared to TensorFlow, FT-Decoder provides 1.5x ~ 3x speedup; while FT-Decoding provides 4x ~ 18x speedup. -#### TensorFlow performance on translation +
-We test with batch size 128, beam width 4 on V100. +#### Decoder and Decoding end-to-end translation performance on PyTorch -| Type | tokens per seconds | BLEU | -|:----:|:------------------:|:----:| -| TensorFlow, beam search, FP32 | 2137 | BLEU 26.29 | -| Decoder, beam search, FP32 | 6473 | BLEU 26.29 | -| Decoding, beam search, FP32 | 8513 | BLEU 26.31 | -| TensorFlow, sampling, FP32 | 4178 | BLEU 25.79 | -| Decoder, sampling, FP32 | 10781 | BLEU 25.79 | -| Decoding, sampling, FP32 | 16524 | BLEU 25.79 | -| TensorFlow, beam search, FP16 | 2949 | BLEU 26.31 | -| Decoder, beam search, FP16 | 8682 | BLEU 26.30 | -| Decoding, beam search, FP16 | 12746 | BLEU 26.33 | -| TensorFlow, sampling, FP16 | 6968 | BLEU 25.83 | -| Decoder, sampling, FP16 | 13773 | BLEU 25.80 | -| Decoding, sampling, FP16 | 26718 | BLEU 25.82 | +The following figure shows the speedup of of FT-Decoder op and FT-Decoding op compared to PyTorch under FP16 with T4. Here, we use the throughput of translating a test set to prevent the total tokens of each methods may be different. Compared to PyTorch, FT-Decoder provides 1.2x ~ 3x speedup; while FT-Decoding provides 3.8x ~ 13x speedup. -#### PyTorch performance on translation +
-batch size 128, beam width 4, max_seq_len 32, beam search algorithm on V100: +### GPT performance -| Type | tokens per seconds | BLEU | -|:----:|:------------------:|:----:| -| PyTorch, FP32 | 2462 | BLEU 24.1 | -| Decoder, FP32 | 3358 | BLEU 24.1 | -| Decoding, FP32 | 8959 | BLEU 24.1 | -| PyTorch, FP16 | 4019 | BLEU 24.1 | -| Decoder, FP16 | 4377 | BLEU 24.1 | -| Decoding, FP16 | 15048 | BLEU 24.1 | +The following figure compares the performances of Megatron and FasterTransformer under FP16 on A100. +In the experiments of decoding, we updated the following parameters: - +
## Release notes ### Changelog +April 2021 +- Support multi-gpus and multi-nodes inference for GPT model on C++ and PyTorch. +- Support single node, multi-gpus inference for GPT model on triton. +- Add the int8 fused multi-head attention kernel for bert. +- Add the FP16 fused multi-head attention kernel of V100 for bert. +- Optimize the kernel of decoder. +- Move to independent repo. +- **Release the FasterTransformer 4.0** + Dec 2020 +- Optimize the decoding by adding the finisehd mask to prevent useless computing. +- Support opennmt encoder. +- Remove the TensorRT plugin supporting. - **Release the FasterTransformer 3.1** Nov 2020 @@ -1020,7 +926,9 @@ July 2019 - Undefined symbol errors when import the extension - Please `import torch` first. If this has been done, it is due to the incompatible C++ ABI. You may need to check the PyTorch used during compilation and execution are the same, or you need to check how your PyTorch is compiled, or the version of your GCC, etc. -- batch_size should be smaller or equal to 1024 in Decoder. -- batch_size x beam_width should be smaller or equal to 1024 in Decoding. - Results of TensorFlow and OP would be different in decoding. This problem is caused by the accumulated log probability, and we do not avoid this problem. - If encounter some problem in the custom environment, try to use the gcc/g++ 4.8 to build the project of TensorFlow op, especially for TensorFlow 1.14. + +### TODO + +- Support the decoding sampling in PyTorch. diff --git a/bert-quantization/bert-pyt-quantization/README.md b/bert-quantization/bert-pyt-quantization/README.md index 0d686045d..cf5de5d85 100644 --- a/bert-quantization/bert-pyt-quantization/README.md +++ b/bert-quantization/bert-pyt-quantization/README.md @@ -21,7 +21,8 @@ export MODEL_DIR= git clone https://github.com/NVIDIA/TensorRT.git cd TensorRT git checkout release/7.2 -pip install tools/pytorch-quantization/. +cd tools/pytorch-quantization +pip install . ``` download SQuAD data: @@ -76,7 +77,7 @@ The results would be like: {"exact_match": 82.63, "f1": 89.53} ``` -Then do PTQ, `ft_mode` is unified with int8_mode in FasterTransformer, can be one of `1` or `2`. +Then do PTQ, `quant_mode` is unified with int8_mode in FasterTransformer, can be one of `ft1` or `ft2` or `ft3`. ```bash python run_squad.py \ @@ -101,14 +102,14 @@ python run_squad.py \ --fp16 \ --calibrator percentile \ --percentile 99.999 \ - --ft_mode 2 + --quant_mode ft2 ``` The results would be like: ```bash -{"exact_match": 81.93, "f1": 89.05} # for mode 1 -{"exact_match": 80.41, "f1": 88.15} # for mode 2 +{"exact_match": 81.92, "f1": 89.09} # for mode 1 +{"exact_match": 80.36, "f1": 88.09} # for mode 2 ``` @@ -117,7 +118,7 @@ The results would be like: If PTQ does not yield an acceptable result you can finetune with quantization to recover accuracy. We recommend to calibrate the pretrained model and finetune to avoid overfitting: -`ft_mode` is unified with int8_mode in FasterTransformer, can be one of `1` or `2`. +`quant_mode` is unified with int8_mode in FasterTransformer, can be one of `ft1` or `ft2` or `ft3`. ```bash python run_squad.py \ @@ -138,7 +139,7 @@ python run_squad.py \ --fp16 \ --calibrator percentile \ --percentile 99.99 \ - --ft_mode 2 + --quant_mode ft2 python -m torch.distributed.launch --nproc_per_node=8 run_squad.py \ --init_checkpoint=$MODEL_DIR/bert-base-uncased-calib-mode-2/pytorch_model.bin \ @@ -161,14 +162,14 @@ python -m torch.distributed.launch --nproc_per_node=8 run_squad.py \ --output_dir=$MODEL_DIR/bert-base-uncased-QAT-mode-2 \ --max_steps=-1 \ --fp16 \ - --ft_mode 2 + --quant_mode ft2 ``` The results would be like: ```bash -{"exact_match": 81.91, "f1": 89.09} # for mode 1 -{"exact_match": 81.72, "f1": 89.09} # for mode 2 +{"exact_match": 82.17, "f1": 89.37} # for mode 1 +{"exact_match": 82.02, "f1": 89.30} # for mode 2 ``` The results of quantization may differ if different seeds are provided. @@ -197,7 +198,7 @@ python run_squad.py \ --fp16 \ --calibrator percentile \ --percentile 99.99 \ - --ft_mode 2 + --quant_mode ft2 python -m torch.distributed.launch --nproc_per_node=8 run_squad.py \ --init_checkpoint=$MODEL_DIR/bert-base-uncased-PTQ-mode-2-for-KD/pytorch_model.bin \ @@ -220,7 +221,7 @@ python -m torch.distributed.launch --nproc_per_node=8 run_squad.py \ --output_dir=$MODEL_DIR/bert-base-uncased-QAT-mode-2 \ --max_steps=-1 \ --fp16 \ - --ft_mode 2 \ + --quant_mode ft2 \ --distillation \ --teacher=$MODEL_DIR/bert-base-uncased-finetuned/pytorch_model.bin ``` @@ -228,5 +229,5 @@ python -m torch.distributed.launch --nproc_per_node=8 run_squad.py \ The results would be like: ```bash -{"exact_match": 83.96, "f1": 90.37} +{"exact_match": 83.67, "f1": 90.37} ``` diff --git a/bert-quantization/bert-pyt-quantization/quant_utils.py b/bert-quantization/bert-pyt-quantization/quant_utils.py index b848900c6..2ff7b55ed 100644 --- a/bert-quantization/bert-pyt-quantization/quant_utils.py +++ b/bert-quantization/bert-pyt-quantization/quant_utils.py @@ -58,31 +58,48 @@ def add_arguments(parser): help='percentile for PercentileCalibrator') group.add_argument('--fuse-qkv', action='store_true', help='use the same scale factor for qkv') - group.add_argument('--quant-asymmetric', action='store_true', - help='use an asymmetric integer range for quantization') - group.add_argument('--ft_mode', type=int, default=None, - help='int8 mode in FasterTransformer') + group.add_argument('--narrow_range', action='store_true', + help='use [-127, 127] range for activations rather than [-128, 127]') + group.add_argument('--quant_mode', type=str, default=None, + help='predefined quantization mode, choices: ["ft1", "ft2", "ft3", "trt"]') def set_args(args): - if args.ft_mode == 1: + if args.quant_mode == 'ft1': args.wprec = 8 args.aprec = 8 args.quant_per_tensor = False args.quant_disable = False args.quant_disable_keyword = ['final_input', 'layernorm_input', 'softmax_input', 'residual_input', 'local_input', 'aftergemm'] args.fuse_qkv = False - args.quant_asymmetric = False - elif args.ft_mode == 2: + args.narrow_range = False + elif args.quant_mode == 'ft2': + args.wprec = 8 + args.aprec = 8 + args.quant_per_tensor = True + args.quant_disable = False + args.quant_disable_keyword = ['final_input', 'layernorm_input', 'softmax_input', 'local_input'] + args.fuse_qkv = True + args.narrow_range = False + elif args.quant_mode == 'ft3': args.wprec = 8 args.aprec = 8 args.quant_per_tensor = True args.quant_disable = False args.quant_disable_keyword = ['final_input', 'layernorm_input', 'local_input'] args.fuse_qkv = True - args.quant_asymmetric = False + args.narrow_range = False + elif args.quant_mode == 'trt': + # for demobert + args.wprec = 8 + args.aprec = 8 + args.quant_per_tensor = True + args.quant_disable = False + args.quant_disable_keyword = ['layernorm_input', 'softmax_input', 'aftergemm'] + args.fuse_qkv = True + args.narrow_range = False else: - raise ValueError("wrong argument value for 'ft_mode'") + raise ValueError("wrong argument value for 'quant_mode'") return args def set_default_quantizers(args): @@ -103,7 +120,7 @@ def set_default_quantizers(args): input_desc = QuantDescriptor(num_bits=args.aprec, calib_method=calib_method, - narrow_range=not args.quant_asymmetric, + narrow_range=args.narrow_range, ) weight_desc = QuantDescriptor(num_bits=args.wprec, axis=(None if args.quant_per_tensor else (0,)), diff --git a/bert-quantization/bert-pyt-quantization/run_squad.py b/bert-quantization/bert-pyt-quantization/run_squad.py index 318e736aa..7c72d8fe6 100755 --- a/bert-quantization/bert-pyt-quantization/run_squad.py +++ b/bert-quantization/bert-pyt-quantization/run_squad.py @@ -888,7 +888,7 @@ def main(): help="scale applied to distillation component of loss") args = parser.parse_args() - if args.ft_mode is not None: + if args.quant_mode is not None: args = quant_utils.set_args(args) args.fp16 = args.fp16 or args.amp print(args) diff --git a/bert-quantization/bert-tf-quantization/README.md b/bert-quantization/bert-tf-quantization/README.md index f5ad53749..3a74beae3 100644 --- a/bert-quantization/bert-tf-quantization/README.md +++ b/bert-quantization/bert-tf-quantization/README.md @@ -9,7 +9,7 @@ Modified the following files: * run_squad.py Hardware settings: - * 8 x Tesla V100-SXM2-16GB (with mclk 877MHz, pclk 1530MHz) + * 4 x Tesla V100-SXM2-16GB (with mclk 877MHz, pclk 1530MHz) ## Setup @@ -27,8 +27,7 @@ Download pretrained bert checkpoint. ```bash wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip -O uncased_L-12_H-768_A-12.zip -unzip uncased_L-12_H-768_A-12.zip -mv uncased_L-12_H-768_A-12 squad_model +unzip uncased_L-12_H-768_A-12.zip -d squad_model ``` Download SQuAD dataset @@ -44,7 +43,7 @@ wget -P squad_data https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.j ### Finetune a high precision model with: ```bash -mpirun -np 8 -H localhost:8 \ +mpirun -np 4 -H localhost:4 \ --allow-run-as-root -bind-to none -map-by slot \ -x NCCL_DEBUG=INFO \ -x LD_LIBRARY_PATH \ @@ -59,24 +58,24 @@ mpirun -np 8 -H localhost:8 \ --do_train=True \ --do_predict=True \ --if_quant=False \ - --train_batch_size=4 \ + --train_batch_size=8 \ --learning_rate=1e-5 \ --num_train_epochs=2.0 \ --save_checkpoints_steps 1000 \ --horovod -python ../sample/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_model/finetuned_base/predictions.json +python ../../sample/tensorflow/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_model/finetuned_base/predictions.json ``` The results would be like: ```bash -{"exact_match": 82.03, "f1": 89.55} +{"exact_match": 82.44, "f1": 89.57} ``` ### PTQ by calibrating: -`ft_mode` is unified with int8_mode in FasterTransformer, can be one of `1` or `2`. +`quant_mode` is unified with int8_mode in FasterTransformer, can be one of `ft1` or `ft2` or `ft3`. ```bash python run_squad.py \ @@ -94,16 +93,16 @@ python run_squad.py \ --calib_batch=16 \ --calib_method=percentile \ --percentile=99.999 \ - --ft_mode=2 + --quant_mode=ft2 -python ../sample/tensorflow_bert/squad_evaluate-v1.1.py squad_data/dev-v1.1.json squad_model/PTQ_mode_2/predictions.json +python ../../sample/tensorflow/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_model/PTQ_mode_2/predictions.json ``` The results would be like: ```bash -{"exact_match": 81.68, "f1": 88.97} # for mode 1 -{"exact_match": 80.65, "f1": 88.31} # for mode 2 +{"exact_match": 81.67, "f1": 88.94} # for mode 1 +{"exact_match": 80.44, "f1": 88.30} # for mode 2 ``` @@ -112,7 +111,7 @@ The results would be like: If PTQ does not yield an acceptable result you can finetune with quantization to recover accuracy. We recommend to calibrate the pretrained model and finetune to avoid overfitting: -`ft_mode` is unified with int8_mode in FasterTransformer, can be one of `1` or `2`. +`quant_mode` is unified with int8_mode in FasterTransformer, can be one of `ft1` or `ft2` or `ft3`. ```bash python run_squad.py \ @@ -128,10 +127,10 @@ python run_squad.py \ --calib_batch=16 \ --calib_method=percentile \ --percentile=99.99 \ - --ft_mode=2 + --quant_mode=ft2 -mpirun -np 8 -H localhost:8 \ +mpirun -np 4 -H localhost:4 \ --allow-run-as-root -bind-to none -map-by slot \ -x NCCL_DEBUG=INFO \ -x LD_LIBRARY_PATH \ @@ -146,21 +145,21 @@ mpirun -np 8 -H localhost:8 \ --do_train=True \ --do_predict=True \ --if_quant=True \ - --train_batch_size=4 \ - --learning_rate=5e-6 \ + --train_batch_size=8 \ + --learning_rate=1e-5 \ --num_train_epochs=2.0 \ --save_checkpoints_steps 1000 \ - --ft_mode=2 + --quant_mode=ft2 \ --horovod -python ../sample/tensorflow_bert/squad_evaluate-v1.1.py squad_data/dev-v1.1.json squad_model/QAT_mode_2/predictions.json +python ../../sample/tensorflow/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_model/QAT_mode_2/predictions.json ``` The results would be like: ```bash -{"exact_match": 82.17, "f1": 89.34} # for mode 1 -{"exact_match": 81.99, "f1": 89.14} # for mode 2 +{"exact_match": 82.11, "f1": 89.39} # for mode 1 +{"exact_match": 81.74, "f1": 89.12} # for mode 2 ``` @@ -187,9 +186,9 @@ python run_squad.py \ --calib_batch=16 \ --calib_method=percentile \ --percentile=99.99 \ - --ft_mode=2 + --quant_mode=ft2 -mpirun -np 8 -H localhost:8 \ +mpirun -np 4 -H localhost:4 \ --allow-run-as-root -bind-to none -map-by slot \ -x NCCL_DEBUG=INFO \ -x LD_LIBRARY_PATH \ @@ -204,20 +203,21 @@ mpirun -np 8 -H localhost:8 \ --do_train=True \ --do_predict=True \ --if_quant=True \ - --train_batch_size=4 \ - --learning_rate=5e-6 \ + --train_batch_size=8 \ + --learning_rate=2e-5 \ --num_train_epochs=10.0 \ --save_checkpoints_steps 1000 \ - --ft_mode=2 - --horovod + --quant_mode=ft2 \ + --horovod \ --distillation=True \ --teacher=squad_model/finetuned_base/model.ckpt-5474 -python ../sample/tensorflow_bert/squad_evaluate-v1.1.py squad_data/dev-v1.1.json squad_model/QAT_KD_mode_2/predictions.json +python ../../sample/tensorflow/tensorflow_bert/squad_evaluate_v1_1.py squad_data/dev-v1.1.json squad_model/QAT_KD_mode_2/predictions.json ``` The results would be like: ```bash -{"exact_match": 83.56, "f1": 90.22} +{"exact_match": 84.06, "f1": 90.63} # for mode 1 +{"exact_match": 84.02, "f1": 90.56} # for mode 2 ``` diff --git a/bert-quantization/bert-tf-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/layers/tensor_quantizer.py b/bert-quantization/bert-tf-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/layers/tensor_quantizer.py index c178c9a9a..16492918a 100644 --- a/bert-quantization/bert-tf-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/layers/tensor_quantizer.py +++ b/bert-quantization/bert-tf-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/layers/tensor_quantizer.py @@ -40,6 +40,7 @@ class QuantDescriptor(): Default None. unsigned: A Boolean. If True, use unsigned. Default False. affine: A Boolean. If True, use affine quantization. Default False. + narrow_range: A Boolean. If True, use narrow range. Default True. disable_key_words: A list of string, indicates disabled quantizer. Raises: @@ -51,6 +52,7 @@ class QuantDescriptor(): - num_bits: read-only property. - unsigned: read-only property. - affine: read-only property. + - narrow_range: read-only property. - axis: read-only property. - disable_key_words: read-only property. """ @@ -64,6 +66,7 @@ def __init__(self, collection_name_prefix, num_bits=8, **kwargs): self._unsigned = kwargs.pop('unsigned', False) self._affine = kwargs.pop('affine', False) + self._narrow_range = kwargs.pop('narrow_range', True) self._axis = kwargs.pop('axis', None) self._collection_name_prefix = collection_name_prefix @@ -92,6 +95,10 @@ def unsigned(self): def affine(self): return self._affine + @property + def narrow_range(self): + return self._narrow_range + @property def axis(self): return self._axis @@ -140,6 +147,7 @@ def __init__(self, quant_desc: QuantDescriptor, scope_name="tensor_quantizer", i self._axis = quant_desc.axis self._unsigned = quant_desc.unsigned self._affine = quant_desc.affine + self._narrow_range = quant_desc.narrow_range self._collection_name_prefix = quant_desc.collection_name_prefix self._scope_name = scope_name self._disable_key_words = quant_desc._disable_key_words @@ -189,7 +197,7 @@ def __call__(self, inputs): if self._if_quant: outputs = fake_quantize(inputs, self._quant_min, self._quant_max, self._num_bits, self._axis, self._unsigned, - self._affine) + self._affine, self._narrow_range) else: outputs = inputs diff --git a/bert-quantization/bert-tf-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/ops/fake_quantize.py b/bert-quantization/bert-tf-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/ops/fake_quantize.py index 8fe17e023..ca744fa77 100644 --- a/bert-quantization/bert-tf-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/ops/fake_quantize.py +++ b/bert-quantization/bert-tf-quantization/ft-tensorflow-quantization/ft_tensorflow_quantization/python/ops/fake_quantize.py @@ -22,7 +22,7 @@ __all__ = ["fake_quantize"] -def fake_quantize(inputs, quant_min=None, quant_max=None, num_bits=8, axis=None, unsigned=False, affine=False): +def fake_quantize(inputs, quant_min=None, quant_max=None, num_bits=8, axis=None, unsigned=False, affine=False, narrow_range=True): """Universal tensor fake quantization function Args: @@ -35,6 +35,7 @@ def fake_quantize(inputs, quant_min=None, quant_max=None, num_bits=8, axis=None, Default None, indicates per tensor quantization. unsigned: A boolean. If True, use unsigned int8. Default False. affine: A boolean. If True, use affine quantization. Default False. + narrow_range: A boolean. If True, use narrow range. Default False. Returns: outputs: A Tensor with same type as inputs @@ -89,6 +90,12 @@ def fake_quantize_core(inputs, quant_min, quant_max): def _scaled_fake_quantize(inputs, quant_min, quant_max): # TODO(haow): Add check for negative values in inputs if unsigned bound = 2.0**(num_bits - 1 + int(unsigned)) - 1.0 + if unsigned: + min_bound = 0 + elif narrow_range: + min_bound = -bound + else: + min_bound = -bound - 1 quant_amax = tf.maximum(tf.abs(quant_min), tf.abs(quant_max)) scale = bound / quant_amax @@ -96,7 +103,7 @@ def _scaled_fake_quantize(inputs, quant_min, quant_max): # Value quantized with quant_amax=0 should all be 0, thus set scale to 1 scale = tf.compat.v2.where(tf.math.less_equal(quant_amax, epsilon), tf.constant(1.), scale) - quantized = tf.clip_by_value(tf.math.round(inputs * scale), -bound, bound) + quantized = tf.clip_by_value(tf.math.round(inputs * scale), min_bound, bound) outputs = quantized / scale return outputs diff --git a/bert-quantization/bert-tf-quantization/modeling.py b/bert-quantization/bert-tf-quantization/modeling.py index 25216a5c7..982bcaeb4 100644 --- a/bert-quantization/bert-tf-quantization/modeling.py +++ b/bert-quantization/bert-tf-quantization/modeling.py @@ -229,7 +229,9 @@ def __init__(self, do_return_all_layers=True, if_quant=if_quant) - self.sequence_output = self.all_encoder_layers[-1] + # self.sequence_output = tf.cast(self.all_encoder_layers[-1], tf.float32) + final_input_quantizer = FakeQuantizer(QuantDense.default_quant_desc_input, 'final_input_quantizer', if_quant) + self.sequence_output = tf.cast(final_input_quantizer(self.all_encoder_layers[-1]), tf.float32) # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level diff --git a/bert-quantization/bert-tf-quantization/run_squad.py b/bert-quantization/bert-tf-quantization/run_squad.py index 1375d4445..10daddb14 100644 --- a/bert-quantization/bert-tf-quantization/run_squad.py +++ b/bert-quantization/bert-tf-quantization/run_squad.py @@ -168,21 +168,36 @@ flags.DEFINE_string("calib_method", "percentile", "calibration method [percentile, mse, max, entropy]") flags.DEFINE_float("percentile", 99.99, "percentile for percentile calibrator") flags.DEFINE_string("calibrator_file", "calibrators.pkl", "pickle file for calibrators") -flags.DEFINE_integer("ft_mode", 1, "int8 mode in FasterTransformer.") +flags.DEFINE_string("quant_mode", 'ft2', "predefined quantization mode, choices: ['ft1', 'ft2', 'ft3', 'trt']") flags.DEFINE_bool("distillation", False, "Whether or not to use the techer-student model for finetuning (Knowledge distillation)") flags.DEFINE_string("teacher", None, "teacher checkpoint file for distillation") flags.DEFINE_float("distillation_loss_scale", 10000., "scale applied to distillation component of loss") -if FLAGS.ft_mode == 1: +if FLAGS.quant_mode == 'ft1': KERNEL_AXIS = 1 - DISABLE_LIST = ['aftergemm', 'softmax_input', 'residual_input', 'local_input'] -elif FLAGS.ft_mode == 2: + ACTIVATION_NARROW_RANGE = False + DISABLE_LIST = ['aftergemm', 'softmax_input', 'residual_input', 'local_input', 'final_input'] + FUSE_QKV = False +elif FLAGS.quant_mode == 'ft2': KERNEL_AXIS = None - DISABLE_LIST = ['local_input'] + ACTIVATION_NARROW_RANGE = False + DISABLE_LIST = ['local_input', 'softmax_input', 'final_input'] + FUSE_QKV = True +elif FLAGS.quant_mode == 'ft3': + KERNEL_AXIS = None + ACTIVATION_NARROW_RANGE = False + DISABLE_LIST = ['local_input', 'final_input'] + FUSE_QKV = True +elif FLAGS.quant_mode == 'trt': + # for demobert + KERNEL_AXIS = None + ACTIVATION_NARROW_RANGE = False + DISABLE_LIST = ['aftergemm', 'softmax_input'] + FUSE_QKV = True else: - raise ValueError("wrong argument value for 'ft_mode'") -input_desc = QuantDescriptor('input', disable_key_words=DISABLE_LIST) + raise ValueError("wrong argument value for 'quant_mode'") +input_desc = QuantDescriptor('input', narrow_range=ACTIVATION_NARROW_RANGE, disable_key_words=DISABLE_LIST) kernel_desc = QuantDescriptor('kernel', axis=KERNEL_AXIS, disable_key_words=DISABLE_LIST) QuantDense.set_default_quant_desc_input(input_desc) QuantDense.set_default_quant_desc_kernel(kernel_desc) @@ -248,7 +263,7 @@ def end(self, session): for calibrator in self.calibrator_lists['kernel']: calibrator.compute_range('max') - if FLAGS.ft_mode == 2: + if FUSE_QKV: tf.compat.v1.logging.info("fusing QKV") for i in range(self.layer_num): prefix = f"bert/encoder/layer_{i}/attention/self" diff --git a/cmake/FasterTransformerConfig.cmake.in b/cmake/FasterTransformerConfig.cmake.in new file mode 100644 index 000000000..73295c587 --- /dev/null +++ b/cmake/FasterTransformerConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + FASTERTRANSFORMER_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${FASTERTRANSFORMER_CMAKE_DIR}) + +if(NOT TARGET transformer-shared) + include("${FASTERTRANSFORMER_CMAKE_DIR}/FasterTransformerTargets.cmake") +endif() + +set(FASTERTRANSFORMER_LIBRARIES transformer-shared) diff --git a/cmake/Modules/FindNCCL.cmake b/cmake/Modules/FindNCCL.cmake new file mode 100644 index 000000000..ed4dbb986 --- /dev/null +++ b/cmake/Modules/FindNCCL.cmake @@ -0,0 +1,88 @@ +# Find the nccl libraries +# +# The following variables are optionally searched for defaults +# NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou… +# NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo +# NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIRS +# NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks +# install NCCL in the same location as the CUDA toolkit. +# See https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers") +set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries") +set(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING "Version of NCCL to build with") + +if ($ENV{NCCL_ROOT_DIR}) + message(WARNING "NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.") +endif() +list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) +# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT}) + +find_path(NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR}) + +if (USE_STATIC_NCCL) + MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.") + SET(NCCL_LIBNAME "nccl_static") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(NCCL_LIBNAME "nccl") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +endif() + +find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + +if(NCCL_FOUND) # obtaining NCCL version and some sanity checks + set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) + + if (NCCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") + file(WRITE ${file} " + #include + #include + int main() + { + std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; + int x; + ncclGetVersion(&x); + return x == NCCL_VERSION_CODE; + } +") + try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER + LINK_LIBRARIES ${NCCL_LIBRARIES}) + if (NOT NCCL_VERSION_MATCHED) + message(FATAL_ERROR "Found NCCL header version and library version do not match! \ +(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") + endif() + message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "NCCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/docs/QAList.md b/docs/QAList.md index d32e80c08..64bceea2f 100644 --- a/docs/QAList.md +++ b/docs/QAList.md @@ -20,7 +20,7 @@ Another method is using multi-threading on the same TensorFlow graph and session ### 5. Which GPUs are supported in FasterTransformer? ### -We have verified the correctness and performance for GPUs with Compute Compatibility >= 7.0 such as V100 and T4. A100 also works, but still have some performance issue for small batch size. +We have verified the correctness and performance for GPUs with Compute Compatibility >= 7.0 such as V100, T4 and A100. ### 6. Do the users only be able to use the docker image we recommend? ### @@ -36,4 +36,4 @@ In C, users need to load the model by themselves and copy into GPU memory. In TensorFlow or PyTorch, users can load the checkpoint and put the weight tensor into FasterTransformer directly. Users can also load the model in other formats, like numpy, and put them into FasterTransformer directly like the weight tensor. - +The multi-gpu inference of GPT is special. FasterTransformer provides a tool to convert the checkpoint of OpenAI and Megatron, and then load the converted model by FasterTransformer directly. diff --git a/docs/decoder_guide.md b/docs/decoder_guide.md index db0edd822..77df0f56d 100644 --- a/docs/decoder_guide.md +++ b/docs/decoder_guide.md @@ -1,6 +1,6 @@ # FasterTransformer Decoder -The FasterTransformer Decoder contains the transformer decoder block, whole decoding progress, and GPT-2 model. +The FasterTransformer Decoder contains the transformer decoder block, whole decoding progress, and GPT model. ## Table Of Contents @@ -10,30 +10,26 @@ The FasterTransformer Decoder contains the transformer decoder block, whole deco - [Decoder](#decoder) - [Decoding progress](#decoding-progress) - [Decoder and Decoding](#decoder-and-decoding) - - [GPT-2](#gpt-2) + - [GPT](#gpt) - [Setup](#setup) - [Requirements](#requirements) - [How to use](#how-to-use) - [Decoder and decoding process](#decoder-and-decoding-process) - [Translation process](#translation-process) - - [GPT-2 process](#gpt-2-process) - [Performance](#performance) - - [Decoder performance](#decoder-performance) - - [Decoder performance on T4 and TensorFlow](#decoder-performance-on-t4-and-tensorflow) - - [Decoder performance on V100 and TensorFlow](#decoder-performance-on-v100-and-tensorflow) - - [Decoding performance](#decoding-performance) - - [Decoding performance on T4 and TensorFlow](#decoding-performance-on-t4-and-tensorflow) - - [Decoding performance on V100 and TensorFlow](#decoding-performance-on-v100-and-tensorflow) - - [Decoder and decoding performance on T4 and PyTorch](#decoder-and-decoding-performance-on-t4-and-pytorch) - - [Decoder and decoding performance on V100 and PyTorch](#decoder-and-decoding-performance-on-v100-and-pytorch) - - [TensorFlow performance on translation](#tensorflow-performance-on-translation) - - [PyTorch performance on translation](#pytorch-performance-on-translation) - - [GPT-2 performance on V100 and TensorFlow](#gpt-2-performance-on-v100-and-tensorflow) + - [End to end translation performance on TensorFlow](#end-to-end-translation-performance-on-tensorflow) + - [Beamsearch performance on V100 and TensorFlow](#beamsearch-performance-on-v100-and-tensorflow) + - [Sampling performance on V100 and TensorFlow](#sampling-performance-on-v100-and-tensorflow) + - [Beamsearch performance on T4 and TensorFlow](#beamsearch-performance-on-t4-and-tensorflow) + - [Sampling performance on T4 and TensorFlow](#sampling-performance-on-t4-and-tensorflow) + - [End to end translation performance on PyTorch](#end-to-end-translation-performance-on-pytorch) + - [Beamsearch performance on V100 and PyTorch](#beamsearch-performance-on-v100-and-pytorch) + - [Beamsearch performance on T4 and PyTorch](#beamsearch-performance-on-t4-and-pytorch) ## Model architecture
-
Fig. 1 Flowchart of Decoding and GPT-2.
+
Fig. 1 Flowchart of Decoding and GPT.
### Decoder @@ -112,28 +108,9 @@ Note that K and P cannot be zero or non-zero value at the same time. FasterTrans Although the decoding process of most methods is similar, we find that there are lots of different kinds to compute the probability and implement the beam search. Therefore, if your chosen beam search algorithm is different from our implementation and it is hard for you to modify the beam search kernel, TensorFlow decoding with FasterTransformer Decoder is the recommended choice. However, the performance of the TensorFlow decoding with the FasterTransformer Decoder is worse than the performance of the FasterTransformer Decoding, especially for small batch sizes. -### GPT-2 +### GPT -The GPT-2 model is based on This project is based on [OpenAI gpt-2 project](https://github.com/openai/gpt-2). GPT-2 is a special case of Decoding, it does not require the cross-attention block and the results from encoder. Users can put some started words into GPT-2, and GPT-2 will use these words to generate the next word. By this method, GPT-2 can translate the sentence, reply the questions, and do many different applications. Fig. 1 shows the difference between GPT-2 standard decoding model. - -* Arguments: - 1. batch size (B) - 2. Top k value (K) - 3. Top p value (P) - 4. maximum sequence length (S) - 5. Head number (H) - 6. Size per head (N) - 7. Number of decoder layers - 8. Start id of the vocabulary - 9. Start ids, which are converted from the input sentences. If “start ids” is empty list, then FasterTransformer will use the “start id” as the first token. - 10. End id of the vocabulary - 11. temperature: -* Inputs: - 1. The table for embedding lookup. The shape is \[ V, H x N \]. - 2. The weights of all parameters. - 3. Position encoding table. The shape is \[ S, H x N \]. -* Outputs: - 1. The output ids. The shape is \[ B \]. +The GPT model is based on [OpenAI gpt-2 project](https://github.com/openai/gpt-2). GPT is a special case of Decoding, it does not require the cross-attention block and the results from encoder. Users can put some started words into GPT, and GPT will use these words to generate the next word. By this method, GPT can translate the sentence, reply the questions, and do many different applications. Fig. 1 shows the difference between GPT standard decoding model. More details are put in [`docs/gpt_guide.md`](gpt_guide.md). ## Setup @@ -539,48 +516,6 @@ For those unable to use the NGC container, to set up the required environment or Note that the results of FasterTransformer may be different, especially when the batch size is larger. - 2.5 Run FasterTransformer encoder and decoder/decoding on TensorFlow at the same time - - In this subsection, we demonstrate how to use the FasterTransformer encoder and decoder/decoding at the same time. - - ```bash - ./bin/encoder_gemm 32 32 8 64 0 0 - ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 - python tensorflow/encoder_decoder_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --encoder_head_number 8 \ - --encoder_size_per_head 64 \ - --decoder_head_number 8 \ - --decoder_size_per_head 64 \ - --vocab_size 30000 \ - --max_seq_len 32 \ - --encoder_num_layer 6 \ - --decoder_num_layer 6 \ - --data_type fp32 - ``` - - The `encoder_decoder_sample.py` files show the results of "TensorFlow encoder + FasterTransformer decoder" and the results of "FasterTransformer encoder + FasterTransformer decoder. The usage is similar to `decoder_sample.py`. - - ```bash - ./bin/encoder_gemm 32 32 8 64 0 0 - ./bin/decoding_gemm 32 4 8 64 30000 32 512 0 - python tensorflow/encoder_decoding_sample.py \ - --batch_size 32 \ - --beam_width 4 \ - --encoder_head_number 8 \ - --encoder_size_per_head 64 \ - --decoder_head_number 8 \ - --decoder_size_per_head 64 \ - --vocab_size 30000 \ - --max_seq_len 32 \ - --encoder_num_layer 6 \ - --decoder_num_layer 6 \ - --data_type fp32 - ``` - - For convenience, we only show how to use the FasterTransformer encoder and decoding with beam search in the `encoder_decoding_sample.py`. The usage is like `decoding_sample.py`. - 3. Run FasterTransformer decoder/decoding on PyTorch Please install OpenNMT-py first before running the demos by @@ -602,7 +537,7 @@ For those unable to use the NGC container, to set up the required environment or python pytorch/decoder_sample.py <--fp16> <--time> python pytorch/decoder_sample.py 8 6 32 8 64 --fp16 --time ``` - Remove `--fp16` for fp32 mode. `--ths` will use the TorchScript custom class. + Remove `--fp16` for fp32 mode. The outputs should be like to the following: @@ -621,7 +556,7 @@ For those unable to use the NGC container, to set up the required environment or python pytorch/decoding_sample.py <--fp16> <--time> python pytorch/decoding_sample.py 8 6 32 8 64 4 31538 --fp16 --time ``` - Remove `--fp16` for fp32 mode. `--ths` will use the TorchScript custom class. + Remove `--fp16` for fp32 mode. The outputs should be like to the following: @@ -729,11 +664,11 @@ For those unable to use the NGC container, to set up the required environment or ```bash python pytorch/run_translation.py --batch_size --beam_size --model_type --data_type --output_file ``` - you can also use `--module_path` to set the FasterTransformer module `.so` file path, and use `--input_file` to set the input file to be translated. + you can also use `--input_file` to set the input file to be translated. the `` can be: - - `ori`: original OpenNMT model - - `decoder_ext`: replace the decoder in OpenNMT model with our FasterTransformer decoder + + - `decoding_ext`: using our FasterTransformer decoding module - `torch_decoding`: PyTorch version decoding with the method FasterTransformer decoding uses - `torch_decoding_with_decoder_ext`: PyTorch version decoding with the method FasterTransformer decoding uses but replace the decoder with the FasterTransformer decoder @@ -766,168 +701,6 @@ For those unable to use the NGC container, to set up the required environment or cat debpe_output.txt | sacrebleu debpe_ref.txt ``` -### GPT-2 process - -1. Install required tools - -```bash -pip install -r tensorflow/utils/gpt2_requirement.txt -``` - -2. Download the gpt-2 model by following scripts. - -```bash -python tensorflow/utils/download_gpt2_model.py -e.g. python tensorflow/utils/download_gpt2_model.py 124M -``` - -In the repo of OpenAI, they provide many models, including `124M`, `355M`, `774M` and `1558M` - -Next, using the following script to convert the model to csv files, which are used in C++ sample codes. - -```bash -mkdir tmp -python tensorflow/utils/convert_tf_ckpt_to_csv.py -o tmp -i models/124M/model.ckpt -``` - -3. Run GPT-2 - - 3.1 Generate the `decoding_gemm_config.in` file. - - - `./bin/decoding_gemm` can generate the best GEMM configuration. The arguments of `decoding_gemm` are: - - ```bash - ./bin/decoding_gemm - ``` - - Assume the settings of decoding are as follows (the hyper-parameters of GPT-2 124M model). - - - `batch_size`=4 - - `top k`=4 - - `top p`=0.6 - - `head_number`=12 - - `size_per_head`=64 - - `vocabulary_size`=50257 - - `sequence_length`=32 - - `data_type`=FP32 - - Then the following scripts can generate the best GEMM configuration under such settings, and record the configuration into the `decoding_gemm_config.in` file. - - ```bash - ./bin/decoding_gemm 4 1 12 64 50257 32 768 0 - ``` - - Note that the beam width of sampling algorithm is always 1, so we need to generate the new configuration. - - 3.2 Run GPT-2 under FP32 on C++ - - Assume the settings are the same as above, and the decoder contains 12 transformer layers. - - The arguments of `gpt2_sample` is - - ```bash - ./bin/decoding_sample - ``` - - where `candidate_num` is the k value of top k, while `probability_threshold` is the p value of top p. - - ```bash - ./bin/gpt2_sample 4 4 0.6 12 64 50257 32 12 0 - ``` - - The outputs should be like to the following: - - ```bash - Device Tesla V100-PCIE-32GB - [INFO] batch_size 4 head_num 12 size_per_head 64 seq_len 32 decoder_layers 12 vocab_size 50257 FT-CPP-gpt2-time 52.67 ms - ``` - - 3.3 Run GPT-2 under FP16 on C++ - - So far, we use the FP32 to run the FasterTransformer. If we use the volta or newer NVIDIA GPU, we can use tensor core to accelerate when we use the FP16. - - To use the FP16, we only need to set the `` flag to 1 like following: - - ```bash - ./bin/decoding_gemm 4 1 12 64 50257 32 768 1 - ./bin/gpt2_sample 4 4 0.6 12 64 50257 32 12 1 - ``` - - Note that the configuration of FP32 and FP16 are different, so we need to generate the configuration again. - - The outputs should be like to the following: - - ```bash - Device Tesla V100-PCIE-32GB - [INFO] batch_size 4 head_num 12 size_per_head 64 seq_len 32 decoder_layers 12 vocab_size 50257 FT-CPP-gpt2-time 40.65 ms - ``` - - 3.4 Run FasterTransformer GPT-2 under FP32 on TensorFlow - - ```bash - ./bin/decoding_gemm 4 1 12 64 50257 32 768 0 - python tensorflow/gpt2_sample.py --batch_size=4 \ - --length=32 \ - --top_k=4 \ - --top_p=0.6 \ - --data_type=fp32 - ``` - - The outputs should be like to the following: - - ```bash - [INFO] FT op time: 57.9542 - ======================================== SAMPLE 1 ======================================== - The first of three films from the acclaimed director, who is also a producer on the upcoming Star Wars: The Force Awakens, is set to be released on December 8 - ======================================== SAMPLE 2 ======================================== - - The first time I saw the new "The Last of Us" trailer, I was blown away. I had never seen anything like it before. It was so - ======================================== SAMPLE 3 ======================================== - - A new study from the University of California, Berkeley, and the University of California, Santa Barbara found that the number of people who have been diagnosed with schizophrenia in - ======================================== SAMPLE 4 ======================================== - - The U.S. government has been accused of using a secret surveillance program to spy on American citizens. - - The Justice Department has said that the program was - ``` - - It generates four different sentences due to different random seeds. - - 3.5 Run FasterTransformer GPT-2 under FP16 on TensorFlow - - ```bash - ./bin/decoding_gemm 4 1 12 64 50257 32 768 1 - python tensorflow/gpt2_sample.py --batch_size=4 \ - --length=32 \ - --top_k=4 \ - --top_p=0.6 \ - --data_type=fp16 - ``` - - The outputs should be like to the following: - - ```bash - [INFO] FT op time: 47.6268 - ======================================== SAMPLE 1 ======================================== - - The U.S. Department of Justice has filed a lawsuit against the company that owns the video game company, Electronic Arts, alleging that the company violated antitrust laws - ======================================== SAMPLE 2 ======================================== - The U.S. government has been trying to find ways to prevent the spread of Zika virus in the Americas, but it has been unable to do so. - - ======================================== SAMPLE 3 ======================================== - - The U.S. Department of Justice has filed a lawsuit against the company that owns the video game company, Electronic Arts, alleging that the company violated antitrust laws - ======================================== SAMPLE 4 ======================================== - - The first of two episodes of the new season of "The Walking Dead" is here, and it's a great one. - - The first episode of the - ``` - - User can check the arguments in `sample/tensorflow/gpt2_sample.py` file. - ## Performance Hardware settings: @@ -941,532 +714,219 @@ To run the following benchmark, we need to install the unix computing tool "bc" apt-get install bc ``` -### Decoder performance +To understand the speedup on real application, we use real end to end model and task in this benchmark on both TensorFlow and PyTorch. It is hard to compare the performance of v3.1 and v4.0 this the benchmark directly. But by our testing, compared to v3.1, v4.0 brings at most 50% speedup, especially for large batch size. -We demonstrate the inference time of FasterTransformer in C++, TensorFlow, and compare to the performance of pure TensorFlow on T4 and V100. The performance of PyTorch are put in the "Decoding performance" subsection. +### End to end translation performance on TensorFlow -In this benchmark, we compare the performance of TensorFlow decoding with beam search method (TF), and the performance of replacing the decoder of TensorFlow by FasterTransformer (FT-OP), and show the speedup of FT-OP compare to TF. +We demonstrate the throughput of TensorFlow (`TF`), `FT Decoder` and `FT Decoding` for end to end translation. Here, TensorFlow means that the program fully runs on TensorFlow. FT Decoder means that we replace the decoder transformer layer by FasterTransformer. FT Decoding means that we replace the whole procedure of decoder by FasterTransformer. Besides, we also replace the encoder transformer layer by FasterTransformer Encoder in FT Decoding. -We do not demonstrate the performance of TensorFlow with XLA since we did not find that using XLA has obvious speedup. +We do not demonstrate the performance of TensorFlow with XLA since we did not find that using XLA has obvious speedup. We also skip the BLEU score because the score of TensorFlow, FT Decoder and FT Decoding are close. -Our results of C++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_decoder_performance.sh` +Althought the bleu scores of all methods are close, the results may be little different and the number of generated tokens may be also different. So, we use throughput but not latency to show the peformance in this benchmark. -In the experiments of decoding, we updated the following parameters: +The benchmark of beamsearch were obtained by running the `sample/tensorflow/scripts/profile_decoding_beamsearch_performance.sh`; while The benchmark of sampling were obtained by running the `sample/tensorflow/scripts/profile_decoding_sampling_performance.sh`.. -* head_num = 8 -* size_per_head = 64 -* num_layers = 6 -* vocabulary_size = 30000 for TensorFlow sample codes, 31538 for PyTorch sample codes -* memory_hidden_dim = 512 +In this benchmark, we updated the following parameters: -#### Decoder performance on T4 and TensorFlow - -* Performance on FP32 - -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:| -| <1, 1, 32> | 509.16 | 107.98 | 4.71 | -| <1, 1, 64> | 951.49 | 223.69 | 4.25 | -| <1, 1, 128> | 1943.97 | 425.28 | 4.57 | -| <1, 4, 32> | 497.88 | 126.70 | 3.92 | -| <1, 4, 64> | 1050.92 | 243.64 | 4.31 | -| <1, 4, 128> | 2064.92 | 508.16 | 4.06 | -| <8, 1, 32> | 510.90 | 125.96 | 4.05 | -| <8, 1, 64> | 995.81 | 244.18 | 4.07 | -| <8, 1, 128> | 2041.21 | 479.02 | 4.26 | -| <8, 4, 32> | 539.70 | 129.21 | 4.17 | -| <8, 4, 64> | 1100.77 | 267.75 | 4.11 | -| <8, 4, 128> | 2100.58 | 558.91 | 3.75 | -| <32, 1, 32> | 575.80 | 123.16 | 4.67 | -| <32, 1, 64> | 1070.51 | 251.52 | 4.25 | -| <32, 1, 128> | 2172.67 | 554.32 | 3.91 | -| <32, 4, 32> | 673.70 | 204.51 | 3.29 | -| <32, 4, 64> | 1335.84 | 492.47 | 2.71 | -| <32, 4, 128> | 3136.18 | 1331.35 | 2.35 | -| <64, 1, 32> | 582.22 | 142.49 | 4.08 | -| <64, 1, 64> | 1243.74 | 312.54 | 3.97 | -| <64, 1, 128> | 2420.20 | 791.30 | 3.05 | -| <64, 4, 32> | 850.54 | 350.63 | 2.42 | -| <64, 4, 64> | 1833.49 | 874.46 | 2.09 | -| <64, 4, 128> | 4586.01 | 2450.19 | 1.87 | -| <128, 1, 32> | 656.85 | 208.91 | 3.14 | -| <128, 1, 64> | 1461.70 | 499.76 | 2.92 | -| <128, 1, 128> | 3209.60 | 1361.95 | 2.35 | -| <128, 4, 32> | 1260.55 | 656.29 | 1.92 | -| <128, 4, 64> | 2875.73 | 1663.91 | 1.72 | -| <128, 4, 128> | 8018.63 | 4718.32 | 1.69 | - -* Performance on FP16 +* head_num = 8 for both encoder and decoder +* size_per_head = 64 for both encoder and decoder +* num_layers = 6 for both encoder and decoder +* vocabulary_size = 32001 +* max_seq_len = 128 -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:| -| <1, 1, 32> | 400.02 | 121.19 | 3.30 | -| <1, 1, 64> | 823.41 | 233.93 | 3.51 | -| <1, 1, 128> | 1616.38 | 422.73 | 3.82 | -| <1, 4, 32> | 476.33 | 128.45 | 3.70 | -| <1, 4, 64> | 868.67 | 261.18 | 3.32 | -| <1, 4, 128> | 1857.95 | 464.51 | 3.99 | -| <8, 1, 32> | 452.70 | 119.73 | 3.78 | -| <8, 1, 64> | 906.15 | 222.74 | 4.06 | -| <8, 1, 128> | 1789.19 | 428.80 | 4.17 | -| <8, 4, 32> | 484.09 | 127.14 | 3.80 | -| <8, 4, 64> | 973.28 | 252.81 | 3.84 | -| <8, 4, 128> | 1907.93 | 527.98 | 3.61 | -| <32, 1, 32> | 476.66 | 124.72 | 3.82 | -| <32, 1, 64> | 933.16 | 240.70 | 3.87 | -| <32, 1, 128> | 1953.02 | 518.10 | 3.76 | -| <32, 4, 32> | 607.62 | 159.24 | 3.81 | -| <32, 4, 64> | 1280.93 | 352.51 | 3.63 | -| <32, 4, 128> | 2511.20 | 882.21 | 2.84 | -| <64, 1, 32> | 501.07 | 135.40 | 3.70 | -| <64, 1, 64> | 1020.40 | 281.34 | 3.62 | -| <64, 1, 128> | 2243.14 | 627.33 | 3.57 | -| <64, 4, 32> | 692.42 | 213.80 | 3.23 | -| <64, 4, 64> | 1517.27 | 542.75 | 2.79 | -| <64, 4, 128> | 3351.21 | 1554.97 | 2.15 | -| <128, 1, 32> | 593.39 | 163.73 | 3.62 | -| <128, 1, 64> | 1258.93 | 358.26 | 3.51 | -| <128, 1, 128> | 2672.11 | 910.34 | 2.93 | -| <128, 4, 32> | 989.35 | 364.63 | 2.71 | -| <128, 4, 64> | 2216.00 | 962.84 | 2.30 | -| <128, 4, 128> | 5515.29 | 2913.02 | 1.89 | - -#### Decoder performance on V100 and TensorFlow +#### Beamsearch performance on V100 and TensorFlow * Performance on FP32 -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:| -| <1, 1, 32> | 239.38 | 68.88 | 3.47 | -| <1, 1, 64> | 500.20 | 133.88 | 3.73 | -| <1, 1, 128> | 1021.87 | 261.55 | 3.90 | -| <1, 4, 32> | 242.70 | 74.93 | 3.23 | -| <1, 4, 64> | 509.43 | 145.60 | 3.49 | -| <1, 4, 128> | 893.73 | 296.82 | 3.01 | -| <8, 1, 32> | 241.06 | 68.85 | 3.50 | -| <8, 1, 64> | 494.16 | 145.88 | 3.38 | -| <8, 1, 128> | 1028.89 | 285.51 | 3.60 | -| <8, 4, 32> | 274.33 | 73.38 | 3.73 | -| <8, 4, 64> | 534.15 | 152.04 | 3.51 | -| <8, 4, 128> | 1090.66 | 321.77 | 3.38 | -| <32, 1, 32> | 249.78 | 71.74 | 3.48 | -| <32, 1, 64> | 527.18 | 150.84 | 3.49 | -| <32, 1, 128> | 1053.79 | 313.93 | 3.35 | -| <32, 4, 32> | 313.01 | 114.31 | 2.73 | -| <32, 4, 64> | 666.00 | 252.23 | 2.64 | -| <32, 4, 128> | 1376.10 | 593.28 | 2.31 | -| <64, 1, 32> | 288.73 | 86.66 | 3.33 | -| <64, 1, 64> | 553.34 | 177.65 | 3.11 | -| <64, 1, 128> | 1125.72 | 404.00 | 2.78 | -| <64, 4, 32> | 377.06 | 156.55 | 2.40 | -| <64, 4, 64> | 806.34 | 373.36 | 2.15 | -| <64, 4, 128> | 1913.47 | 974.17 | 1.96 | -| <128, 1, 32> | 319.11 | 110.49 | 2.88 | -| <128, 1, 64> | 666.36 | 243.54 | 2.73 | -| <128, 1, 128> | 1426.32 | 591.99 | 2.40 | -| <128, 4, 32> | 528.52 | 256.18 | 2.06 | -| <128, 4, 64> | 1215.82 | 620.55 | 1.95 | -| <128, 4, 128> | 3167.89 | 1733.38 | 1.82 | +| Batch Size | Beam Width | Precision | TF
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:----------:|:---------:|:-------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 1 | FP32 | 95 | 351 | 800 | 3.69 | 8.42 | +| 1 | 4 | FP32 | 110 | 341 | 763 | 3.10 | 6.93 | +| 1 | 32 | FP32 | 78 | 171 | 489 | 2.19 | 6.26 | +| 8 | 1 | FP32 | 484 | 1645 | 3694 | 3.39 | 7.63 | +| 8 | 4 | FP32 | 511 | 1435 | 3068 | 2.80 | 6.00 | +| 8 | 32 | FP32 | 231 | 427 | 916 | 1.84 | 3.96 | +| 128 | 1 | FP32 | 3157 | 8373 | 19803 | 2.65 | 6.27 | +| 128 | 4 | FP32 | 1773 | 3648 | 7848 | 2.05 | 4.42 | * Performance on FP16 -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:| -| <1, 1, 32> | 209.70 | 70.37 | 2.97 | -| <1, 1, 64> | 423.41 | 141.34 | 2.99 | -| <1, 1, 128> | 775.10 | 287.64 | 2.69 | -| <1, 4, 32> | 215.05 | 81.37 | 2.64 | -| <1, 4, 64> | 449.72 | 146.28 | 3.07 | -| <1, 4, 128> | 910.03 | 291.50 | 3.12 | -| <8, 1, 32> | 226.01 | 68.60 | 3.29 | -| <8, 1, 64> | 437.30 | 153.32 | 2.85 | -| <8, 1, 128> | 915.96 | 286.39 | 3.19 | -| <8, 4, 32> | 248.44 | 75.81 | 3.27 | -| <8, 4, 64> | 463.51 | 154.71 | 2.99 | -| <8, 4, 128> | 960.88 | 293.46 | 3.27 | -| <32, 1, 32> | 233.93 | 69.80 | 3.35 | -| <32, 1, 64> | 482.73 | 147.54 | 3.27 | -| <32, 1, 128> | 922.02 | 294.40 | 3.13 | -| <32, 4, 32> | 279.34 | 88.29 | 3.16 | -| <32, 4, 64> | 582.95 | 193.42 | 3.01 | -| <32, 4, 128> | 1198.26 | 454.66 | 2.63 | -| <64, 1, 32> | 245.73 | 76.29 | 3.22 | -| <64, 1, 64> | 463.44 | 158.65 | 2.92 | -| <64, 1, 128> | 1007.24 | 332.69 | 3.02 | -| <64, 4, 32> | 331.58 | 114.84 | 2.88 | -| <64, 4, 64> | 699.38 | 262.69 | 2.66 | -| <64, 4, 128> | 1618.15 | 695.07 | 2.32 | -| <128, 1, 32> | 270.86 | 82.38 | 3.28 | -| <128, 1, 64> | 537.55 | 181.03 | 2.96 | -| <128, 1, 128> | 1183.11 | 442.73 | 2.67 | -| <128, 4, 32> | 433.38 | 165.23 | 2.62 | -| <128, 4, 64> | 928.87 | 410.96 | 2.26 | -| <128, 4, 128> | 2297.10 | 1175.40 | 1.95 | - - +#### Sampling performance on V100 and TensorFlow -### Decoding performance +* Performance on FP32 -We demonstrate the inference time of FasterTransformer in C++, TensorFlow and PyTorch, and compare to the performance of pure TensorFlow and PyTorch on T4 and V100. +| Batch Size | Topk/Topp | Precision | TF
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:---------:|:---------:|:-------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 4 | FP32 | 119 | 379 | 759 | 3.18 | 6.37 | +| 1 | 32 | FP32 | 103 | 368 | 739 | 3.57 | 7.17 | +| 1 | 0.75 | FP32 | 111 | 324 | 619 | 2.91 | 5.57 | +| 8 | 4 | FP32 | 491 | 1765 | 3475 | 3.59 | 7.07 | +| 8 | 32 | FP32 | 483 | 1637 | 3395 | 3.38 | 7.02 | +| 8 | 0.75 | FP32 | 460 | 1460 | 2645 | 3.17 | 5.75 | +| 128 | 4 | FP32 | 3387 | 9203 | 18165 | 2.71 | 5.36 | +| 128 | 32 | FP32 | 3380 | 8605 | 17541 | 2.54 | 5.18 | +| 128 | 0.75 | FP32 | 3194 | 6898 | 13925 | 2.15 | 4.35 | -For the benchmark of TensorFlow, we compare the performance of TensorFlow (TF), the performance of TensorFlow with FasterTransformer OP (FT-OP) and the performance of FasterTransformer on C++ (TF-CPP), and show the speedup of FT-OP and FT-CPP compare to the TensorFlow. +* Performance on FP16 -We do not demonstrate the performance of TensorFlow with XLA since we did not find that using XLA has obvious speedup. +| Batch Size | Topk/Topp | Precision | TF
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:---------:|:---------:|:-------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 4 | FP16 | 169 | 412 | 992 | 2.43 | 5.86 | +| 1 | 32 | FP16 | 167 | 376 | 970 | 2.25 | 5.80 | +| 1 | 0.75 | FP16 | 160 | 350 | 845 | 2.18 | 5.28 | +| 8 | 4 | FP16 | 739 | 1802 | 4620 | 2.43 | 6.25 | +| 8 | 32 | FP16 | 785 | 1754 | 4425 | 2.23 | 5.63 | +| 8 | 0.75 | FP16 | 715 | 1586 | 3634 | 2.21 | 5.08 | +| 128 | 4 | FP16 | 6217 | 11392 | 29409 | 1.83 | 4.73 | +| 128 | 32 | FP16 | 5937 | 10366 | 27995 | 1.74 | 4.71 | +| 128 | 0.75 | FP16 | 5129 | 8423 | 22094 | 1.64 | 4.30 | -For the benchmark of PyTorch, we compare the performance of PyTorch decoding with beam search (PyTorch), the performance of replacing the decoder of PyTorch by FasterTransformer (Decoder) and performance of FasterTransformer Decoding with beam search (Decoding), and show the speedup Decoder and Decoding compare to the PyTorch. Due to the dynamic property, it is hard to trace/script the PyTorch decoder/decoding model, so we only test on plain PyTorch. +#### Beamsearch performance on T4 and TensorFlow -The results of C++ and TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_decoding_performance.sh`. +* Performance on FP32 -The results of PyTorch were obtained by running the `../sample/pytorch/scripts/profile_decoder_decoding.sh`. +| Batch Size | Beam Width | Precision | TF
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:----------:|:---------:|:-------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 1 | FP32 | 40 | 151 | 599 | 3.77 | 14.97 | +| 1 | 4 | FP32 | 34 | 137 | 563 | 4.02 | 16.55 | +| 1 | 32 | FP32 | 37 | 91 | 330 | 2.45 | 8.91 | +| 8 | 1 | FP32 | 193 | 807 | 2868 | 4.18 | 14.86 | +| 8 | 4 | FP32 | 198 | 644 | 2205 | 3.25 | 11.13 | +| 8 | 32 | FP32 | 94 | 209 | 366 | 2.22 | 3.89 | +| 128 | 1 | FP32 | 1234 | 3420 | 10313 | 2.77 | 8.35 | +| 128 | 4 | FP32 | 677 | 1260 | 3114 | 1.86 | 4.59 | -In the experiments of decoding, we updated the following parameters: +* Performance on FP16 -* head_num = 8 -* size_per_head = 64 -* num_layers = 6 -* vocabulary_size = 30000 for TensorFlow sample codes, 31538 for PyTorch sample codes -* memory_hidden_dim = 512 +| Batch Size | Beam Width | Precision | TF
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:----------:|:---------:|:-------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 1 | FP16 | 57 | 175 | 786 | 3.07 | 13.78 | +| 1 | 4 | FP16 | 55 | 169 | 766 | 3.07 | 13.92 | +| 1 | 32 | FP16 | 45 | 94 | 465 | 2.08 | 10.33 | +| 8 | 1 | FP16 | 226 | 683 | 4077 | 3.02 | 18.03 | +| 8 | 4 | FP16 | 217 | 631 | 3440 | 2.90 | 15.85 | +| 8 | 32 | FP16 | 151 | 259 | 619 | 1.71 | 4.09 | +| 128 | 1 | FP16 | 2060 | 4474 | 21675 | 2.17 | 10.52 | +| 128 | 4 | FP16 | 1250 | 1948 | 8796 | 1.55 | 7.03 | -#### Decoding performance on T4 and TensorFlow +#### Sampling performance on T4 and TensorFlow * Performance on FP32 -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| -| <1, 1, 32> | 453.10 | 31.84 | 14.23 | 28.00 | 16.18 | -| <1, 1, 64> | 882.08 | 61.51 | 14.34 | 57.33 | 15.38 | -| <1, 1, 128> | 1843.03 | 126.54 | 14.56 | 122.76 | 15.01 | -| <1, 4, 32> | 471.63 | 40.71 | 11.58 | 36.44 | 12.94 | -| <1, 4, 64> | 937.28 | 79.41 | 11.80 | 75.54 | 12.40 | -| <1, 4, 128> | 1926.79 | 166.26 | 11.58 | 160.75 | 11.98 | -| <8, 1, 32> | 482.82 | 43.48 | 11.10 | 39.85 | 12.11 | -| <8, 1, 64> | 921.57 | 87.21 | 10.56 | 83.39 | 11.05 | -| <8, 1, 128> | 1894.78 | 184.38 | 10.27 | 183.43 | 10.32 | -| <8, 4, 32> | 515.76 | 56.47 | 9.13 | 53.63 | 9.61 | -| <8, 4, 64> | 1014.02 | 119.61 | 8.47 | 120.85 | 8.39 | -| <8, 4, 128> | 2020.41 | 277.44 | 7.28 | 300.16 | 6.73 | -| <32, 1, 32> | 534.25 | 56.06 | 9.52 | 53.65 | 9.95 | -| <32, 1, 64> | 1034.65 | 121.27 | 8.53 | 121.52 | 8.51 | -| <32, 1, 128> | 1966.53 | 285.25 | 6.89 | 300.35 | 6.54 | -| <32, 4, 32> | 640.24 | 154.65 | 4.13 | 154.34 | 4.14 | -| <32, 4, 64> | 1354.65 | 350.07 | 3.86 | 367.81 | 3.68 | -| <32, 4, 128> | 3027.38 | 859.86 | 3.52 | 947.46 | 3.19 | -| <64, 1, 32> | 553.85 | 86.66 | 6.39 | 85.61 | 6.46 | -| <64, 1, 64> | 1114.51 | 192.89 | 5.77 | 198.66 | 5.61 | -| <64, 1, 128> | 2318.32 | 472.83 | 4.90 | 512.98 | 4.51 | -| <64, 4, 32> | 825.52 | 285.46 | 2.89 | 289.26 | 2.85 | -| <64, 4, 64> | 1752.80 | 653.98 | 2.68 | 685.59 | 2.55 | -| <64, 4, 128> | 4390.23 | 1631.13 | 2.69 | 1798.83 | 2.44 | -| <128, 1, 32> | 620.29 | 151.94 | 4.08 | 153.28 | 4.04 | -| <128, 1, 64> | 1366.14 | 342.94 | 3.98 | 358.99 | 3.80 | -| <128, 1, 128> | 2987.18 | 868.05 | 3.44 | 945.11 | 3.16 | -| <128, 4, 32> | 1170.25 | 542.47 | 2.15 | 552.39 | 2.11 | -| <128, 4, 64> | 2760.15 | 1257.03 | 2.19 | 1334.39 | 2.06 | -| <128, 4, 128> | 7774.93 | 3155.91 | 2.46 | 3445.01 | 2.25 | +| Batch Size | Topk/Topp | Precision | TF
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:---------:|:---------:|:-------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 4 | FP32 | 49 | 201 | 584 | 4.10 | 11.91 | +| 1 | 32 | FP32 | 50 | 175 | 568 | 3.50 | 11.36 | +| 1 | 0.75 | FP32 | 48 | 156 | 494 | 3.25 | 10.29 | +| 8 | 4 | FP32 | 226 | 791 | 2753 | 3.50 | 12.18 | +| 8 | 32 | FP32 | 230 | 859 | 2643 | 3.73 | 11.49 | +| 8 | 0.75 | FP32 | 230 | 706 | 2225 | 3.06 | 9.67 | +| 128 | 4 | FP32 | 1443 | 3729 | 8822 | 2.58 | 6.11 | +| 128 | 32 | FP32 | 1372 | 3396 | 8694 | 2.47 | 6.33 | +| 128 | 0.75 | FP32 | 1259 | 2640 | 7127 | 2.09 | 5.66 | * Performance on FP16 -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| -| <1, 1, 32> | 396.28 | 34.38 | 11.52 | 26.66 | 14.86 | -| <1, 1, 64> | 768.43 | 63.88 | 12.02 | 56.44 | 13.61 | -| <1, 1, 128> | 1543.99 | 129.90 | 11.88 | 123.63 | 12.48 | -| <1, 4, 32> | 419.53 | 35.09 | 11.95 | 26.25 | 15.98 | -| <1, 4, 64> | 806.38 | 59.80 | 13.48 | 54.02 | 14.92 | -| <1, 4, 128> | 1570.90 | 123.67 | 12.70 | 115.83 | 13.56 | -| <8, 1, 32> | 410.31 | 36.86 | 11.13 | 26.83 | 15.29 | -| <8, 1, 64> | 795.15 | 63.40 | 12.54 | 58.65 | 13.55 | -| <8, 1, 128> | 1639.86 | 132.13 | 12.41 | 127.12 | 12.90 | -| <8, 4, 32> | 439.64 | 38.89 | 11.30 | 35.99 | 12.21 | -| <8, 4, 64> | 891.54 | 82.09 | 10.86 | 79.82 | 11.16 | -| <8, 4, 128> | 1766.03 | 182.58 | 9.67 | 193.54 | 9.12 | -| <32, 1, 32> | 466.24 | 40.58 | 11.48 | 35.76 | 13.03 | -| <32, 1, 64> | 886.57 | 82.15 | 10.79 | 80.28 | 11.04 | -| <32, 1, 128> | 1837.41 | 187.04 | 9.82 | 195.01 | 9.42 | -| <32, 4, 32> | 536.00 | 84.37 | 6.35 | 82.82 | 6.47 | -| <32, 4, 64> | 1116.74 | 189.16 | 5.90 | 198.95 | 5.61 | -| <32, 4, 128> | 2473.57 | 470.40 | 5.25 | 518.77 | 4.76 | -| <64, 1, 32> | 480.88 | 53.39 | 9.00 | 50.89 | 9.44 | -| <64, 1, 64> | 939.87 | 114.97 | 8.17 | 118.25 | 7.94 | -| <64, 1, 128> | 2051.09 | 280.67 | 7.30 | 305.32 | 6.71 | -| <64, 4, 32> | 668.45 | 143.41 | 4.66 | 144.53 | 4.62 | -| <64, 4, 64> | 1476.17 | 332.89 | 4.43 | 351.14 | 4.20 | -| <64, 4, 128> | 3282.27 | 860.21 | 3.81 | 966.68 | 3.39 | -| <128, 1, 32> | 587.50 | 80.61 | 7.28 | 80.79 | 7.27 | -| <128, 1, 64> | 1107.02 | 182.72 | 6.05 | 193.22 | 5.72 | -| <128, 1, 128> | 2635.13 | 467.93 | 5.63 | 518.73 | 5.07 | -| <128, 4, 32> | 996.88 | 265.51 | 3.75 | 271.80 | 3.66 | -| <128, 4, 64> | 2157.85 | 627.24 | 3.44 | 671.76 | 3.21 | -| <128, 4, 128> | 5389.81 | 1646.64 | 3.27 | 1848.24 | 2.91 | - -#### Decoding performance on V100 and TensorFlow +| Batch Size | Topk/Topp | Precision | TF
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:---------:|:---------:|:-------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 4 | FP16 | 70 | 211 | 765 | 3.01 | 10.92 | +| 1 | 32 | FP16 | 68 | 201 | 756 | 2.95 | 11.11 | +| 1 | 0.75 | FP16 | 65 | 163 | 658 | 2.50 | 10.12 | +| 8 | 4 | FP16 | 296 | 904 | 3821 | 3.05 | 12.90 | +| 8 | 32 | FP16 | 291 | 851 | 3929 | 2.92 | 13.50 | +| 8 | 0.75 | FP16 | 280 | 723 | 3168 | 2.58 | 11.31 | +| 128 | 4 | FP16 | 2649 | 4810 | 21185 | 1.81 | 7.99 | +| 128 | 32 | FP16 | 2337 | 4632 | 18966 | 1.98 | 8.11 | +| 128 | 0.75 | FP16 | 1937 | 3269 | 15599 | 1.68 | 8.05 | -* Performance of FP32 +### End to end translation performance on PyTorch -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| -| <1, 1, 32> | 247.70 | 20.99 | 11.80 | 19.17 | 12.92 | -| <1, 1, 64> | 495.89 | 43.63 | 11.36 | 39.93 | 12.41 | -| <1, 1, 128> | 936.57 | 90.46 | 10.35 | 87.20 | 10.74 | -| <1, 4, 32> | 234.78 | 30.85 | 7.61 | 28.12 | 8.34 | -| <1, 4, 64> | 464.19 | 54.83 | 8.46 | 52.79 | 8.79 | -| <1, 4, 128> | 909.90 | 117.46 | 7.74 | 113.13 | 8.04 | -| <8, 1, 32> | 231.98 | 28.18 | 8.23 | 25.61 | 9.05 | -| <8, 1, 64> | 457.38 | 56.72 | 8.06 | 53.44 | 8.55 | -| <8, 1, 128> | 923.71 | 121.91 | 7.57 | 117.66 | 7.85 | -| <8, 4, 32> | 249.10 | 31.72 | 7.85 | 29.34 | 8.49 | -| <8, 4, 64> | 503.95 | 65.72 | 7.66 | 64.22 | 7.84 | -| <8, 4, 128> | 1020.94 | 147.66 | 6.91 | 149.51 | 6.82 | -| <32, 1, 32> | 245.18 | 31.71 | 7.73 | 29.16 | 8.40 | -| <32, 1, 64> | 521.13 | 65.71 | 7.93 | 64.31 | 8.10 | -| <32, 1, 128> | 968.92 | 149.11 | 6.49 | 149.72 | 6.47 | -| <32, 4, 32> | 290.96 | 67.00 | 4.34 | 66.66 | 4.36 | -| <32, 4, 64> | 662.04 | 147.43 | 4.49 | 155.35 | 4.26 | -| <32, 4, 128> | 1445.38 | 352.77 | 4.09 | 382.38 | 3.77 | -| <64, 1, 32> | 267.80 | 42.61 | 6.28 | 42.18 | 6.34 | -| <64, 1, 64> | 573.75 | 93.68 | 6.12 | 94.01 | 6.10 | -| <64, 1, 128> | 1204.28 | 217.32 | 5.54 | 228.94 | 5.26 | -| <64, 4, 32> | 369.10 | 113.17 | 3.26 | 114.41 | 3.22 | -| <64, 4, 64> | 811.20 | 251.04 | 3.23 | 265.57 | 3.05 | -| <64, 4, 128> | 1896.34 | 615.58 | 3.08 | 687.73 | 2.75 | -| <128, 1, 32> | 300.77 | 67.01 | 4.48 | 66.01 | 4.55 | -| <128, 1, 64> | 619.74 | 150.08 | 4.12 | 151.31 | 4.09 | -| <128, 1, 128> | 1406.48 | 356.22 | 3.94 | 387.80 | 3.62 | -| <128, 4, 32> | 497.61 | 202.93 | 2.45 | 207.86 | 2.39 | -| <128, 4, 64> | 1194.74 | 463.58 | 2.57 | 496.50 | 2.40 | -| <128, 4, 128> | 3068.19 | 1135.37 | 2.70 | 1259.20 | 2.43 | - -* Performance of FP16 - -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| -| <1, 1, 32> | 179.29 | 22.79 | 7.86 | 19.90 | 9.00 | -| <1, 1, 64> | 424.71 | 46.31 | 9.17 | 42.07 | 10.09 | -| <1, 1, 128> | 800.49 | 106.68 | 7.50 | 102.70 | 7.79 | -| <1, 4, 32> | 215.21 | 22.99 | 9.36 | 20.42 | 10.53 | -| <1, 4, 64> | 426.36 | 47.33 | 9.00 | 42.67 | 9.99 | -| <1, 4, 128> | 842.32 | 105.93 | 7.95 | 105.07 | 8.01 | -| <8, 1, 32> | 218.83 | 22.45 | 9.74 | 20.29 | 10.78 | -| <8, 1, 64> | 429.64 | 46.16 | 9.30 | 42.66 | 10.07 | -| <8, 1, 128> | 827.80 | 96.64 | 8.56 | 94.76 | 8.73 | -| <8, 4, 32> | 228.45 | 25.30 | 9.02 | 23.36 | 9.77 | -| <8, 4, 64> | 434.26 | 51.36 | 8.45 | 49.95 | 8.69 | -| <8, 4, 128> | 879.69 | 113.05 | 7.78 | 115.80 | 7.59 | -| <32, 1, 32> | 224.73 | 25.34 | 8.86 | 23.12 | 9.72 | -| <32, 1, 64> | 447.28 | 51.98 | 8.60 | 50.01 | 8.94 | -| <32, 1, 128> | 887.31 | 114.14 | 7.77 | 114.74 | 7.73 | -| <32, 4, 32> | 249.40 | 43.55 | 5.72 | 43.17 | 5.77 | -| <32, 4, 64> | 549.04 | 96.69 | 5.67 | 101.74 | 5.39 | -| <32, 4, 128> | 1182.18 | 225.50 | 5.24 | 248.09 | 4.76 | -| <64, 1, 32> | 227.12 | 30.99 | 7.32 | 29.93 | 7.58 | -| <64, 1, 64> | 494.82 | 67.05 | 7.37 | 67.49 | 7.33 | -| <64, 1, 128> | 1000.46 | 154.54 | 6.47 | 160.94 | 6.21 | -| <64, 4, 32> | 304.52 | 68.84 | 4.42 | 69.72 | 4.36 | -| <64, 4, 64> | 666.90 | 154.89 | 4.30 | 164.80 | 4.04 | -| <64, 4, 128> | 1494.30 | 373.57 | 4.00 | 425.44 | 3.51 | -| <128, 1, 32> | 252.69 | 43.08 | 5.86 | 42.74 | 5.91 | -| <128, 1, 64> | 535.56 | 93.53 | 5.72 | 97.05 | 5.51 | -| <128, 1, 128> | 1134.44 | 225.94 | 5.02 | 245.81 | 4.61 | -| <128, 4, 32> | 410.80 | 114.56 | 3.58 | 118.16 | 3.47 | -| <128, 4, 64> | 934.86 | 263.50 | 3.54 | 283.36 | 3.29 | -| <128, 4, 128> | 2236.95 | 653.69 | 3.42 | 746.66 | 2.99 | +We demonstrate the throughput of PyTorch, FT Decoder and FT Decoding for end to end translation. Here, PyTorch means that the program fully runs on PyTorch. FT Decoder means that we replace the decoder transformer layer by FasterTransformer. FT Decoding means that we replace the whole procedure of decoder by FasterTransformer. - +This benchmark were obtained by running the `../sample/pytorch/scripts/profile_decoder_decoding.sh`. -#### Decoder and decoding performance on T4 and PyTorch +In this benchmark, we updated the following parameters: -* Performance on FP32 +* head_num = 8 for both encoder and decoder +* size_per_head = 64 for both encoder and decoder +* num_layers = 6 for both encoder and decoder +* vocabulary_size = 31538 +* max_seq_len = 128 -| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | -|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| -| <1, 32, 1> | 484.75 | 144.20 | 29.08 | 3.36 | 16.66 | -| <1, 64, 1> | 964.91 | 295.16 | 57.97 | 3.26 | 16.64 | -| <1, 128, 1> | 2482.00 | 716.21 | 118.97 | 3.46 | 20.86 | -| <8, 32, 1> | 640.09 | 198.37 | 41.27 | 3.22 | 15.50 | -| <8, 64, 1> | 1026.29 | 326.66 | 86.32 | 3.14 | 11.88 | -| <8, 128, 1> | 2077.31 | 683.36 | 180.75 | 3.03 | 11.49 | -| <32, 32, 1> | 539.02 | 182.05 | 55.35 | 2.96 | 9.73 | -| <32, 64, 1> | 1060.14 | 368.43 | 121.32 | 2.87 | 8.73 | -| <32, 128, 1> | 2198.63 | 822.78 | 294.63 | 2.67 | 7.46 | -| <64, 32, 1> | 544.38 | 216.06 | 87.28 | 2.51 | 6.23 | -| <64, 64, 1> | 1359.49 | 483.68 | 196.35 | 2.81 | 6.92 | -| <64, 128, 1> | 2409.26 | 1239.34 | 487.91 | 1.94 | 4.93 | -| <128, 32, 1> | 705.29 | 321.99 | 157.30 | 2.19 | 4.48 | -| <128, 64, 1> | 1490.15 | 765.70 | 359.43 | 1.94 | 4.14 | -| <128, 128, 1> | 3328.75 | 2032.92 | 900.86 | 1.63 | 3.69 | -| <1, 32, 4> | 519.91 | 170.90 | 37.49 | 3.04 | 13.86 | -| <1, 64, 4> | 1022.17 | 329.85 | 75.47 | 3.09 | 13.54 | -| <1, 128, 4> | 2087.35 | 654.85 | 156.97 | 3.18 | 13.29 | -| <8, 32, 4> | 653.81 | 212.86 | 55.83 | 3.07 | 11.71 | -| <8, 64, 4> | 1056.50 | 363.22 | 121.80 | 2.90 | 8.67 | -| <8, 128, 4> | 2187.94 | 842.20 | 298.90 | 2.59 | 7.31 | -| <32, 32, 4> | 588.74 | 320.21 | 160.45 | 1.83 | 3.66 | -| <32, 64, 4> | 1280.28 | 773.54 | 363.31 | 1.65 | 3.52 | -| <32, 128, 4> | 2869.27 | 2116.43 | 916.30 | 1.35 | 3.13 | -| <64, 32, 4> | 694.86 | 530.53 | 297.42 | 1.30 | 2.33 | -| <64, 64, 4> | 1777.26 | 1331.30 | 687.77 | 1.33 | 2.58 | -| <64, 128, 4> | 4769.54 | 3960.06 | 1740.75 | 1.20 | 2.73 | -| <128, 32, 4> | 990.83 | 975.95 | 576.75 | 1.01 | 1.71 | -| <128, 64, 4> | 2794.30 | 2610.29 | 1310.25 | 1.07 | 2.13 | +#### Beamsearch performance on V100 and PyTorch -* Performance on FP16 +* Perofrmance on FP32 -| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | -|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| -| <1, 32, 1> | 636.17 | 187.04 | 28.32 | 3.40 | 22.46 | -| <1, 64, 1> | 1030.81 | 313.46 | 53.82 | 3.28 | 19.15 | -| <1, 128, 1> | 2029.57 | 612.47 | 121.08 | 3.31 | 16.76 | -| <8, 32, 1> | 546.08 | 163.20 | 34.43 | 3.34 | 15.86 | -| <8, 64, 1> | 1112.37 | 315.34 | 73.64 | 3.52 | 15.10 | -| <8, 128, 1> | 2237.78 | 638.65 | 160.04 | 3.50 | 13.98 | -| <32, 32, 1> | 546.68 | 171.72 | 40.91 | 3.18 | 13.36 | -| <32, 64, 1> | 1374.25 | 342.27 | 89.34 | 4.01 | 15.38 | -| <32, 128, 1> | 2219.99 | 712.94 | 206.78 | 3.11 | 10.73 | -| <64, 32, 1> | 557.29 | 196.28 | 60.96 | 2.83 | 9.14 | -| <64, 64, 1> | 1127.56 | 423.53 | 133.64 | 2.66 | 8.43 | -| <64, 128, 1> | 2431.01 | 1024.73 | 324.01 | 2.37 | 7.50 | -| <128, 32, 1> | 604.19 | 260.15 | 100.36 | 2.32 | 6.02 | -| <128, 64, 1> | 1252.95 | 594.85 | 228.57 | 2.10 | 5.48 | -| <128, 128, 1> | 2727.85 | 1526.56 | 567.00 | 1.78 | 4.81 | -| <1, 32, 4> | 568.26 | 165.05 | 33.89 | 3.44 | 16.76 | -| <1, 64, 4> | 1099.60 | 321.63 | 68.78 | 3.41 | 15.98 | -| <1, 128, 4> | 2177.06 | 630.75 | 146.24 | 3.45 | 14.88 | -| <8, 32, 4> | 558.22 | 173.52 | 41.02 | 3.21 | 13.60 | -| <8, 64, 4> | 1105.78 | 343.64 | 88.14 | 3.21 | 12.54 | -| <8, 128, 4> | 2240.45 | 728.21 | 205.81 | 3.07 | 10.88 | -| <32, 32, 4> | 606.68 | 267.60 | 104.44 | 2.26 | 5.80 | -| <32, 64, 4> | 1254.07 | 606.08 | 237.79 | 2.06 | 5.27 | -| <32, 128, 4> | 2741.17 | 1553.44 | 580.81 | 1.76 | 4.71 | -| <64, 32, 4> | 669.47 | 399.96 | 192.19 | 1.67 | 3.48 | -| <64, 64, 4> | 1424.02 | 966.43 | 436.73 | 1.47 | 3.26 | -| <64, 128, 4> | 3638.59 | 2843.25 | 1091.42 | 1.27 | 3.33 | -| <128, 32, 4> | 968.40 | 690.89 | 369.87 | 1.40 | 2.61 | -| <128, 64, 4> | 2087.75 | 1808.63 | 838.92 | 1.15 | 2.48 | -| <128, 128, 4> | 6735.41 | 5440.68 | 2082.84 | 1.23 | 3.23 | - -#### Decoder and decoding performance on V100 and PyTorch +| Batch Size | Beam Width | Precision | PyTorch
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:----------:|:---------:|:------------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 1 | FP32 | 92 | 277 | 699 | 3.00 | 7.56 | +| 1 | 4 | FP32 | 80 | 226 | 703 | 2.82 | 8.76 | +| 1 | 32 | FP32 | 69 | 217 | 471 | 3.12 | 6.76 | +| 8 | 1 | FP32 | 385 | 1232 | 3225 | 3.20 | 8.37 | +| 8 | 4 | FP32 | 352 | 1121 | 2756 | 3.18 | 7.81 | +| 8 | 32 | FP32 | 262 | 465 | 950 | 1.77 | 3.62 | +| 128 | 1 | FP32 | 2968 | 6213 | 12848 | 2.09 | 4.32 | +| 128 | 4 | FP32 | 1953 | 2447 | 6759 | 1.25 | 3.46 | -* Performance on FP32 +* Performance on FP16 -| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | -|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| -| <1, 32, 1> | 353.90 | 103.39 | 19.72 | 3.42 | 17.94 | -| <1, 64, 1> | 698.88 | 212.27 | 40.61 | 3.29 | 17.20 | -| <1, 128, 1> | 1449.20 | 441.20 | 79.19 | 3.28 | 18.30 | -| <8, 32, 1> | 439.07 | 139.12 | 27.43 | 3.15 | 16.00 | -| <8, 64, 1> | 761.94 | 237.07 | 55.40 | 3.21 | 13.75 | -| <8, 128, 1> | 1731.31 | 535.99 | 117.83 | 3.23 | 14.69 | -| <32, 32, 1> | 373.02 | 124.94 | 30.53 | 2.98 | 12.21 | -| <32, 64, 1> | 771.97 | 250.84 | 66.12 | 3.07 | 11.67 | -| <32, 128, 1> | 1563.37 | 527.23 | 147.27 | 2.96 | 10.61 | -| <64, 32, 1> | 391.65 | 166.63 | 43.54 | 2.35 | 8.99 | -| <64, 64, 1> | 763.75 | 347.91 | 95.53 | 2.19 | 7.99 | -| <64, 128, 1> | 1626.91 | 734.35 | 225.06 | 2.21 | 7.22 | -| <128, 32, 1> | 399.32 | 205.76 | 65.84 | 1.94 | 6.06 | -| <128, 64, 1> | 845.62 | 428.30 | 147.87 | 1.97 | 5.71 | -| <128, 128, 1> | 1780.45 | 1061.66 | 362.33 | 1.67 | 4.91 | -| <1, 32, 4> | 361.21 | 113.60 | 29.08 | 3.17 | 12.42 | -| <1, 64, 4> | 733.17 | 220.84 | 52.21 | 3.31 | 14.04 | -| <1, 128, 4> | 1489.75 | 467.02 | 125.59 | 3.18 | 11.86 | -| <8, 32, 4> | 382.98 | 124.76 | 30.43 | 3.06 | 12.58 | -| <8, 64, 4> | 768.14 | 248.43 | 64.50 | 3.09 | 11.90 | -| <8, 128, 4> | 1535.88 | 532.08 | 149.88 | 2.88 | 10.24 | -| <32, 32, 4> | 401.86 | 196.38 | 69.34 | 2.04 | 5.79 | -| <32, 64, 4> | 842.37 | 435.26 | 151.97 | 1.93 | 5.54 | -| <32, 128, 4> | 1758.36 | 1076.28 | 367.99 | 1.63 | 4.77 | -| <64, 32, 4> | 433.80 | 283.74 | 114.21 | 1.52 | 3.79 | -| <64, 64, 4> | 955.72 | 698.55 | 256.37 | 1.36 | 3.72 | -| <64, 128, 4> | 2137.94 | 1777.37 | 642.46 | 1.20 | 3.32 | -| <128, 32, 4> | 510.07 | 456.99 | 213.86 | 1.11 | 2.38 | -| <128, 64, 4> | 1140.04 | 1192.74 | 485.95 | .95 | 2.34 | +| Batch Size | Beam Width | Precision | PyTorch
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:----------:|:---------:|:------------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 1 | FP16 | 78 | 267 | 967 | 3.40 | 12.39 | +| 1 | 4 | FP16 | 76 | 251 | 868 | 3.29 | 11.39 | +| 1 | 32 | FP16 | 70 | 217 | 635 | 3.10 | 9.07 | +| 8 | 1 | FP16 | 357 | 1242 | 4508 | 3.47 | 12.61 | +| 8 | 4 | FP16 | 336 | 886 | 3769 | 2.63 | 11.20 | +| 8 | 32 | FP16 | 265 | 575 | 1454 | 2.17 | 5.48 | +| 128 | 1 | FP16 | 3193 | 7396 | 19264 | 2.31 | 6.03 | +| 128 | 4 | FP16 | 2141 | 3141 | 12609 | 1.46 | 5.88 | + + +#### Beamsearch performance on T4 and PyTorch + +* Perofrmance on FP32 + +| Batch Size | Beam Width | Precision | PyTorch
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:----------:|:---------:|:------------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 1 | FP32 | 62 | 179 | 566 | 2.85 | 8.99 | +| 1 | 4 | FP32 | 56 | 158 | 535 | 2.79 | 9.46 | +| 1 | 32 | FP32 | 47 | 144 | 312 | 3.06 | 6.62 | +| 8 | 1 | FP32 | 259 | 764 | 2418 | 2.94 | 9.30 | +| 8 | 4 | FP32 | 239 | 711 | 1914 | 2.97 | 7.99 | +| 8 | 32 | FP32 | 140 | 183 | 358 | 1.30 | 2.54 | +| 128 | 1 | FP32 | 1803 | 2885 | 6400 | 1.60 | 3.54 | +| 128 | 4 | FP32 | 690 | 836 | 2519 | 1.21 | 3.64 | * Performance on FP16 -| | PyTorch (ms) | Decoder (ms) | Decoding (ms) | Decoder Speedup | Decoding Speedup | -|:-----------------------:|:------:|:------:|:------:|:---------:|:---------:| -| <1, 32, 1> | 364.93 | 104.67 | 23.59 | 3.48 | 15.46 | -| <1, 64, 1> | 730.63 | 219.29 | 48.02 | 3.33 | 15.21 | -| <1, 128, 1> | 1448.80 | 435.08 | 90.06 | 3.32 | 16.08 | -| <8, 32, 1> | 396.70 | 113.47 | 28.43 | 3.49 | 13.95 | -| <8, 64, 1> | 766.96 | 213.44 | 58.41 | 3.59 | 13.13 | -| <8, 128, 1> | 1508.97 | 430.11 | 123.92 | 3.50 | 12.17 | -| <32, 32, 1> | 380.00 | 113.32 | 30.81 | 3.35 | 12.33 | -| <32, 64, 1> | 755.43 | 230.70 | 56.28 | 3.27 | 13.42 | -| <32, 128, 1> | 1592.17 | 481.88 | 140.00 | 3.30 | 11.37 | -| <64, 32, 1> | 385.02 | 150.23 | 36.38 | 2.56 | 10.58 | -| <64, 64, 1> | 1006.94 | 352.55 | 77.56 | 2.85 | 12.98 | -| <64, 128, 1> | 1647.93 | 669.11 | 174.38 | 2.46 | 9.45 | -| <128, 32, 1> | 393.47 | 172.10 | 49.39 | 2.28 | 7.96 | -| <128, 64, 1> | 846.32 | 371.34 | 109.92 | 2.27 | 7.69 | -| <128, 128, 1> | 1812.89 | 892.29 | 260.72 | 2.03 | 6.95 | -| <1, 32, 4> | 403.72 | 111.89 | 28.33 | 3.60 | 14.25 | -| <1, 64, 4> | 758.80 | 215.31 | 58.97 | 3.52 | 12.86 | -| <1, 128, 4> | 1565.94 | 431.89 | 113.51 | 3.62 | 13.79 | -| <8, 32, 4> | 388.91 | 117.17 | 31.56 | 3.31 | 12.32 | -| <8, 64, 4> | 768.24 | 232.11 | 61.85 | 3.30 | 12.42 | -| <8, 128, 4> | 1618.71 | 497.68 | 136.25 | 3.25 | 11.88 | -| <32, 32, 4> | 415.84 | 183.10 | 51.08 | 2.27 | 8.14 | -| <32, 64, 4> | 874.10 | 390.93 | 112.19 | 2.23 | 7.79 | -| <32, 128, 4> | 1806.96 | 876.53 | 255.26 | 2.06 | 7.07 | -| <64, 32, 4> | 453.94 | 234.66 | 84.20 | 1.93 | 5.39 | -| <64, 64, 4> | 948.13 | 517.52 | 185.68 | 1.83 | 5.10 | -| <64, 128, 4> | 2071.99 | 1333.14 | 446.57 | 1.55 | 4.63 | -| <128, 32, 4> | 486.71 | 349.62 | 146.36 | 1.39 | 3.32 | -| <128, 64, 4> | 1084.80 | 808.79 | 330.19 | 1.34 | 3.28 | -| <128, 128, 4> | 2638.70 | 2248.28 | 800.58 | 1.17 | 3.29 | - -#### TensorFlow performance on translation - -We test with batch_size 128, beam width 4 on V100. - -| Type | tokens per seconds | BLEU | -|:----:|:------------------:|:----:| -| TensorFlow, beam search, FP32 | 2137 | BLEU 26.29 | -| Decoder, beam search, FP32 | 6473 | BLEU 26.29 | -| Decoding, beam search, FP32 | 8513 | BLEU 26.31 | -| TensorFlow, sampling, FP32 | 4178 | BLEU 25.79 | -| Decoder, sampling, FP32 | 10781 | BLEU 25.79 | -| Decoding, sampling, FP32 | 16524 | BLEU 25.79 | -| TensorFlow, beam search, FP16 | 2949 | BLEU 26.31 | -| Decoder, beam search, FP16 | 8682 | BLEU 26.30 | -| Decoding, beam search, FP16 | 12746 | BLEU 26.33 | -| TensorFlow, sampling, FP16 | 6968 | BLEU 25.83 | -| Decoder, sampling, FP16 | 13773 | BLEU 25.80 | -| Decoding, sampling, FP16 | 26718 | BLEU 25.82 | - -#### PyTorch performance on translation - -batch size 128, beam width 4, max_seq_len 32, beam search algorithm on V100: - -| Type | tokens per seconds | BLEU | -|:----:|:------------------:|:----:| -| PyTorch, FP32 | 2462 | BLEU 24.1 | -| Decoder, FP32 | 3358 | BLEU 24.1 | -| Decoding, FP32 | 8959 | BLEU 24.1 | -| PyTorch, FP16 | 4019 | BLEU 24.1 | -| Decoder, FP16 | 4377 | BLEU 24.1 | -| Decoding, FP16 | 15048 | BLEU 24.1 | - -#### GPT-2 performance on V100 and TensorFlow - -* Performance on 124M GPT-2 model, top 1 sampling - -We use the source codes of [here](https://github.com/openai/gpt-2) as the baseline, and compare to FasterTransformer TensorFlow OP and CPP under FP16. - -| | TF (ms) | FT-OP (ms) | FT-OP Speedup | FT-CPP (ms) | FT-CPP Speedup | -|:---------------------------------:|:-------:|:----------:|:-------------:|:-----------:|:--------------:| -| <1, 128> | 1887.46 | 171.04 | 11.03 | 151.12 | 12.49 | -| <1, 1024> | 15734.21 | 1676.85 | 9.38 | 1651.21 | 9.53 | -| <8, 128> | 1756.14 | 180.87 | 9.71 | 173.97 | 10.09 | -| <8, 1024> | 17227.00 | 1917.78 | 8.98 | 1906.97 | 9.03 | -| <32, 128> | 2275.04 | 214.84 | 10.59 | 204.82 | 11.11 | -| <32, 1024> | 31674.81 | 2462.69 | 12.86 | 2440.94 | 12.98 | +| Batch Size | Beam Width | Precision | PyTorch
Throughput (token/sec) | FT Decoder
Throughput (token/sec) | FT Decoding
Throughput (token/sec) | FT Decoder
Speedup | FT Decoding
Speedup | +|:----------:|:----------:|:---------:|:------------------------------------:|:---------------------------------------:|:----------------------------------------:|:------------------------:|:-------------------------:| +| 1 | 1 | FP16 | 60 | 176 | 774 | 2.93 | 12.81 | +| 1 | 4 | FP16 | 55 | 170 | 699 | 3.08 | 12.68 | +| 1 | 32 | FP16 | 46 | 147 | 468 | 3.17 | 10.06 | +| 8 | 1 | FP16 | 254 | 832 | 3389 | 3.27 | 13.32 | +| 8 | 4 | FP16 | 237 | 759 | 2981 | 3.19 | 12.53 | +| 8 | 32 | FP16 | 164 | 256 | 636 | 1.56 | 3.87 | +| 128 | 1 | FP16 | 2035 | 4000 | 10836 | 1.96 | 5.32 | +| 128 | 4 | FP16 | 977 | 1192 | 6369 | 1.21 | 6.51 | + diff --git a/docs/encoder_guide.md b/docs/encoder_guide.md index ad7dffa61..c0dba089b 100644 --- a/docs/encoder_guide.md +++ b/docs/encoder_guide.md @@ -14,14 +14,14 @@ The FasterTransformer Encoder contains the optimized BERT model, Effective Faste - [Encoder process](#encoder-process) - [Performance](#performance) - [Encoder performance](#encoder-performance) - - [Encoder performance on T4 and cpp](#encoder-performance-on-t4-and-cpp) + - [Encoder performance on A100 and TensorFlow](#encoder-performance-on-a100-and-tensorflow) - [Encoder performance on T4 and TensorFlow](#encoder-performance-on-t4-and-tensorflow) - [Encoder performance on V100 and TensorFlow](#encoder-performance-on-v100-and-tensorflow) + - [Encoder performance comparison between T4, V100, A100 and A100 with MIG mode on TensorFlow](#encoder-performance-comparison-between-t4-v100-a100-and-a100-with-mig-mode-on-tensorflow) + - [Encoder performance comparison between different features on T4 and TensorFlow](#encoder-performance-comparison-between-different-features-on-t4-and-tensorflow) + - [Encoder performance on A100 and PyTorch](#encoder-performance-on-a100-and-pytorch) - [Encoder performance on T4 and PyTorch](#encoder-performance-on-t4-and-pytorch) - [Encoder performance on V100 and PyTorch](#encoder-performance-on-v100-and-pytorch) - - [Effective FasterTransformer performance](#effective-fastertransformer-performance) - - [Performance on TensorFlow](#performance-on-tensorflow) - - [Effective FasterTransformer performance on PyTorch](#effective-fastertransformer-performance-on-pytorch) - [Performance on BERT Applications: SQuAD MRPC](#performance-on-bert-applications-squad-mrpc) - [Performance of TensorFlow](#performance-of-tensorflow) - [Performance of PyTorch](#performance-of-pytorch) @@ -32,7 +32,7 @@ The FasterTransformer Encoder contains the optimized BERT model, Effective Faste The following configurations are supported in the FasterTransformer encoder. - Batch size (B1): smaller or equal to 4096 -- Sequence length (S): smaller or equal to 1024. For INT8 data type, sequence length should be a multiple of 32. +- Sequence length (S): smaller or equal to 1024. For INT8 mode=1, S should be a multiple of 32 when S > 384. - Head number (H) and size per head (N): - 16 heads * 64 per heads - 12 heads * 64 per heads @@ -41,7 +41,7 @@ The following configurations are supported in the FasterTransformer encoder. - Data type: FP32, FP16 and INT8 - Any number layer (N1) if the memory is enough -In the FasterTransformer v1.0, we provide a highly optimized BERT-equivalent encoder model. Next, based on the idea of [Effective Transformer](https://github.com/bytedance/effective_transformer), we further optimize BERT inference by removing the useless padding in FasterTransformer v2.1 and provide the Effective FasterTransformer. In FasterTransformer v3.0, we provide INT8 quantization inference to get better performance. In FasterTransformer v3.1, we optimize the INT8 kernels to improve the performance of INT8 inference, and integrate the multi-head attention of TensorRT plugin into FasterTransformer. The following graph demonstrates the flow chart of these optimization, except INT8. +In the FasterTransformer v1.0, we provide a highly optimized BERT-equivalent encoder model. Next, based on the idea of [Effective Transformer](https://github.com/bytedance/effective_transformer), we further optimize BERT inference by removing the useless padding in FasterTransformer v2.1 and provide the Effective FasterTransformer. In FasterTransformer v3.0, we provide INT8 quantization inference to get better performance. In FasterTransformer v3.1, we optimize the INT8 kernels to improve the performance of INT8 inference, and integrate the multi-head attention of TensorRT plugin into FasterTransformer. In FasterTransformer v4.0, we add the multi-head attention kernel to support FP16 on V100 and INT8 on T4, A100. The following graph demonstrates the flow chart of these optimization, except INT8.
Fig. 1 Flowchart of encoder.
@@ -83,7 +83,7 @@ Note that S2 means that the total sequence length after removing padd Besides, notice that the multi-head attention kernel from TensorRT is powerful but have some limitation. First, this kernel requires Turing or new GPU and the size per head must be 64. When the conditions are not satisfied, we use original multi-head attention implementation of FasterTransformer. Second, it requires an additional sequence length offset as we describe above. When the input has padding, the shape of the sequence length offset is \[2 x B1 + 1 \]. Assume there are there sentences with sequence length s1, s2 and s3, and the sequence length after padding is S. Then the sequence length offset is \[0, s1, S, s2 + S, 2 x S, 2 x S + s3, 3 x S\]. On the other hand, when we remove the padding, the shape of the sequence length offset is \[B1 + 1\], and the sequence length offset is \[0, s1, s1 + s2, s1 + s2 + s3 \]. Namely, the sequence length offset records the sequence length for each sentence. When we have padding, we view the padding as some independent sentences. -In FasterTransformer v3.1, we implement two pipelines of INT8 inference, as shown in Fig. 3.. For int8_mode == 1 (int8v1), we don't quantize residual connection, use int32 as the output of int8 gemms and use per-channel quantization for weights. For int8_mode == 2 (int8v2), we quantize residual connection, use int8 as the output of int8 gemms and use per-tensor quantization for weights. Generally speaking, int8_mode == 1 will have higher accuracy while int8_mode == 2 will have better performance. +In FasterTransformer v4.0, we implement two pipelines of INT8 inference, as shown in Fig. 3.. For int8_mode == 1 (int8v1), we don't quantize residual connection, use int32 as the output of int8 gemms and use per-channel quantization for weights. For int8_mode == 2 (int8v2), we quantize residual connection, use int8 as the output of int8 gemms and use per-tensor quantization for weights. Generally speaking, int8_mode == 1 will have higher accuracy while int8_mode == 2 will have better performance.
Fig. 3 Workflow of int8 inference.
@@ -107,7 +107,6 @@ The following section lists the requirements to use FasterTransformer Encoder. - Python 3 is recommended because some features are not supported in python 2 - Tensorflow 1.13 or 1.14 or 1.15 - PyTorch >= 1.4.0 -- TensorRT 5 or newer version These components are readily available within the NGC TensorFlow/PyTorch Docker image below. @@ -238,15 +237,17 @@ For those unable to use the NGC container, to set up the required environment or ```bash #For int8_mode == 1 Device Tesla T4 + Device Tesla T4 before allocate free 14.64 GB total 14.76 GB After allocate free 14.62 GB used 0.13 GB total 14.76 GB - [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 8.14 ms + [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 7.49 ms ( 50 iterations) #For int8_mode == 2 Device Tesla T4 + Device Tesla T4 before allocate free 14.64 GB total 14.76 GB After allocate free 14.62 GB used 0.13 GB total 14.76 GB - [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 5.52 ms + [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 4.79 ms ( 50 iterations) ``` 1.5 Run Effective FasterTransformer under FP32 on C++ @@ -286,16 +287,18 @@ For those unable to use the NGC container, to set up the required environment or ```bash #For int8_mode == 1 Device Tesla T4 + Device Tesla T4 before allocate free 14.64 GB total 14.76 GB After allocate free 14.62 GB used 0.14 GB total 14.76 GB - [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 4.81 ms + [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 4.06 ms ( 50 iterations) #For int8_mode == 2 Device Tesla T4 + Device Tesla T4 before allocate free 14.64 GB total 14.76 GB After allocate free 14.62 GB used 0.14 GB total 14.76 GB - [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 3.49 ms + [INFO] batch_size 32 seq_len 32 layer 12 FT-CPP-time 2.69 ms ( 50 iterations) ``` 2. Run FasterTransformer on TensorFlow (on T4 GPU) @@ -354,7 +357,7 @@ For those unable to use the NGC container, to set up the required environment or [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 10.13 ms ``` - 2.3 Run FasterTransformer encoder under INT8 on TensorFlow + 2.3 Run FasterTransformer and Effective FasterTransformer encoder under INT8 on TensorFlow To use the INT8 in TensorFlow, we only need to set the `--int8_mode 1` or `--int8_mode 2` like following: @@ -391,17 +394,25 @@ For those unable to use the NGC container, to set up the required environment or ```bash #For int8_mode == 1 [INFO] Encoder TF v.s. FT with tensor input Cross check False - [INFO] Max diff 4.1328125 + [INFO] Max diff 4.19140625 + [INFO] min diff 0.0 + [INFO] Encoder TF v.s. EFF-FT with tensor input Cross check False + [INFO] Max diff 4.19140625 [INFO] min diff 0.0 - [INFO] batch_size 32 max_seq_len 32 12 layer TF-time 14.01 ms - [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 9.85 ms + [INFO] batch_size 32 max_seq_len 32 precision FP16 12 layer TF-while-time 12.64 ms ( 50 iterations) + [INFO] batch_size 32 max_seq_len 32 precision INT8-v1 12 layer FT-OP-while-time 7.68 ms ( 50 iterations) + [INFO] batch_size 32 max_seq_len 32 precision INT8-v1 12 layer EFF-OP-while-time 4.34 ms ( 50 iterations) #For int8_mode == 2 [INFO] Encoder TF v.s. FT with tensor input Cross check False [INFO] Max diff 6.06640625 [INFO] min diff 0.0 - [INFO] batch_size 32 max_seq_len 32 12 layer TF-time 14.01 ms - [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 7.55 ms + [INFO] Encoder TF v.s. EFF-FT with tensor input Cross check False + [INFO] Max diff 6.06640625 + [INFO] min diff 0.0 + [INFO] batch_size 32 max_seq_len 32 precision FP16 12 layer TF-while-time 12.47 ms ( 50 iterations) + [INFO] batch_size 32 max_seq_len 32 precision INT8-v2 12 layer FT-OP-while-time 4.94 ms ( 50 iterations) + [INFO] batch_size 32 max_seq_len 32 precision INT8-v2 12 layer EFF-OP-while-time 2.93 ms ( 50 iterations) ``` Note: since we do not use the correct scales for quantization in this test, the Cross Check between TF and FT should fail. @@ -435,61 +446,6 @@ For those unable to use the NGC container, to set up the required environment or [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 24.17 ms ``` - 2.5 Run Effective FasterTransformer under INT8 on TensorFlow - - To use the Effective FasterTransformer under INT8, we only need to set the `--remove_padding True` and `--int8_mode 1`/`--int8_mode 2` like following: - - ```bash - #For int8_mode == 1 - ./bin/encoder_gemm 32 32 12 64 1 1 - python tensorflow/encoder_sample.py \ - --batch_size 32 \ - --max_seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp16 \ - --test_time 1 \ - --remove_padding True \ - --avg_seq_len 16 \ - --int8_mode 1 \ - --allow_gemm_test False - - #For int8_mode == 2 - ./bin/encoder_gemm 32 32 12 64 1 2 - python tensorflow/encoder_sample.py \ - --batch_size 32 \ - --max_seq_len 32 \ - --head_number 12 \ - --size_per_head 64 \ - --num_layer 12 \ - --data_type fp16 \ - --test_time 1 \ - --remove_padding True \ - --avg_seq_len 16 \ - --int8_mode 2 \ - --allow_gemm_test False - ``` - - The outputs should be like to the following: - - ```bash - #For int8_mode == 1 - [INFO] Encoder TF v.s. FT with tensor input Cross check False - [INFO] Max diff 4.1328125 - [INFO] min diff 0.0 - [INFO] batch_size 32 max_seq_len 32 12 layer TF-time 13.96 ms - [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 7.32 ms - - #For int8_mode == 2 - [INFO] Encoder TF v.s. FT with tensor input Cross check False - [INFO] Max diff 6.06640625 - [INFO] min diff 0.0 - [INFO] batch_size 32 max_seq_len 32 12 layer TF-time 13.94 ms - [INFO] batch_size 32 max_seq_len 32 12 layer FT-OP-tensor-time 6.15 ms - ``` - Note: since we do not use the correct scales for quantization in this test, the Cross Check between TF and FT should fail. - 2.6 Run FasterTransformer for GLUE dataset This subsection demonstrates how to integrate the FasterTransformer in TensorFlow, and evaluate the accuracy of FasterTransformer on GLUE dataset. To evaluate on GLUE dataset, it requires the repo of [BERT](https://github.com/google-research/bert). @@ -505,7 +461,7 @@ For those unable to use the NGC container, to set up the required environment or 2.6.2 Download the GLUE MRPC dataset. Note that the file `download_glue_data.py` can only executed under python3. ```bash - wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py + wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/1502038877f6a88c225a34450793fbc3ea87eaba/download_glue_data.py python download_glue_data.py --tasks MRPC ``` @@ -787,19 +743,19 @@ For those unable to use the NGC container, to set up the required environment or 2.7.7 Evaluate the accuracy of FasterTransformer under INT8 - Please refer to the directory `bert-quantization\bert-tf-quantization` first for how to get a quantized model. In `section 2.7.7` and `section 2.7.8`, to keep consistent with the procedures described in `bert-quantization\bert-tf-quantization`, we use `https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip` as initial checkpoint with finetuned accuracy of == <89.55, 82.42>. + Please refer to the directory `bert-quantization\bert-tf-quantization` first for how to get a quantized model. In `section 2.7.7` and `section 2.7.8`, to keep consistent with the procedures described in `bert-quantization\bert-tf-quantization`, we use `https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip` as initial checkpoint with finetuned accuracy of == <89.57, 82.44>. - In `bert-tf-quantization`, we give detailed procedure of Post Training Quantization(PTQ), Quantization Aware Training (QAT) and QAT with Knowledge-distillation. Since they have the same inference procedure, we use QAT checkpoint to show how to evaluate the accuracy of FasterTransformer under INT8. + In `bert-tf-quantization`, we give detailed procedure of Post Training Quantization(PTQ), Quantization Aware Training (QAT) and QAT with Knowledge-distillation. Since they have the same inference procedure, we use QAT-KD checkpoint to show how to evaluate the accuracy of FasterTransformer under INT8. - Suppose we already fine-tuned a FP32 checkpoint using QAT with ft_mode == 2 as described in `bert-tf-quantization`. The path to checkpoint is `squad_model/QAT_mode_2/`. + Suppose we already fine-tuned a FP32 checkpoint using QAT-KD with int8_mode == 2 as described in `bert-tf-quantization`. The path to checkpoint is `squad_model/QAT_KD_mode_2/`. We first convert the checkpoint from FP32 to FP16 (this step is not nesseary, but it will give us a better performance) and then quantize the FP16 checkpoint using `tensorflow/tensorflow_bert/ckpt_quantization.py`. This file requires three arguments, the location of initial checkpoint, the location putting the quantized checkpoint and the int8_mode. ```bash - python tensorflow/tensorflow_bert/ckpt_type_convert.py --init_checkpoint=squad_model/QAT_mode_2/model.ckpt-5474 --fp16_checkpoint=squad_model/QAT_mode_2_fp16/model.ckpt + python tensorflow/tensorflow_bert/ckpt_type_convert.py --init_checkpoint=squad_model/QAT_KD_mode_2/model.ckpt-27374 --fp16_checkpoint=squad_model/QAT_KD_mode_2_fp16/model.ckpt - python tensorflow/tensorflow_bert/ckpt_quantization.py --init_checkpoint=squad_model/QAT_mode_2_fp16/model.ckpt --quantized_checkpoint=squad_model/QAT_mode_2_fp16_quantized/model.ckpt --int8_mode=2 + python tensorflow/tensorflow_bert/ckpt_quantization.py --init_checkpoint=squad_model/QAT_KD_mode_2_fp16/model.ckpt --quantized_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt --int8_mode=2 ./bin/encoder_gemm 8 384 12 64 1 1 python tensorflow/tensorflow_bert/run_squad_wrap.py \ @@ -807,7 +763,7 @@ For those unable to use the NGC container, to set up the required environment or --predict_batch_size=8 \ --vocab_file=squad_model/vocab.txt \ --bert_config_file=squad_model/bert_config.json \ - --init_checkpoint=squad_model/QAT_mode_2_fp16_quantized/model.ckpt \ + --init_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt \ --train_file=squad_data/train-v1.1.json \ --do_predict=True \ --predict_file=squad_data/dev-v1.1.json \ @@ -822,7 +778,7 @@ For those unable to use the NGC container, to set up the required environment or The results of TensorFlow would be like: ```bash - {"exact_match": 81.95837275307474, "f1": 89.2021841747768} + {"exact_match": 83.85052034058657, "f1": 90.46351799300075} ``` 2.7.8 Evaluate the accuracy of Effective FasterTransformer under INT8 @@ -836,7 +792,7 @@ For those unable to use the NGC container, to set up the required environment or --predict_batch_size=8 \ --vocab_file=squad_model/vocab.txt \ --bert_config_file=squad_model/bert_config.json \ - --init_checkpoint=squad_model/QAT_mode_2_fp16_quantized/model.ckpt \ + --init_checkpoint=squad_model/QAT_KD_mode_2_fp16_quantized/model.ckpt \ --train_file=squad_data/train-v1.1.json \ --do_predict=True \ --predict_file=squad_data/dev-v1.1.json \ @@ -852,7 +808,7 @@ For those unable to use the NGC container, to set up the required environment or The results of TensorFlow would be like: ```bash - {"exact_match": 81.95837275307474, "f1": 89.2021841747768} + {"exact_match": 83.85052034058657, "f1": 90.46351799300075} ``` 2.7.9 Compare the speed of BERT of TensorFlow and FasterTransformer under both FP32 and FP16. @@ -908,11 +864,11 @@ For those unable to use the NGC container, to set up the required environment or 3.2 Run the PyTorch encoder sample: ```bash - python pytorch/encoder_sample.py <--fp16> <--int8_mode 0/1/2> <--time> <--ths> <--remove_padding> <--allow_gemm_test> + python pytorch/encoder_sample.py <--fp16> <--int8_mode 0/1/2> <--time> <--remove_padding> <--allow_gemm_test> python pytorch/encoder_sample.py 1 12 32 12 64 --fp16 --time ``` - Remove `--fp16` for fp32 mode. `--int8_mode 1` or `--int8_mode 2` will use int8_mode 1 or 2 in FasterTransformer. `--ths` will run on TorchScript mode. `--remove_padding` will remove the padding of sentence and this brings speedups when the average of sequence length is smaller than the maximum sequence length. `--allow_gemm_test` will enable gemm test config during forward pass. + Remove `--fp16` for fp32 mode. `--int8_mode 1` or `--int8_mode 2` will use int8_mode 1 or 2 in FasterTransformer. `--remove_padding` will remove the padding of sentence and this brings speedups when the average of sequence length is smaller than the maximum sequence length. `--allow_gemm_test` will enable gemm test config during forward pass. The outputs should be like to the following: @@ -934,7 +890,6 @@ For those unable to use the NGC container, to set up the required environment or ``` the `` can be: - `ori`: original HuggingFace's BERT encoder - - `ext`: our PyTorch eager extension - `ths`: original HuggingFace's BERT encoder in TorchScript mode - `thsext`: our TorchScript custom class @@ -987,7 +942,7 @@ For those unable to use the NGC container, to set up the required environment or --cache_dir pytorch/bert_squad/models/ \ --max_seq_length 384 \ --per_gpu_eval_batch_size 8 \ - --model_type ext \ # or thsext, quantized model cannot run on original PyTorch codes (ori or ths) + --model_type thsext \ # quantized model cannot run on original PyTorch codes (ori or ths) --data_type fp16 \ --int8_mode 2 \ # or 1, should match with the quantized checkpoint --allow_gemm_test @@ -996,6 +951,7 @@ For those unable to use the NGC container, to set up the required environment or ## Performance Hardware settings: +* A100 (with mclk 1593MHz, pclk 1410MHz) with AMD EPYC 7742 64-Core Processor * T4 (with mclk 5000MHz, pclk 1590MHz) with Intel(R) Xeon(R) CPU E5-2670 0 @ 2.60GHz * V100 (with mclk 877MHz, pclk 1380MHz) with Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz (dgx-1 server) @@ -1009,9 +965,9 @@ apt-get install bc ### Encoder performance -We demonstrate the inference time of FasterTransformer in C++, TensorFlow and PyTorch, and compare to the performance on T4 and V100. +We demonstrate the inference time of FasterTransformer in C++, TensorFlow and PyTorch, and compare to the performance on A100, T4 and V100. -For the benchmark of TensorFlow, we compare the performance of TensorFlow with XLA (TF), the performance of TensorFlow with FasterTransformer OP (FT-OP) and the performance of FasterTransformer on C++ (TF-CPP), and show the speedup of FT-OP and FT-CPP compare to the TensorFlow. +For the benchmark of TensorFlow, we compare the performance of TensorFlow with XLA (TF), the performance of TensorFlow with FasterTransformer OP (FT) and the performance of TensorFlow with Effctive FasterTransformer (EFF-FT), and show the speedup of FT and EFF-FT compare to the TensorFlow. Compare to v3.1, we modify the profiling method to hide the overhead of seession run and memory copy. Thus, both TensorFlow and FasterTransformer are faster than v3.1 in TensorFlow benchmark on small batch size and sequence length. Because this new method has no obvious overhead compare to the FasterTransformer on C++, we skip the comparison with the C++ implementation. For the benchmark of PyTorch, we compare the performance of TorchScript and the performance of PyTorch with FasterTransformer custom extension (CustomExt), and show the speedup of CustomExt compare to the TorchScript. For standard PyTorch implementation, we find that its performance is smaller or close to TorchScript, so we only demonstrate the performance of TorchScript. Besides, because CustomExt has no obvious overhead compare to the FasterTransformer on C++, we skip the comparison with the C++ implementation. @@ -1029,402 +985,370 @@ In the experiments of encoder, we updated the following parameters: * size_per_head = 64 * num_layers = 12 -#### Encoder performance on T4 and cpp - -| | FT-FP16-cpp (ms) | FT-EFF-cpp (ms) | FT-int8v2-cpp (ms) | FT-EFF-int8v2-cpp (ms) | FT-EFF-cpp Speedup | FT-int8v2-cpp Speedup | FT-EFF-int8v2-cpp Speedup | -|:---------------------:|:----------------:|:---------------:|:------------------:|:----------------------:|:------------------:|:---------------------:|:-------------------------:| -| <1, 32> | 1.45 | 2.92 | 1.66 | 1.74 | 0.50 | 0.87 | 0.83 | -| <1, 64> | 1.53 | 2.95 | 1.66 | 1.74 | 0.52 | 0.92 | 0.88 | -| <1, 128> | 1.92 | 3.41 | 1.75 | 1.76 | 0.56 | 1.10 | 1.10 | -| <8, 32> | 2.61 | 3.28 | 2.13 | 1.94 | 0.80 | 1.23 | 1.35 | -| <8, 64> | 4.38 | 3.34 | 2.94 | 2.25 | 1.31 | 1.49 | 1.95 | -| <8, 128> | 8.55 | 5.09 | 5.33 | 3.24 | 1.68 | 1.60 | 2.64 | -| <32, 32> | 8.53 | 5.07 | 5.58 | 3.48 | 1.68 | 1.53 | 2.45 | -| <32, 64> | 17.84 | 9.14 | 10.13 | 5.98 | 1.95 | 1.76 | 2.98 | -| <32, 128> | 35.14 | 18.05 | 20.01 | 11.28 | 1.95 | 1.76 | 3.12 | -| <64, 32> | 18.4 | 9.3 | 10.58 | 6.44 | 1.98 | 1.74 | 2.86 | -| <64, 64> | 34.19 | 17.98 | 19.64 | 11.15 | 1.90 | 1.74 | 3.07 | -| <64, 128> | 68.28 | 34.49 | 38.92 | 22.27 | 1.98 | 1.75 | 3.07 | -| <128, 32> | 35.41 | 18.19 | 20.52 | 12.07 | 1.95 | 1.73 | 2.93 | -| <128, 64> | 67.73 | 34.09 | 37.74 | 21.78 | 1.99 | 1.79 | 3.11 | -| <128, 128> | 139.72 | 66.88 | 80.12 | 49.95 | 2.09 | 1.74 | 2.80 | +#### Encoder performance on A100 and TensorFlow -#### Encoder performance on T4 and TensorFlow +* Performance on TF32 -* Performance on FP32 +User can use `export NVIDIA_TF32_OVERRIDE=0` to enforce the program run under FP32. + +| Batch_size | Seq_len | Precision | TF
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:---------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP32 | 2.57 | 1.87 | 1.97 | 1.37 | 1.30 | +| 1 | 128 | FP32 | 5.37 | 4.70 | 2.55 | 1.14 | 2.10 | +| 1 | 384 | FP32 | 7.39 | 6.61 | 9.03 | 1.11 | 0.81 | +| 8 | 32 | FP32 | 5.26 | 4.59 | 4.65 | 1.14 | 1.13 | +| 8 | 128 | FP32 | 13.29 | 12.54 | 7.03 | 1.05 | 1.89 | +| 8 | 384 | FP32 | 38.07 | 36.66 | 22.17 | 1.03 | 1.71 | +| 32 | 32 | FP32 | 13.78 | 13.24 | 7.69 | 1.04 | 1.79 | +| 32 | 128 | FP32 | 45.90 | 45.02 | 24.63 | 1.01 | 1.86 | +| 32 | 384 | FP32 | 150.26 | 143.41 | 84.28 | 1.04 | 1.78 | -| | TF (ms) | FT-OP (ms) | FT-CPP (ms) | FT-OP Speedup | FT-CPP Speedup | -|:---------------------:|:-------:|:----------:|:-----------:|:-------------:|:--------------:| -| <1, 32> | 8.09 | 4.70 | 2.57 | 1.72 | 3.14 | -| <1, 64> | 8.47 | 5.75 | 3.71 | 1.47 | 2.28 | -| <1, 128> | 9.78 | 7.50 | 6.17 | 1.30 | 1.58 | -| <8, 32> | 14.45 | 12.56 | 11.70 | 1.15 | 1.23 | -| <8, 64> | 24.99 | 23.17 | 22.22 | 1.07 | 1.12 | -| <8, 128> | 49.23 | 48.09 | 46.99 | 1.02 | 1.04 | -| <32, 32> | 47.23 | 46.44 | 45.58 | 1.01 | 1.03 | -| <32, 64> | 94.47 | 87.76 | 86.57 | 1.07 | 1.09 | -| <32, 128> | 195.23 | 177.79 | 174.46 | 1.09 | 1.11 | -| <64, 32> | 94.07 | 86.56 | 85.25 | 1.08 | 1.10 | -| <64, 64> | 191.65 | 173.67 | 170.94 | 1.10 | 1.12 | -| <64, 128> | 393.74 | 352.82 | 341.39 | 1.11 | 1.15 | -| <128, 32> | 190.89 | 170.55 | 167.32 | 1.11 | 1.14 | -| <128, 64> | 384.93 | 344.23 | 332.85 | 1.11 | 1.15 | -| <128, 128> | 817.94 | 729.46 | 711.44 | 1.12 | 1.14 | +* Performance on TF32 (Note that the absolute differences of results under TF32 are larger than FP32.) + +User can use `export NVIDIA_TF32_OVERRIDE=1` to enforce the program run under TF32. + +| Batch_size | Seq_len | Precision | TF
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:---------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | TF32 | 2.65 | 1.70 | 1.76 | 1.55 | 1.50 | +| 1 | 128 | TF32 | 2.58 | 1.74 | 1.83 | 1.48 | 1.40 | +| 1 | 384 | TF32 | 3.48 | 2.69 | 2.22 | 1.29 | 1.56 | +| 8 | 32 | TF32 | 2.84 | 2.11 | 1.90 | 1.34 | 1.49 | +| 8 | 128 | TF32 | 4.88 | 3.99 | 2.74 | 1.22 | 1.78 | +| 8 | 384 | TF32 | 12.16 | 12.16 | 7.54 | 1.00 | 1.61 | +| 32 | 32 | TF32 | 4.82 | 4.11 | 2.89 | 1.17 | 1.66 | +| 32 | 128 | TF32 | 12.97 | 11.77 | 7.03 | 1.10 | 1.84 | +| 32 | 384 | TF32 | 46.86 | 45.79 | 25.87 | 1.02 | 1.81 | * Performance on FP16 -| | TF (ms) | FT-OP (ms) | FT-CPP (ms) | FT-OP Speedup | FT-CPP Speedup | -|:---------------------:|:-------:|:----------:|:-----------:|:-------------:|:--------------:| -| <1, 32> | 9.70 | 4.24 | 1.45 | 2.28 | 6.68 | -| <1, 64> | 9.08 | 4.83 | 1.53 | 1.87 | 5.93 | -| <1, 128> | 6.56 | 3.91 | 1.92 | 1.67 | 3.41 | -| <8, 32> | 8.22 | 6.00 | 2.61 | 1.37 | 3.14 | -| <8, 64> | 9.55 | 5.94 | 4.38 | 1.60 | 2.18 | -| <8, 128> | 14.76 | 10.04 | 8.55 | 1.47 | 1.72 | -| <32, 32> | 14.53 | 10.13 | 8.53 | 1.43 | 1.70 | -| <32, 64> | 27.73 | 18.84 | 17.84 | 1.47 | 1.55 | -| <32, 128> | 53.68 | 36.33 | 35.14 | 1.47 | 1.52 | -| <64, 32> | 27.85 | 19.53 | 18.40 | 1.42 | 1.51 | -| <64, 64> | 52.34 | 35.67 | 34.19 | 1.46 | 1.53 | -| <64, 128> | 104.39 | 70.87 | 68.28 | 1.47 | 1.52 | -| <128, 32> | 52.92 | 36.55 | 35.41 | 1.44 | 1.49 | -| <128, 64> | 101.59 | 70.16 | 67.73 | 1.44 | 1.49 | -| <128, 128> | 209.16 | 143.13 | 139.72 | 1.46 | 1.49 | - -* Performance on INT8 - -| | TF FP16 (ms) | FT-int8v1-op (ms) | FT-int8v2-op (ms) | FT-int8v1-op Speedup | FT-int8v2-op Speedup | -|----------------------:|:------------:|:-----------------:|:-----------------:|:--------------------:|:--------------------:| -| <1, 32> | 8.95 | 5.23 | 5.70 | 1.71 | 1.57 | -| <1, 64> | 9.54 | 5.01 | 4.98 | 1.90 | 1.92 | -| <1, 128> | 9.27 | 5.85 | 5.31 | 1.58 | 1.75 | -| <8, 32> | 9.04 | 5.86 | 5.77 | 1.54 | 1.57 | -| <8, 64> | 8.87 | 6.42 | 5.84 | 1.38 | 1.52 | -| <8, 128> | 14.74 | 10.78 | 7.45 | 1.37 | 1.98 | -| <32, 32> | 14.01 | 9.85 | 7.54 | 1.42 | 1.86 | -| <32, 64> | 27.13 | 17.84 | 10.95 | 1.52 | 2.48 | -| <32, 128> | 53.44 | 33.62 | 19.78 | 1.59 | 2.70 | -| <64, 32> | 26.46 | 17.21 | 11.42 | 1.54 | 2.32 | -| <64, 64> | 51.2 | 31.64 | 19.37 | 1.62 | 2.64 | -| <64, 128> | 103.66 | 65.08 | 36.57 | 1.59 | 2.83 | -| <128, 32> | 50.91 | 31.27 | 20.17 | 1.63 | 2.52 | -| <128, 64> | 100.14 | 60.55 | 35.77 | 1.65 | 2.80 | -| <128, 128> | 206.65 | 128.28 | 73.08 | 1.61 | 2.83 | +| Batch_size | Seq_len | Precision | TF
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:---------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP16 | 2.08 | 0.90 | 0.96 | 2.31 | 2.16 | +| 1 | 128 | FP16 | 2.21 | 1.00 | 1.02 | 2.21 | 2.16 | +| 1 | 384 | FP16 | 2.50 | 1.33 | 1.20 | 1.87 | 2.08 | +| 8 | 32 | FP16 | 2.08 | 1.04 | 1.04 | 2.00 | 2.00 | +| 8 | 128 | FP16 | 3.32 | 1.77 | 1.43 | 1.87 | 2.32 | +| 8 | 384 | FP16 | 8.55 | 4.39 | 2.57 | 1.94 | 3.32 | +| 32 | 32 | FP16 | 3.28 | 1.77 | 1.40 | 1.85 | 2.34 | +| 32 | 128 | FP16 | 8.38 | 4.81 | 2.75 | 1.74 | 3.04 | +| 32 | 384 | FP16 | 27.80 | 14.67 | 7.71 | 1.89 | 3.60 | + +* Performance on INT8-v1 + +| Batch_size | Seq_len | TF-FP16
Latency (ms) | FT-INT8-v1
Latency (ms) | EFF-FT-INT8-v1
Latency (ms) | FT-INT8-v1
Speedup | EFF-FT-INT8-v1
Speedup | +|:----------:|:-------:|:--------------------------:|:----------------------------------------:|:--------------------------------------------:|:-----------------------------------:|:---------------------------------------:| +| 1 | 32 | 2.09 | 1.40 | 1.51 | 1.49 | 1.38 | +| 1 | 128 | 2.17 | 1.74 | 1.70 | 1.24 | 1.27 | +| 1 | 384 | 2.49 | 1.99 | 2.04 | 1.25 | 1.22 | +| 8 | 32 | 2.10 | 1.77 | 1.87 | 1.18 | 1.12 | +| 8 | 128 | 3.31 | 2.34 | 2.01 | 1.41 | 1.64 | +| 8 | 384 | 8.57 | 4.75 | 3.17 | 1.80 | 2.70 | +| 32 | 32 | 3.29 | 2.59 | 2.09 | 1.27 | 1.57 | +| 32 | 128 | 8.32 | 5.53 | 3.36 | 1.50 | 2.47 | +| 32 | 384 | 27.81 | 14.80 | 8.34 | 1.87 | 3.33 | + +* Performance on INT8-v2 + +Note that the INT8-v2 leads to larger overhead because some optimization requires the weights of Q, K and V are continous. Using c API or preprocess the model can bring better performance. + +| Batch_size | Seq_len | TF-FP16
Latency (ms) | FT-INT8-v2
Latency (ms) | EFF-FT-INT8-v2
Latency (ms) | FT-INT8-v2
Speedup | EFF-FT-INT8-v2
Speedup | +|:----------:|:-------:|:--------------------------:|:----------------------------------------:|:--------------------------------------------:|:-----------------------------------:|:---------------------------------------:| +| 1 | 32 | 2.06 | 1.53 | 1.53 | 1.34 | 1.34 | +| 1 | 128 | 2.18 | 1.57 | 1.57 | 1.38 | 1.38 | +| 1 | 384 | 2.44 | 1.70 | 1.62 | 1.43 | 1.50 | +| 8 | 32 | 2.07 | 1.63 | 1.62 | 1.26 | 1.27 | +| 8 | 128 | 3.33 | 1.93 | 1.73 | 1.72 | 1.92 | +| 8 | 384 | 8.52 | 3.40 | 2.50 | 2.50 | 3.40 | +| 32 | 32 | 3.33 | 2.18 | 1.85 | 1.52 | 1.80 | +| 32 | 128 | 8.32 | 3.85 | 2.48 | 2.16 | 3.35 | +| 32 | 384 | 27.80 | 10.07 | 5.63 | 2.76 | 4.93 | -#### Encoder performance on V100 and TensorFlow +#### Encoder performance on T4 and TensorFlow * Performance on FP32 -| | TF (ms) | FT-OP (ms) | FT-CPP (ms) | FT-OP Speedup | FT-CPP Speedup | -|:---------------------:|:-------:|:----------:|:-----------:|:-------------:|:--------------:| -| <1, 32> | 3.78 | 2.99 | 1.76 | 1.26 | 2.14 | -| <1, 64> | 4.55 | 3.29 | 2.16 | 1.38 | 2.10 | -| <1, 128> | 5.23 | 4.15 | 2.94 | 1.26 | 1.77 | -| <8, 32> | 7.42 | 6.14 | 4.66 | 1.20 | 1.59 | -| <8, 64> | 10.80 | 9.98 | 8.48 | 1.08 | 1.27 | -| <8, 128> | 18.73 | 17.63 | 15.50 | 1.06 | 1.20 | -| <32, 32> | 18.16 | 16.97 | 15.34 | 1.07 | 1.18 | -| <32, 64> | 33.87 | 32.69 | 30.01 | 1.03 | 1.12 | -| <32, 128> | 66.11 | 64.31 | 59.46 | 1.02 | 1.11 | -| <64, 32> | 34.17 | 32.56 | 29.91 | 1.04 | 1.14 | -| <64, 64> | 66.21 | 63.51 | 58.84 | 1.04 | 1.12 | -| <64, 128> | 133.61 | 126.58 | 119.08 | 1.05 | 1.12 | -| <128, 32> | 65.36 | 62.72 | 58.22 | 1.04 | 1.12 | -| <128, 64> | 131.12 | 123.94 | 117.80 | 1.05 | 1.11 | -| <128, 128> | 253.90 | 251.03 | 234.30 | 1.01 | 1.08 | +| Batch_size | Seq_len | Precision | TF
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:---------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP32 | 3.63 | 2.51 | 2.96 | 1.44 | 1.22 | +| 1 | 128 | FP32 | 7.10 | 6.11 | 4.78 | 1.16 | 1.48 | +| 1 | 384 | FP32 | 19.69 | 19.01 | 11.16 | 1.03 | 1.76 | +| 8 | 32 | FP32 | 12.85 | 11.44 | 6.75 | 1.12 | 1.90 | +| 8 | 128 | FP32 | 49.19 | 42.88 | 23.48 | 1.14 | 2.09 | +| 8 | 384 | FP32 | 154.09 | 140.36 | 84.84 | 1.09 | 1.81 | +| 32 | 32 | FP32 | 45.37 | 41.66 | 22.33 | 1.08 | 2.03 | +| 32 | 128 | FP32 | 183.49 | 162.62 | 93.41 | 1.12 | 1.96 | +| 32 | 384 | FP32 | 602.69 | 552.46 | 331.55 | 1.09 | 1.81 | * Performance on FP16 -| | TF (ms) | FT-OP (ms) | FT-CPP (ms) | FT-OP Speedup | FT-CPP Speedup | -|:---------------------:|:-------:|:----------:|:-----------:|:-------------:|:--------------:| -| <1, 32> | 3.44 | 3.05 | 1.24 | 1.12 | 2.77 | -| <1, 64> | 4.96 | 2.88 | 1.45 | 1.72 | 3.42 | -| <1, 128> | 3.59 | 2.79 | 1.57 | 1.28 | 2.28 | -| <8, 32> | 3.94 | 3.00 | 1.80 | 1.31 | 2.18 | -| <8, 64> | 5.12 | 3.86 | 2.45 | 1.32 | 2.08 | -| <8, 128> | 7.16 | 5.21 | 3.79 | 1.37 | 1.88 | -| <32, 32> | 7.27 | 5.25 | 3.60 | 1.38 | 2.01 | -| <32, 64> | 11.26 | 8.47 | 6.61 | 1.32 | 1.70 | -| <32, 128> | 20.62 | 15.52 | 12.52 | 1.32 | 1.64 | -| <64, 32> | 11.31 | 8.57 | 6.59 | 1.31 | 1.71 | -| <64, 64> | 19.94 | 15.63 | 12.22 | 1.27 | 1.63 | -| <64, 128> | 36.25 | 28.86 | 23.73 | 1.25 | 1.52 | -| <128, 32> | 20.15 | 15.27 | 12.24 | 1.31 | 1.64 | -| <128, 64> | 35.67 | 28.73 | 23.40 | 1.24 | 1.52 | -| <128, 128> | 68.84 | 54.53 | 46.11 | 1.26 | 1.49 | +| Batch_size | Seq_len | Precision | TF
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:---------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP16 | 5.84 | 1.92 | 2.33 | 3.04 | 2.50 | +| 1 | 128 | FP16 | 3.30 | 2.57 | 2.47 | 1.28 | 1.33 | +| 1 | 384 | FP16 | 5.91 | 3.85 | 3.34 | 1.53 | 1.76 | +| 8 | 32 | FP16 | 4.48 | 2.69 | 2.76 | 1.66 | 1.62 | +| 8 | 128 | FP16 | 12.34 | 8.31 | 4.89 | 1.48 | 2.52 | +| 8 | 384 | FP16 | 39.88 | 24.37 | 12.47 | 1.63 | 3.19 | +| 32 | 32 | FP16 | 12.21 | 8.40 | 4.89 | 1.45 | 2.49 | +| 32 | 128 | FP16 | 45.50 | 32.32 | 17.04 | 1.40 | 2.67 | +| 32 | 384 | FP16 | 159.93 | 97.66 | 48.60 | 1.63 | 3.29 | + +* Performance on INT8-v1 + +| Batch_size | Seq_len | TF-FP16
Latency (ms) | FT-INT8-v1
Latency (ms) | EFF-FT-INT8-v1
Latency (ms) | FT-INT8-v1
Speedup | EFF-FT-INT8-v1
Speedup | +|:----------:|:-------:|:--------------------------:|:----------------------------------------:|:--------------------------------------------:|:-----------------------------------:|:---------------------------------------:| +| 1 | 32 | 4.44 | 1.98 | 2.50 | 2.24 | 1.77 | +| 1 | 128 | 3.48 | 2.93 | 3.03 | 1.18 | 1.14 | +| 1 | 384 | 5.94 | 3.68 | 3.24 | 1.61 | 1.83 | +| 8 | 32 | 4.51 | 2.71 | 3.04 | 1.66 | 1.48 | +| 8 | 128 | 12.10 | 6.63 | 3.93 | 1.82 | 3.07 | +| 8 | 384 | 39.64 | 18.10 | 9.68 | 2.19 | 4.09 | +| 32 | 32 | 12.30 | 7.54 | 4.39 | 1.63 | 2.80 | +| 32 | 128 | 45.26 | 22.97 | 12.45 | 1.97 | 3.63 | +| 32 | 384 | 173.49 | 68.89 | 39.49 | 2.51 | 4.39 | + +* Performance on INT8-v2 + +| Batch_size | Seq_len | TF-FP16
Latency (ms) | FT-INT8-v2
Latency (ms) | EFF-FT-INT8-v2
Latency (ms) | FT-INT8-v2
Speedup | EFF-FT-INT8-v2
Speedup | +|:----------:|:-------:|:--------------------------:|:----------------------------------------:|:--------------------------------------------:|:-----------------------------------:|:---------------------------------------:| +| 1 | 32 | 4.57 | 3.03 | 2.96 | 1.50 | 1.54 | +| 1 | 128 | 3.75 | 2.41 | 2.20 | 1.55 | 1.70 | +| 1 | 384 | 5.88 | 2.90 | 3.17 | 2.02 | 1.85 | +| 8 | 32 | 4.10 | 2.87 | 2.91 | 1.42 | 1.40 | +| 8 | 128 | 12.27 | 5.13 | 3.38 | 2.39 | 3.63 | +| 8 | 384 | 39.50 | 13.47 | 7.14 | 2.93 | 5.53 | +| 32 | 32 | 12.09 | 5.90 | 3.51 | 2.04 | 3.44 | +| 32 | 128 | 45.23 | 17.19 | 9.09 | 2.63 | 4.97 | +| 32 | 384 | 158.10 | 52.41 | 28.65 | 3.01 | 5.51 | -#### Encoder performance on T4 and PyTorch +#### Encoder performance on V100 and TensorFlow * Performance on FP32 -| | TorchScript (ms) | CustomExt (ms) | CustomExt Speedup | -|:---------------------:|:----------------:|:--------------:|:-----------------:| -| <1, 32> | 12.16 | 2.71 | 4.48 | -| <1, 64> | 12.45 | 3.98 | 3.12 | -| <1, 128> | 12.46 | 6.37 | 1.95 | -| <8, 32> | 13.50 | 11.34 | 1.19 | -| <8, 64> | 24.91 | 21.14 | 1.17 | -| <8, 128> | 52.93 | 42.62 | 1.24 | -| <32, 32> | 51.04 | 41.27 | 1.23 | -| <32, 64> | 93.57 | 82.27 | 1.13 | -| <32, 128> | 197.38 | 164.14 | 1.20 | -| <64, 32> | 93.19 | 81.00 | 1.15 | -| <64, 64> | 191.26 | 162.27 | 1.17 | -| <64, 128> | 394.24 | 333.26 | 1.18 | -| <128, 32> | 191.31 | 158.40 | 1.20 | -| <128, 64> | 383.33 | 325.44 | 1.17 | -| <128, 128> | 799.05 | 670.33 | 1.19 | +| Batch_size | Seq_len | Precision | TF
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:---------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP32 | 2.84 | 1.56 | 1.81 | 1.82 | 1.56 | +| 1 | 128 | FP32 | 3.97 | 2.90 | 2.27 | 1.36 | 1.74 | +| 1 | 384 | FP32 | 8.32 | 6.72 | 5.02 | 1.23 | 1.65 | +| 8 | 32 | FP32 | 5.71 | 4.65 | 3.33 | 1.22 | 1.71 | +| 8 | 128 | FP32 | 16.98 | 15.51 | 9.76 | 1.09 | 1.73 | +| 8 | 384 | FP32 | 53.46 | 48.70 | 28.52 | 1.09 | 1.87 | +| 32 | 32 | FP32 | 16.65 | 15.24 | 9.48 | 1.09 | 1.75 | +| 32 | 128 | FP32 | 62.98 | 57.85 | 32.53 | 1.08 | 1.93 | +| 32 | 384 | FP32 | 208.40 | 188.17 | 109.95 | 1.10 | 1.89 | * Performance on FP16 -| | TorchScript (ms) | CustomExt (ms) | CustomExt Speedup | -|:---------------------:|:----------------:|:--------------:|:-----------------:| -| <1, 32> | 11.59 | 1.84 | 6.29 | -| <1, 64> | 12.15 | 1.91 | 6.36 | -| <1, 128> | 12.47 | 2.26 | 5.51 | -| <8, 32> | 11.17 | 2.90 | 3.85 | -| <8, 64> | 10.74 | 4.61 | 2.32 | -| <8, 128> | 13.51 | 8.65 | 1.56 | -| <32, 32> | 13.03 | 8.70 | 1.49 | -| <32, 64> | 25.71 | 16.81 | 1.52 | -| <32, 128> | 53.81 | 32.88 | 1.63 | -| <64, 32> | 25.67 | 17.24 | 1.48 | -| <64, 64> | 51.31 | 32.36 | 1.58 | -| <64, 128> | 106.72 | 63.85 | 1.67 | -| <128, 32> | 51.73 | 33.42 | 1.54 | -| <128, 64> | 101.73 | 62.72 | 1.62 | -| <128, 128> | 209.20 | 130.55 | 1.60 | - -* Performance on INT8 - -| | TorchScript FP16 (ms) | FT-INT8v1-OP (ms) | FT-INT8v2-OP (ms) | FT-INT8v1-OP Speedup | FT-INT8v2-OP Speedup | -|:---------------------:|:---------------------:|:-----------------:|:-----------------:|:--------------------:|:--------------------:| -| <1, 32> | 11.66 | 1.99 | 2.23 | 5.86 | 5.23 | -| <1, 64> | 12.2 | 1.96 | 2.18 | 6.22 | 5.60 | -| <1, 128> | 11.27 | 2.16 | 2.31 | 5.22 | 4.88 | -| <8, 32> | 11.09 | 2.83 | 2.59 | 3.92 | 4.28 | -| <8, 64> | 10.79 | 4.43 | 3.31 | 2.44 | 3.26 | -| <8, 128> | 13.5 | 8.82 | 5.67 | 1.53 | 2.38 | -| <32, 32> | 12.85 | 8.49 | 6 | 1.51 | 2.14 | -| <32, 64> | 25.57 | 16.07 | 10.18 | 1.59 | 2.51 | -| <32, 128> | 53.58 | 32.76 | 19.72 | 1.64 | 2.72 | -| <64, 32> | 25.34 | 16.25 | 10.71 | 1.56 | 2.37 | -| <64, 64> | 51.28 | 30.78 | 19.74 | 1.67 | 2.60 | -| <64, 128> | 104.8 | 64.53 | 39.04 | 1.62 | 2.68 | -| <128, 32> | 50.8 | 31.03 | 20.66 | 1.64 | 2.46 | -| <128, 64> | 99.47 | 60.09 | 38.77 | 1.66 | 2.57 | -| <128, 128> | 210.27 | 128.63 | 79.11 | 1.63 | 2.66 | +| Batch_size | Seq_len | Precision | TF
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:---------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP16 | 2.96 | 1.13 | 1.58 | 2.61 | 1.87 | +| 1 | 128 | FP16 | 2.49 | 1.43 | 1.49 | 1.74 | 1.67 | +| 1 | 384 | FP16 | 3.61 | 2.12 | 2.04 | 1.70 | 1.76 | +| 8 | 32 | FP16 | 2.86 | 1.74 | 1.63 | 1.64 | 1.75 | +| 8 | 128 | FP16 | 5.58 | 3.59 | 2.99 | 1.55 | 1.86 | +| 8 | 384 | FP16 | 14.59 | 10.60 | 7.31 | 1.37 | 1.99 | +| 32 | 32 | FP16 | 5.56 | 3.55 | 3.00 | 1.56 | 1.85 | +| 32 | 128 | FP16 | 17.24 | 12.41 | 8.05 | 1.38 | 2.14 | +| 32 | 384 | FP16 | 53.24 | 40.20 | 26.55 | 1.32 | 2.00 | + +#### Encoder performance comparison between T4, V100, A100 and A100 with MIG mode on TensorFlow + +* Performance of EFF-FT on FP16 + +| Batch Size | Sequence
Length"| Precision | T4 Latency | T4
Throughput | V100 Latency | V100
Throughput | A100 Latency | A100
Throughput | A100 MIG
Latency | A100 MIG
Throughput | +|:----------:|:---------------------:|:---------:|:----------:|:-------------------:|:------------:|:---------------------:|:------------:|:---------------------:|:---------------------:|:------------------:| +| 1 | 32 | FP16 | 1.84 | 543 | 1.58 | 632 | 0.88 | 1136 | 1.42 | 4929 | +| 1 | 128 | FP16 | 1.88 | 531 | 1.49 | 671 | 0.91 | 1098 | 1.54 | 4545 | +| 1 | 384 | FP16 | 2.99 | 334 | 2.04 | 490 | 1.12 | 892 | 3.38 | 2071 | +| 8 | 32 | FP16 | 2.1 | 3809 | 1.63 | 4907 | 0.96 | 8333 | 2 | 28000 | +| 8 | 128 | FP16 | 4.93 | 1622 | 2.99 | 2675 | 1.38 | 5797 | 5.23 | 10707 | +| 8 | 384 | FP16 | 13.47 | 593 | 7.31 | 1094 | 2.64 | 3030 | 13.87 | 4037 | +| 32 | 32 | FP16 | 4.89 | 6543 | 3 | 10666 | 1.34 | 23880 | 5.26 | 42585 | +| 32 | 128 | FP16 | 17.32 | 1847 | 8.05 | 3975 | 3.01 | 10631 | 16.41 | 13650 | +| 32 | 384 | FP16 | 52.75 | 606 | 26.55 | 1205 | 8.98 | 3563 | 53.2 | 4210 | + +* Performance of EFF-FT on INT8-v2 + +| Batch Size | Sequence
Length"| Precision | T4 Latency | T4
Throughput | V100 Latency | V100
Throughput | A100 Latency | A100
Throughput | A100 MIG
Latency | A100 MIG
Throughput | +|:----------:|:---------------------:|:---------:|:----------:|:-------------------:|:------------:|:---------------------:|:------------:|:---------------------:|:-----------------:|:------------------:| +| 1 | 32 | INT8-v2 | 1.87 | 534 | x | x | 1.19 | 840 | 1.35 | 5185 | +| 1 | 128 | INT8-v2 | 1.88 | 531 | x | x | 1.2 | 833 | 1.46 | 4794 | +| 1 | 384 | INT8-v2 | 2.4 | 416 | x | x | 1.31 | 763 | 2.23 | 3139 | +| 8 | 32 | INT8-v2 | 2.37 | 3375 | x | x | 1.23 | 6504 | 1.87 | 29946 | +| 8 | 128 | INT8-v2 | 3.07 | 2605 | x | x | 1.32 | 6060 | 3.43 | 16326 | +| 8 | 384 | INT8-v2 | 7.42 | 1078 | x | x | 2.24 | 3571 | 10.94 | 5118 | +| 32 | 32 | INT8-v2 | 3.41 | 9384 | x | x | 1.51 | 21192 | 4.26 | 52582 | +| 32 | 128 | INT8-v2 | 9.53 | 3357 | x | x | 2.28 | 14035 | 11.57 | 19360 | +| 32 | 384 | INT8-v2 | 30.53 | 1048 | x | x | 5.71 | 5604 | 37.03 | 6049 | + +#### Encoder performance comparison between different features on T4 and TensorFlow + +| Batch Size | Sequence
Length | FT-FP16
Latency | EFF-FT
Latency | FT-INT8-v2
Latency | EFF-FT-INT8-v2
Latency | EFF-FT
Speedup | FT-INT8-v2
Speedup | EFF-FT-INT8-v2
Speedup | +|:----------:|:---------------------:|:---------------------:|:--------------------:|:------------------------:|:----------------------------:|:--------------------:|:--------------------:|:---------------------:| +| 1 | 32 | 1.58 | 1.84 | 2.06 | 1.87 | 0.859 | 0.767 | 0.845 | +| 1 | 128 | 1.96 | 1.88 | 2.19 | 1.88 | 1.043 | 0.895 | 1.043 | +| 1 | 384 | 3.71 | 2.99 | 2.68 | 2.4 | 1.241 | 1.384 | 1.546 | +| 8 | 32 | 2.62 | 2.1 | 2.61 | 2.37 | 1.248 | 1.004 | 1.105 | +| 8 | 128 | 8.37 | 4.93 | 5.19 | 3.07 | 1.698 | 1.613 | 2.726 | +| 8 | 384 | 25.87 | 13.47 | 14.54 | 7.42 | 1.921 | 1.779 | 3.487 | +| 32 | 32 | 8.43 | 4.89 | 6 | 3.41 | 1.724 | 1.405 | 2.472 | +| 32 | 128 | 34.13 | 17.32 | 17.91 | 9.53 | 1.971 | 1.906 | 3.581 | +| 32 | 384 | 99.72 | 52.75 | 55.39 | 30.53 | 1.890 | 1.800 | 3.266 | + +#### Encoder performance on A100 and PyTorch + +* Performance on TF32 + +User can use `export NVIDIA_TF32_OVERRIDE=1` to enforce the program run under TF32. + +| Batch_size | Seq_len | Precision | TorchScript
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:------------------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | TF32 | 5.01 | 1.99 | 2.08 | 2.51 | 2.40 | +| 1 | 128 | TF32 | 4.89 | 1.94 | 2.02 | 2.52 | 2.42 | +| 1 | 384 | TF32 | 4.52 | 2.74 | 2.07 | 1.64 | 2.18 | +| 8 | 32 | TF32 | 5.07 | 2.18 | 2.09 | 2.32 | 2.42 | +| 8 | 128 | TF32 | 4.96 | 4.03 | 3.82 | 1.23 | 1.29 | +| 8 | 384 | TF32 | 12.87 | 12.21 | 8.02 | 1.05 | 1.60 | +| 32 | 32 | TF32 | 4.78 | 4.16 | 3.92 | 1.14 | 1.21 | +| 32 | 128 | TF32 | 13.49 | 11.64 | 7.00 | 1.15 | 1.92 | +| 32 | 384 | TF32 | 47.17 | 45.47 | 25.52 | 1.03 | 1.84 | -#### Encoder performance on V100 and PyTorch +* Performance on FP16 -* Performance on FP32 +| Batch_size | Seq_len | Precision | TorchScript
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:------------------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP16 | 4.22 | 0.93 | 1.15 | 4.53 | 3.66 | +| 1 | 128 | FP16 | 4.47 | 1.09 | 1.20 | 4.10 | 3.72 | +| 1 | 384 | FP16 | 4.16 | 1.36 | 1.16 | 3.05 | 3.58 | +| 8 | 32 | FP16 | 4.19 | 1.11 | 1.13 | 3.77 | 3.70 | +| 8 | 128 | FP16 | 4.35 | 1.81 | 1.59 | 2.40 | 2.73 | +| 8 | 384 | FP16 | 9.51 | 4.77 | 3.04 | 1.99 | 3.12 | +| 32 | 32 | FP16 | 4.19 | 1.82 | 1.58 | 2.30 | 2.65 | +| 32 | 128 | FP16 | 9.57 | 4.81 | 2.79 | 1.98 | 3.43 | +| 32 | 384 | FP16 | 32.43 | 14.68 | 8.07 | 2.20 | 4.01 | + +* Performance on INT8-v1 + +| Batch_size | Seq_len | TorchScript-FP16
Latency (ms) | FT-INT8-v1
Latency (ms) | EFF-FT-INT8-v1
Latency (ms) | FT-INT8-v1
Speedup | EFF-FT-INT8-v1
Speedup | +|:----------:|:-------:|:-----------------------------------:|:-----------------------------:|:---------------------------------:|:------------------------:|:----------------------------:| +| 1 | 32 | 4.20 | 1.44 | 1.56 | 2.91 | 2.69 | +| 1 | 128 | 4.30 | 1.75 | 1.69 | 2.45 | 2.54 | +| 1 | 384 | 4.28 | 1.98 | 1.76 | 2.16 | 2.43 | +| 8 | 32 | 4.18 | 1.77 | 1.91 | 2.36 | 2.18 | +| 8 | 128 | 4.44 | 2.28 | 1.76 | 1.94 | 2.52 | +| 8 | 384 | 9.50 | 4.73 | 2.70 | 2.00 | 3.51 | +| 32 | 32 | 4.28 | 2.51 | 2.10 | 1.70 | 2.03 | +| 32 | 128 | 9.59 | 5.57 | 3.53 | 1.72 | 2.71 | +| 32 | 384 | 32.38 | 14.97 | 8.57 | 2.16 | 3.77 | + +* Performance on INT8-v2 + +| Batch_size | Seq_len | TorchScript-FP16
Latency (ms) | FT-INT8-v2
Latency (ms) | EFF-FT-INT8-v2
Latency (ms) | FT-INT8-v2
Speedup | EFF-FT-INT8-v2
Speedup | +|:----------:|:-------:|:-----------------------------------:|:-----------------------------:|:---------------------------------:|:------------------------:|:----------------------------:| +| 1 | 32 | 4.35 | 1.55 | 1.60 | 2.80 | 2.71 | +| 1 | 128 | 4.40 | 1.59 | 1.62 | 2.76 | 2.71 | +| 1 | 384 | 4.27 | 1.75 | 1.69 | 2.44 | 2.52 | +| 8 | 32 | 4.18 | 1.67 | 1.64 | 2.50 | 2.54 | +| 8 | 128 | 4.24 | 1.79 | 1.63 | 2.36 | 2.60 | +| 8 | 384 | 9.46 | 3.45 | 2.37 | 2.74 | 3.99 | +| 32 | 32 | 4.21 | 2.07 | 1.72 | 2.03 | 2.44 | +| 32 | 128 | 9.57 | 3.88 | 2.45 | 2.46 | 3.90 | +| 32 | 384 | 32.37 | 10.39 | 6.08 | 3.11 | 5.32 | -| | TorchScript (ms) | CustomExt (ms) | CustomExt Speedup | -|:---------------------:|:----------------:|:--------------:|:-----------------:| -| <1, 32> | 6.39 | 1.80 | 3.55 | -| <1, 64> | 8.63 | 2.20 | 3.92 | -| <1, 128> | 6.76 | 3.03 | 2.23 | -| <8, 32> | 6.71 | 4.74 | 1.41 | -| <8, 64> | 9.52 | 8.34 | 1.14 | -| <8, 128> | 18.80 | 15.34 | 1.22 | -| <32, 32> | 18.24 | 15.08 | 1.20 | -| <32, 64> | 34.39 | 29.60 | 1.16 | -| <32, 128> | 65.60 | 58.64 | 1.11 | -| <64, 32> | 34.24 | 29.60 | 1.15 | -| <64, 64> | 63.26 | 58.85 | 1.07 | -| <64, 128> | 130.51 | 117.66 | 1.10 | -| <128, 32> | 63.47 | 57.86 | 1.09 | -| <128, 64> | 126.92 | 115.19 | 1.10 | -| <128, 128> | 254.07 | 230.81 | 1.10 | +#### Encoder performance on T4 and PyTorch -* Performance on FP16 +* Performance on FP32 -| | TorchScript (ms) | CustomExt (ms) | CustomExt Speedup | -|:---------------------:|:----------------:|:--------------:|:-----------------:| -| <1, 32> | 8.50 | 1.69 | 5.02 | -| <1, 64> | 8.66 | 1.71 | 5.06 | -| <1, 128> | 6.74 | 1.91 | 3.52 | -| <8, 32> | 7.72 | 1.84 | 4.19 | -| <8, 64> | 6.74 | 2.51 | 2.68 | -| <8, 128> | 6.67 | 3.73 | 1.78 | -| <32, 32> | 6.19 | 3.70 | 1.67 | -| <32, 64> | 9.36 | 6.78 | 1.38 | -| <32, 128> | 18.41 | 12.63 | 1.45 | -| <64, 32> | 9.20 | 6.63 | 1.38 | -| <64, 64> | 17.35 | 12.36 | 1.40 | -| <64, 128> | 34.14 | 23.90 | 1.42 | -| <128, 32> | 17.09 | 12.32 | 1.38 | -| <128, 64> | 33.28 | 23.44 | 1.41 | -| <128, 128> | 66.03 | 46.83 | 1.40 | +| Batch_size | Seq_len | Precision | TorchScript
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:------------------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP32 | 11.51 | 2.65 | 4.35 | 4.34 | 2.64 | +| 1 | 128 | FP32 | 11.45 | 6.23 | 4.60 | 1.83 | 2.48 | +| 1 | 384 | FP32 | 22.09 | 18.81 | 11.24 | 1.17 | 1.96 | +| 8 | 32 | FP32 | 13.82 | 11.31 | 6.75 | 1.22 | 2.04 | +| 8 | 128 | FP32 | 53.91 | 43.51 | 21.87 | 1.23 | 2.46 | +| 8 | 384 | FP32 | 180.82 | 144.51 | 87.56 | 1.25 | 2.06 | +| 32 | 32 | FP32 | 52.66 | 42.34 | 25.22 | 1.24 | 2.08 | +| 32 | 128 | FP32 | 203.04 | 168.30 | 109.82 | 1.20 | 1.84 | +| 32 | 384 | FP32 | 707.27 | 587.40 | 373.25 | 1.20 | 1.89 | -### Effective FasterTransformer performance +* Performance on FP16 -We demonstrate the inference time of Effective FasterTransformer OP (Effective FT), TensorFlow XLA (TF) and PyTorch (TorchScript) under FP32, FP16 and INT8, and compare to the performance on T4 and V100. +| Batch_size | Seq_len | Precision | TorchScript
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:------------------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP16 | 11.04 | 1.72 | 3.56 | 6.41 | 3.10 | +| 1 | 128 | FP16 | 10.51 | 2.05 | 3.90 | 5.12 | 2.69 | +| 1 | 384 | FP16 | 10.37 | 3.78 | 3.73 | 2.74 | 2.78 | +| 8 | 32 | FP16 | 10.94 | 2.67 | 4.20 | 4.09 | 2.60 | +| 8 | 128 | FP16 | 13.41 | 8.55 | 4.71 | 1.56 | 2.84 | +| 8 | 384 | FP16 | 46.82 | 26.03 | 16.88 | 1.79 | 2.77 | +| 32 | 32 | FP16 | 13.02 | 8.51 | 5.32 | 1.52 | 2.44 | +| 32 | 128 | FP16 | 54.28 | 33.38 | 15.71 | 1.62 | 3.45 | +| 32 | 384 | FP16 | 192.80 | 103.63 | 51.00 | 1.86 | 3.78 | + +* Performance on INT8-v1 + +| Batch_size | Seq_len | TorchScript-FP16
Latency (ms) | FT-INT8-v1
Latency (ms) | EFF-FT-INT8-v1
Latency (ms) | FT-INT8-v1
Speedup | EFF-FT-INT8-v1
Speedup | +|:----------:|:-------:|:-----------------------------------:|:-----------------------------:|:---------------------------------:|:------------------------:|:----------------------------:| +| 1 | 32 | 11.09 | 1.90 | 2.34 | 5.83 | 4.73 | +| 1 | 128 | 10.51 | 1.87 | 2.05 | 5.62 | 5.12 | +| 1 | 384 | 10.29 | 4.25 | 3.48 | 2.42 | 2.95 | +| 8 | 32 | 10.54 | 2.65 | 2.20 | 3.97 | 4.79 | +| 8 | 128 | 13.44 | 8.48 | 3.93 | 1.58 | 3.41 | +| 8 | 384 | 47.11 | 23.19 | 14.75 | 2.03 | 3.19 | +| 32 | 32 | 13.09 | 9.10 | 5.27 | 1.43 | 2.48 | +| 32 | 128 | 53.91 | 29.45 | 15.56 | 1.83 | 3.46 | +| 32 | 384 | 190.43 | 88.69 | 53.14 | 2.14 | 3.58 | + +* Performance on INT8-v2 + +| Batch_size | Seq_len | TorchScript-FP16
Latency (ms) | FT-INT8-v2
Latency (ms) | EFF-FT-INT8-v2
Latency (ms) | FT-INT8-v2
Speedup | EFF-FT-INT8-v2
Speedup | +|:----------:|:-------:|:-----------------------------------:|:-----------------------------:|:---------------------------------:|:------------------------:|:----------------------------:| +| 1 | 32 | 10.96 | 1.98 | 1.99 | 5.53 | 5.50 | +| 1 | 128 | 10.52 | 2.04 | 1.95 | 5.15 | 5.39 | +| 1 | 384 | 10.49 | 2.81 | 2.71 | 3.73 | 3.87 | +| 8 | 32 | 10.49 | 2.61 | 2.10 | 4.01 | 4.99 | +| 8 | 128 | 13.46 | 5.45 | 3.22 | 2.46 | 4.18 | +| 8 | 384 | 47.14 | 15.19 | 9.37 | 3.10 | 5.03 | +| 32 | 32 | 13.00 | 6.10 | 3.46 | 2.13 | 3.75 | +| 32 | 128 | 53.09 | 18.90 | 10.73 | 2.80 | 4.94 | +| 32 | 384 | 186.43 | 59.04 | 35.05 | 3.15 | 5.31 | -The FP32/FP16 results of TensorFlow were obtained by running the `sample/tensorflow/scripts/profile_effective_transformer_performance.sh`. +#### Encoder performance on V100 and PyTorch -The INT8 results of TensorFlow are obtained by running the `profile_effective_transformer_performance_int8.sh`. +* Performance on FP32 -The FP32/FP16 results of PyTorch were obtained by running the `profile_encoder_effective_transformer.sh`. +| Batch_size | Seq_len | Precision | TorchScript
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:------------------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP32 | 6.91 | 1.75 | 1.99 | 3.94 | 3.47 | +| 1 | 128 | FP32 | 7.84 | 3.23 | 2.20 | 2.42 | 3.56 | +| 1 | 384 | FP32 | 9.46 | 7.40 | 7.66 | 1.27 | 1.23 | +| 8 | 32 | FP32 | 8.78 | 5.09 | 4.79 | 1.72 | 1.83 | +| 8 | 128 | FP32 | 20.86 | 16.65 | 11.43 | 1.25 | 1.82 | +| 8 | 384 | FP32 | 61.20 | 52.38 | 37.15 | 1.16 | 1.64 | +| 32 | 32 | FP32 | 20.29 | 16.41 | 9.68 | 1.23 | 2.09 | +| 32 | 128 | FP32 | 70.65 | 62.44 | 36.18 | 1.13 | 1.95 | +| 32 | 384 | FP32 | 232.66 | 209.90 | 119.96 | 1.10 | 1.93 | -The INT8 results of PyTorch were obtained by running the `profile_encoder_effective_transformer_int8.sh`. +* Performance on FP16 -In the experiments, we updated the following parameters: +| Batch_size | Seq_len | Precision | TorchScript
Latency (ms) | FT
Latency (ms) | EFF-FT
Latency (ms) | FT
Speedup | EFF-FT
Speedup | +|:----------:|:-------:|:---------:|:------------------------------:|:---------------------:|:-------------------------:|:----------------:|:--------------------:| +| 1 | 32 | FP16 | 6.66 | 1.31 | 1.81 | 5.08 | 3.67 | +| 1 | 128 | FP16 | 6.99 | 1.59 | 1.81 | 4.39 | 3.86 | +| 1 | 384 | FP16 | 7.16 | 2.37 | 2.05 | 3.02 | 3.49 | +| 8 | 32 | FP16 | 7.18 | 1.99 | 1.84 | 3.60 | 3.90 | +| 8 | 128 | FP16 | 6.83 | 3.97 | 3.48 | 1.72 | 1.96 | +| 8 | 384 | FP16 | 17.51 | 11.47 | 7.44 | 1.52 | 2.35 | +| 32 | 32 | FP16 | 8.99 | 3.93 | 3.43 | 2.28 | 2.62 | +| 32 | 128 | FP16 | 20.20 | 13.53 | 9.29 | 1.49 | 2.17 | +| 32 | 384 | FP16 | 65.62 | 44.13 | 27.92 | 1.48 | 2.35 | -* head_num = 12 -* size_per_head = 64 -* num_layers = 12 -* average sequence length of Effective FasterTransformer is max_seq_len / 2 - -#### Performance on TensorFlow - -* Performance on V100 with FP32 - -| | TF (ms) | Effective FT (ms) | Effective FT Speedup | -|:-------------------------:|:-------:|:-----------------:|:--------------------:| -| <1, 32> | 3.94 | 2.8 | 1.4 | -| <1, 64> | 4.13 | 2.86 | 1.44 | -| <1, 128> | 5.31 | 3.57 | 1.48 | -| <8, 32> | 6.99 | 4.34 | 1.61 | -| <8, 64> | 10.77 | 6.5 | 1.65 | -| <8, 128> | 18.55 | 11.01 | 1.68 | -| <32, 32> | 18.31 | 10.76 | 1.7 | -| <32, 64> | 34.51 | 19.61 | 1.75 | -| <32, 128> | 66.97 | 36.94 | 1.81 | -| <64, 32> | 34.64 | 19.47 | 1.77 | -| <64, 64> | 66.38 | 36.26 | 1.83 | -| <64, 128> | 131.9 | 71.79 | 1.83 | -| <128, 32> | 66.98 | 35.62 | 1.88 | -| <128, 64> | 129.4 | 69.98 | 1.84 | -| <128, 128> | 258.44 | 139.94 | 1.84 | - -* Performance on V100 with FP16 - -| | TF (ms) | Effective FT (ms) | Effective FT Speedup | -|:-------------------------:|:-------:|:-----------------:|:--------------------:| -| <1, 32> | 3.49 | 2.64 | 1.32 | -| <1, 64> | 3.27 | 2.77 | 1.18 | -| <1, 128> | 3.49 | 2.74 | 1.27 | -| <8, 32> | 3.87 | 2.83 | 1.36 | -| <8, 64> | 5.04 | 3.42 | 1.47 | -| <8, 128> | 7.11 | 4.44 | 1.60 | -| <32, 32> | 7.00 | 4.37 | 1.60 | -| <32, 64> | 10.99 | 6.03 | 1.82 | -| <32, 128> | 19.89 | 10.71 | 1.85 | -| <64, 32> | 11.06 | 5.98 | 1.84 | -| <64, 64> | 19.81 | 10.42 | 1.90 | -| <64, 128> | 36.47 | 19.21 | 1.89 | -| <128, 32> | 19.67 | 10.37 | 1.89 | -| <128, 64> | 35.34 | 18.58 | 1.90 | -| <128, 128> | 69.08 | 36.76 | 1.87 | - -* Performance on T4 with FP32 - -| | TF (ms) | Effective FT (ms) | Effective FT Speedup | -|:-------------------------:|:-------:|:-----------------:|:--------------------:| -| <1, 32> | 9.61 | 5.26 | 1.82 | -| <1, 64> | 7.27 | 5.37 | 1.35 | -| <1, 128> | 9.04 | 6.02 | 1.50 | -| <8, 32> | 14.50 | 8.34 | 1.73 | -| <8, 64> | 25.38 | 13.98 | 1.81 | -| <8, 128> | 49.90 | 27.24 | 1.83 | -| <32, 32> | 48.08 | 26.08 | 1.84 | -| <32, 64> | 96.04 | 51.82 | 1.85 | -| <32, 128> | 203.03 | 90.93 | 2.23 | -| <64, 32> | 95.96 | 50.05 | 1.91 | -| <64, 64> | 189.31 | 87.43 | 2.16 | -| <64, 128> | 387.62 | 199.81 | 1.93 | -| <128, 32> | 189.80 | 84.86 | 2.23 | -| <128, 64> | 377.90 | 192.17 | 1.96 | -| <128, 128> | 820.11 | 408.03 | 2.00 | - -* Performance on T4 with FP16 - -| | TF (ms) | Effective FT (ms) | Effective FT Speedup | -|:-------------------------:|:-------:|:-----------------:|:--------------------:| -| <1, 32> | 9.69 | 4.57 | 2.12 | -| <1, 64> | 8.43 | 5.42 | 1.55 | -| <1, 128> | 8.24 | 4.41 | 1.86 | -| <8, 32> | 8.57 | 4.56 | 1.87 | -| <8, 64> | 9.59 | 6.26 | 1.53 | -| <8, 128> | 15.16 | 7.51 | 2.01 | -| <32, 32> | 15.12 | 6.59 | 2.29 | -| <32, 64> | 27.64 | 10.96 | 2.52 | -| <32, 128> | 53.27 | 20.32 | 2.62 | -| <64, 32> | 27.59 | 11.22 | 2.45 | -| <64, 64> | 51.94 | 20.24 | 2.56 | -| <64, 128> | 103.40 | 38.71 | 2.67 | -| <128, 32> | 52.46 | 20.27 | 2.58 | -| <128, 64> | 101.15 | 38.14 | 2.65 | -| <128, 128> | 207.24 | 75.06 | 2.76 | - -* Performance on T4 with INT8 - -| | TF FP16 (ms) | EFF-FT-int8v1-op (ms) | EFF-FT-int8v2-op (ms) | EFF-FT-int8v1-op Speedup | EFF-FT-int8v2-op Speedup | -|:-------------------------:|:------------:|:---------------------:|:---------------------:|:------------------------:|:------------------------:| -| <1, 32> | 7.16 | 5.54 | 6.06 | 1.29 | 1.18 | -| <1, 64> | 6.68 | 5.98 | 5.61 | 1.12 | 1.19 | -| <1, 128> | 6.78 | 5.53 | 5.2 | 1.23 | 1.30 | -| <8, 32> | 6.38 | 5.75 | 5.18 | 1.11 | 1.23 | -| <8, 64> | 9.7 | 6.36 | 5.71 | 1.53 | 1.70 | -| <8, 128> | 14.72 | 7.59 | 6.37 | 1.94 | 2.31 | -| <32, 32> | 13.96 | 7.32 | 6.26 | 1.91 | 2.23 | -| <32, 64> | 26.72 | 12.01 | 8.03 | 2.22 | 3.33 | -| <32, 128> | 53.56 | 22.63 | 13.16 | 2.37 | 4.07 | -| <64, 32> | 26.56 | 11.96 | 8.71 | 2.22 | 3.05 | -| <64, 64> | 52.13 | 21.22 | 12.81 | 2.46 | 4.07 | -| <64, 128> | 102.85 | 43.84 | 23.03 | 2.35 | 4.47 | -| <128, 32> | 50.07 | 20.51 | 13.81 | 2.44 | 3.63 | -| <128, 64> | 99.58 | 39.26 | 22.73 | 2.54 | 4.38 | -| <128, 128> | 208.58 | 87.74 | 49.87 | 2.38 | 4.18 | - -#### Effective FasterTransformer performance on PyTorch - -* Performance on T4 with FP16 - -| | TorchScript (ms) | Effective FT (ms) | Effective FT Speedup | -|:-------------------------:|:----------------:|:-----------------:|:--------------------:| -| <1, 32> | 12.45 | 2.16 | 5.76 | -| <1, 64> | 12.82 | 2.08 | 6.16 | -| <1, 128> | 11.54 | 2.24 | 5.15 | -| <8, 32> | 11.25 | 2.91 | 3.87 | -| <8, 64> | 11.1 | 3.34 | 3.32 | -| <8, 128> | 13.59 | 5.19 | 2.62 | -| <32, 32> | 12.95 | 5.18 | 2.50 | -| <32, 64> | 25.89 | 9.64 | 2.69 | -| <32, 128> | 54.72 | 18.83 | 2.91 | -| <64, 32> | 25.57 | 9.94 | 2.57 | -| <64, 64> | 51.11 | 18.15 | 2.82 | -| <64, 128> | 105.87 | 35.53 | 2.98 | -| <128, 32> | 50.54 | 18.69 | 2.70 | -| <128, 64> | 102.3 | 35.14 | 2.91 | -| <128, 128> | 214.55 | 72.22 | 2.97 | - -* Performance on T4 with INT8 - -| | TorchScript (ms) | EFF-FT-INT8v1-OP (ms) | EFF-FT-INT8v2-OP (ms) | EFF-FT-INT8v1-OP Speedup | EFF-FT-INT8v2-OP Speedup | -|:-------------------------:|:----------------:|:---------------------:|:---------------------:|:------------------------:|:------------------------:| -| <1, 32> | 11.74 | 2.49 | 2.42 | 4.71 | 4.85 | -| <1, 64> | 12.04 | 2.3 | 2.45 | 5.23 | 4.91 | -| <1, 128> | 14.17 | 2.93 | 3.01 | 4.84 | 4.71 | -| <8, 32> | 11.13 | 2.53 | 2.6 | 4.40 | 4.28 | -| <8, 64> | 10.7 | 3.27 | 2.9 | 3.27 | 3.69 | -| <8, 128> | 13.63 | 5.97 | 4.08 | 2.28 | 3.34 | -| <32, 32> | 12.97 | 5.62 | 4.3 | 2.31 | 3.02 | -| <32, 64> | 25.83 | 10.22 | 6.97 | 2.53 | 3.71 | -| <32, 128> | 54.59 | 21.32 | 12.92 | 2.56 | 4.23 | -| <64, 32> | 25.85 | 10.28 | 7.43 | 2.51 | 3.48 | -| <64, 64> | 52.35 | 19.59 | 12.81 | 2.67 | 4.09 | -| <64, 128> | 109.09 | 42.38 | 25.09 | 2.57 | 4.35 | -| <128, 32> | 52.08 | 19.67 | 13.66 | 2.65 | 3.81 | -| <128, 64> | 102.25 | 37.93 | 24.69 | 2.70 | 4.14 | -| <128, 128> | 215.16 | 86.5 | 54.04 | 2.49 | 3.98 | ### Performance on BERT Applications: SQuAD MRPC #### Performance of TensorFlow -* [BERT-base-SQuAD-1.1 model]( For FP32/FP16, we use https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_ckpt_base_qa_squad11_amp_128/versions/19.03.1/zip; for INT8, we use https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip, finetune it with FP32 precision and then use QAT to generate INT8 checkpoint.), batch size 8, seq len 128, on T4. +* [BERT-base-SQuAD-1.1 model]( For FP32/FP16, we use https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_ckpt_base_qa_squad11_amp_128/versions/19.03.1/zip; for INT8, we use https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip, finetune it with FP32 precision and then use QAT-KD to generate INT8 checkpoint.), batch size 8, seq len 128, on T4. We use `tensorflow/tensorflow_bert/profile_transformer_inference.py` to generate the following results. @@ -1437,11 +1361,11 @@ We use `tensorflow/tensorflow_bert/profile_transformer_inference.py` to generate | FasterTransformer OP FP16 | 79.03% | 86.23% | 1.19 | | FasterTransformer OP FP16 (remove padding) | 79.03% | 86.23% | 0.68 | | | | | -| Finetune FP32 | 82.42% | 89.55% | x | -| FasterTransformer OP INT8v1 | 81.86% | 89.30% | 1.27 | -| FasterTransformer OP INT8v2 | 81.94% | 89.20% | 0.92 | -| FasterTransformer OP INT8v1 (remove padding) | 81.86% | 89.30% | 0.89 | -| FasterTransformer OP INT8v2 (remove padding) | 81.94% | 89.20% | 0.66 | +| Finetune FP32 | 78.81% | 86.25% | x | +| FasterTransformer OP INT8v1 | 80.84% | 87.69% | 1.25 | +| FasterTransformer OP INT8v2 | 80.70% | 87.59% | 0.83 | +| FasterTransformer OP INT8v1 (remove padding) | 80.84% | 87.69% | 0.77 | +| FasterTransformer OP INT8v2 (remove padding) | 80.70% | 87.59% | 0.62 | #### Performance of PyTorch @@ -1456,10 +1380,10 @@ We use `tensorflow/tensorflow_bert/profile_transformer_inference.py` to generate | FasterTransformer OP FP16 | 86.93% | 93.15% | 11.28 | | FasterTransformer OP FP16 (remove padding) | 86.96% | 93.16% | 5.60 | | | | | -| FasterTransformer OP INT8 mode 2 | 86.08% * | 92.59% * | 7.19 | -| FasterTransformer OP INT8 mode 2 (remove padding) | 86.08% * | 92.59% * | 4.51 | +| FasterTransformer OP INT8 mode 2 | 87.50% * | 93.48% * | 6.81 | +| FasterTransformer OP INT8 mode 2 (remove padding) | 87.50% * | 93.48% * | 3.48 | -Note *: the checkpoint for INT8 is QAT-finetuned by ourselves, and is different with other's, so the accuracy is not comparable, the corresponding FP16 accuracy is `F1/EM: 92.89%/86.62%`. +Note *: the checkpoint for INT8 is QAT-finetuned with knowledge-distillation by ourselves, and is different with other's, so the accuracy is not comparable, the corresponding FP16 accuracy is `F1/EM: 92.89%/86.62%`. * BERT-base-MRPC, dev set: batch size 8, seq len 128, on T4 diff --git a/docs/gpt_guide.md b/docs/gpt_guide.md new file mode 100644 index 000000000..f6a271643 --- /dev/null +++ b/docs/gpt_guide.md @@ -0,0 +1,294 @@ +# GPT + +## Table Of Contents + +- [GPT](#gpt) + - [Table Of Contents](#table-of-contents) + - [Model architecture](#model-architecture) + - [Introduction](#introduction) + - [Setup](#setup) + - [Requirements](#requirements) + - [How to use](#how-to-use) + - [Prepare](#prepare) + - [Run GPT](#run-gpt) + - [gpt with triton backend](#gpt-with-triton-backend) + - [Performance](#performance) + - [Perofrmance of GPT-89B](#perofrmance-of-gpt-89b) + - [Performance C of GPT-175B](#performance-c-of-gpt-175b) + +## Model architecture + +
+
Fig. 1 Flowchart of GPT model.
+ +## Introduction + +GPT model is a variant of Decoding model. GPT model does not require the results from encoder and the cross multi-head attention, and use GeLU as the activation. However, OpenAI shows that using very giant model and lots of training data can significantly improve the capacity of GPT model in [their paper](https://arxiv.org/abs/2005.14165). However, it is impossible to put such model into a single GPU. For example, the largest model, GPT-3, has 175 billion parameters, which takes about 350GBs under half data type. Therefore, multi-gpus, even multi-nodes, is necessary. + +In FasterTransformer 4.0, we propose the multi-gpu inference library to run GPT-3. FasterTransformer supports `Tensor Parallel` and `Layer Parallel` in the same time and provides the api of cpp, TensorFlow/PyTorch op and triton backend. In cpp and PyTorch op, users can use MPI to run multiple gpus on multiple nodes. For example, using 4 dgx-1 V100 nodes (16 GBs memory per GPU) to run the GPT-3 model. To be convenient on serving, we also provide the triton backend. However, this backend only supports single nodes, multi-gpus currently. For TensorFlow op, FasterTransformer only supports single gpu now. + +The arguments, inputs, and outputs of GPT: + +* Arguments: + 1. Maximum batch size (B) + 2. Maximum sequence length (S) + 3. Top k value (K) + 4. Top p value (P) + 5. Head number (H) + 6. Size per head (N) + 7. Number of decoder layers + 8. Start id of the vocabulary + 9. End id of the vocabulary + 10. Vocab size (V) + 11. Tensor parallel size + 12. Layer parallel size +* Inputs: + 1. The table for embedding lookup. The shape is \[ V, H x N \]. + 2. The weights of all parameters. + 3. Position encoding table. The shape is \[ S, H x N \]. + 4. Inputs contexts. The shape is \[ b, s \], where b <= B, s <= S. +* Outputs: + 1. The output ids. The shape is \[b, S \]. + +## Setup + +### Requirements + +- CMake >= 3.8 for Tensorflow, CMake >= 3.13 for PyTorch +- CUDA 10.1 or newer version +- Python 3 is recommended because some features are not supported in python 2 +- Tensorflow 1.13 or 1.14 or 1.15 +- PyTorch >= 1.5.0 + +Recommend use nvcr image like `nvcr.io/nvidia/tensorflow:20.12-tf1-py3` or `nvcr.io/nvidia/pytorch:20.12-py3`. + +## How to use + +### Prepare + +* Install required tools + +```bash +pip install -r ../requirement.txt +``` + +To run the GPT on c, users need to convert the checkpoint of TensorFlow or PyTorch to binary files, and then load by FasterTransformer c api. Unfortunately, there is no published large model. So, users are only able to verify the correctness by smaller model. Currently, FasterTransformer provides two kinds of samples. First one is using the checkpoint of [OpenAI GPT-2 model](https://github.com/openai/gpt-2) (which is trained by TensorFlow); Another choice is using the checkpoint of [Megatron](https://github.com/NVIDIA/Megatron-LM) (which is trained by pytorch). + +* Download vocab and merge table + +They can be used in both OpenAI GPT-2 and Megatron. + +```bash +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -P models +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -P models +``` + +* Downlaod openai-gpt model and convert + +To convert the OpenAI GPT model to binary, FasterTransformer provides a tool `sample/tensorflow/utils/openai_gpt_ckpt_convert.py` to convert the checkpoint. + +```bash +python tensorflow/utils/download_gpt2_model.py +e.g. python tensorflow/utils/download_gpt2_model.py 124M +python ../sample/tensorflow/utils/openai_gpt_ckpt_convert.py -o models/openai-gpt-models/c-model/124m/ -i models/124M/model.ckpt -g 1 # convert 124M model with 1 TP mode +python ../sample/tensorflow/utils/openai_gpt_ckpt_convert.py -o models/openai-gpt-models/c-model/124m/ -i models/124M/model.ckpt -g 4 # convert 124M model with 4 TP mode +``` + +In the repo of OpenAI, they provide many models, including `124M`, `355M`, `774M` and `1558M` + +* Download megatron model and convert + +To convert the Megatron GPT model to binary, FasterTransformer provides a tool `sample/pytorch/utils/megatron_ckpt_convert.py` to convert the checkpoint. + +```bash +wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip +mkdir -p models/megatron-models/345m +unzip megatron_lm_345m_v0.0.zip -d models/megatron-models/345m +git clone https://github.com/NVIDIA/Megatron-LM.git +python ../sample/pytorch/utils/megatron_ckpt_convert.py -i ./models/megatron-models/345m/release/ -o ./models/megatron-models/c-model/345m/ -t_g 1 -i_g 1 +python ../sample/pytorch/utils/megatron_ckpt_convert.py -i ./models/megatron-models/345m/release/ -o ./models/megatron-models/c-model/345m/ -t_g 1 -i_g 8 +``` + +where `t_g` means the number GPUs of TP during training, and `i_g` means the number of GPUs for TP during inference. + +Note that there are different checkpoint version of Megatron. The version of the checkpoint above is 0. If users have trained a model by themselves, the default version of latest Megatron is 3. To convert the checkpoint with version 3, please add `-checkpoint_version 3`. + +### Run GPT + +1. Run GPT under on C++ with multiple gpu + + 1.1 Generate the `decoding_gemm_config.in` file. + + ```bash + ./bin/gpt_gemm + E.g., ./bin/gpt_gemm 8 8 12 64 50257 32 1 1 + ``` + + Here, `local_batch_size` can be set as `batch_size` if users do not use the layer parallelism. If users use layer parallelism, we recommand to set `local_batch_size` to be smaller than `batch_size` to hide the bubble. But this requires larger `batch_size`. `context_local_batch_size` is used for computing the k/v cache of input. Similar to `local_batch_size`, users can use `batch_size` directly if you don't use layer parallelism, and setting to be smaller than `batch_size` when you use layer parallelism. + + 1.2 Run GPT on C++ + + Users can see the details of arguments in `sample/cpp/gpt_config.ini`. It controls the model path, model size, tensor parallelism size, and some hyper-parameters. + + ```bash + ./bin/gpt_sample + ``` + + then use following script to convert the token ids to sentence. + + ```bash + python ../sample/pytorch/utils/convert_gpt_token.py --vocab_file=./models/gpt2-vocab.json --bpe_file=./models/gpt2-merges.txt + ``` + + By setting the `is_half` of `gpt_config.ini` to 1, users can run gpt model under fp16. + + 1.3 Run with tensor parallelism (TP), layer parallelism (LP) and pipeline parallelism (PP) + + Users can use `tensor_para_size` and `layer_para_size` in `gpt_config.ini` to control the size of model parallel. Besides, in the layer parallelism, we can use pipeline parallelism to reduce the bubbles. We can set the `layer_para_batch_size` to determine the real batch size for each forward. For example, if the total batch size is 4, and layer_para_batch_size is 1, then we will split the total batch into 4 parts, and each time we only use 1 batch size. Users can set them in the `gpt_config.ini`. + + Note that we split the definition of LP and PP here, but we often combine them to hide the cost of bubble. + + ```bash + mpirun -n 8 ./bin/gpt_sample + python ../sample/pytorch/utils/convert_gpt_token.py --vocab_file=./models/gpt2-vocab.json --bpe_file=./models/gpt2-merges.txt + ``` + + 1.4 Run gpt on multi-nodes + + Since the c sample codes use the MPI to communicate, it can extend to multi-nodes easily, except that users need to setup some network environment to communicate between multi-nodes. The following scripts are an example to show how to run multi-nodes inference on slurm. + + ```bash + srun -N2 -n2 -t 600 --pty bash # Assume we get 2 nodes: prm-dgx-09 and prm-dgx-10 + srun -N2 -n2 docker pull nvcr.io/nvidia/tensorflow:20.07-tf1-py3 + + srun -N2 -n2 nvidia-docker run -itd --rm --privileged --network=host --pid=host --cap-add=IPC_LOCK --device=/dev/infiniband -v $PWD:$PWD -w $PWD --name ft-test nvcr.io/nvidia/tensorflow:20.12-tf1-py3 /bin/bash + + srun -N2 -n2 nvidia-docker exec -i --env SLURM_NTASKS --env SLURM_NODEID --env SLURM_PROCID --env SLURM_STEP_NODELIST --env SLURMD_NODENAME --privileged ft-test bash -c "mkdir /root/.ssh && cp $PWD/ssh/* /root/.ssh && chmod 700 /root/.ssh && chmod 640 /root/.ssh/authorized_keys2 && chmod 400 /root/.ssh/id_rsa && apt-get update && apt-get install ssh -y && mkdir /run/sshd/ && /usr/sbin/sshd -p 11068 && nvidia-smi -lgc 1530" + + nvidia-docker exec -ti ft-test bash + cd FasterTransformer/build + mpirun --allow-run-as-root -np 2 -H prm-dgx-09:1,prm-dgx-10:1 -mca plm_rsh_args "-p 11068" ./bin/gpt_sample + srun -N2 -n2 docker stop ft-test + ``` + +2. Run GPT on PyTorch + + Basically, `gpt_sample.py` includes the example how to declare a model, load a ckeckpoint, and forward context inputs and get generated outputs in Pytorch. + + For generating outputs based on context inputs, create a text file including the context inputs (line by line) and set `--sample_file_input` to the text file path. (By default, the script will generate outputs without context inputs.) Set `--sample_file_output` to write the outputs to a file. Use `--fp_16` to run in FP16. + + Run with `-h` to see more settings. + ```bash + python ./pytorch/gpt_sample.py -h + ``` + + 2.1 Run GPT with TP and PP on single node (NVIDIA DGX A100) + ```bash + # No parallelism (tensor_para_size=1, layer_para_size=1) + mpirun -n 1 --allow-run-as-root python ./pytorch/gpt_sample.py + + # TP (tensor_para_size=8, layer_para_size=1) + mpirun -n 8 --allow-run-as-root python ./pytorch/gpt_sample.py --tensor_para_size=8 --layer_para_size=1 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/8-gpu" + + # LP (tensor_para_size=1, layer_para_size=8) + mpirun -n 8 --allow-run-as-root python ./pytorch/gpt_sample.py --tensor_para_size=1 --layer_para_size=8 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/1-gpu" + + # TP and LP (tensor_para_size=4, layer_para_size=2) + mpirun -n 8 --allow-run-as-root python ./pytorch/gpt_sample.py --tensor_para_size=4 --layer_para_size=2 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/4-gpu" + ``` + + For PP, set `--layer_para_batch_size` so that batch_size >= layer_para_batch_size. + + 2.2 Run GPT with TP and PP on single-node/multi-node (NVIDIA SuperPOD) + #### Set up in interactive mode + + ```bash + srun -A devtech -J devtech-gpt:gpt  -p luna -N1 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:20.12-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer --pty bash + + mkdir build && cd build + cmake -DSM=80 -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON .. && make -j12 + ``` + + #### Run on singe-node + * tensor_para_size=8, layer_para_size=1 + + ```bash + srun -A devtech -p luna -N1 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:20.12-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer/build python ./pytorch/gpt_sample.py --tensor_para_size=8 --layer_para_size=1 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/8-gpu" + ``` + + #### Run on multi-node + * tensor_para_size=8, layer_para_size=2 + + ```bash + srun -A devtech -p luna -N2 --mpi=pmix --ntasks-per-node=8 --container-image nvcr.io/nvidia/pytorch:20.12-py3 --container-mounts /lustre/fsw/devtech/hpc-devtech/dahn/FasterTransformer:/workspace/fastertransformer --container-workdir /workspace/fastertransformer/build python ./pytorch/gpt_sample.py --tensor_para_size=8 --layer_para_size=2 --ckpt_path="/workspace/fastertransformer/models/megatron-models/c-model/345m/8-gpu" + ``` + +3. Run GPT on tensorflow + + Note that the tensorflow op only supports single gpu. + + ```bash + ./bin/gpt_gemm 4 4 12 64 50257 1 1 1 + python tensorflow/gpt_sample.py --batch_size=4 \ + --length=32 \ + --top_k=4 \ + --top_p=0.6 \ + --data_type=fp16 + ``` + +### gpt with triton backend + +Details are in [transformer_backend](https://github.com/triton-inference-server/fastertransformer_backend) + +## Performance + +Hardware settings: +* 8xA100-80GBs (with mclk 1593MHz, pclk 1410MHz) with AMD EPYC 7742 64-Core Processor + +We demonstrate the inference time of Megatron and FasterTransformer on Triton, and show the speedup of FasterTransformer compare to Megatron. In the experiments of encoder, we updated the following parameters: + +* head_num = 96 +* size_per_head = 128 +* num_layers = 48 for GPT-89B model, 96 for GPT-175B model +* data_type = FP16 +* vocab_size = 51200 +* top_p = 0.9 +* tensor parallel size = 8 + +### Perofrmance of GPT-89B + +| Batch_size | Input Seqlen | Output Seqlen | Megatron
Latency (ms) | FT
Latency (ms) | FT
Speedup | +|:----------:|:------------:|:-------------:|:---------------------------:|:---------------------:|:----------------:| +| 1 | 128 | 8 | 342.86 | 279.44 | 1.23 | +| 2 | 128 | 8 | 369.43 | 280.24 | 1.32 | +| 4 | 128 | 8 | 540.97 | 317.71 | 1.70 | +| 8 | 128 | 8 | 912.46 | 377.50 | 2.42 | +| 12 | 128 | 8 | 1263.39 | 445.46 | 2.84 | +| 16 | 128 | 8 | 1663.39 | 524.80 | 3.17 | +| 20 | 128 | 8 | 1991.16 | 575.83 | 3.46 | +| 32 | 128 | 8 | 3086.85 | 786.57 | 3.92 | +| | | | | | | +| 1 | 512 | 32 | 1244.81 | 887.52 | 1.40 | +| 2 | 512 | 32 | 1357.54 | 940.11 | 1.44 | +| 4 | 512 | 32 | 1970.08 | 1133.22 | 1.74 | +| 8 | 512 | 32 | 3341.66 | 1415.02 | 2.36 | +| 16 | 512 | 32 | 6090.07 | 1952.2 | 3.12 | + +### Performance C of GPT-175B + +| Batch_size | Input Seqlen | Output Seqlen | Megatron
Latency (ms) | FT
Latency (ms) | FT
Speedup | +|:----------:|:------------:|:-------------:|:---------------------------:|:---------------------:|:----------------:| +| 1 | 128 | 8 | 660.38 | 488.86 | 1.35 | +| 2 | 128 | 8 | 687.34 | 509.47 | 1.35 | +| 4 | 128 | 8 | 1004.88 | 629.64 | 1.60 | +| 8 | 128 | 8 | 1705.07 | 749.86 | 2.27 | +| 12 | 128 | 8 | 2365.02 | 886.24 | 2.67 | +| 16 | 128 | 8 | 3111.57 | 1037.47 | 3.00 | +| 20 | 128 | 8 | 3723.73 | 1135.72 | 3.28 | +| 32 | 128 | 8 | 5778.72 | 1547.44 | 3.73 | +| | | | | | | +| 1 | 512 | 32 | 2384.78 | 1719.96 | 1.39 | +| 2 | 512 | 32 | 2503.24 | 1830.56 | 1.37 | +| 4 | 512 | 32 | 3658.65 | 2092.56 | 1.75 | +| 8 | 512 | 32 | 6238.79 | 2629.97 | 2.37 | +| 16 | 512 | 32 | 11409.53 | 3706.23 | 3.08 | \ No newline at end of file diff --git a/docs/images/FT_Encoder_T4.png b/docs/images/FT_Encoder_T4.png index 72c978eb9..17b13f4fc 100644 Binary files a/docs/images/FT_Encoder_T4.png and b/docs/images/FT_Encoder_T4.png differ diff --git a/docs/images/FT_GPT_A100.png b/docs/images/FT_GPT_A100.png new file mode 100644 index 000000000..10605fcae Binary files /dev/null and b/docs/images/FT_GPT_A100.png differ diff --git a/docs/images/PyTorch_Decoder_T4_fp16.png b/docs/images/PyTorch_Decoder_T4_fp16.png deleted file mode 100644 index fb728ab6e..000000000 Binary files a/docs/images/PyTorch_Decoder_T4_fp16.png and /dev/null differ diff --git a/docs/images/Py_Decoder_T4.png b/docs/images/Py_Decoder_T4.png new file mode 100644 index 000000000..c4f398b98 Binary files /dev/null and b/docs/images/Py_Decoder_T4.png differ diff --git a/docs/images/Py_Encoder_T4.png b/docs/images/Py_Encoder_T4.png new file mode 100644 index 000000000..0268f2a21 Binary files /dev/null and b/docs/images/Py_Encoder_T4.png differ diff --git a/docs/images/TF_Decoder_T4.png b/docs/images/TF_Decoder_T4.png new file mode 100644 index 000000000..ab940c7be Binary files /dev/null and b/docs/images/TF_Decoder_T4.png differ diff --git a/docs/images/TF_Decoder_T4_fp16.png b/docs/images/TF_Decoder_T4_fp16.png deleted file mode 100644 index 294aac5df..000000000 Binary files a/docs/images/TF_Decoder_T4_fp16.png and /dev/null differ diff --git a/docs/images/TF_Encoder_T4.png b/docs/images/TF_Encoder_T4.png index 284fb8b33..d4af8ed8a 100644 Binary files a/docs/images/TF_Encoder_T4.png and b/docs/images/TF_Encoder_T4.png differ diff --git a/docs/images/TF_cpp_Encoder_T4.png b/docs/images/TF_cpp_Encoder_T4.png deleted file mode 100644 index 72aea7de7..000000000 Binary files a/docs/images/TF_cpp_Encoder_T4.png and /dev/null differ diff --git a/docs/images/gpt_flowchart.png b/docs/images/gpt_flowchart.png new file mode 100644 index 000000000..f6ba9250f Binary files /dev/null and b/docs/images/gpt_flowchart.png differ diff --git a/docs/images/workflow-of-int8-inference.png b/docs/images/workflow-of-int8-inference.png index 3a24ac8da..8adcf2bd9 100644 Binary files a/docs/images/workflow-of-int8-inference.png and b/docs/images/workflow-of-int8-inference.png differ diff --git a/fastertransformer/CMakeLists.txt b/fastertransformer/CMakeLists.txt index 239d97ba3..de18cf29f 100644 --- a/fastertransformer/CMakeLists.txt +++ b/fastertransformer/CMakeLists.txt @@ -13,13 +13,15 @@ # limitations under the License. cmake_minimum_required(VERSION 3.8) add_subdirectory(cuda) +add_subdirectory(utils) add_subdirectory(gemm_test) if(BUILD_TF) add_subdirectory(tf_op) endif() -if(BUILD_THE OR BUILD_THS) +if(BUILD_PYT) add_subdirectory(th_op) endif() -add_subdirectory(trt_fused_multihead_attention) \ No newline at end of file +add_subdirectory(trt_fused_multihead_attention) +add_subdirectory(triton_backend) \ No newline at end of file diff --git a/fastertransformer/bert_encoder_transformer.h b/fastertransformer/bert_encoder_transformer.h old mode 100755 new mode 100644 index b28c96f1c..bc76a9993 --- a/fastertransformer/bert_encoder_transformer.h +++ b/fastertransformer/bert_encoder_transformer.h @@ -21,19 +21,20 @@ #pragma once #include -#include "fastertransformer/allocator.h" +#include "fastertransformer/utils/allocator.h" +#include "fastertransformer/utils/common_structure.h" #include "fastertransformer/cuda/cuda_kernels.h" #include "fastertransformer/cuda/cuda_int8_kernels.h" #include "fastertransformer/cuda/open_attention.h" -#include "fastertransformer/common_structure.h" #include "fastertransformer/gemm_test/encoder_gemm_func.h" #include "fastertransformer/gemm_test/encoder_igemm_func.h" +#include "fastertransformer/utils/functions.h" namespace fastertransformer { template -class EncoderInitParam +class BertInitParam { public: const T *from_tensor = nullptr; @@ -55,11 +56,15 @@ class EncoderInitParam int valid_word_num = -1; int layer_idx = 0; int layer_num = 12; - - //First 80 are for activation amaxs. - //For each activation amax, there are 4 values: amax, amax/127.0f, amax/127.0f/127.0f, 127.0f/amax -- input_amax 0-3 , Q_aftergemm_amax 4-7, Qbias_amax 8-11, K_aftergemm_amax 12-15, Kbias_amax 16-19, V_aftergemm_amax 20-23, Vbias_amax 24-27, bmm1_amax 28-31, Softmax_amax 32-35, bmm2_amax 36-39, Proj_aftergemm_scale 40-43, ProjBiasNorm_amax 44-47, FC1_aftergemm_amax 48-51, F1Bias_amax 52-55, FC2_aftergemm_amax 56-59, F2BiasNorm_amax 60-63, reserve 64-79 - //following by kernel amaxs : query_weight_amax_list, key_weight_amax_list, value_weight_amax_list, proj_weight_amax_list, FC1_weight_amax_list, FC2_weight_amax_list - //following by int8 gemm deQ scale list: Q_deQ_scale, K_deQ_scale, V_deQ_scale, bmm1_deQ_scale, bmm2_deQ_scale, FC0_deQ_scale, FC1_deQ_scale, FC2_deQ_scale + + //Part 1: + // First 80 are for activation amaxs. For each activation amax, there are 4 values: amax, amax/127.0f, amax/127.0f/127.0f, 127.0f/amax -- input_amax 0-3 , Q_aftergemm_amax 4-7, Qbias_amax 8-11, K_aftergemm_amax 12-15, Kbias_amax 16-19, V_aftergemm_amax 20-23, Vbias_amax 24-27, bmm1_amax 28-31, Softmax_amax 32-35, bmm2_amax 36-39, Proj_aftergemm_scale 40-43, ProjBiasNorm_amax 44-47, FC1_aftergemm_amax 48-51, F1Bias_amax 52-55, FC2_aftergemm_amax 56-59, F2BiasNorm_amax 60-63, reserve 64-79 + //Part 2: + // Kernel amaxs, for each kernel amax list, there are output_channel values : query_weight_amax_list, key_weight_amax_list, value_weight_amax_list, proj_weight_amax_list, FC1_weight_amax_list, FC2_weight_amax_list + //Part 3: + // Int8 gemm deQFactor list (8 values): Q_deQ_scale, K_deQ_scale, V_deQ_scale, bmm1_deQ_scale, bmm2_deQ_scale, FC0_deQ_scale, FC1_deQ_scale, FC2_deQ_scale + //Part 4: + // Amax used in trt fused mha kernel (3 values) : QKVbias_amax, Softmax_amax, bmm2_amax const float *amaxList = nullptr; const int* trt_seqlen_offset = nullptr; int trt_seqlen_size = -1; @@ -90,25 +95,20 @@ class BertEncoderTransformer IAllocator *allocator_ = NULL; typename Traits_::MultiHeadAttention *attention_ = NULL; typedef typename Traits_::DataType DataType_; - EncoderInitParam param_; + BertInitParam param_; - const cudaDataType_t computeType_ = Traits_::computeType; const cudaDataType_t AType_ = Traits_::AType; const cudaDataType_t BType_ = Traits_::BType; const cudaDataType_t CType_ = Traits_::CType; - int cublasAlgo_[3]; - std::map cublasLtAlgoMap_; - std::map cublasAlgoMap_; + std::map cublasAlgoMap_; std::map parameterMap_; DataType_ *buf_ = NULL; DataType_ *attr_out_buf_; DataType_ *attr_matmul_buf_; DataType_ *inter_matmul_buf_; - DataType_ *attr_out_tmp_buf_; - - DataType_ *out_tmp_buf_; - DataType_ *from_tensor_tmp_buf_; + DataType_ *attr_matmul_unnormed_buf_; + void* cublas_workspace_ = NULL; int batch_size_; int from_seq_len_; @@ -116,17 +116,18 @@ class BertEncoderTransformer int head_num_; int size_per_head_; + int sm_; bool allow_gemm_test_ = false; bool use_ORDER_COL32_2R_4R4_ = false; - //for int8 quantization const float *FC0_weight_amax_list, *FC1_weight_amax_list, *FC2_weight_amax_list; - float int8O_gemm_deQ_scale_list[INT8O_GEMM_NUM]; - const float *bmm2_amax_ptr, *ProjBiasNorm_amax_ptr, *F1Bias_amax_ptr, *F2BiasNorm_amax_ptr, *to_tensor_amax_ptr, *Proj_aftergemm_amax_ptr, *F1_aftergemm_amax_ptr, *F2_aftergemm_amax_ptr; + float scale_list[INT8O_GEMM_NUM+TRT_FUSED_MHA_AMAX_NUM]; + const float *bmm2_amax_ptr, *ProjBiasNorm_amax_ptr, *F1Bias_amax_ptr, *F2BiasNorm_amax_ptr, *to_tensor_amax_ptr, *Proj_aftergemm_amax_ptr, *F1_aftergemm_amax_ptr, *F2_aftergemm_amax_ptr, *int8O_gemm_deQ_scale_list; //int8_mode == 0 -- not use int8 - //int8_mode == 1 -- use int8 without quantized residual - //int8_mode == 2 -- use int8 with quantized residual + //int8_mode == 1 -- use int8; without quantized residual; when (batch*seqLen >= 512) or (seqLen % 32 !=0 ), using trt fused mha + //int8_mode == 2 -- use int8; with quantized residual; with trt fused mha + //int8_mode == 3 -- use int8; with quantized residual; without trt fused mha int int8_mode_; int layer_idx_; int layer_num_; @@ -143,11 +144,11 @@ class BertEncoderTransformer layer_idx_ = layer_idx; } - int calBufSizeInByte(int batch_size, int seq_len, int head_num, int size_per_head, int int8_mode){ - int m = batch_size*seq_len; - int n = head_num*size_per_head; - int k = n; - int normal_buf_size; + size_t calBufSizeInByte(int batch_size, int seq_len, int head_num, int size_per_head, int int8_mode){ + size_t m = batch_size*seq_len; + size_t n = head_num*size_per_head; + size_t k = n; + size_t normal_buf_size; if (int8_mode != 0){ //transA_from_tensor & transformer_out_tmp_DataType normal_buf_size = m*k*sizeof(DataType_) + @@ -163,7 +164,7 @@ class BertEncoderTransformer m*n*sizeof(DataType_); } else{ - normal_buf_size = sizeof(DataType_) * (m*n) * (6 + 3); + normal_buf_size = sizeof(DataType_) * (m*n) * 7 + ((sizeof(half) == sizeof(DataType_)) ? CUBLAS_WORKSPACE_SIZE : 0); } return normal_buf_size; } @@ -172,12 +173,12 @@ class BertEncoderTransformer { char mark[1000]; bool parameterInMap; + int dataType = is_fp16 == 0 ? FLOAT_DATATYPE : HALF_DATATYPE; if (int8_mode != 0) { - int8_mode = 1; - is_fp16 = 1; + dataType = INT8_DATATYPE; } - sprintf(mark, "%d_%d_%d_%d_%d_%d", batch_size, seq_len, head_num, size_per_head, int8_mode, is_fp16); + sprintf(mark, "%d_%d_%d_%d_%d", batch_size, seq_len, head_num, size_per_head, dataType); if (parameterMap_.find(std::string(mark)) != parameterMap_.end()) parameterInMap = true; else @@ -211,65 +212,6 @@ class BertEncoderTransformer buffer = reinterpret_cast(allocator->malloc(buf_size_in_byte, false)); } - void readAlgoFromConfig(int int8_mode) - { - - if (int8_mode != 0) - { - cublasLtAlgoMap_.clear(); - parameterMap_.clear(); - FILE* fd = fopen(IGEMM_CONFIG, "r"); - if (fd == NULL) - return; - int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val, swizzle, reductionScheme, workspaceSize, stages; - int batch_size, seq_len, head_num, size_per_head; - while(fscanf(fd,"%d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d\n", &batch_size, &seq_len, &head_num, &size_per_head, &batchCount2, &m2, &n2, &k2, &algoId, &customOption, &tile, &splitK_val, &swizzle, &reductionScheme, &workspaceSize, &stages)!=EOF) - { - char mark[256]; - sprintf(mark, "%d_%d_%d_%d_1_1", batch_size, seq_len, head_num, size_per_head); - std::string markStr0(mark); - sprintf(mark, "%d_%d_%d_%d", batchCount2, m2, n2, k2); - std::string markStr(mark); - //workspaceSize should be zero - if (cublasLtAlgoMap_.find(markStr) == cublasLtAlgoMap_.end() && workspaceSize == 0) - { - parameterMap_[markStr0] = 1; - cublasLtAlgoMap_[markStr].algoId = algoId; - cublasLtAlgoMap_[markStr].customOption = customOption; - cublasLtAlgoMap_[markStr].tile = tile; - cublasLtAlgoMap_[markStr].splitK_val = splitK_val; - cublasLtAlgoMap_[markStr].swizzle = swizzle; - cublasLtAlgoMap_[markStr].reductionScheme = reductionScheme; - cublasLtAlgoMap_[markStr].workspaceSize = workspaceSize; - cublasLtAlgoMap_[markStr].stages = stages; - } - } - fclose(fd); - } - else - { - cublasAlgoMap_.clear(); - parameterMap_.clear(); - FILE* fd = fopen(GEMM_CONFIG, "r"); - if (fd == NULL) - return; - int batchCount2, m2, n2, k2, is_fp16, algoId; - int batch_size, seq_len, head_num, size_per_head; - float runtime; - while(fscanf(fd,"%d %d %d %d ### %d %d %d %d %d %d %f\n", &batch_size, &seq_len, &head_num, &size_per_head, &batchCount2, &m2, &n2, &k2, &is_fp16, &algoId, &runtime)!=EOF) - { - char mark[256]; - sprintf(mark, "%d_%d_%d_%d_0_%d", batch_size, seq_len, head_num, size_per_head, is_fp16); - std::string markStr0(mark); - parameterMap_[markStr0] = 1; - sprintf(mark, "%d_%d_%d_%d_%d", batchCount2, m2, n2, k2, is_fp16); - std::string markStr(mark); - cublasAlgoMap_[markStr] = algoId; - } - fclose(fd); - } - } - bool gemmTest(int batch_size, int seq_len, int head_num, int size_per_head, int int8_mode, int is_fp16) { @@ -283,7 +225,7 @@ class BertEncoderTransformer if (!checkParameterInMap(batch_size, seq_len, head_num, size_per_head, int8_mode, is_fp16)) { - readAlgoFromConfig(int8_mode); + readAlgoFromConfig(int8_mode, cublasAlgoMap_, parameterMap_); } else { @@ -302,7 +244,7 @@ class BertEncoderTransformer { generate_encoder_igemm_config(batch_size, seq_len, head_num, size_per_head, gemm_test_buf); freeBufferForGemmTest(allocator_, gemm_test_buf); - readAlgoFromConfig(int8_mode); + readAlgoFromConfig(int8_mode, cublasAlgoMap_, parameterMap_); hasChangedConfig = true; } } @@ -320,7 +262,7 @@ class BertEncoderTransformer if (!checkParameterInMap(batch_size, seq_len, head_num, size_per_head, int8_mode, is_fp16)) { - readAlgoFromConfig(int8_mode); + readAlgoFromConfig(int8_mode, cublasAlgoMap_, parameterMap_); } else { @@ -342,7 +284,7 @@ class BertEncoderTransformer else generate_encoder_gemm_config(batch_size, seq_len, head_num, size_per_head, gemm_test_buf); freeBufferForGemmTest(allocator_, gemm_test_buf); - readAlgoFromConfig(int8_mode); + readAlgoFromConfig(int8_mode, cublasAlgoMap_, parameterMap_); hasChangedConfig = true; } } @@ -355,60 +297,6 @@ class BertEncoderTransformer return hasChangedConfig; } - void getBestAlgoFromMap(int batch_size, int seq_len, int head_num, int size_per_head, int is_fp16) - { - int m = batch_size * seq_len; - int n = head_num * size_per_head; - int k = n; - char mark[256]; - int foundAlgo = 0; - sprintf(mark, "1_%d_%d_%d_%d", m, n, k, is_fp16); - std::string markStr(mark); - if (cublasAlgoMap_.find(markStr) != cublasAlgoMap_.end()) - { - cublasAlgo_[0] = cublasAlgoMap_[markStr]; - foundAlgo += 1; - } - if (foundAlgo == 1) - { - sprintf(mark, "1_%d_%d_%d_%d", m, 4*n, k, is_fp16); - std::string markStr(mark); - if (cublasAlgoMap_.find(markStr) != cublasAlgoMap_.end()) - { - cublasAlgo_[1] = cublasAlgoMap_[markStr]; - foundAlgo += 1; - } - if (foundAlgo == 2) - { - sprintf(mark, "1_%d_%d_%d_%d", m, n, 4*k, is_fp16); - std::string markStr(mark); - if (cublasAlgoMap_.find(markStr) != cublasAlgoMap_.end()) - { - cublasAlgo_[2] = cublasAlgoMap_[markStr]; - foundAlgo += 1; - } - } - } - - if (foundAlgo != 3) - { - printf("[WARNING][BertEncoderTransformer] Loading GEMM algorithms error, using default GEMM algorithms!\n"); - if (is_fp16 == 0) - { - cublasAlgo_[0] = -1; - cublasAlgo_[1] = -1; - cublasAlgo_[2] = -1; - } - else - { - cublasAlgo_[0] = 99; - cublasAlgo_[1] = 99; - cublasAlgo_[2] = 99; - } - } - } - - //free buffer for BertEncoderTransformer void freeBuffer() { @@ -465,34 +353,28 @@ class BertEncoderTransformer int n = k; int buf_size = m * n; - int buf_size_in_byte = calBufSizeInByte(batch_size_, from_seq_len_, head_num_, size_per_head_, int8_mode_); + size_t buf_size_in_byte = calBufSizeInByte(batch_size_, from_seq_len_, head_num_, size_per_head_, int8_mode_); //allocate buffer if (int8_mode_ != 0){ - //check if seq_len is a multiple of 32 - if (from_seq_len_ % 32 != 0){ - printf("[ERROR] seq_len should be a multiple of 32 when using int8 quantization\n"); - exit(-1); - } - buf_ = reinterpret_cast(allocator_->malloc(buf_size_in_byte, false)); if (buf_ == nullptr) throw std::runtime_error(std::string("Allocator failed to allocate internal buffer.")); - attr_out_buf_ = (DataType_*)(((void*)buf_) + m*k*sizeof(DataType_) + m*k*sizeof(int8_t) + 3*n*k*sizeof(int8_t) + 4*m*k * sizeof(int)); + attr_out_buf_ = (DataType_*)(((char*)buf_) + m*k*sizeof(DataType_) + m*k*sizeof(int8_t) + 3*n*k*sizeof(int8_t) + 4*m*k * sizeof(int)); attr_matmul_buf_ = attr_out_buf_ + buf_size; inter_matmul_buf_ = attr_matmul_buf_ + buf_size; - int8_from_tensor_tmp_ = (int8_t *)(((void*)buf_) + m*k*(sizeof(DataType_))); + int8_from_tensor_tmp_ = (int8_t *)(((char*)buf_) + m*k*(sizeof(DataType_))); attr_matmul_buf_tmp_ = int8_from_tensor_tmp_; transformer_out_tmp_int8_ = int8_from_tensor_tmp_; transA_from_tensor_tmp_ = (DataType_*)buf_; transformer_out_tmp_DataType_ = transA_from_tensor_tmp_; - int_buf_ = (int32_t*)(((void*)buf_) + (m * k) * (sizeof(DataType_) + sizeof(int8_t)) + 3*n*k*sizeof(int8_t)); + int_buf_ = (int32_t*)(((char*)buf_) + (m * k) * (sizeof(DataType_) + sizeof(int8_t)) + 3*n*k*sizeof(int8_t)); - tmp_DataType_ = (DataType_*)(((void*)buf_) + m*k*(sizeof(DataType_)+sizeof(int8_t)) + 3*n*k*sizeof(int8_t) + 4*m*k * sizeof(int32_t) + 6*m*n*sizeof(DataType_)); + tmp_DataType_ = (DataType_*)(((char*)buf_) + m*k*(sizeof(DataType_)+sizeof(int8_t)) + 3*n*k*sizeof(int8_t) + 4*m*k * sizeof(int32_t) + 6*m*n*sizeof(DataType_)); tmp_int8_ = (int8_t*)tmp_DataType_; } else{ @@ -500,13 +382,21 @@ class BertEncoderTransformer if (buf_ == nullptr) throw std::runtime_error(std::string("Allocator failed to allocate internal buffer.")); - attr_out_buf_ = buf_; + if (sizeof(half) == sizeof(DataType_)) + { + //cublas_workspace_ should be the start pointer of cudaMalloc() + //to ensure 16B alignemnet + cublas_workspace_ = buf_; + attr_out_buf_ = (DataType_*)((char *)cublas_workspace_ + CUBLAS_WORKSPACE_SIZE); + } + else + { + cublas_workspace_ = nullptr; + attr_out_buf_ = (DataType_*)buf_; + } attr_matmul_buf_ = attr_out_buf_ + buf_size; inter_matmul_buf_ = attr_matmul_buf_ + buf_size; - - attr_out_tmp_buf_ = inter_matmul_buf_ + 4 * buf_size; - out_tmp_buf_ = attr_out_tmp_buf_ + buf_size; - from_tensor_tmp_buf_ = out_tmp_buf_ + buf_size; + attr_matmul_unnormed_buf_ = inter_matmul_buf_ + 4 * buf_size; } } @@ -523,14 +413,8 @@ class BertEncoderTransformer size_per_head_, int8_mode_, is_fp16); } - if (int8_mode_ == 0) - { - //get best FP16/FP32 algo from map - getBestAlgoFromMap(batch_size_, from_seq_len_, head_num_, size_per_head_, is_fp16); - } - //allocate buffer for attention_ - attention_->allocateBuffer(allocator, batch_size_, from_seq_len_, to_seq_len, + attention_->allocateBuffer(allocator, cublas_workspace_, batch_size_, from_seq_len_, to_seq_len, head_num_, size_per_head_, hasChangedConfig, use_trt_kernel); } catch (std::runtime_error &error) @@ -549,48 +433,33 @@ class BertEncoderTransformer try { - if (int8_mode_ != 0){ - - // check sm version -#ifdef CUDA11_MODE - int device{-1}; - cudaGetDevice(&device); - cudaDeviceProp props; - cudaGetDeviceProperties(&props, device); - if (props.major * 10 + props.minor >= 80){ - use_ORDER_COL32_2R_4R4_ = true; - } -#endif + sm_ = getSMVersion(); + if (sm_ >= 80){ + use_ORDER_COL32_2R_4R4_ = true; + } + if (sm_ < 75 && int8_mode_ != 0){ + printf("[ERROR][BertEncoderTransformer] int8 mode only works with sm >= 75.\n"); + exit(-1); + } - //read best algos from config - int isConfigExist = access(IGEMM_CONFIG, 0); - if (isConfigExist == -1) - { - if (!allow_gemm_test_) - { - printf("[WARNING][BertEncoderTransformer] %s is not found; using default GEMM algo\n", IGEMM_CONFIG); - } - } - else + int isConfigExist = -1; + if (int8_mode_ != 0) + isConfigExist = access(IGEMM_CONFIG, 0); + else + isConfigExist = access(GEMM_CONFIG, 0); + if (isConfigExist == -1) + { + if (!allow_gemm_test_) { - readAlgoFromConfig(int8_mode_); + printf("[WARNING][BertEncoderTransformer] %s is not found; using default GEMM algo\n", int8_mode_ != 0 ? IGEMM_CONFIG : GEMM_CONFIG); } } - else{ - int isConfigExist = access(GEMM_CONFIG, 0); - if (isConfigExist == -1) - { - if (!allow_gemm_test_) - { - printf("[WARNING][BertEncoderTransformer] %s is not found; using default GEMM algo\n", GEMM_CONFIG); - } - } - else - { - readAlgoFromConfig(int8_mode_); - } + else + { + readAlgoFromConfig(int8_mode_, cublasAlgoMap_, parameterMap_); } - attention_ = new typename Traits_::MultiHeadAttention(int8_mode_, allow_gemm_test_, use_ORDER_COL32_2R_4R4_); + + attention_ = new typename Traits_::MultiHeadAttention(int8_mode_, allow_gemm_test_, use_ORDER_COL32_2R_4R4_, sm_); } catch (std::runtime_error &error) { @@ -598,6 +467,22 @@ class BertEncoderTransformer } } + BertEncoderTransformer(const BertEncoderTransformer *transformer) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + sm_ = transformer->sm_; + use_ORDER_COL32_2R_4R4_ = transformer->use_ORDER_COL32_2R_4R4_; + int8_mode_ = transformer->int8_mode_; + allow_gemm_test_ = transformer->allow_gemm_test_; + + cublasAlgoMap_ = transformer->cublasAlgoMap_; + parameterMap_ = transformer->parameterMap_; + + attention_ = new typename Traits_::MultiHeadAttention(transformer->attention_); + } + void genTransATensorAndInt8TensorForFirstLayer(){ const int m = param_.sequence_id_offset == nullptr ? batch_size_ * from_seq_len_ : param_.valid_word_num; const int k = head_num_ * size_per_head_; @@ -607,7 +492,7 @@ class BertEncoderTransformer transA_from_tensor_ = (const DataType_*)transA_from_tensor_tmp_; quantized_kernelLauncher(int8_from_tensor_tmp_, transA_from_tensor_, m*k, to_tensor_amax_ptr+3, param_.stream); } - else + else if (int8_mode_ == 2 || int8_mode_ == 3) { transposeMatrix_colMajorToCOL32_quantize_kernelLauncher(int8_from_tensor_tmp_, param_.from_tensor, k, m, to_tensor_amax_ptr+3, param_.stream); } @@ -619,7 +504,7 @@ class BertEncoderTransformer * We will keep the Ctor empty to ensure the sub classes follow the same init routine. * Please be aware that no dynamic memory allocation should be placed **/ - void initialize(EncoderInitParam param) + void initialize(BertInitParam param) { #ifndef NDEBUG PRINT_FUNC_NAME_(); @@ -645,10 +530,14 @@ class BertEncoderTransformer FC0_weight_amax_list = param_.amaxList + ACTIVATION_AMAX_NUM + 3*hidden_dim; FC1_weight_amax_list = FC0_weight_amax_list + hidden_dim; FC2_weight_amax_list = FC1_weight_amax_list + 4*hidden_dim; - - if (int8_mode_ == 2) - check_cuda_error(cudaMemcpyAsync(int8O_gemm_deQ_scale_list, FC2_weight_amax_list + hidden_dim, INT8O_GEMM_NUM*sizeof(float), cudaMemcpyDeviceToHost, param_.stream)); + //This D2H copy operation will cause performance degradation + if ( (int8_mode_ == 1 && ((batch_size_*from_seq_len_ >= 512) || (from_seq_len_ % 32 != 0)) ) || int8_mode_ == 2 || int8_mode_ == 3) + { + //copy (int8O_gemm_deQ_scale_list + trt_fused_mha_amax_list) amax into scale_list + check_cuda_error(cudaMemcpyAsync(scale_list, FC2_weight_amax_list + hidden_dim, (INT8O_GEMM_NUM+TRT_FUSED_MHA_AMAX_NUM)*sizeof(float), cudaMemcpyDeviceToHost, param_.stream)); + int8O_gemm_deQ_scale_list = scale_list; + } int k = hidden_dim; const int m = param_.sequence_id_offset == nullptr ? batch_size_ * from_seq_len_ : param_.valid_word_num; @@ -658,10 +547,10 @@ class BertEncoderTransformer else { transA_from_tensor_ = param_.from_tensor; - if (int8_mode_ == 2){ + if (int8_mode_ == 2 || int8_mode_ == 3){ int8_from_tensor_ = (const int8_t*)transA_from_tensor_; } - else{ + else if (int8_mode_ == 1){ quantized_kernelLauncher(int8_from_tensor_tmp_, transA_from_tensor_, m*k, to_tensor_amax_ptr + 3, param_.stream); int8_from_tensor_ = (const int8_t*)(int8_from_tensor_tmp_); } @@ -669,11 +558,11 @@ class BertEncoderTransformer multi_head_init_param.int8_from_tensor = int8_from_tensor_; - multi_head_init_param.cublaslt_handle = param_.cublaslt_handle; - multi_head_init_param.amaxList = param_.amaxList; multi_head_init_param.int8O_gemm_deQ_scale_list = int8O_gemm_deQ_scale_list; + + multi_head_init_param.trt_fused_mha_amax_list = scale_list + INT8O_GEMM_NUM; } multi_head_init_param.from_tensor = param.from_tensor; @@ -682,12 +571,13 @@ class BertEncoderTransformer multi_head_init_param.attr_mask = param.attr_mask; multi_head_init_param.stream = param.stream; multi_head_init_param.cublas_handle = param.cublas_handle; + multi_head_init_param.cublaslt_handle = param_.cublaslt_handle; multi_head_init_param.attr_out = attr_out_buf_; multi_head_init_param.valid_word_num = param.valid_word_num; multi_head_init_param.sequence_id_offset = param.sequence_id_offset; multi_head_init_param.trt_seqlen_offset = param_.trt_seqlen_offset; multi_head_init_param.trt_seqlen_size = param_.trt_seqlen_size; - + attention_->initialize(multi_head_init_param); } @@ -719,16 +609,16 @@ class BertEncoderTransformer { cublasLtMM_withAlgo(int_buf_, 1, m, n, k, m*k, n*k, m*n, (int8_t*)attr_out_buf_, (int8_t*)(param_.self_attention. attention_output_weight.kernel), - param_.cublaslt_handle, param_.stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + param_.cublaslt_handle, param_.stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); add_bias_input_layernorm_COL32_int32I_DataTypeO_kernelLauncher(attr_matmul_buf_, int_buf_, transA_from_tensor_, param_.self_attention.attention_output_weight.bias, param_.self_layernorm.gamma, param_.self_layernorm.beta, m, n, param_.stream, FC0_weight_amax_list, bmm2_amax_ptr); } - else + else if (int8_mode_ == 2 || int8_mode_ == 3) { cublasLtMM_withAlgo_int8IO((int8_t*)int_buf_, 1, m, n, k, m*k, n*k, m*n, int8O_gemm_deQ_scale_list[5], (int8_t*)attr_out_buf_, (int8_t*)(param_.self_attention. attention_output_weight.kernel), - param_.cublaslt_handle, param_.stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + param_.cublaslt_handle, param_.stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); add_bias_input_layernorm_COL32_int8IO_kernelLauncher((int8_t*)attr_matmul_buf_, (int8_t*)int_buf_, int8_from_tensor_, param_.self_attention.attention_output_weight.bias, param_.self_layernorm.gamma, param_.self_layernorm.beta, @@ -747,16 +637,16 @@ class BertEncoderTransformer quantized_kernelLauncher(attr_matmul_buf_tmp_, attr_matmul_buf_, k*m, ProjBiasNorm_amax_ptr + 3, param_.stream); cublasLtMM_withAlgo(int_buf_, 1, m, n, k, m*k, n*k, m*n, attr_matmul_buf_tmp_, (int8_t*)(param_.ffn.intermediate_weight.kernel), - param_.cublaslt_handle, param_.stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + param_.cublaslt_handle, param_.stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); add_bias_act_COL32_int32I_int8O_kernelLauncher((int8_t*)inter_matmul_buf_, int_buf_, param_.ffn.intermediate_weight.bias, m, n, param_.stream, FC1_weight_amax_list, ProjBiasNorm_amax_ptr+2, F1Bias_amax_ptr+3); } - else + else if (int8_mode_ == 2 || int8_mode_ == 3) { cublasLtMM_withAlgo_int8IO((int8_t*)int_buf_, 1, m, n, k, m*k, n*k, m*n, int8O_gemm_deQ_scale_list[6], (int8_t*)attr_matmul_buf_, (int8_t*)(param_.ffn.intermediate_weight.kernel), - param_.cublaslt_handle, param_.stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + param_.cublaslt_handle, param_.stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); add_bias_act_COL32_int8IO_kernelLauncher((int8_t*)inter_matmul_buf_, (int8_t*)int_buf_, param_.ffn.intermediate_weight.bias, m, n, param_.stream, F1_aftergemm_amax_ptr+1, F1Bias_amax_ptr+3); @@ -774,7 +664,7 @@ class BertEncoderTransformer { cublasLtMM_withAlgo(int_buf_, 1, m, n, k, m*k, n*k, m*n, (int8_t*)inter_matmul_buf_, (int8_t*)(param_.ffn.output_weight.kernel), - param_.cublaslt_handle, param_.stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + param_.cublaslt_handle, param_.stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); if (layer_idx_ != layer_num_ - 1) { add_bias_input_layernorm_COL32_int32I_DataTypeO_kernelLauncher(param_.transformer_out, int_buf_, attr_matmul_buf_, @@ -791,11 +681,11 @@ class BertEncoderTransformer transposeMatrix_COL32ToColMajor_kernelLauncher(param_.transformer_out, transformer_out_tmp_DataType_, m, n, param_.stream); } } - else + else if (int8_mode_ == 2 || int8_mode_ == 3) { cublasLtMM_withAlgo_int8IO((int8_t*)int_buf_, 1, m, n, k, m*k, n*k, m*n, int8O_gemm_deQ_scale_list[7], (int8_t*)inter_matmul_buf_, (int8_t*)(param_.ffn.output_weight.kernel), - param_.cublaslt_handle, param_.stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + param_.cublaslt_handle, param_.stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); if (layer_idx_ != layer_num_ - 1) { add_bias_input_layernorm_COL32_int8IO_kernelLauncher((int8_t*)param_.transformer_out, (int8_t*)int_buf_, (int8_t*)attr_matmul_buf_, @@ -819,21 +709,20 @@ class BertEncoderTransformer #endif } else{ - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.attention_output_weight.kernel, AType_, n, - attr_out_buf_, BType_, k, - &beta, - attr_matmul_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); + cublasMM_cublasLtMM_wrapper(param_.cublaslt_handle, param_.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, &alpha, + param_.self_attention.attention_output_weight.kernel, AType_, n, + attr_out_buf_, BType_, k, + &beta, (DataType_ *)attr_matmul_buf_, CType_, n, + param_.stream, cublasAlgoMap_, sm_, cublas_workspace_); + add_bias_input_layernorm_kernelLauncher(attr_matmul_buf_, - param_.from_tensor, param_.self_attention.attention_output_weight.bias, + param_.from_tensor, + param_.self_attention.attention_output_weight.bias, param_.self_layernorm.gamma, - param_.self_layernorm.beta, m, n, param_.stream); - + param_.self_layernorm.beta, + m, n, param_.stream); + #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); @@ -841,18 +730,14 @@ class BertEncoderTransformer n *= 4; - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.ffn.intermediate_weight.kernel, AType_, n, - attr_matmul_buf_, BType_, k, - &beta, - inter_matmul_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[1]))); - - add_bias_act_kernelLauncher(inter_matmul_buf_, param_.ffn.intermediate_weight.bias, m, n, param_.stream); + cublasMM_cublasLtMM_wrapper(param_.cublaslt_handle, param_.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, &alpha, + param_.ffn.intermediate_weight.kernel, AType_, n, + attr_matmul_buf_, BType_, k, + &beta, (DataType_ *)inter_matmul_buf_, CType_, n, + param_.stream, cublasAlgoMap_, sm_, cublas_workspace_); + + add_bias_act_kernelLauncher(inter_matmul_buf_, param_.ffn.intermediate_weight.bias, m, n, ActivationType::GELU, param_.stream); #ifndef NDEBUG cudaDeviceSynchronize(); @@ -862,18 +747,15 @@ class BertEncoderTransformer n = k; k *= 4; - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.ffn.output_weight.kernel, AType_, n, - inter_matmul_buf_, BType_, k, - &beta, - param_.transformer_out, CType_, n, - computeType_, - static_cast(cublasAlgo_[2]))); + cublasMM_cublasLtMM_wrapper(param_.cublaslt_handle, param_.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, &alpha, + param_.ffn.output_weight.kernel, AType_, n, + inter_matmul_buf_, BType_, k, + &beta, (DataType_ *)(param_.transformer_out), CType_, n, + param_.stream, cublasAlgoMap_, sm_, cublas_workspace_); - add_bias_input_layernorm_kernelLauncher(param_.transformer_out, attr_matmul_buf_, + add_bias_input_layernorm_kernelLauncher(param_.transformer_out, + attr_matmul_buf_, param_.ffn.output_weight.bias, param_.ffn_layernorm.gamma, param_.ffn_layernorm.beta, @@ -891,16 +773,6 @@ class BertEncoderTransformer } } - void trt_initialize(DataType_ *from_tensor, DataType_ *to_tensor, DataType_ *attr_mask, DataType_ *out, cudaStream_t stream, cublasHandle_t cublas_handle) - { - param_.from_tensor = from_tensor; - param_.to_tensor = to_tensor; - param_.stream = stream; - param_.transformer_out = out; - param_.cublas_handle = cublas_handle; - attention_->trt_initialize(from_tensor, to_tensor, attr_mask, stream, param_.cublas_handle); - } - ~BertEncoderTransformer() { if (buf_ != NULL) @@ -918,3 +790,4 @@ class BertEncoderTransformer }; } // namespace fastertransformer + diff --git a/fastertransformer/common.h b/fastertransformer/common.h deleted file mode 100644 index bcc628775..000000000 --- a/fastertransformer/common.h +++ /dev/null @@ -1,485 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include "stdio.h" - -#define MAX_CONFIG_NUM 20 -#define GEMM_NUM 6 -#define COL32_ 32 -#define ACTIVATION_AMAX_NUM 80 -#define INT8O_GEMM_NUM 8 -#define GEMM_CONFIG "gemm_config.in" -#define IGEMM_CONFIG "igemm_config.in" - -#include "fastertransformer/gemm_test/encoder_gemm_func.h" -#include "fastertransformer/gemm_test/encoder_igemm_func.h" - -namespace fastertransformer -{ - -enum class OperationType -{ - FP32, - FP16 -}; -enum class AllocatorType -{ - CUDA, - TF, - TH -}; - -#define PRINT_FUNC_NAME_() \ - do \ - { \ - std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ - } while (0) - -static const char *_cudaGetErrorEnum(cudaError_t error) -{ - return cudaGetErrorString(error); -} - -static inline __device__ int8_t float_to_int8_rn(float x) -{ - uint32_t dst; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" - : "=r"(dst) - : "f"(x)); - return reinterpret_cast(dst); -} - -static const char *_cudaGetErrorEnum(cublasStatus_t error) -{ - switch (error) - { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return ""; -} - -//for int8 cublasLtMM with algo -//ATransform should be m*n, CUBLASLT_ORDER_COL32 -//kernel should be n*k, CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 -//res is m*n, CUBLASLT_ORDER_COL32 -template -void cublasLtMM_withAlgo(int *res, int batchCount, int m, int n, int k, - int64_t stridea, int64_t strideb, int64_t stridec, - const int8_t *ATransform, const T *kernel, cublasLtHandle_t cublasLt_handle, - cudaStream_t stream, std::map &cublasLtAlgoMap, - bool use_ORDER_COL32_2R_4R4, bool use_default_algo = false) -{ - cublasOperation_t opTranspose = CUBLAS_OP_T; -#ifdef CUDA11_MODE - cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; -#else - cudaDataType_t computeType = CUDA_R_32I; -#endif - cublasLtMatmulDesc_t matmulDesc; - cublasLtMatrixLayout_t AtransformDesc = NULL; - cublasLtMatrixLayout_t BtransformDesc = NULL; - cublasLtMatrixLayout_t CtransformDesc = NULL; - cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; - - cublasLtOrder_t order_matrixB; -#ifdef CUDA11_MODE - if (use_ORDER_COL32_2R_4R4) - order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; - else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; -#else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; -#endif - - - - int ldaTransform = 32 * m; - int ldbTransform; - if (use_ORDER_COL32_2R_4R4) - ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; - else - ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; - int ldcTransform = 32 * m; - - // create matmulDesc -#ifdef CUDA11_MODE - cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); -#else - cublasLtMatmulDescCreate(&matmulDesc, computeType); -#endif - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); - cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); - cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); - cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB)); - cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, m, n, ldcTransform); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); - if (batchCount > 1) - { - cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, sizeof(stridea)); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, sizeof(strideb)); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec)); - } - - int alphaI = 1; - int betaI = 0; - - //get algo - cublasLtMatmulAlgo_t algo; - char mark[1000]; - sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k); - std::string markStr(mark); - int findAlgo = 0; - if ((!use_default_algo) && cublasLtAlgoMap.find(markStr) != cublasLtAlgoMap.end() && cublasLtAlgoMap[markStr].workspaceSize == 0) - { - //printf("find algo %s\n", markStr.c_str()); - findAlgo = 1; - - cublasLtMatmulAlgoInit(cublasLt_handle, computeType, CUDA_R_32I, CUDA_R_8I, CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, cublasLtAlgoMap[markStr].algoId, &algo); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(cublasLtAlgoMap[markStr].customOption), sizeof(cublasLtAlgoMap[markStr].customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(cublasLtAlgoMap[markStr].tile), sizeof(cublasLtAlgoMap[markStr].tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(cublasLtAlgoMap[markStr].splitK_val), sizeof(cublasLtAlgoMap[markStr].splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(cublasLtAlgoMap[markStr].swizzle), sizeof(cublasLtAlgoMap[markStr].swizzle)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(cublasLtAlgoMap[markStr].reductionScheme), sizeof(int)); -#ifdef CUDA11_MODE - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(cublasLtAlgoMap[markStr].stages), sizeof(cublasLtAlgoMap[markStr].stages)); -#endif - } - else - { - findAlgo = 1; - int algoId; - if (use_ORDER_COL32_2R_4R4) - { - algoId = 7; - } - else - { - algoId = 6; - } - int swizzle = 0; - int customOption = 0; - int tile = 20; - int splitK_val = 0; - int reductionScheme = 0; - cublasLtMatmulAlgoInit(cublasLt_handle, computeType, CUDA_R_32I, CUDA_R_8I, CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(customOption), sizeof(customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(splitK_val), sizeof(splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); -#ifdef CUDA11_MODE - int stages; - if (use_ORDER_COL32_2R_4R4) - stages = 15; - else - stages = 13; - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); -#endif - } - - cublasLtMatmul(cublasLt_handle, - matmulDesc, - &alphaI, - ATransform, - AtransformDesc, - kernel, - BtransformDesc, - &betaI, - res, - CtransformDesc, - res, - CtransformDesc, - (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream); - - cublasLtMatmulDescDestroy(matmulDesc); - cublasLtMatrixLayoutDestroy(AtransformDesc); - cublasLtMatrixLayoutDestroy(BtransformDesc); - cublasLtMatrixLayoutDestroy(CtransformDesc); -} - -//for int8 IO cublasLtMM with algo -//ATransform should be m*k CUBLASLT_ORDER_COL32 -//kernel should be n*k CUBLASLT_ORDER_COL4_4R2_8C -//res is m*n CUBLASLT_ORDER_COL32 -template -void cublasLtMM_withAlgo_int8IO(int8_t *res, int batchCount, int m, int n, int k, - int64_t stridea, int64_t strideb, int64_t stridec, - const float alpha, const int8_t *ATransform, const T *kernel, - cublasLtHandle_t cublasLt_handle, cudaStream_t stream, - std::map &cublasLtAlgoMap, - bool use_ORDER_COL32_2R_4R4, bool use_default_algo=false) -{ - cublasOperation_t opTranspose = CUBLAS_OP_T; - //int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE - //cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - cudaDataType_t scaleType = CUDA_R_32F; -#ifdef CUDA11_MODE - cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; -#else - cudaDataType_t computeType = CUDA_R_32I; -#endif - cublasLtMatmulDesc_t matmulDesc; - cublasLtMatrixLayout_t AtransformDesc = NULL; - cublasLtMatrixLayout_t BtransformDesc = NULL; - cublasLtMatrixLayout_t CtransformDesc = NULL; - cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; - - cublasLtOrder_t order_matrixB; -#ifdef CUDA11_MODE - if (use_ORDER_COL32_2R_4R4) - order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; - else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; -#else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; -#endif - - - int ldaTransform = 32 * m; - - int ldbTransform; - if (use_ORDER_COL32_2R_4R4) - ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; - else - ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; - - - int ldcTransform = 32 * m; - - // create matmulDesc -#ifdef CUDA11_MODE - cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); -#else - cublasLtMatmulDescCreate(&matmulDesc, computeType); -#endif - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scaleType, sizeof(scaleType)); - //cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, sizeof(cublasLtPointerMode_t)); - cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); - cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); - cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_matrixB, sizeof(order_matrixB)); - cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, m, n, ldcTransform); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); - if (batchCount > 1) - { - cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, sizeof(stridea)); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, sizeof(strideb)); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec)); - } - //get algo - cublasLtMatmulAlgo_t algo; - char mark[1000]; - sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k); - std::string markStr(mark); - int findAlgo = 0; - if ((!use_default_algo) && cublasLtAlgoMap.find(markStr) != cublasLtAlgoMap.end() && cublasLtAlgoMap[markStr].workspaceSize == 0) - { - findAlgo = 1; - cublasLtMatmulAlgoInit(cublasLt_handle, computeType, CUDA_R_32F, CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, cublasLtAlgoMap[markStr].algoId, &algo); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(cublasLtAlgoMap[markStr].customOption), sizeof(cublasLtAlgoMap[markStr].customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(cublasLtAlgoMap[markStr].tile), sizeof(cublasLtAlgoMap[markStr].tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(cublasLtAlgoMap[markStr].splitK_val), sizeof(cublasLtAlgoMap[markStr].splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(cublasLtAlgoMap[markStr].swizzle), sizeof(cublasLtAlgoMap[markStr].swizzle)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(cublasLtAlgoMap[markStr].reductionScheme), sizeof(int)); -#ifdef CUDA11_MODE - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(cublasLtAlgoMap[markStr].stages), sizeof(cublasLtAlgoMap[markStr].stages)); -#endif - } - else - { - findAlgo = 1; - int algoId; - if (use_ORDER_COL32_2R_4R4) - { - algoId = 7; - } - else - { - algoId = 6; - } - int swizzle = 0; - int customOption = 0; - int tile = 20; - int splitK_val = 0; - int reductionScheme = 0; - cublasLtMatmulAlgoInit(cublasLt_handle, computeType, CUDA_R_32F, CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(customOption), sizeof(customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(splitK_val), sizeof(splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); -#ifdef CUDA11_MODE - int stages; - if (use_ORDER_COL32_2R_4R4) - stages = 15; - else - stages = 13; - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); -#endif - } - - float beta = 0.0f; - cublasLtMatmul(cublasLt_handle, - matmulDesc, - &alpha, - ATransform, - AtransformDesc, - kernel, - BtransformDesc, - &beta, - res, - CtransformDesc, - res, - CtransformDesc, - (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream); - - cublasLtMatmulDescDestroy(matmulDesc); - cublasLtMatrixLayoutDestroy(AtransformDesc); - cublasLtMatrixLayoutDestroy(BtransformDesc); - cublasLtMatrixLayoutDestroy(CtransformDesc); -} - -template -void check(T result, char const *const func, const char *const file, int const line) -{ - if (result) - { - throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + - (_cudaGetErrorEnum(result)) + " " + file + - ":" + std::to_string(line) + " \n"); - } -} - -#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) - -template -void print_to_file(T *result, const int size, char *file) -{ - FILE *fd = fopen(file, "w"); - T *tmp = (T *)malloc(sizeof(T) * size); - check_cuda_error(cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); - for (int i = 0; i < size; ++i) - { - float val; - if (sizeof(T) == 2) - val = (T)__half2float(tmp[i]); - else - val = (T)tmp[i]; - fprintf(fd, "%f\n", val); - } - free(tmp); - fclose(fd); -} - -template -void print_to_screen(T *result, const int size) -{ - T *tmp = (T *)malloc(sizeof(T) * size); - check_cuda_error(cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); - for (int i = 0; i < size; ++i) - printf("%d, %f\n", i, (float)tmp[i]); - free(tmp); -} - -template -void check_max_val(const T *result, const int size) -{ - T *tmp = new T[size]; - cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost); - float max_val = -100000; - for (int i = 0; i < size; i++) - { - float val = (float)(tmp[i]); - if (val > max_val) - max_val = val; - } - delete tmp; - printf("[INFO][CUDA] addr %p max val: %f \n", result, max_val); -} - -inline int getSMVersion() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - cudaDeviceProp props; - check_cuda_error(cudaGetDeviceProperties(&props, device)); - return props.major * 10 + props.minor; -} - -template -void check_abs_mean_val(const T *result, const int size) -{ - T *tmp = new T[size]; - cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost); - float sum = 0.0f; - for (int i = 0; i < size; i++) - { - sum += abs((float)tmp[i]); - } - delete tmp; - printf("[INFO][CUDA] addr %p abs mean val: %f \n", result, sum / size); -} - -inline int div_up(int a, int n) -{ - return (a + n - 1) / n; -} - -} //namespace fastertransformer diff --git a/fastertransformer/cuda/CMakeLists.txt b/fastertransformer/cuda/CMakeLists.txt index 6aa9b77e9..6fdefcb89 100644 --- a/fastertransformer/cuda/CMakeLists.txt +++ b/fastertransformer/cuda/CMakeLists.txt @@ -19,6 +19,7 @@ set(encoder_kernel_files set(decoder_kernel_files open_decoder.cu + masked_multihead_attention.cu ) set(online_softmax_beamsearch_kernel_files @@ -38,6 +39,16 @@ set_property(TARGET cuda_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET cuda_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(cuda_kernels PUBLIC -lcublas -lcudart -lcurand) +add_library(attention_kernels STATIC attention_kernels.cu) +set_property(TARGET attention_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET attention_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(attention_kernels PUBLIC -lcublas -lcudart -lcurand) + +add_library(transformer_kernels STATIC transformer_kernels.cu) +set_property(TARGET transformer_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET transformer_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(transformer_kernels PUBLIC -lcublas -lcudart -lcurand) + add_library(cuda_int8_kernels STATIC cuda_int8_kernels.cu) set_property(TARGET cuda_int8_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET cuda_int8_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) @@ -46,12 +57,14 @@ target_link_libraries(cuda_int8_kernels PUBLIC -lcublas -lcudart -lcurand cuda_k add_library(encoder STATIC ${encoder_kernel_files}) set_property(TARGET encoder PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET encoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(encoder PUBLIC -lcublas -lcublasLt -lcudart -lcurand cuda_kernels cuda_int8_kernels trt_fused_multi_head_attention encoder_gemm_func encoder_igemm_func) +target_link_libraries(encoder PUBLIC -lcublas -lcublasLt -lcudart -lcurand cuda_kernels + cuda_int8_kernels attention_kernels transformer_kernels + trt_fused_multi_head_attention encoder_gemm_func encoder_igemm_func) add_library(decoder STATIC ${decoder_kernel_files}) set_property(TARGET decoder PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(decoder PUBLIC -lcublas -lcudart -lcurand) +target_link_libraries(decoder PUBLIC -lcublas -lcudart -lcurand cuda_kernels attention_kernels transformer_kernels nccl_utils nvtx_utils) add_library(online_softmax_beamsearch STATIC ${online_softmax_beamsearch_kernel_files}) set_property(TARGET online_softmax_beamsearch PROPERTY POSITION_INDEPENDENT_CODE ON) @@ -67,3 +80,11 @@ add_library(decoding STATIC ${decoding_kernel_files}) set_property(TARGET decoding PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoding PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(decoding PUBLIC -lcublas -lcudart -lcurand topk online_softmax_beamsearch cuda_kernels) + +if (BUILD_GPT) + target_compile_features(cuda_kernels PRIVATE cxx_std_14) + target_compile_features(attention_kernels PRIVATE cxx_std_14) + target_compile_features(decoder PRIVATE cxx_std_14) + target_compile_features(online_softmax_beamsearch PRIVATE cxx_std_14) + target_compile_features(topk PRIVATE cxx_std_14) +endif() diff --git a/fastertransformer/cuda/attention_kernels.cu b/fastertransformer/cuda/attention_kernels.cu new file mode 100644 index 000000000..0bd550ab6 --- /dev/null +++ b/fastertransformer/cuda/attention_kernels.cu @@ -0,0 +1,845 @@ +/* +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "fastertransformer/cuda/attention_kernels.cuh" + +namespace fastertransformer +{ + +#define FINAL_MASK 0xffffffff + +template +__inline__ __device__ +T warpReduceSum(T val) +{ + #pragma unroll + for(int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template + __inline__ __device__ +T blockReduceSum(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if(lane == 0) + shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + + return val; +} + +template + __inline__ __device__ +T warpReduceMax(T val) +{ + #pragma unroll + for(int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +/* Calculate the maximum of all elements in a block */ +template + __inline__ __device__ +T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if(lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + + val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +__inline__ __device__ +int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4) +{ + return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4; +} + +template +__global__ +void add_QKV_bias(T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, const T* bias_V, T* q_buf_, T* k_buf_, T* v_buf_, + const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int word_per_block) +{ + + T* data_ptr; + T* buf_ptr; + const T* bias_ptr; + + int m = batch_size * seq_len; + int n = head_num * size_per_head; + + int qkv_id = blockIdx.x * word_per_block / m; + int row_offset = (blockIdx.x * word_per_block % m) * n; + + if(qkv_id == 0) + { + data_ptr = Q + row_offset; + buf_ptr = q_buf_; + bias_ptr = bias_Q; + } + else if(qkv_id == 1) + { + data_ptr = K + row_offset; + buf_ptr = k_buf_; + bias_ptr = bias_K; + } + else + { + data_ptr = V + row_offset; + buf_ptr = v_buf_; + bias_ptr = bias_V; + } + + int batch_id = (blockIdx.x * word_per_block % m) / seq_len; + int head_id = threadIdx.x / size_per_head; + int id_in_head = threadIdx.x % size_per_head; + int word_start_id = (blockIdx.x * word_per_block) % seq_len; + + T bias = __ldg(&bias_ptr[threadIdx.x]); + + for(int i = word_start_id; i < word_start_id + word_per_block; ++i) + { + T tmp = data_ptr[threadIdx.x] + bias; + + int target_id = batch_id * (seq_len * head_num * size_per_head) + head_id * seq_len * size_per_head + + i * size_per_head + id_in_head; + + buf_ptr[target_id] = tmp; + data_ptr += n; + } +} + +template <> +__global__ +void add_QKV_bias(half* Q, const half* bias_Q, half* K, const half* bias_K, half* V, const half* bias_V, + half* q_buf_, half* k_buf_, half* v_buf_, + const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int word_per_block) +{ + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int batch_id = tid / (head_num * seq_len * size_per_head); + int seq_id = (tid % (head_num * seq_len * size_per_head)) / (head_num * size_per_head); + int head_id = (tid % (head_num * size_per_head)) / size_per_head; + int id = tid % size_per_head; + int target_id = target_index(batch_id, seq_id, head_id, id, batch_size, seq_len, head_num, size_per_head); + + int bias_id = threadIdx.x; + + half2* src_ptr = (half2*)Q; + half2* dst_ptr = (half2*)q_buf_; + const half2* bias_ptr = (const half2*)bias_Q; + + dst_ptr[target_id] = __hadd2(src_ptr[tid], __ldg(&bias_ptr[bias_id])); + + src_ptr = (half2*)K; + dst_ptr = (half2*)k_buf_; + bias_ptr = (const half2*)bias_K; + dst_ptr[target_id] = __hadd2(src_ptr[tid], __ldg(&bias_ptr[bias_id])); + + src_ptr = (half2*)V; + dst_ptr = (half2*)v_buf_; + bias_ptr = (const half2*)bias_V; + dst_ptr[target_id] = __hadd2(src_ptr[tid], __ldg(&bias_ptr[bias_id])); +} + + +template +__global__ +void add_QKV_bias_generalized(const T* __restrict Q, + const T* __restrict bias_Q, + const T* __restrict K, + const T* __restrict bias_K, + const T* __restrict V, + const T* __restrict bias_V, + T* q_buf_, T* k_buf_, T* v_buf_, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int word_per_block) +{ + + const T* data_ptr; + T* buf_ptr; + T bias; + + int n = head_num * size_per_head; + const int blocks_per_word = n / blockDim.x; + const int blocks_per_buffer = gridDim.x / 3; + const int qkv_id = blockIdx.x / blocks_per_buffer; + const int block_id_in_buffer = blockIdx.x % blocks_per_buffer; + const int offset = block_id_in_buffer * blockDim.x + threadIdx.x; + const int bias_id = offset % n; + + if(qkv_id == 0) + { + data_ptr = Q + offset; + buf_ptr = q_buf_; + bias = __ldg(&bias_Q[bias_id]); + } + else if(qkv_id == 1) + { + data_ptr = K + offset; + buf_ptr = k_buf_; + bias = __ldg(&bias_K[bias_id]); + } + else + { + data_ptr = V + offset; + buf_ptr = v_buf_; + bias = __ldg(&bias_V[bias_id]); + } + + const int head_id = bias_id / size_per_head; + const int size_id = bias_id % size_per_head; + + for(int i = 0; i < word_per_block; i++) + { + const int block_lane = i * blocks_per_buffer; + const int batch_id = (block_id_in_buffer + block_lane) / seq_len / blocks_per_word; + const int word_id = ((block_id_in_buffer + block_lane) / blocks_per_word) % seq_len; + + int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head + + word_id * size_per_head + size_id; + buf_ptr[target_id] = __ldg(&data_ptr[block_lane * blockDim.x]) + bias; + } +} + +template +void add_QKV_bias_transpose_kernelLauncher( + T* q_buf, + T* k_buf, + T* v_buf, + T* Q, + const T* bias_Q, + T* K, + const T* bias_K, + T* V, + const T* bias_V, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream) +{ + const int k = head_num * size_per_head; + dim3 grid, block; + if(k <= 1024) + { + if(sizeof(T) == 4) + { + const int m = batch_size * seq_len; + const int word_per_block = 1; + assert(k <= 1024); + assert(m / word_per_block * 3 <= 65536); + + dim3 grid(m / word_per_block * 3); + dim3 block(k); + add_QKV_bias<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf, k_buf, v_buf, + batch_size, seq_len, head_num, size_per_head, word_per_block); + } + else + { + const int word_per_block = 1; + grid.x = batch_size * seq_len / word_per_block; + block.x = head_num * size_per_head * word_per_block / 2; + + assert(block.x <= 1024); + + add_QKV_bias<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf, k_buf, + v_buf, batch_size, seq_len, head_num, size_per_head / 2, word_per_block); + } + } + else + { + // k > 1024, so split into many block + if(sizeof(T) == 4) + { + const int m = batch_size * seq_len; + const int word_per_block = 4; + dim3 block; + if(k % 512 == 0) + block.x = 512; + else if(k % 384 == 0) + block.x = 384; + else if(k % 256 == 0) + block.x = 256; + else if(k % 128 == 0) + block.x = 128; + else + printf("[ERROR] no supported k %d \n", k); + assert(k % block.x == 0); + dim3 grid(m * k / block.x / word_per_block * 3); + assert(grid.x <= 65536 && grid.x > 0); + add_QKV_bias_generalized<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf, k_buf, v_buf, + batch_size, seq_len, head_num, size_per_head, word_per_block); + + } + else + { + const int m = batch_size * seq_len; + const int word_per_block = 4; + const int half_k = k / 2; + dim3 block; + if(half_k % 512 == 0) + block.x = 512; + else if(half_k % 384 == 0) + block.x = 384; + else if(half_k % 256 == 0) + block.x = 256; + else if(half_k % 128 == 0) + block.x = 128; + else if(half_k % 64 == 0) + block.x = 64; + else + printf("[ERROR] no supported half_k %d \n", half_k); + assert(half_k % block.x == 0); + dim3 grid(m * half_k / block.x / word_per_block * 3); + assert(grid.x <= 65536 && grid.x > 0); + add_QKV_bias_generalized<<>>((const half2*)Q, (const half2*)bias_Q, + (const half2*)K, (const half2*)bias_K, + (const half2*)V, (const half2*)bias_V, + (half2*)q_buf, (half2*)k_buf, (half2*)v_buf, + batch_size, seq_len, head_num, + size_per_head / 2, word_per_block); + } + } +} + +template +__global__ +void add_fusedQKV_bias_transpose_kernel( + T* q_buf, + T* k_buf, + T* v_buf, + const T* __restrict QKV, + const T* __restrict qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head) +{ + // QKV: [m, 3, n] + // qkv_bias: [3, n] + // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] + + T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; + const int n = head_num * size_per_head; + for(int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 3 * n; index += gridDim.x * blockDim.x) + { + int bias_id = index % (3 * n); + T val = __ldg(&QKV[index]) + __ldg(&qkv_bias[bias_id]); + + int tmp_index = index; + const int target_batch_id = tmp_index / (seq_len * 3 * n); + tmp_index -= target_batch_id * seq_len * 3 * n; + const int seq_id = tmp_index / (3 * n); + tmp_index -= seq_id * 3 * n; + const int qkv_id = tmp_index / n; + tmp_index -= qkv_id * n; + const int head_id = tmp_index / size_per_head; + const int size_id = tmp_index - head_id * size_per_head; + + qkv_ptr[qkv_id][ + target_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + + seq_id * size_per_head + + size_id] = val; + } +} + +template +void add_fusedQKV_bias_transpose_kernelLauncher( + T* q_buf, + T* k_buf, + T* v_buf, + T* QKV, + const T* qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream) +{ + const int m = batch_size * seq_len; + const int n = head_num * size_per_head; + dim3 block(384); + dim3 grid((int)(ceil(1.0 * m * n / 384))); + add_fusedQKV_bias_transpose_kernel<<>>( + q_buf, k_buf, v_buf, QKV, qkv_bias, + batch_size, seq_len, head_num, size_per_head); +} + +template +__global__ +void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, + const T scalar) +{ + int batch_id = blockIdx.x / head_num; + int qk_offset = blockIdx.x * seq_len * seq_len; + int mask_offset = batch_id * seq_len * seq_len; + + __shared__ float s_sum, s_max; + + for(int i = 0; i < seq_len; ++i) + { + float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f; + float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f; + + mask_val = (1.0f - mask_val) * -10000.0f; + + float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scalar + mask_val): -1e20f; + + float max_val = blockReduceMax(tmp); + + if(threadIdx.x == 0) + s_max = max_val; + __syncthreads(); + + qk = threadIdx.x < seq_len ? __expf(tmp - s_max) : 0.0f; + + float sum_val = blockReduceSum(qk); + + if(threadIdx.x == 0) + { + s_sum = sum_val + 1e-6f; + } + __syncthreads(); + + if(threadIdx.x < seq_len) + qk_buf_[threadIdx.x + qk_offset] = (T)(qk / s_sum); + + qk_offset += seq_len; + mask_offset += seq_len; + } +} + + +template +__global__ +void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, + const int seq_len, const float scalar) +{ + int batch_id = blockIdx.x / head_num / seq_len; + int seq_id = blockIdx.x % seq_len; + int qk_offset = blockIdx.x * seq_len; + int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len; + + __shared__ float s_sum, s_max; + + float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f; + float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f; + + mask_val = (1.0f - mask_val) * -10000.0f; + + float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scalar + mask_val) : -1e20f; + float max_val = blockReduceMax(tmp); + if(threadIdx.x == 0) + s_max = max_val; + __syncthreads(); + + float qk_tmp = threadIdx.x < seq_len ? __expf((float)(tmp - s_max)) : 0.0f; + float sum_val = blockReduceSum(qk_tmp); + + if(threadIdx.x == 0) + { + s_sum = sum_val + 1e-6f; + } + __syncthreads(); + + if(threadIdx.x < seq_len) + qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum); +} + +//grid = (seq_len/word_per_thread, batch_size, head_num) +//block.x = max(32, (seq_len + 31)/32*32) +template +__global__ +void softmax_kernel_v3(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) +{ + + bool qual = threadIdx.x < seq_len; + for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ + float tmp = -1e20f; + int qk_offset; + __shared__ float s_mean, s_max; + if (qual){ + qk_offset = ((blockIdx.y*head_num + blockIdx.z)*seq_len + seq_id) *seq_len + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len + threadIdx.x; + + float qk = static_cast(qk_buf_[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + tmp = qk * static_cast(scalar) + mask_val; + } + + float max_val = blockReduceMax(tmp); + if (threadIdx.x == 0){ + s_max = max_val; + } + __syncthreads(); + + float qk_tmp = qual ? __expf(tmp - s_max) : 0.0f; + float sum_val = blockReduceSum(qk_tmp); + if (threadIdx.x == 0){ + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + if(qual) + qk_buf_[qk_offset] = (T)(qk_tmp * s_mean); + } +} + + +//grid = (seq_len/word_per_thread, batch_size, head_num) +//block.x = max(32, (seq_len/2 + 31)/32*32) +//seq_len % 2 == 0 +template <> +__global__ +void softmax_kernel_v3(half* qk_buf_, const half* attr_mask, + const int batch_size, const int head_num, + const int seq_len, const half scalar) +{ + int threadIdx2 = threadIdx.x << 1; + bool qual = threadIdx2 < seq_len; + half2* qk_buf_half2Ptr = (half2*) qk_buf_; + const half2* attr_mask_half2Ptr = (const half2*) attr_mask; + __shared__ float s_mean, s_max; + for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ + int qk_offset; + half2 tmp = __float2half2_rn(0.0f); + + float max_val = -1e20f; + half2 qk; + if (qual){ + qk_offset = ((((blockIdx.y*head_num + blockIdx.z)*seq_len + seq_id) *seq_len) >> 1) + threadIdx.x; + int mask_offset = (((blockIdx.y * seq_len + seq_id) * seq_len) >> 1) + threadIdx.x; + + qk = qk_buf_half2Ptr[qk_offset]; + half2 mask_val = __ldg(&attr_mask_half2Ptr[mask_offset]); + half2 mask_val_tmp = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), __float2half2_rn(-10000.0f)); + tmp = __hadd2(__hmul2(__half2half2(scalar), qk), mask_val_tmp); + max_val = fmax((float)tmp.x, (float)tmp.y); + } + + max_val = blockDim.x <= 32 ? warpReduceMax(max_val) : blockReduceMax(max_val); + + if (threadIdx.x == 0){ + s_max = max_val; + } + __syncthreads(); + + if (qual){ + tmp = h2exp(__hsub2(tmp, __float2half2_rn(s_max))); + } + float sum_val = blockDim.x <= 32 ? warpReduceSum((float)(tmp.x + tmp.y)) : blockReduceSum((float)(tmp.x + tmp.y)); + + if (threadIdx.x == 0){ + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + if(qual){ + qk = __hmul2(tmp, __float2half2_rn(s_mean)); + qk_buf_half2Ptr[qk_offset] = qk; + } + } +} + +template +__global__ +void softmax_kernel_v3_LE32(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) +{ + bool qual = threadIdx.x < seq_len; + for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ + int qk_offset; + __shared__ float s_mean, s_max; + float tmp = -1e20f; + if (qual){ + qk_offset = ((blockIdx.y*head_num + blockIdx.z)*seq_len + seq_id) *seq_len + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len + threadIdx.x; + + float qk = static_cast(qk_buf_[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + tmp = static_cast(qk) * static_cast(scalar) + mask_val; + } + float max_val = warpReduceMax(tmp); + + if (threadIdx.x == 0){ + s_max = max_val; + } + __syncthreads(); + + tmp = qual ? __expf(tmp - s_max) : 0.0f; + float sum_val = warpReduceSum(tmp); + + if (threadIdx.x == 0){ + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + if(qual) + qk_buf_[qk_offset] = (T)(tmp * s_mean); + } +} + +template +void attn_softmax_kernelLauncher( + T* buffer, + const T* attr_mask, + const int batch_size, + const int seq_len, + const int head_num, + const T scalar, + cudaStream_t stream) +{ + dim3 grid, block; + //deal with odd seq_len + if (seq_len % 2 != 0){ + if(seq_len <= 32) + block.x = 32; + else if(seq_len > 32 && seq_len <= 64) + block.x = 64; + else if(seq_len > 64 && seq_len <= 128) + block.x = 128; + else if(seq_len > 128 && seq_len <= 256) + block.x = 256; + else if(seq_len > 256 && seq_len <= 512) + block.x = 512; + else + block.x = 1024; + + if(batch_size * head_num <= 120) + { + grid.x = batch_size * head_num * seq_len; + softmax_kernel_v2<<>>(buffer, attr_mask, batch_size, head_num, seq_len, scalar); + } + else + { + grid.x = batch_size * head_num; + softmax_kernel<<>>(buffer, attr_mask, batch_size, head_num, seq_len, scalar); + } + } + //deal with even seq_len + else{ + grid.x = seq_len; + if (batch_size * head_num > 360) + grid.x = ceil(float(seq_len)/32.0f); + grid.y = batch_size; + grid.z = head_num; + if (seq_len <= 32){ + block.x = 32; + softmax_kernel_v3_LE32<<>>(buffer, attr_mask, batch_size, head_num, seq_len, scalar); + } + else{ + if (sizeof(T) == 2){ + block.x = (seq_len/2 + 31)/32*32; + softmax_kernel_v3<<>>(buffer, attr_mask, batch_size, head_num, seq_len, scalar); + } + else{ + block.x = (seq_len + 31)/32*32; + softmax_kernel_v3<<>>(buffer, attr_mask, batch_size, head_num, seq_len, scalar); + } + } + grid.x = grid.y = grid.z = 1; + } +} + +template +__global__ +void transpose(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head) +{ + int batch_id = blockIdx.x / (head_num * seq_len); + int seq_id = blockIdx.x % seq_len; + int head_id = (blockIdx.x % (head_num * seq_len))/ seq_len; + dst[batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head + + head_id * size_per_head + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; +} + +template<> + __global__ +void transpose(half* src, half* dst, + const int batch_size, const int seq_len, const int head_num, const int size_per_head) +{ + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + int batch_id = tid / (head_num * seq_len * size_per_head); + int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head); + int seq_id = (tid % (seq_len * size_per_head)) / size_per_head; + int id = tid % size_per_head; + + int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head); + half2* src_ptr = (half2*)src; + half2* dst_ptr = (half2*)dst; + + dst_ptr[target_id] = src_ptr[tid]; +} + +template +void transpose_kernelLauncher( + T* dst, + T* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream) +{ + dim3 grid, block; + if(sizeof(T) == 2) + { + const int seq_per_block = 4; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head / 2; + + assert(grid.x * seq_per_block == batch_size * head_num * seq_len); + + transpose<<>>(src, dst, + batch_size, seq_len, head_num, size_per_head / 2); + } + else + { + const int seq_per_block = 1; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head; + transpose<<>>(src, dst, + batch_size, seq_len, head_num, size_per_head); + } +} + +template void add_QKV_bias_transpose_kernelLauncher( + float* q_buf, + float* k_buf, + float* v_buf, + float* Q, + const float* bias_Q, + float* K, + const float* bias_K, + float* V, + const float* bias_V, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +template void add_QKV_bias_transpose_kernelLauncher( + half* q_buf, + half* k_buf, + half* v_buf, + half* Q, + const half* bias_Q, + half* K, + const half* bias_K, + half* V, + const half* bias_V, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +template void add_fusedQKV_bias_transpose_kernelLauncher( + float* q_buf, + float* k_buf, + float* v_buf, + float* QKV, + const float* qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +template void add_fusedQKV_bias_transpose_kernelLauncher( + half* q_buf, + half* k_buf, + half* v_buf, + half* QKV, + const half* qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +template void attn_softmax_kernelLauncher( + float* buffer, + const float* attr_mask, + const int batch_size, + const int seq_len, + const int head_num, + const float scalar, + cudaStream_t stream); + +template void attn_softmax_kernelLauncher( + half* buffer, + const half* attr_mask, + const int batch_size, + const int seq_len, + const int head_num, + const half scalar, + cudaStream_t stream); + +template void transpose_kernelLauncher( + float* dst, + float* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +template void transpose_kernelLauncher( + half* dst, + half* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +} // namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/cuda/attention_kernels.cuh b/fastertransformer/cuda/attention_kernels.cuh new file mode 100644 index 000000000..7fbe4b229 --- /dev/null +++ b/fastertransformer/cuda/attention_kernels.cuh @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include "fastertransformer/utils/arguments.h" +#include + +namespace fastertransformer +{ + +template +void add_QKV_bias_transpose_kernelLauncher( + T* q_buf, + T* k_buf, + T* v_buf, + T* Q, + const T* bias_Q, + T* K, + const T* bias_K, + T* V, + const T* bias_V, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +template +void add_fusedQKV_bias_transpose_kernelLauncher( + T* q_buf, + T* k_buf, + T* v_buf, + T* QKV, + const T* qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +template +void attn_softmax_kernelLauncher( + T* buffer, + const T* attr_mask, + const int batch_size, + const int seq_len, + const int head_num, + const T scalar, + cudaStream_t stream); + +template +void transpose_kernelLauncher( + T* dst, + T* src, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + cudaStream_t stream); + +} // namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/cuda/cuda_int8_kernels.cu b/fastertransformer/cuda/cuda_int8_kernels.cu index 835f114b7..157761317 100644 --- a/fastertransformer/cuda/cuda_int8_kernels.cu +++ b/fastertransformer/cuda/cuda_int8_kernels.cu @@ -14,8 +14,7 @@ * limitations under the License. */ -#include "fastertransformer/common.h" - +#include "fastertransformer/utils/common.h" #include "cuda_kernels.h" #include "cuda_int8_kernels.h" #include @@ -23,6 +22,7 @@ #include #include #include + namespace fastertransformer{ template @@ -47,7 +47,6 @@ half2 gelu(half2 val) } - template __inline__ __device__ T warpReduceSum(T val) @@ -166,6 +165,7 @@ template void transposeMatrix_COL32ToColMajor_kernelLauncher(half *dst, co template void transposeMatrix_COL32ToColMajor_kernelLauncher(int8_t* dst, const int8_t* src, const int m, const int n, cudaStream_t stream); + //transpose matrix & transfrom col-major to COL32 & quantize //input matrix is (m, n) col-major //output matrix is (n, m) COL32, using char4 to write out @@ -263,7 +263,7 @@ void transposeMatrix_colMajorToCOL32_kernel(half2*dst, const half2* src, const i } } -//transpose matrix & transfrom col-major to COL32 & quantize +//transpose matrix & transfrom col-major to COL32 //input matrix is (m, n) col-major //output matrix is (n, m) COL32, using char4 to write out //m should be a mutiple of 32 @@ -283,6 +283,41 @@ template void transposeMatrix_colMajorToCOL32_kernelLauncher(float* dst, template void transposeMatrix_colMajorToCOL32_kernelLauncher(half *dst, const half* src, const int m, const int n, cudaStream_t stream); +//transfrom row-major to COL32 +//input matrix is (m, n) row-major +//output matrix is (m, n) COL32 +//n should be a mutiple of 32 +//grid((n+31)/32, (m+31)/32) +//block(8, 32) +__global__ +void rowMajorToCOL32_kernel(char4*dst, const char4* src, const int m, const int n) +{ + + int n_id = (blockIdx.x*blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y*blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) + { + + // COL32_col = n_id >> 5 ; COL32_row = (m_id << 5) + (n_id & 31); + // COL32_idx = (COL32_col << 5) * m + COL32_row = (n_id & 0xffffffe0)*m + (m_id << 5) + (n_id & 31) + dst[((n_id & 0xffffffe0)*m + (m_id << 5) + (n_id & 31)) >> 2] = __ldg(src+((m_id*n+n_id) >> 2)); + } +} + +//transfrom row-major to COL32 +//input matrix is (m, n) row-major +//output matrix is (m, n) COL32 +//n should be a mutiple of 32 +//grid((n+31)/32, (m+31)/32) +//block(8, 32) +void rowMajorToCOL32_kernelLauncher(int8_t* dst, const int8_t* src, const int m, const int n, cudaStream_t stream) +{ + assert(n%32 == 0); + rowMajorToCOL32_kernel<<>>((char4*)dst, (const char4*)src, m, n); +} + //add bias to matrix of m * n, CUBLASLT_ORDER_COL32 //grid, thread = (m), (n/4) @@ -939,7 +974,6 @@ void transpose_COL32_kernel(int8_t* dst, const int8_t* src, const int batch_size const int size_per_head, const float *bmm2_deQFactor, const float* out_scale_ptr, const int batch_size_x_seq_len, const int seq_len_x_size_per_head) { - const float scale = __ldg(bmm2_deQFactor) * __ldg(out_scale_ptr); int threadIdx4 = threadIdx.x << 2; int batch_id = blockIdx.y; int seq_id = blockIdx.x; @@ -964,15 +998,9 @@ void transpose_COL32_kernel(int8_t* dst, const int8_t* src, const int batch_size COL32_col = mk_col >> 5; int inIdx = ((batch_id*head_num + head_id)*seq_len_x_size_per_head + (COL32_col << 5 )*seq_len + COL32_row) >> 2; - char4 tmp; const char4* src_ptr4 = (const char4*)src; - tmp = __ldg(src_ptr4 + inIdx); - tmp.x = float_to_int8_rn(tmp.x*scale); - tmp.y = float_to_int8_rn(tmp.y*scale); - tmp.z = float_to_int8_rn(tmp.z*scale); - tmp.w = float_to_int8_rn(tmp.w*scale); char4 *dst_ptr4 = (char4 *)dst; - dst_ptr4[outIdx] = tmp; + dst_ptr4[outIdx] = __ldg(src_ptr4 + inIdx); } void transpose_COL32_kernelLauncher(int8_t* dst, const int8_t* src, const int batch_size, const int seq_len, const int head_num, @@ -1046,7 +1074,6 @@ __global__ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, const int8_t* src, const int* sequence_id_map, const int valid_word_num, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const float *bmm2_deQFactor, const float* out_scale_ptr, const int seq_len_x_size_per_head) { - const float scale = __ldg(bmm2_deQFactor) * __ldg(out_scale_ptr); int threadIdx4 = threadIdx.x << 2; int batch_id = blockIdx.y; int seq_id = blockIdx.x; @@ -1072,17 +1099,11 @@ void transpose_COL32_rebuild_padding_kernel(int8_t* dst, const int8_t* src, cons COL32_col = mk_col >> 5; int inIdx = ((batch_id*head_num + head_id)*seq_len_x_size_per_head + (COL32_col << 5 )*seq_len + COL32_row) >> 2; - char4 tmp; const char4* src_ptr4 = (const char4*)src; - tmp = __ldg(src_ptr4 + inIdx); - tmp.x = float_to_int8_rn(tmp.x*scale); - tmp.y = float_to_int8_rn(tmp.y*scale); - tmp.z = float_to_int8_rn(tmp.z*scale); - tmp.w = float_to_int8_rn(tmp.w*scale); char4 *dst_ptr4 = (char4 *)dst; - dst_ptr4[outIdx] = tmp; + dst_ptr4[outIdx] = __ldg(src_ptr4 + inIdx); } } diff --git a/fastertransformer/cuda/cuda_int8_kernels.h b/fastertransformer/cuda/cuda_int8_kernels.h index c3975ae56..676803233 100644 --- a/fastertransformer/cuda/cuda_int8_kernels.h +++ b/fastertransformer/cuda/cuda_int8_kernels.h @@ -117,4 +117,7 @@ void rebuild_sequence_length_padding_COL32_kernelLauncher(const T* src, T* tgt, const int* mask_offset, const int m, const int n, const int tgt_m, cudaStream_t stream); +void rowMajorToCOL32_kernelLauncher(int8_t* dst, const int8_t* src, + const int m, const int n, cudaStream_t stream); + } //namespace fastertransformer diff --git a/fastertransformer/cuda/cuda_kernels.cu b/fastertransformer/cuda/cuda_kernels.cu index a9c74c4c4..467bf4fec 100644 --- a/fastertransformer/cuda/cuda_kernels.cu +++ b/fastertransformer/cuda/cuda_kernels.cu @@ -14,38 +14,15 @@ * limitations under the License. */ -#include "fastertransformer/common.h" - +#include "fastertransformer/utils/common.h" #include "cuda_kernels.h" #include #include #include #include #include -namespace fastertransformer{ - -template -__inline__ __device__ -T gelu(T x) -{ - float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__inline__ __device__ -half2 gelu(half2 val) -{ - half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); - -} +namespace fastertransformer{ template __inline__ __device__ @@ -75,7 +52,6 @@ T blockReduceSum(T val) return val; } - template __inline__ __device__ T warpReduceMax(T val) @@ -107,278 +83,7 @@ T blockReduceMax(T val) } template -__global__ -void add_bias_act(T* out, const T* bias, int m, int n) -{ - T val, reg_bias; - - int row_id = blockIdx.x; - int ite = n / blockDim.x; - int tid = threadIdx.x; - - for(int i = 0; i < ite; ++i) - { - reg_bias = __ldg(&bias[i * blockDim.x + tid]); - row_id = blockIdx.x; - - while(row_id < m){ - val = out[tid + i * blockDim.x + row_id * n]+ reg_bias; - out[tid + i * blockDim.x + row_id * n] = gelu(val); - row_id += gridDim.x; - } - } -} - -template <> -__global__ -void add_bias_act(half* out, const half* bias, int m, int n) -{ - half2 val, reg_bias; - int row_id = blockIdx.x; - int ite = n / blockDim.x / 2; - int tid = threadIdx.x; - - half2* out_ptr = (half2*) out; - const half2* bias_ptr = (half2*) bias; - for(int i = 0; i < ite; ++i) - { - reg_bias = __ldg(&bias_ptr[i * blockDim.x + tid]); - row_id = blockIdx.x; - - while(row_id < m){ - val = out_ptr[tid + i * blockDim.x + row_id * n / 2]; - val = __hadd2(val, reg_bias); - out_ptr[tid + i * blockDim.x + row_id * n / 2] = gelu(val); - row_id += gridDim.x; - } - } -} - -template -__global__ -void add_bias_input_layernorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) -{ - int tid = threadIdx.x; - - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - - float local_out = 0.0f; - local_out += (float)(out[blockIdx.x * n + tid] + input[blockIdx.x * n + tid] + __ldg(&bias[tid])); - - mean = blockReduceSum(local_out); - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - variance = blockReduceSum((local_out - s_mean) * (local_out - s_mean)); - if(threadIdx.x == 0) - s_variance = variance / n + 1e-6f; - __syncthreads(); - - out[blockIdx.x * n + tid] = - (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); -} - -template <> -__global__ -void add_bias_input_layernorm(half* out, const half* input, const half* bias, - const half* gamma, const half* beta, int m, int n) -{ - - int tid = threadIdx.x; - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float2 local_out_fp2; - - half2* out_ptr = (half2*)out; - const half2* input_ptr = (const half2*)input; - const half2* bias_ptr = (const half2*)bias; - const half2* gamma_ptr = (const half2*)gamma; - const half2* beta_ptr = (const half2*)beta; - - float local_out = 0.0f; - int id = blockIdx.x * n / 2 + tid; - local_out_fp2 = __half22float2(__hadd2(__hadd2(out_ptr[id], input_ptr[id]), __ldg(&bias_ptr[tid]))); - local_out += local_out_fp2.x; - local_out += local_out_fp2.y; - - mean = blockReduceSum(local_out); - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); - variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); - variance = blockReduceSum(variance); - if(threadIdx.x == 0) - s_variance = rsqrtf(variance / n + 1e-6f); - __syncthreads(); - - float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); - float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); - local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; - local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; - out_ptr[id] = __float22half2_rn(local_out_fp2); -} - - -template -__global__ -void add_bias_input_layernorm_v2(T* out, const T* __restrict input, const T* __restrict bias, - const T* __restrict gamma, const T* __restrict beta, int n) -{ - const int ite = 4; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float local_out[ite]; - - float sum = 0.0f; - #pragma unroll - for(int i = 0; i < ite; i++) - { - int col_id = i * blockDim.x + tid; - int id = bid * n + col_id; - local_out[i] = (float)(out[id] + __ldg(&input[id]) + __ldg(&bias[col_id])); - sum += local_out[i]; - } - - mean = blockReduceSum(sum); - if(tid == 0) - s_mean = mean / n; - __syncthreads(); - - float var = 0.0f; - #pragma unroll - for(int i = 0; i < ite; i++) - { - float diff = local_out[i] - s_mean; - var += diff * diff; - } - - variance = blockReduceSum(var); - if(tid == 0) - s_variance = rsqrtf(variance / n + 1e-6f); - __syncthreads(); - - #pragma unroll - for(int i = 0; i < ite; i++) - { - int col_id = i * blockDim.x + tid; - int id = bid * n + col_id; - out[id] = (T)((local_out[i] - s_mean) * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); - } -} - -template <> -__global__ -void add_bias_input_layernorm_v2(half* out, const half* __restrict input, const half* __restrict bias, - const half* __restrict gamma, const half* __restrict beta, int n) -{ - const int ite = 4; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - half2 local_out_half2[ite]; - - half2* out_ptr = (half2*)out; - const half2* input_ptr = (const half2*)input; - const half2* bias_ptr = (const half2*)bias; - const half2* gamma_ptr = (const half2*)gamma; - const half2* beta_ptr = (const half2*)beta; - - // float sum = 0.0f; - half2 sum = __float2half2_rn(0.0f); - #pragma unroll - for(int i = 0; i < ite; i++) - { - int col_id = i * blockDim.x + tid; - int id = bid * n / 2 + col_id; - local_out_half2[i] = out_ptr[id] + __ldg(&input_ptr[id]) + __ldg(&bias_ptr[col_id]); - sum += local_out_half2[i]; - } - - mean = blockReduceSum((float)(sum.x + sum.y)); - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - float var = 0.0f; - half2 s_mean_2 = __float2half2_rn(s_mean); - #pragma unroll - for(int i = 0; i < ite; i++) - { - local_out_half2[i] = local_out_half2[i] - s_mean_2; - float v1 = (float)local_out_half2[i].x; - float v2 = (float)local_out_half2[i].y; - var += v1 * v1 + v2 * v2; - } - - variance = blockReduceSum(var); - if(threadIdx.x == 0) - s_variance = rsqrtf(variance / n + 1e-6f); - __syncthreads(); - - half2 s_var_2 = __float2half2_rn(s_variance); - #pragma unroll - for(int i = 0; i < ite; i++) - { - int col_id = i * blockDim.x + tid; - int id = bid * n / 2 + col_id; - out_ptr[id] = local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) + __ldg(&beta_ptr[col_id]); - } -} - -template -void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, cudaStream_t stream) -{ - dim3 grid(ceil(m / 4.)); - dim3 block(n / 4); - assert(block.x <= 1024); - add_bias_act<<>>(out, bias, m, n); -} - -template -void add_bias_input_layernorm_kernelLauncher(T* out, const T* input, const T* bias, - const T* gamma, const T* beta, int m, int n, cudaStream_t stream) -{ - dim3 grid(m); - dim3 block(n); - assert(n <= 1024); - if(n == 768 || n == 1024) - add_bias_input_layernorm_v2<<>>(out, input, bias, gamma, beta, n); - else - add_bias_input_layernorm<<>>(out, input, bias, gamma, beta, m, n); -} - -template <> -void add_bias_input_layernorm_kernelLauncher(half* out, const half* input, const half* bias, - const half* gamma, const half* beta, int m, int n, cudaStream_t stream) -{ - dim3 grid(m); - dim3 block(n / 2); - assert(n / 2 <= 1024); - - if(m >= 512 && (n == 768 || n == 1024)) - add_bias_input_layernorm_v2<<>>(out, input, bias, gamma, beta, n); - else - add_bias_input_layernorm<<>>(out, input, bias, gamma, beta, m, n); -} - -template -__global__ void update_logits_kernel(T* logits, const T* bias, const int end_id, const bool* finished, const int n) +__global__ void update_logits_kernel(float* logits, const T* tmp_logits, const T* bias, const int end_id, const bool* finished, const int n) { int bid = blockIdx.x; bool finish = finished[bid]; @@ -388,35 +93,45 @@ __global__ void update_logits_kernel(T* logits, const T* bias, const int end_id, __shared__ float s_max_val; __shared__ float s_sum_val; - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + if(finish) { - if(finish) - logits[offset + tid] = (tid == end_id) ? FLT_MAX : -1 * FLT_MAX; - else - logits[offset + tid] += bias[tid]; - max_val = max(max_val, logits[offset + tid]); + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + logits[offset + tid] = (tid == end_id) ? 0 : -FLT_MAX; + } } + else + { + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + if(finish) + logits[offset + tid] = (tid == end_id) ? FLT_MAX : -1 * FLT_MAX; + else + logits[offset + tid] = (float)(tmp_logits[offset + tid] + bias[tid]); + max_val = max(max_val, logits[offset + tid]); + } - max_val = blockReduceMax((float)max_val); - if(threadIdx.x == 0) - s_max_val = max_val; - __syncthreads(); + max_val = blockReduceMax((float)max_val); + if(threadIdx.x == 0) + s_max_val = max_val; + __syncthreads(); - float sum_val = 0.0f; - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) - { - logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val); - sum_val += (float)logits[offset + tid]; - } + float sum_val = 0.0f; + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val); + sum_val += (float)logits[offset + tid]; + } - sum_val = blockReduceSum(sum_val); - if(threadIdx.x == 0) - s_sum_val = sum_val; - __syncthreads(); + sum_val = blockReduceSum(sum_val); + if(threadIdx.x == 0) + s_sum_val = sum_val; + __syncthreads(); - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) - { - logits[offset + tid] = logf((float)logits[offset + tid] / s_sum_val); + for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + { + logits[offset + tid] = logf((float)logits[offset + tid] / s_sum_val); + } } } @@ -445,11 +160,11 @@ __global__ void update_logits_kernel_without_softmax(T* logits, const T* bias, c template __global__ void softmax_kernel(T* logits, const T* bias, const int end_id, const bool* finished, - const int n) + const int n_padded, const int n) { int bid = blockIdx.x; bool finish = (finished != nullptr) ? finished[bid] : false; - int offset = bid * n; + int offset = bid * n_padded; float max_val = -1 * FLT_MAX; const bool IS_FP16 = std::is_same::value; @@ -457,14 +172,21 @@ __global__ void softmax_kernel(T* logits, const T* bias, __shared__ float s_max_val; __shared__ float s_sum_val; - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + for(int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) { - if(finish) - logits[offset + tid] = (tid == end_id) ? MAX_T_VAL : -MAX_T_VAL; + if(tid < n) + { + if(finish) + logits[offset + tid] = (tid == end_id) ? MAX_T_VAL : -MAX_T_VAL; + else + { + T bias_val = (bias != nullptr) ? bias[tid] : (T)0.0f; + logits[offset + tid] += bias_val; + } + } else { - T bias_val = (bias != nullptr) ? bias[tid] : (T)0.0f; - logits[offset + tid] += bias_val; + logits[offset + tid] = -MAX_T_VAL; } max_val = max(max_val, (float)logits[offset + tid]); } @@ -475,7 +197,7 @@ __global__ void softmax_kernel(T* logits, const T* bias, __syncthreads(); float sum_val = 0.0f; - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + for(int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) { logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val); sum_val += (float)logits[offset + tid]; @@ -486,7 +208,7 @@ __global__ void softmax_kernel(T* logits, const T* bias, s_sum_val = sum_val; __syncthreads(); - for(int tid = threadIdx.x; tid < n; tid += blockDim.x) + for(int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) { logits[offset + tid] = ((float)logits[offset + tid] / s_sum_val); } @@ -596,15 +318,43 @@ template void remove_sequence_length_padding_kernelLauncher(const half* src, hal int* mask_offset, const int m, const int n, cudaStream_t stream); -void update_logits(float* logits, const float* bias, const int end_id, const bool* finished, +template +__global__ void cuda_random_uniform_kernel(T* buffer, const int size) +{ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((T)1337.f, idx, 0, &local_state); + for(int index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = (T)(curand_uniform(&local_state) * 0.2f - 0.1f); + } +} + +template +void cuda_random_uniform_kernelLauncher(T *buffer, const int size) +{ + cuda_random_uniform_kernel<<<256, 256>>>(buffer, size); +} + +template void cuda_random_uniform_kernelLauncher(float *buffer, const int size); +template void cuda_random_uniform_kernelLauncher(half *buffer, const int size); + +template +void update_logits(float* logits, const T* tmp_logits, const T* bias, const int end_id, const bool* finished, const int m, const int n, cudaStream_t stream) { dim3 grid(m); dim3 block(min(n, 1024)); /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ - update_logits_kernel<<>>(logits, bias, end_id, finished, n); + update_logits_kernel<<>>(logits, tmp_logits, bias, end_id, finished, n); } +template void update_logits(float* logits, const float* tmp_logits, const float* bias, const int end_id, + const bool* finished, const int m, const int n, cudaStream_t stream); + +template void update_logits(float* logits, const half* tmp_logits, const half* bias, const int end_id, + const bool* finished, const int m, const int n, cudaStream_t stream); + template void update_logits_without_softmax(T* logits, const T* bias, const int end_id, const bool* finished, const int m, const int n, cudaStream_t stream) @@ -623,33 +373,19 @@ template void update_logits_without_softmax(half* logits, const half* bias, cons template void softmax_kernelLauncher(T* logits, const T* bias, const int end_id, const bool* finished, - const int m, const int n, cudaStream_t stream) + const int m, const int n_padded, const int n, cudaStream_t stream) { dim3 grid(m); dim3 block(min(n, 1024)); /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ - softmax_kernel<<>>(logits, bias, end_id, finished, n); + softmax_kernel<<>>(logits, bias, end_id, finished, n_padded, n); } template void softmax_kernelLauncher(float* logits, const float* bias, const int end_id, const bool* finished, - const int m, const int n, cudaStream_t stream); + const int m, const int n_padded, const int n, cudaStream_t stream); template void softmax_kernelLauncher(half* logits, const half* bias, const int end_id, const bool* finished, - const int m, const int n, cudaStream_t stream); - -template void add_bias_act_kernelLauncher( - float* out, const float* bias, int m, int n, cudaStream_t stream); - -template void add_bias_input_layernorm_kernelLauncher( - float* out, const float* input, const float* bias, const float* gamma, const float* beta, - int m, int n, cudaStream_t stream); - -template void add_bias_act_kernelLauncher( - half* out, const half* bias, int m, int n, cudaStream_t stream); - -template void add_bias_input_layernorm_kernelLauncher( - half* out, const half* input, const half* bias, const half* gamma, const half* beta, - int m, int n, cudaStream_t stream); + const int m, const int n_padded, const int n, cudaStream_t stream); /* *********************************** Debug tools *********************************** */ @@ -678,6 +414,17 @@ void print_kernel(const T* buf, uint size) printf("\n"); } +template <> +__global__ +void print_kernel(const int* buf, uint size) +{ + for(int i = 0; i < size; i++) + { + printf("%d ", buf[i]); + } + printf("\n"); +} + template void print_first_k(const T* buf, uint size, cudaStream_t stream) { @@ -701,6 +448,7 @@ void print_abs_mean(const T* buf, uint size, cudaStream_t stream) template void print_first_k(const float*, uint size, cudaStream_t); template void print_first_k(const half*, uint size, cudaStream_t); template void print_first_k(const int*, uint size, cudaStream_t); +template void print_first_k(const bool*, uint size, cudaStream_t); template void print_abs_mean(const float* buf, uint size, cudaStream_t stream); template void print_abs_mean(const half* buf, uint size, cudaStream_t stream); @@ -708,6 +456,7 @@ template void print_abs_mean(const int* buf, uint size, cudaStream_t stream); /* **************************** end of Debug tools *********************************** */ +// TODO remove in v4.1 /* *************************** depreciated kernels *********************************** */ template diff --git a/fastertransformer/cuda/cuda_kernels.h b/fastertransformer/cuda/cuda_kernels.h index aba208494..a86ce7dbd 100644 --- a/fastertransformer/cuda/cuda_kernels.h +++ b/fastertransformer/cuda/cuda_kernels.h @@ -17,7 +17,7 @@ #include #include #include -#include "fastertransformer/arguments.h" +#include "fastertransformer/utils/arguments.h" #include "fastertransformer/cuda/topk_kernels.cuh" namespace fastertransformer @@ -33,15 +33,6 @@ void init_kernelLauncher(bool* finished, int* sequence_length, int* word_ids, T* cum_log_probs, const int sentence_id, const int batch_size, const int beam_width, cudaStream_t stream); -template -void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, cudaStream_t stream); - -template -void add_bias_input_layernorm_kernelLauncher(T *out, const T *input_tensor, - const T *bias, const T *gamma, - const T *beta, int m, int n, - cudaStream_t stream); - template void embedding_lookup_sine_position_encoding_kernel_launcher(T *from_tensor, const T *embedding_table, @@ -72,22 +63,63 @@ void embedding_position_lookups_kernel_launcher(T* from_tensor, int step, cudaStream_t stream); +template +void start_id_embedding_position_lookups_kernel_launcher(T* from_tensor, + int* output_ids, + const T* embedding_table, + const T* pos_table, + const int* word_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + template void apply_temperature_penalty_kernelLauncher(T* logits, const T temperature, const int m, - const int n, + const int vocab_size, + const int vocab_size_padd, cudaStream_t stream); +void set_start_ids_kernelLauncher(int* out_ids, + const int* in_ids, + const int max_start_len, + const int step, + const int ite, + const int batch_size, + const int local_batch_size, + const int end_id, + cudaStream_t stream); + +template +void kernel_padding_kernelLauncher(T *padded_kernel, const T *kernel, + const int row_dim, const int col_dim, + const int padded_col_dim, cudaStream_t stream); + +template +void bias_padding_kernelLauncher(T1 *padded_bias, const T2 *bias, + const int col_dim, const int padded_col_dim, + cudaStream_t stream); + template -void transpose(T *out, const T *in, int batch, +void transpose(T *out, const T *in, int batch, int height, int width, int stride, cudaStream_t stream); -/* *************************** end of common kernel *********************************** */ + +template +void transpose_axis_01_kernelLauncher(DataType_ *out, DataType_ *in, const int dim0, + const int dim1, const int dim2, cudaStream_t stream); + void build_sequence_length_padding_offset_kernelLauncher(const int *sequence_length, const int batch_size, const int max_seq_len, int *valid_word_num, int *tmp_mask_offset, cudaStream_t stream); +template +void cuda_random_uniform_kernelLauncher(T *buffer, const int size); + /* *************************** end of common kernel *********************************** */ /* ********************************** BeamSearch kernel *********************************** */ @@ -96,7 +128,8 @@ void broadcast_kernelLauncher(float *log_probs, float *cum_log_probs, const int batch_size, const int beam_width, const int vocab_size, cudaStream_t stream); -void update_logits(float* logits, const float* bias, const int end_ids, +template +void update_logits(float* logits, const T *tmp_logits, const T* bias, const int end_ids, const bool* finished, const int m, const int n, cudaStream_t stream); @@ -106,9 +139,9 @@ void apply_logit_penalties(int step, int* current_ids, int* previous_ids, int* parent_ids, - Gpt2Arguments args, + GptArguments args, cudaStream_t stream); - + void update_kernelLauncher(float* log_probs, float* cum_log_probs, bool* finished, int* parent_ids, int* sequence_length, int* word_ids, int* output_ids, @@ -126,8 +159,10 @@ void update_kernelLauncher_v2(bool* finished, int* parent_ids, template void update_KV_cache_kernelLauncher(T **key_cache, T **value_cache, const int *beam_ids, + const bool* finished, const int batch_size, const int beam_width, - const int hidden_dim, const int step, + const int head_num, const int size_per_head, + const int step, const int decoder_max_seq_len, const int cache_size, const int decoder_layers, cudaStream_t stream); @@ -163,6 +198,16 @@ void topp_initialization_kernelLauncher(bool* finished, const int logits_buf_size, DecodingSamplingArguments args, cudaStream_t stream); + +void topp_initialization_kernelLauncher_v2(bool* finished, + int* sequence_length, + int* word_ids, + int* topp_id_val_buf, + int* topp_offset_buf, + int* begin_topp_offset_buf_, + const int logits_buf_size, + DecodingSamplingArguments args, + cudaStream_t stream); void init_topp_id_val_kernel_kernelLauncher(int *topp_id_val_buf, int *topp_offset_buf, @@ -177,7 +222,7 @@ void update_logits_without_softmax(T* logits, const T* bias, const int end_ids, template void softmax_kernelLauncher(T* logits, const T* bias, const int end_ids, - const bool* finished, const int m, const int n, + const bool* finished, const int m, const int n_padded, const int n, cudaStream_t stream); /* *************************** end of Sampling kernel *********************************** */ diff --git a/fastertransformer/cuda/decoding_kernel_check.cpp b/fastertransformer/cuda/decoding_kernel_check.cpp index 95f8ca843..e5636836f 100644 --- a/fastertransformer/cuda/decoding_kernel_check.cpp +++ b/fastertransformer/cuda/decoding_kernel_check.cpp @@ -93,7 +93,7 @@ void init_kernel_check(bool *d_finished, int *d_sequence_length, int *d_word_ids printf("[INFO] decoding init check Finish. \n"); } -void update_logits_kernel_check(float *logits, const float *bias, const int end_id, const bool *finished, const int m, const int n, cudaStream_t stream) +void update_logits_kernel_check(float *logits, const float *tmp_logits, const float *bias, const int end_id, const bool *finished, const int m, const int n, cudaStream_t stream) { // m: batch_size * beam_width // n: vocab size @@ -109,7 +109,7 @@ void update_logits_kernel_check(float *logits, const float *bias, const int end_ check_cuda_error(cudaMemcpy(h_logits, logits, sizeof(float) * m * n, cudaMemcpyDeviceToHost)); check_cuda_error(cudaMemcpy(h_bias, bias, sizeof(float) * n, cudaMemcpyDeviceToHost)); check_cuda_error(cudaMemcpy(h_finished, finished, sizeof(bool) * m, cudaMemcpyDeviceToHost)); - update_logits(logits, bias, end_id, finished, m, n, stream); + update_logits(logits, tmp_logits, bias, end_id, finished, m, n, stream); cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); check_cuda_error(cudaMemcpy(h_logits_after_update, logits, sizeof(float) * m * n, cudaMemcpyDeviceToHost)); diff --git a/fastertransformer/cuda/decoding_kernel_check.h b/fastertransformer/cuda/decoding_kernel_check.h index 6308a11ef..f496e519b 100644 --- a/fastertransformer/cuda/decoding_kernel_check.h +++ b/fastertransformer/cuda/decoding_kernel_check.h @@ -16,7 +16,7 @@ limitations under the License. #pragma once #include "cuda_kernels.h" -#include "fastertransformer/common.h" +#include "fastertransformer/utils/common.h" #include #include #include @@ -41,9 +41,11 @@ void update_kernel_check(float* log_probs, float* cum_log_probs, int* ids, bool* const int end_id, int* finished_count); template -void update_KV_cache_kernel_check(T** key_cache, T** value_cache, const int* beam_ids, const int batch_size, const int beam_width, const int hidden_dim, - const int step, const int cache_size, const int decoder_layers, cudaStream_t stream){ +void update_KV_cache_kernel_check(T** key_cache, T** value_cache, const int* beam_ids, const int batch_size, const int beam_width, + const int head_num, const int size_per_head, const int step, const int cache_size, const int decoder_layers, cudaStream_t stream){ + const int hidden_dim = head_num * size_per_head; + printf("[INFO] decoding update KV cache check for step %d. \n", step); const int src_id = step & 0x1; const int tgt_id = 1 - src_id; @@ -67,7 +69,8 @@ void update_KV_cache_kernel_check(T** key_cache, T** value_cache, const int* bea check_cuda_error(cudaMemcpy(h_beam_ids, beam_ids, sizeof(int) * batch_size * beam_width, cudaMemcpyDeviceToHost)); // compute on GPU and copy the result to CPU - update_KV_cache_kernelLauncher(key_cache, value_cache, beam_ids, batch_size, beam_width, hidden_dim, step, cache_size, decoder_layers, stream); + // we use sequence major cache format here + update_KV_cache_kernelLauncher(key_cache, value_cache, beam_ids, nullptr, batch_size, beam_width, head_num, size_per_head, step, -1, cache_size, decoder_layers, stream); cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); check_cuda_error(cudaMemcpy(h_key_cache_tgt_after_update, key_cache[tgt_id], sizeof(T) * cache_size * decoder_layers, cudaMemcpyDeviceToHost)); diff --git a/fastertransformer/cuda/decoding_kernels.cu b/fastertransformer/cuda/decoding_kernels.cu index a806c4e7f..2c0b56ea7 100644 --- a/fastertransformer/cuda/decoding_kernels.cu +++ b/fastertransformer/cuda/decoding_kernels.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "fastertransformer/common.h" +#include "fastertransformer/utils/common.h" #include "cuda_kernels.h" #include "cub/cub.cuh" @@ -36,15 +36,18 @@ namespace fastertransformer int* word_ids, T* cum_log_probs, const int sentence_id, - const int beam_width) + const int beam_width, + const int batch_size) { const bool IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16)? HALF_FLT_MAX : 1e20f; - int tid = threadIdx.x; - finished[tid] = false; - sequence_length[tid] = 0; - word_ids[tid] = sentence_id; - cum_log_probs[tid] = (tid % beam_width == 0) ? (T)0.0f: -MAX_T_VAL; + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width; index += blockDim.x * gridDim.x) + { + finished[index] = false; + sequence_length[index] = 0; + word_ids[index] = sentence_id; + cum_log_probs[index] = (index % beam_width == 0) ? (T)0.0f: -MAX_T_VAL; + } } template @@ -57,27 +60,30 @@ namespace fastertransformer const int beam_width, cudaStream_t stream) { - dim3 grid(1); - dim3 block(min(1024, batch_size * beam_width)); - assert(batch_size * beam_width <= 1024); + dim3 grid((int)ceil(batch_size * beam_width * 1.0 / 256)); + dim3 block(256); init_kernel<<>>(finished, sequence_length, word_ids, cum_log_probs, sentence_id, - beam_width); + beam_width, + batch_size); } __global__ void sampling_init_kernel(bool* finished, int* sequence_length, int* word_ids, - const int start_id) + const int start_id, + const int batch_size) { - const int tid = threadIdx.x; - finished[tid] = false; - sequence_length[tid] = 0; - word_ids[tid] = start_id; + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size; index += blockDim.x * gridDim.x) + { + finished[index] = false; + sequence_length[index] = 0; + word_ids[index] = start_id; + } } void sampling_init_kernelLauncher(bool* finished, @@ -87,14 +93,15 @@ namespace fastertransformer const int batch_size, cudaStream_t stream) { - dim3 grid(1); - dim3 block(min(1024, batch_size)); - assert(batch_size <= 1024); + dim3 grid((int)ceil(batch_size * 1.0 / 256)); + dim3 block(256); + sampling_init_kernel<<>>(finished, sequence_length, word_ids, - start_id); + start_id, + batch_size); } template @@ -136,8 +143,7 @@ namespace fastertransformer hidden_units); } - - + // TODO Add half2 implementation template __global__ void embedding_position_lookups_kernel(T* from_tensor, const T* embedding_table, @@ -178,15 +184,84 @@ namespace fastertransformer step); } + template + __global__ void start_id_embedding_position_lookups_kernel(T* from_tensor, + int* output_ids, + const T* embedding_table, + const T* pos_table, + const int* word_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units) + { + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units; index += blockDim.x * gridDim.x) + { + // transpose the word_ids [batch, length] (part of [batch, max_length]) to output_ids [length, batch] + if(index < batch_size * max_length) + { + const int seq_id = index % max_length; + const int batch_id = index / max_length; + if(seq_id < length) + output_ids[seq_id * batch_size + batch_id] = word_ids[index]; + // output_ids[index] = word_ids[index]; + } + + // embedding lookup from word ids [batch, length] (part of [batch, max_length]) and [vocab, hidden] to generate embedding [batch, length, hidden] + const int word_index = index / hidden_units; + const int word_index_row = word_index / length; + const int word_index_col = word_index % length; + const int real_word_index = word_index_row * max_length + word_index_col; + const int step = start_step + word_index % length; + const int col_index = index % hidden_units; + from_tensor[index] = embedding_table[word_ids[real_word_index] * hidden_units + col_index] + + pos_table[(step - 1) * hidden_units + col_index]; + } + } + + + template + void start_id_embedding_position_lookups_kernel_launcher(T* from_tensor, + int *output_ids, + const T* embedding_table, + const T* pos_table, + const int* word_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream) + { + dim3 grid(min(batch_size * length, 65536)); + dim3 block(min(hidden_units, 1024)); + start_id_embedding_position_lookups_kernel<<>>(from_tensor, + output_ids, + embedding_table, + pos_table, + word_ids, + start_step, + length, + max_length, + batch_size, + hidden_units); + } + + // TODO Add half2 implementation template __global__ void apply_temperature_penalty_kernel(T* logits, const T temperature_inverse, const int m, - const int n) + const int vocab_size, + const int vocab_size_padd) { - for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < m * n; index += blockDim.x * gridDim.x) + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16)? HALF_FLT_MAX : FLT_MAX; + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < m * vocab_size_padd; index += blockDim.x * gridDim.x) { - logits[index] = logits[index] * temperature_inverse; + if(index % vocab_size_padd < vocab_size) logits[index] = logits[index] * temperature_inverse; + else logits[index] = -MAX_T_VAL; } } @@ -194,16 +269,112 @@ namespace fastertransformer void apply_temperature_penalty_kernelLauncher(T* logits, const T temperature, const int m, - const int n, + const int vocab_size, + const int vocab_size_padd, cudaStream_t stream) { dim3 grid(min(m, 65536)); - dim3 block(min(n, 1024)); + dim3 block(min(vocab_size_padd, 1024)); const T temperature_inverse = (T)(1.f / (float) temperature); apply_temperature_penalty_kernel<<>>(logits, temperature_inverse, m, - n); + vocab_size, + vocab_size_padd); + } + + __global__ void set_start_ids_kernel(int* out_ids, + const int* in_ids, + const int max_start_len, + const int step, + const int ite, + const int batch_size, + const int local_batch_size, + const int end_id) + { + const int id = blockIdx.x * blockDim.x + threadIdx.x; + if(id < local_batch_size) + { + int in_id = in_ids[(ite * local_batch_size + id) * max_start_len + step]; + if(in_id != end_id) + out_ids[step * batch_size + ite * local_batch_size + id] = in_ids[(ite * local_batch_size + id) * max_start_len + step]; + } + } + + void set_start_ids_kernelLauncher(int* out_ids, + const int* in_ids, + const int max_start_len, + const int step, + const int ite, + const int batch_size, + const int local_batch_size, + const int end_id, + cudaStream_t stream) + { + dim3 grid((int)(ceil(local_batch_size / 512.))); + set_start_ids_kernel<<>>(out_ids, + in_ids, + max_start_len, + step, + ite, + batch_size, + local_batch_size, + end_id); + } + + template + __global__ void kernel_padding_kernel(T *padded_kernel, const T *kernel, + const int row_dim, const int col_dim, const int padded_col_dim) + { + for(int id = threadIdx.x + blockIdx.x * blockDim.x; id < row_dim * padded_col_dim; id += blockDim.x * gridDim.x) + { + int row_id = id / padded_col_dim; + int col_id = id % padded_col_dim; + if(col_id < col_dim) + { + padded_kernel[id] = kernel[row_id * col_dim + col_id]; + } + else + { + padded_kernel[id] = (T)(0.0f); + } + } + } + + template + void kernel_padding_kernelLauncher(T *padded_kernel, const T *kernel, + const int row_dim, const int col_dim, const int padded_col_dim, cudaStream_t stream) + { + // pad 0 into the kernel from shape [row_dim, col_dim] to [row_dim, padded_col_dim] + dim3 block(512); + dim3 grid(min(65536, (int)(ceil(row_dim * padded_col_dim / 512.)) )); + kernel_padding_kernel<<>>(padded_kernel, kernel, row_dim, col_dim, padded_col_dim); + } + + template + __global__ void bias_padding_kernel(T1 *padded_bias, const T2 *bias, + const int col_dim, const int padded_col_dim) + { + const int index = blockIdx.x * blockDim.x + threadIdx.x; + if(index < col_dim) + { + padded_bias[index] = (T1)bias[index]; + } + else if(index >= col_dim && index < padded_col_dim) + { + padded_bias[index] = (T1)(std::is_same::value ? -60000 : -1e20f); + } + } + + template + void bias_padding_kernelLauncher(T1 *padded_bias, const T2 *bias, + const int col_dim, const int padded_col_dim, cudaStream_t stream) + { + // pad -max into the bias from shape [col_dim] to [padded_col_dim] + dim3 block(512); + dim3 grid( (int)(ceil(padded_col_dim / 512.)) ); + assert(grid.x < 65536); + bias_padding_kernel<<>>(padded_bias, bias, col_dim, padded_col_dim); } /* *************************** end of common kernel *********************************** */ @@ -289,19 +460,22 @@ namespace fastertransformer int* sequence_length, int* word_ids, int* output_ids, const int vocab_size, const int end_id, + const int batch_size, const int beam_width, int* finished_count) { - int tid = threadIdx.x; - sequence_length[tid] = finished[tid] ? sequence_length[tid] : sequence_length[tid] + 1; - - int beam_id = word_ids[tid] / vocab_size; - int word_id = word_ids[tid] % vocab_size; - - sequence_length[tid] = sequence_length[beam_id]; - finished[tid] = word_id == end_id ? 1 : 0; - parent_ids[tid] = beam_id; - word_ids[tid] = word_id; - output_ids[tid] = word_id; + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width; index += blockDim.x * gridDim.x) + { + sequence_length[index] = finished[index] ? sequence_length[index] : sequence_length[index] + 1; + + int beam_id = word_ids[index] / vocab_size; + int word_id = word_ids[index] % vocab_size; + + sequence_length[index] = sequence_length[beam_id]; + finished[index] = word_id == end_id ? 1 : 0; + parent_ids[index] = beam_id; + word_ids[index] = word_id; + output_ids[index] = word_id; + } } void update_kernelLauncher_v2(bool* finished, int* parent_ids, @@ -311,14 +485,15 @@ namespace fastertransformer DecodingBeamsearchArguments args, cudaStream_t stream) { - dim3 grid(1); - dim3 block(args.batch_size_ * args.beam_width_); - assert(block.x <= 1024); + dim3 grid((int)ceil(args.batch_size_ * args.beam_width_ * 1.0 / 256)); + dim3 block(256); update_kernel_v2<<>>(finished, parent_ids, sequence_length, word_ids, - output_ids, args.vocab_size_, - args.end_id_, finished_count); + output_ids, args.vocab_size_padded_, + args.end_id_, + args.batch_size_, args.beam_width_, + finished_count); } template @@ -326,7 +501,8 @@ namespace fastertransformer T* key_tgt_cache, const T* __restrict value_src_cache, T* value_tgt_cache, - const int* beam_ids, + const int* beam_ids, + const bool* finished, const int batch_size, const int beam_width, const int hidden_dim, @@ -337,6 +513,7 @@ namespace fastertransformer int layer_id = blockIdx.x / batch_size / beam_width / step; int batch_id = (blockIdx.x % (batch_size * beam_width * step)) / (beam_width * step); int beam_id = (blockIdx.x % (beam_width * step)) / step; + if(finished[batch_id * beam_width + beam_id]) return; int step_id = blockIdx.x % step; int hidden_id = step_id * batch_size * beam_width * hidden_dim + @@ -365,6 +542,7 @@ namespace fastertransformer const half* __restrict value_src_cache, half* value_tgt_cache, const int* beam_ids, + const bool* finished, const int batch_size, const int beam_width, const int hidden_dim, @@ -375,6 +553,7 @@ namespace fastertransformer int layer_id = blockIdx.x / batch_size / beam_width / step; int batch_id = (blockIdx.x % (batch_size * beam_width * step)) / (beam_width * step); int beam_id = (blockIdx.x % (beam_width * step)) / step; + if(finished[batch_id * beam_width + beam_id]) return; int step_id = blockIdx.x % step; int hidden_id = (step_id * batch_size * beam_width * hidden_dim + @@ -396,29 +575,121 @@ namespace fastertransformer } + template + __global__ void update_KV_batch_major_cache_kernel(const T* __restrict key_src_cache, + T* key_tgt_cache, + const T* __restrict value_src_cache, + T* value_tgt_cache, + const int* beam_ids, + const bool* finished, + const int batch_size, + const int beam_width, + const int size_per_head, + const int cache_size, + const int step, + const int max_seq_len, + const int decoder_layers) + { + int layer_id = blockIdx.z; + int head_id = blockIdx.y; + int bb_id = blockIdx.x; + int batch_id = bb_id / beam_width; + int beam_id = bb_id % beam_width; + + if(finished[batch_id * beam_width + beam_id]) return; + + const int hidden_dim = size_per_head * gridDim.y; + + int src_offset = layer_id * cache_size + + (beam_ids[batch_id * beam_width + beam_id] * hidden_dim + + head_id * size_per_head) * max_seq_len; + int tgt_offset = layer_id * cache_size + + ((batch_id * beam_width + beam_id) * hidden_dim + + head_id * size_per_head) * max_seq_len; + + // for better memory access always do 16 byte loads. + // [B, H, Dh/x, L, x] and [B, H, L, Dh/x, x] (i.e. [B, H, L, Dh]) + auto key_src_ptr = reinterpret_cast(key_src_cache + src_offset); + auto value_src_ptr = reinterpret_cast(value_src_cache + src_offset); + auto key_tgt_ptr = reinterpret_cast(key_tgt_cache + tgt_offset); + auto value_tgt_ptr = reinterpret_cast(value_tgt_cache + tgt_offset); + constexpr int x = (sizeof(T) == 4)? 4 : 8; + + // step starts from 1 + #if 0 + constexpr int WARP_SIZE = 32; + const int num_warps = blockDim.x / WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + for (int dhx = warp_id; dhx < size_per_head/x; dhx += num_warps) + { + for (int tid = lane_id; tid < step; tid += WARP_SIZE) + { + key_tgt_ptr[dhx * max_seq_len + tid] = key_src_ptr[dhx * max_seq_len + tid]; + } + } + #else + // seems to be a bit faster + for (int tid = threadIdx.x; tid < max_seq_len * size_per_head/x; tid += blockDim.x) + { + // could consider fast int division here + if (tid % max_seq_len < step) + { + key_tgt_ptr[tid] = key_src_ptr[tid]; + } + } + #endif + + for (int tid = threadIdx.x; tid < step * size_per_head/x; tid += blockDim.x) + { + value_tgt_ptr[tid] = value_src_ptr[tid]; + } + } + template void update_KV_cache_kernelLauncher(T** key_cache, T** value_cache, const int* beam_ids, + const bool* finished, const int batch_size, - const int beam_width, - const int hidden_dim, - const int step, + const int beam_width, + const int head_num, + const int size_per_head, + const int step, + const int decoder_max_seq_len, const int cache_size, const int decoder_layers, cudaStream_t stream) { - dim3 grid(decoder_layers * batch_size * beam_width * step); - dim3 block(min(1024, hidden_dim)); - block.x = block.x / (4 / sizeof(T)); - int src_id = step & 0x1; int tgt_id = 1 - src_id; - update_KV_cache_kernel<<>>( - key_cache[src_id], key_cache[tgt_id], - value_cache[src_id], value_cache[tgt_id], - beam_ids, batch_size, beam_width, hidden_dim, cache_size, step, decoder_layers); + if (decoder_max_seq_len < 0) + { + int hidden_dim = head_num * size_per_head; + dim3 grid(decoder_layers * batch_size * beam_width * step); + dim3 block(min(1024, hidden_dim)); + block.x = block.x / (4 / sizeof(T)); + + update_KV_cache_kernel<<>>( + key_cache[src_id], key_cache[tgt_id], + value_cache[src_id], value_cache[tgt_id], + beam_ids, finished, + batch_size, beam_width, hidden_dim, cache_size, step, decoder_layers); + } + else + { + dim3 grid(batch_size * beam_width, head_num, decoder_layers); + constexpr int block_sz = 128; + + update_KV_batch_major_cache_kernel<<>>( + key_cache[src_id], key_cache[tgt_id], + value_cache[src_id], value_cache[tgt_id], + beam_ids, finished, + batch_size, beam_width, size_per_head, cache_size, step, + decoder_max_seq_len, decoder_layers); + } + } template @@ -488,10 +759,10 @@ namespace fastertransformer int* current_ids, int* previous_ids, int* parent_ids, - Gpt2Arguments args, + GptArguments args, cudaStream_t stream) { - int vocab_size = args.vocab_size_; + int vocab_size = args.vocab_size_padded_; int beam_width = 1; int batch_size = args.batch_size_; dim3 block(256); @@ -573,6 +844,36 @@ namespace fastertransformer transpose_kernel<<>>(out, in, height, width, tH, tW, stride); } + // TODO Add half2 implementation + template + __global__ void transpose_axis_01_kernel(DataType_ *out, DataType_ *in, const int dim0, const int dim1, const int dim2) + { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if(index < dim0 * dim1 * dim2) + { + const int input_dim2_index = index % dim2; + index = (index - input_dim2_index) / dim2; + const int input_dim1_index = index % dim1; + index = (index - input_dim1_index) / dim1; + const int input_dim0_index = index % dim0; + + out[input_dim1_index * dim0 * dim2 + + input_dim0_index * dim2 + + input_dim2_index] = in[input_dim0_index * dim1 * dim2 + + input_dim1_index * dim2 + + input_dim2_index]; + } + } + + template + void transpose_axis_01_kernelLauncher(DataType_ *out, DataType_ *in, const int dim0, + const int dim1, const int dim2, cudaStream_t stream) + { + dim3 block(512); + dim3 grid((int)(ceil(dim0 * dim1 * dim2 / 512.))); + transpose_axis_01_kernel<<>>(out, in, dim0, dim1, dim2); + } + /* *************************** end of BeamSearch kernel *********************************** */ /* ********************************** Sampling kernel *********************************** */ @@ -611,6 +912,45 @@ namespace fastertransformer } } + __global__ void topp_initialization_kernel_v2(bool* finished, + int* sequence_length, + int* word_ids, + int* topp_id_val_buf, + int* topp_offset_buf, + int* begin_topp_offset_buf_, + const int batch_size, + const int n, + const int start_id) + { + int tid = threadIdx.x; + int bid = blockIdx.x; + + if(bid == 0) + { + for(int i = tid; i < batch_size + 1; i+= blockDim.x) + { + topp_offset_buf[i] = i * n; + begin_topp_offset_buf_[i] = topp_offset_buf[i]; + } + + for(int i = tid; i < batch_size; i+= blockDim.x) + { + if(finished != nullptr) finished[i] = false; + if(sequence_length != nullptr) sequence_length[i] = 0; + if(word_ids != nullptr) word_ids[i] = start_id; + } + } + + int index = tid + bid * blockDim.x; + while(index < batch_size * n) + { + topp_id_val_buf[index] = index % n; + index += blockDim.x * gridDim.x; + } + } + + + void topp_initialization_kernelLauncher(bool* finished, int* sequence_length, int* word_ids, @@ -631,6 +971,28 @@ namespace fastertransformer args.start_id_); } + void topp_initialization_kernelLauncher_v2(bool* finished, + int* sequence_length, + int* word_ids, + int* topp_id_val_buf, + int* topp_offset_buf, + int* begin_topp_offset_buf_, + const int n, + DecodingSamplingArguments args, + cudaStream_t stream) +{ + // n: the coloumn number of logits_buffer for top_p sampling + topp_initialization_kernel_v2<<<32, 512, 0, stream>>>(finished, + sequence_length, + word_ids, + topp_id_val_buf, + topp_offset_buf, + begin_topp_offset_buf_, + args.batch_size_, + n, + args.start_id_); +} + template size_t get_topp_sort_temp_storage_size(const T* log_probs, const int* id_vals, @@ -720,8 +1082,6 @@ namespace fastertransformer step_ids, parent_ids, max_sequence_lengths, beams); } - - /* ********************************** Instantiation *********************************** */ template void embedding_lookup_sine_position_encoding_kernel_launcher(float* from_tensor, @@ -761,25 +1121,73 @@ namespace fastertransformer int step, cudaStream_t stream); + template + void start_id_embedding_position_lookups_kernel_launcher(float* from_tensor, + int* output_ids, + const float* embedding_table, + const float* pos_table, + const int* word_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + + template + void start_id_embedding_position_lookups_kernel_launcher(half* from_tensor, + int* output_ids, + const half* embedding_table, + const half* pos_table, + const int* word_ids, + const int start_step, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + template void apply_temperature_penalty_kernelLauncher(float* logits, const float temperature, const int m, - const int n, + const int vocab_size, + const int vocab_size_padd, cudaStream_t stream); template void apply_temperature_penalty_kernelLauncher(half* logits, const half temperature, const int m, - const int n, + const int vocab_size, + const int vocab_size_padd, cudaStream_t stream); + template void kernel_padding_kernelLauncher(float *padded_kernel, const float *kernel, + const int row_dim, const int col_dim, + const int padded_col_dim, cudaStream_t stream); + + template void kernel_padding_kernelLauncher(half *padded_kernel, const half *kernel, + const int row_dim, const int col_dim, + const int padded_col_dim, cudaStream_t stream); + + template void bias_padding_kernelLauncher(float *padded_bias, const float *bias, const int col_dim, + const int padded_col_dim, cudaStream_t stream); + + template void bias_padding_kernelLauncher(float *padded_bias, const half *bias, const int col_dim, + const int padded_col_dim, cudaStream_t stream); + + template void bias_padding_kernelLauncher(half *padded_bias, const half *bias, const int col_dim, + const int padded_col_dim, cudaStream_t stream); + template void update_KV_cache_kernelLauncher(float** key_cache, float** value_cache, const int* beam_ids, + const bool* finished, const int batch_size, const int beam_width, - const int hidden_dim, + const int head_num, + const int size_per_head, const int step, + const int decoder_max_seq_len, const int cache_size, const int decoder_layers, cudaStream_t stream); @@ -787,10 +1195,13 @@ namespace fastertransformer template void update_KV_cache_kernelLauncher(half** key_cache, half** value_cache, const int* beam_ids, + const bool* finished, const int batch_size, const int beam_width, - const int hidden_dim, + const int head_num, + const int size_per_head, const int step, + const int decoder_max_seq_len, const int cache_size, const int decoder_layers, cudaStream_t stream); @@ -800,7 +1211,7 @@ namespace fastertransformer int* current_ids, int* previous_ids, int* parent_ids, - Gpt2Arguments args, + GptArguments args, cudaStream_t stream); template void apply_logit_penalties(int step, @@ -808,7 +1219,7 @@ namespace fastertransformer int* current_ids, int* previous_ids, int* parent_ids, - Gpt2Arguments args, + GptArguments args, cudaStream_t stream); template size_t get_topp_sort_temp_storage_size(const float* log_probs, @@ -838,6 +1249,20 @@ namespace fastertransformer int width,int stride, cudaStream_t stream); + template void transpose_axis_01_kernelLauncher(float *out, + float *in, + const int dim0, + const int dim1, + const int dim2, + cudaStream_t stream); + + template void transpose_axis_01_kernelLauncher(half *out, + half *in, + const int dim0, + const int dim1, + const int dim2, + cudaStream_t stream); + template void init_kernelLauncher(bool* finished, int* sequence_length, int* word_ids, @@ -855,6 +1280,7 @@ namespace fastertransformer const int batch_size, const int beam_width, cudaStream_t stream); + /* *************************** end of Instantiation *********************************** */ } // end of name space fastertransformer diff --git a/fastertransformer/cuda/masked_multihead_attention.cu b/fastertransformer/cuda/masked_multihead_attention.cu new file mode 100644 index 000000000..005928de9 --- /dev/null +++ b/fastertransformer/cuda/masked_multihead_attention.cu @@ -0,0 +1,866 @@ +/*************************************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are not permit- + * ted. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include "masked_multihead_attention.h" +#include "masked_multihead_attention_utils.h" +#include +#include + +//#define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +//#define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) +// Does not seem to improve the accuracy +//#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed accross the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T, int Dh > +struct Qk_vec_ {}; + +template<> struct Qk_vec_ { using Type = float; }; +template<> struct Qk_vec_ { using Type = float2; }; +template<> struct Qk_vec_ { using Type = float4; }; +template<> struct Qk_vec_ { using Type = uint32_t; }; +template<> struct Qk_vec_ { using Type = uint32_t; }; +template<> struct Qk_vec_ { using Type = uint2; }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T, int THREADS_PER_KEY > +struct K_vec_ {}; + +template<> struct K_vec_ { using Type = float; }; +template<> struct K_vec_ { using Type = float2; }; +template<> struct K_vec_ { using Type = float4; }; +template<> struct K_vec_ { using Type = uint32_t; }; +template<> struct K_vec_ { using Type = uint2; }; +template<> struct K_vec_ { using Type = uint4; }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T, int V_VEC_SIZE > +struct V_vec_ {}; + +template<> struct V_vec_ { using Type = float; }; +template<> struct V_vec_ { using Type = float2; }; +template<> struct V_vec_ { using Type = float4; }; +template<> struct V_vec_ { using Type = uint32_t; }; +template<> struct V_vec_ { using Type = uint2; }; +template<> struct V_vec_ { using Type = uint4; }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template< typename T> +struct Qk_vec_acum_fp32_ {}; + +template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> struct Qk_vec_acum_fp32_ { using Type = float2; }; +template<> struct Qk_vec_acum_fp32_ { using Type = float4; }; +//template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> struct Qk_vec_acum_fp32_ { using Type = float2; }; +template<> struct Qk_vec_acum_fp32_ { using Type = Float4_; }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T> +struct K_vec_acum_fp32_ {}; + +template<> struct K_vec_acum_fp32_ { using Type = float; }; +template<> struct K_vec_acum_fp32_ { using Type = float2; }; +template<> struct K_vec_acum_fp32_ { using Type = float4; }; +template<> struct K_vec_acum_fp32_ { using Type = float2; }; +template<> struct K_vec_acum_fp32_ { using Type = Float4_; }; +template<> struct K_vec_acum_fp32_ { using Type = Float8_; }; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template< typename T > +struct V_vec_acum_fp32_ {}; + +template<> struct V_vec_acum_fp32_ { using Type = float; }; +template<> struct V_vec_acum_fp32_ { using Type = float2; }; +template<> struct V_vec_acum_fp32_ { using Type = float4; }; +template<> struct V_vec_acum_fp32_ { using Type = float2; }; +template<> struct V_vec_acum_fp32_ { using Type = Float4_; }; +template<> struct V_vec_acum_fp32_ { using Type = Float8_; }; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int THREADS_PER_KEY, typename K_vec, int N > +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) { +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); + #pragma unroll + for( int ii = 1; ii < N; ++ii ) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); + #pragma unroll + for( int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2 ) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T, int THREADS_PER_KEY > +struct Qk_dot { + template< typename K_vec, int N > + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2 &a, uint32_t b) { + float4 c; float zero = 0.f; + asm volatile( \ + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" \ + " {%0, %1, %2, %3}, \n" \ + " {%4, %5}, \n" \ + " {%6}, \n" \ + " {%7, %7, %7, %7}; \n" \ + \ + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y) + , "r"(b) + , "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int N > +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); + #pragma unroll + for( int ii = 1; ii < N; ++ii ) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template< int N > + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int WARPS_PER_BLOCK, int WARP_SIZE = 32 > +inline __device__ float block_sum(float *red_smem, float sum) { + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. + #pragma unroll + for( int mask = WARP_SIZE / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if( lane == 0 ) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if( lane < WARPS_PER_BLOCK ) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. + #pragma unroll + for( int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float &dst, float src) { + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t &dst, float src) { + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t &dst, float2 src) { + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2 &dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2 &dst, float2 src) { + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4 &dst, float4 src) { + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) { + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) { + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) { + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) { + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) { + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) { + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) { + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) { + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline __device__ __host__ T div_up(T m, T n) { + return (m + n-1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline size_t smem_size_in_bytes(const Masked_multihead_attention_params ¶ms, + int threads_per_value, + int threads_per_block) { + // The amount of shared memory needed to store the Q*K^T values in float. + size_t qk_sz = div_up(params.timestep + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if( sizeof(T) != 4 ) { + logits_sz = div_up(params.seq_length, 4) * 4 * sizeof(T); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; + + // The max. + return max(softmax_sz, red_sz); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) { + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK +> +__global__ void masked_multihead_attention_kernel(Masked_multihead_attention_params params) { + + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float *qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char *logits_smem_ = smem_; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if( sizeof(T) != 4 ) { + logits_smem_ += div_up(params.timestep + 1, 4) * 16; //sizeof(float); + } + T *logits_smem = reinterpret_cast(logits_smem_); +#else + float *logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + T *out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + // Shared memory to store Q inputs. + __shared__ T q_smem[Dh]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); + + // The batch. + const int bi = blockIdx.y; + if(params.finished != nullptr && params.finished[bi] == true) return; + // The head. + const int hi = blockIdx.x; + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + // The thread in the block. + const int tidx = threadIdx.x; + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + int qkv_base_offset = (params.stride == 0)? bhi*Dh : bi*params.stride + hi*Dh; + + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + if( tidx < QK_VECS_PER_WARP ) { + + // The offset in the Q and K buffer also accounts for the batch. + int qk_offset = qkv_base_offset + tidx*QK_VEC_SIZE; + // The offset in the bias buffer. + int qk_bias_offset = hi*Dh + tidx*QK_VEC_SIZE; + + // Trigger the loads from the Q and K buffers. + Qk_vec q = *reinterpret_cast(¶ms.q[qk_offset]); + Qk_vec k = *reinterpret_cast(¶ms.k[qk_offset]); + + // Trigger the loads from the Q and K bias buffers. + Qk_vec q_bias = *reinterpret_cast(¶ms.q_bias[qk_bias_offset]); + Qk_vec k_bias = *reinterpret_cast(¶ms.k_bias[qk_bias_offset]); + + // Computes the Q/K values with bias. + q = add(q, q_bias); + k = add(k, k_bias); + + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx*QK_VEC_SIZE]) = q; + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi*params.seq_length*Dh + + co*params.seq_length*QK_ELTS_IN_16B + + params.timestep*QK_ELTS_IN_16B + + ci; + + // Trigger the stores to global memory. + *reinterpret_cast(¶ms.k_cache[offset]) = k; + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec; +#endif + float qk = dot(q, k); + #pragma unroll + for( int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2 ) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + + // Normalize qk. + qk *= params.inv_sqrt_dh; + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if( tidx == 0 ) { + qk_max = qk; + qk_smem[params.timestep] = qk; + } + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec = typename K_vec_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec q[K_VECS_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < K_VECS_PER_THREAD; ++ii ) { + q[ii] = *reinterpret_cast(&q_smem[ki + ii*THREADS_PER_KEY*K_VEC_SIZE]); + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + T *k_cache = ¶ms.k_cache[bhi*params.seq_length*Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + for( int ti = ko; ti < ti_end; ti += K_PER_ITER ) { + + // The keys loaded from the key cache. + K_vec k[K_VECS_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < K_VECS_PER_THREAD; ++ii ) { + int jj = ii * params.seq_length + ti; + if( ti < params.timestep ) { + k[ii] = *reinterpret_cast(&k_cache[jj*QK_ELTS_IN_16B]); + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q, k) * params.inv_sqrt_dh; + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + qk_max = fmaxf(qk_max, qk); + qk_smem[ti] = qk; + } + } + + // Perform the final reduction to compute the max inside each warp. + // + // NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the + // group so it's not needed to run the reduction inside the group (again). + #pragma unroll + for( int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2 ) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if( lane == 0 ) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; + #pragma unroll + for( int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2 ) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + float logit = __expf(qk_smem[ti] - qk_max); + sum += logit; + qk_smem[ti] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); + } + + // Make sure the logits are in shared memory. + __syncthreads(); + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec = typename V_vec_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + + // The base pointer for the value in the cache buffer. + T *v_cache = ¶ms.v_cache[bhi*params.seq_length*Dh + vi]; + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; zero(out); + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // Loop over the timesteps to compute the partial outputs. + for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + + // Load the values from the cache. + V_vec v = *reinterpret_cast(&v_cache[ti*Dh]); + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti]; + out = fma(logit, cast_to_float(v), out); +#else + T logit = logits_smem[ti]; + + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + + // One group of threads computes the product(s) for the current timestep. + if( vo == params.timestep % V_PER_ITER ) { + + // Trigger the loads from the V buffer. + V_vec v = *reinterpret_cast(¶ms.v[qkv_base_offset + vi]); + // Trigger the loads from the V bias buffer. + V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + + // Compute the V values with bias. + v = add(v, v_bias); + + // Store the values with bias back to global memory in the cache for V. + *reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + out = fma(logits_smem[params.timestep], cast_to_float(v), out); +#else + out = fma(logits_smem[params.timestep], v, out); +#endif + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + #pragma unroll + for( int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2 ) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if( vo >= midpoint && vo < active_groups ) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint)*Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint)*Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if( vo < midpoint ) { + out = add(*reinterpret_cast(&out_smem[vo*Dh + vi]), out); + } + __syncthreads(); + } + + // Output the final values. + if( vo == 0 ) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(¶ms.out[bhi*Dh + vi]), out); +#else + *reinterpret_cast(¶ms.out[bhi*Dh + vi]) = out; +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel \ + <<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < typename T, int Dh > +void mmha_launch_kernel(const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream) { + constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16; + if( params.timestep < 32 ) { + MMHA_LAUNCH_KERNEL(T, Dh, 4, THREADS_PER_VALUE, 64, stream); + } else if( params.timestep < 2048 ) { + MMHA_LAUNCH_KERNEL(T, Dh, 2, THREADS_PER_VALUE, 128, stream); + } else { + MMHA_LAUNCH_KERNEL(T, Dh, 1, THREADS_PER_VALUE, 256, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +void masked_multihead_attention_(const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream) { + switch ( params.hidden_size_per_head ) { + case 32: + mmha_launch_kernel(params, stream); + break; + case 64: + mmha_launch_kernel(params, stream); + break; + case 128: + mmha_launch_kernel(params, stream); + break; + default: + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream) { + masked_multihead_attention_(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream) { + masked_multihead_attention_(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#undef MMHA_LAUNCH_KERNEL + diff --git a/fastertransformer/cuda/masked_multihead_attention.h b/fastertransformer/cuda/masked_multihead_attention.h new file mode 100644 index 000000000..c7a0dd1b3 --- /dev/null +++ b/fastertransformer/cuda/masked_multihead_attention.h @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are not permit- + * ted. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(call) do { \ + cudaError_t status_ = call; \ + if( status_ != cudaSuccess ) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ +} while(0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The structure of parameters for the masked multihead attention kernel. +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. + +template< typename T > +struct Masked_multihead_attention_params { + + // The output buffer. Dimensions B x D. + T *out; + + // The input Qs and the associated bias. Dimensions B x D and D, resp. + const T *q, *q_bias; + // The input Ks and the associated bias. Dimensions B x D and D, resp. + const T *k, *k_bias; + // The input Vs and the associated bias. Dimensions B x D and D, resp. + const T *v, *v_bias; + + // The cache for the Ks. The size must be at least B x L x D. + T *k_cache; + // The cache for the Vs. The size must be at least B x L x D. + T *v_cache; + + // allows to exist attention eary + bool *finished; + + // Stride to handle the case when KQV is a single buffer + int stride; + + // The batch size. + int batch_size; + // The sequence length. + int seq_length; + // The number of heads (H). + int num_heads; + // The hidden dimension per head (Dh). + int hidden_size_per_head; + // The current timestep. + int timestep; + + // The 1.f / sqrt(Dh). Computed on the host. + float inv_sqrt_dh; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention (const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream); +void masked_multihead_attention (const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/fastertransformer/cuda/masked_multihead_attention_utils.h b/fastertransformer/cuda/masked_multihead_attention_utils.h new file mode 100644 index 000000000..e6a833b7e --- /dev/null +++ b/fastertransformer/cuda/masked_multihead_attention_utils.h @@ -0,0 +1,665 @@ +/*************************************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are not permit- + * ted. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float4_ { + float2 x; + float2 y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float add(float a, float b) { + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t float_to_half(float f) { + union { uint32_t u32; uint16_t u16[2]; } tmp; +#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? + float zero = 0.f; + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#endif + return tmp.u16[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { uint32_t u32; uint16_t u16[2]; } tmp; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa,fb,fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ Acc mul(A a, B b); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(float a, float b) { + return a * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float v) { + return v; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint4 v) { +#if 1 + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); +#else + uint32_t c = add(v.x, v.y); + uint32_t d = add(v.z, v.w); + c = add(c, d); +#endif + return sum(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x+ v.w.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename A, typename T > +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void zero(uint16_t &dst) { + dst = uint16_t(0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline __device__ void zero(T &dst) { + constexpr int WORDS = sizeof(T) / 4; + union { T raw; uint32_t words[WORDS]; } tmp; + #pragma unroll + for( int ii = 0; ii < WORDS; ++ii ) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + diff --git a/fastertransformer/cuda/multi_head_attention.h b/fastertransformer/cuda/multi_head_attention.h index 4474e5afd..2e420a38d 100644 --- a/fastertransformer/cuda/multi_head_attention.h +++ b/fastertransformer/cuda/multi_head_attention.h @@ -19,8 +19,8 @@ #pragma once -#include "fastertransformer/common.h" -#include "fastertransformer/common_structure.h" +#include "fastertransformer/utils/common.h" +#include "fastertransformer/utils/common_structure.h" namespace fastertransformer{ namespace cuda{ @@ -41,12 +41,9 @@ class MultiHeadInitParam{ cublasLtHandle_t cublaslt_handle; cudaStream_t stream; - //First 80 are for activation amaxs. - //For each activation amax, there are 4 values: amax, amax/127.0f, amax/127.0f/127.0f, 127.0f/amax -- input_amax 0-3 , Q_aftergemm_amax 4-7, Qbias_amax 8-11, K_aftergemm_amax 12-15, Kbias_amax 16-19, V_aftergemm_amax 20-23, Vbias_amax 24-27, bmm1_amax 28-31, Softmax_amax 32-35, bmm2_amax 36-39, Proj_aftergemm_scale 40-43, ProjBiasNorm_amax 44-47, FC1_aftergemm_amax 48-51, F1Bias_amax 52-55, FC2_aftergemm_amax 56-59, F2BiasNorm_amax 60-63, reserve 64-79 - //following by kernel amaxs : query_weight_amax_list, key_weight_amax_list, value_weight_amax_list, proj_weight_amax_list, FC1_weight_amax_list, FC2_weight_amax_list - //following by int8 gemm deQ scale list: Q_deQ_scale, K_deQ_scale, V_deQ_scale, bmm1_deQ_scale, bmm2_deQ_scale, proj_deQ_scale, FC1_deQ_scale, FC2_deQ_scale const float *amaxList; const float *int8O_gemm_deQ_scale_list; + const float *trt_fused_mha_amax_list; const int *trt_seqlen_offset; int trt_seqlen_size; @@ -61,6 +58,7 @@ class MultiHeadInitParam{ int8_from_tensor = nullptr; amaxList = nullptr; int8O_gemm_deQ_scale_list = nullptr; + trt_fused_mha_amax_list = nullptr; stream = 0; trt_seqlen_offset = nullptr; diff --git a/fastertransformer/cuda/online_softmax_beamsearch_kernels.cu b/fastertransformer/cuda/online_softmax_beamsearch_kernels.cu index f8326e4a3..3d84ccb9f 100644 --- a/fastertransformer/cuda/online_softmax_beamsearch_kernels.cu +++ b/fastertransformer/cuda/online_softmax_beamsearch_kernels.cu @@ -136,7 +136,7 @@ __global__ void batch_topk_kernel( const int * __restrict x, const T * __restrict y, int * __restrict z, - T * __restrict v, + float * __restrict v, int V, int K, T diversity_rate) @@ -178,7 +178,7 @@ __global__ void batch_topk_kernel( if (i < K) { z[i] = x[total.p[i]]; - v[i] = y[total.p[i]]; + v[i] = (float)y[total.p[i]]; } } } @@ -221,8 +221,8 @@ template __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel( const T * __restrict x, - const float * __restrict b, - const T * __restrict c, + const T * __restrict b, + const float * __restrict c, const bool * __restrict finished, int * __restrict z, T * __restrict v, @@ -301,7 +301,7 @@ template __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel( const T * __restrict x, - const float * __restrict b, + const T * __restrict b, const bool * __restrict finished, float * __restrict t, int V, @@ -362,7 +362,7 @@ __global__ void beam_online_softmax_topk_stage1_kernel( #pragma unroll 1 for(int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) { - T bias = (T)(b == nullptr ? 0.0f : b[elem_id]); // gpt-2 does not use bias + T bias = b == nullptr ? (T)0.0f : b[elem_id]; // gpt-2 does not use bias T elem = x[elem_id] + bias; MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); @@ -397,7 +397,7 @@ template __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel( const float * __restrict x, - const T * __restrict c, + const float * __restrict c, int * __restrict z, T * __restrict v, int K, @@ -479,7 +479,7 @@ void topK_kernelLauncher(T* log_probs, { const int batch_size = args.batch_size_; const int beam_width = args.beam_width_; - const int vocab_size = args.vocab_size_; + const int vocab_size = args.vocab_size_padded_; const int diversity_rate = args.beam_search_diversity_rate_; const int block_size = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE; @@ -530,7 +530,7 @@ void topK_kernelLauncher(T* log_probs, template void beam_online_softmax_topk_stage2_kernelLauncher( const float * temp_storage, - const T * cum_log_probs, + const float * cum_log_probs, int * ids, T * vals, int batch_size, @@ -571,18 +571,18 @@ void beam_online_softmax_topk_stage2_kernelLauncher( template void topK_softMax_kernelLauncher(const T* log_probs, - const float* bias, - const bool* finished, - T* cum_log_probs, - int* ids, - void* temp_storage, - const int temp_storage_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int end_id, - T diversity_rate, - cudaStream_t stream) + const T* bias, + const bool* finished, + float* cum_log_probs, + int* ids, + void* temp_storage, + const int temp_storage_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int end_id, + T diversity_rate, + cudaStream_t stream) { const int items_per_thread = 1; const int block_sz = (MAX_K < 16)? (MAX_K < 8)? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE:128:64; @@ -591,9 +591,10 @@ void topK_softMax_kernelLauncher(const T* log_probs, assert(temp_storage_size % 2 == 0); assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width); + const int topk_buf_offset = ceil(batch_size * beam_width * beam_width / 4.) * 4; int* topk_tmp_id_buf = reinterpret_cast(temp_storage); - T* topk_tmp_val_buf = reinterpret_cast(topk_tmp_id_buf + batch_size * beam_width * beam_width); - float* tmp_buffer = reinterpret_cast(topk_tmp_val_buf + batch_size * beam_width * beam_width); + T* topk_tmp_val_buf = reinterpret_cast(topk_tmp_id_buf + topk_buf_offset); + float* tmp_buffer = reinterpret_cast(topk_tmp_val_buf + topk_buf_offset); #ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX int voc_parts = 4; @@ -638,7 +639,7 @@ void topK_softMax_kernelLauncher(const T* log_probs, else { #ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX - beam_online_softmax_topk_stage2_kernelLauncher + beam_online_softmax_topk_stage2_kernelLauncher (tmp_buffer, cum_log_probs, ids, cum_log_probs, batch_size, beam_width, voc_parts, stream); #else @@ -652,18 +653,18 @@ void topK_softMax_kernelLauncher(const T* log_probs, template void topK_softMax(const T* log_probs, - const float* bias, - const bool* finished, - T* cum_log_probs, - int* ids, - void* temp_storage, - DecodingBeamsearchArguments args, - cudaStream_t stream) + const T* bias, + const bool* finished, + float* cum_log_probs, + int* ids, + void* temp_storage, + DecodingBeamsearchArguments args, + cudaStream_t stream) { const int temp_storage_size = args.temp_storage_size_; const int batch_size = args.batch_size_; const int beam_width = args.beam_width_; - const int vocab_size = args.vocab_size_; + const int vocab_size = args.vocab_size_padded_; const int end_id = args.end_id_; const T diversity_rate = args.beam_search_diversity_rate_; @@ -699,6 +700,11 @@ void topK_softMax(const T* log_probs, (log_probs, bias, finished, cum_log_probs, ids, temp_storage, temp_storage_size, batch_size, beam_width, vocab_size, end_id, diversity_rate, stream); break; + case 32 : + topK_softMax_kernelLauncher + (log_probs, bias, finished, cum_log_probs, ids, temp_storage, temp_storage_size, + batch_size, beam_width, vocab_size, end_id, diversity_rate, stream); + break; default : printf("[ERROR] Topk kernel does not support beamwidth = %d \n", beam_width); exit(0); @@ -730,9 +736,9 @@ template void topK_softMax(const float* log_probs, cudaStream_t stream); template void topK_softMax(const half* log_probs, - const float* bias, - const bool* finished, - half* cum_log_probs, + const half* bias, + const bool* finished, + float* cum_log_probs, int* ids, void * tmp_storage, DecodingBeamsearchArguments args, diff --git a/fastertransformer/cuda/open_attention.cu b/fastertransformer/cuda/open_attention.cu index 25c03ccff..7e089209a 100644 --- a/fastertransformer/cuda/open_attention.cu +++ b/fastertransformer/cuda/open_attention.cu @@ -17,12 +17,14 @@ * Open sourced multi-head attention **/ -#include "fastertransformer/allocator.h" +#include "fastertransformer/utils/allocator.h" #include "fastertransformer/cuda/multi_head_attention.h" #include "fastertransformer/cuda/open_attention.h" +#include "fastertransformer/cuda/attention_kernels.cuh" #include #include #include + namespace fastertransformer{ namespace cuda{ @@ -111,6 +113,14 @@ __global__ void mappingRemovePaddingData(int *mapping, const int* sequence_id_of mapping[idx + __ldg(sequence_id_offset + idx)] = idx; } +void mappingRemovePaddingData_kernelLauncher(const int batch_size, const int seq_len, + const int valid_word_num, int *mapping, + const int* sequence_id_offset, cudaStream_t stream) +{ + cudaMemsetAsync(mapping, -1, batch_size * seq_len * sizeof(int), stream); + mappingRemovePaddingData<<>>(mapping, sequence_id_offset, valid_word_num); +} + //add_QK_bias_transform for batch int8 cublasLtMatmul & per axis quantization for weight //1.add QK bias //2.transform each Q K CUBLASLT_ORDER_COL32 matrixes into a series of sub-matrix (with CUBLASLT_ORDER_COL32/CUBLASLT_ORDER_COL4_4R2_8C layout) @@ -225,6 +235,153 @@ void add_QK_bias_transform(int8_t *q_buf_, int8_t *k_buf_, const int32_t* Q, con buf_ptr4[(((batch_id*head_num + head_id) * stride + (new_col << 5)*seq_len + new_row) >> 2)] = tmp4; } +template +void add_QK_bias_transform_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int32_t* Q, const T* bias_Q, + const int32_t* K, const T* bias_K, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float * q_weight_amax, const float *q_input_deQFactor_div127_ptr, + const float * k_weight_amax, const float *k_input_deQFactor_div127_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream) +{ + add_QK_bias_transform<<>>( + q_buf, k_buf, Q, bias_Q, K, bias_K, + batch_size * seq_len, batch_size, seq_len, head_num, size_per_head, seq_len*size_per_head, + q_weight_amax, q_input_deQFactor_div127_ptr, k_weight_amax, k_input_deQFactor_div127_ptr, + q_output_scale_ptr, k_output_scale_ptr, use_ORDER_COL32_2R_4R4); +} + +template +void add_QK_bias_transform_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int32_t* Q, const float* bias_Q, + const int32_t* K, const float* bias_K, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float * q_weight_amax, const float *q_input_deQFactor_div127_ptr, + const float * k_weight_amax, const float *k_input_deQFactor_div127_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_QK_bias_transform_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int32_t* Q, const half* bias_Q, + const int32_t* K, const half* bias_K, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float * q_weight_amax, const float *q_input_deQFactor_div127_ptr, + const float * k_weight_amax, const float *k_input_deQFactor_div127_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +//add_QK_bias_padding_transform for batch int8 cublasLtMatmul & per tensor quantization for weight +//1.add QK bias +//2.padding seq_len in k_buf_ to a multiple of 32 named seq_len_padded +//3.transform each Q K CUBLASLT_ORDER_COL32 matrixes into a series of sub-matrix (with CUBLASLT_ORDER_COL32/CUBLASLT_ORDER_COL4_4R2_8C layout) +// Q, K are CUBLASLT_ORDER_COL32 matrixes of m = batch_size * seq_len, n = head_num * size_per_head +// q_buf_ is of batchCount = batch_size * head_num, m = seq_len, n = size_per_head, CUBLASLT_ORDER_COL32 +// k_buf_ is of batchCount = batch_size * head_num, m = seq_len_padded, n = size_per_head, CUBLASLT_ORDER_COL4_4R2_8C +//only for int8 IO +//size_per_head must be a multiple of 32 +//grid.x = batch_size * seq_len * 2; +//block.x = head_num * size_per_head / 4; +//using char4 +template +__global__ +void add_QK_bias_transform_varlen(int8_t *q_buf_, int8_t *k_buf_, const int8_t* Q, const T* bias_Q, + const int8_t* K, const T* bias_K, const int m, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const int seq_len_padded, const int stride_q, const int stride_k, + const float *q_input_deQFactor_ptr, const float *k_input_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) +{ + const char4* data_ptr; + char4* buf_ptr4; + const T* bias_ptr; + int qk_id = blockIdx.x / m; + + data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; + buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; + bias_ptr = qk_id == 0 ? bias_Q : bias_K; + const float input_deQFactor = qk_id == 0 ? __ldg(q_input_deQFactor_ptr) : __ldg(k_input_deQFactor_ptr); + const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); + + int threadIdx4 = threadIdx.x << 2; + int batch_id = (blockIdx.x % m) / seq_len; + int head_id = threadIdx4 / size_per_head; + int id_in_head = threadIdx4 % size_per_head; + int word_id = blockIdx.x % seq_len; + + int data_id = (((threadIdx4 >> 5) << 5)*m + ((blockIdx.x%m) << 5) + (threadIdx4&31)) >> 2; + + float scale; + float tmp; + char4 tmp4 = __ldg(data_ptr+data_id); + scale = static_cast(tmp4.x) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr+threadIdx4)) + scale; + tmp4.x = float_to_int8_rn(tmp*output_scale); + + threadIdx4 = threadIdx4+1; + scale = static_cast(tmp4.y) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr+threadIdx4)) + scale; + tmp4.y = float_to_int8_rn(tmp*output_scale); + + threadIdx4 = threadIdx4+1; + scale = static_cast(tmp4.z) * input_deQFactor; + tmp = static_cast(__ldg(bias_ptr+threadIdx4)) + scale; + tmp4.z = float_to_int8_rn(tmp*output_scale); + + threadIdx4 = threadIdx4+1; + scale = static_cast(tmp4.w) * input_deQFactor;; + tmp = static_cast(__ldg(bias_ptr+threadIdx4)) + scale; + tmp4.w = float_to_int8_rn(tmp*output_scale); + + + //row_id, col_id of sub-matrix (m = seq_len/seq_len_padded, n = size_per_head), column-major + + int row_id = word_id; + int col_id = id_in_head; + //new (row, rol) of LtTrans COL32/COL4 sub-matrix, leading dim = (COL32_ * seq_len / COL32_ * seq_len_padded) + int new_col = col_id >> 5; + int new_row; + if (use_ORDER_COL32_2R_4R4) + { + int row_in_tile = row_id & 31; + int col_in_tile = col_id & 31; + new_row = (qk_id != 1) ? + //COL32 + ((row_id << 5) + (col_id&31)) + : + //COL32_2R_4R4 + ( + ((row_id >> 5) << 10) + + //(((row%8)/2*4+row/8)*2+row%2)*32+col + (((((((row_in_tile&7)>>1)<<2)+(row_in_tile>>3))<<1)+(row_in_tile&1))<<5)+col_in_tile + ) + ; + } + else + { + new_row = (qk_id != 1) ? + //COL32 + ((row_id << 5) + (col_id&31)) + : + //COL4 + ////row_id/8 is the number of tile of (8 rows 32 columns) -- column-major + ////row_id%2 is even row, otherwise odd row + ////col_id%COL32_/8 is the number tile of (8 rows 8 columns) + ( + ((((row_id >> 3) << 3) + ((row_id&1) << 2) + ((col_id&31) >> 3)) << 5) + + ////col_id%8 >= 4 is the right half of (8 rows 8 columns) tile + ////(row_id%8/2) is (the row id of alternating 4 rows) - 1 + (((((col_id&7) >= 4)?4:0) + ((row_id&7) >> 1)) << 2) + + ////col_id%4 is the id of 4 cols + (col_id&3) + ) + ; + } + + const int act_seq_len = (qk_id == 0) ? seq_len : seq_len_padded; + const int stride = (qk_id == 0) ? stride_q : stride_k; + buf_ptr4[(((batch_id*head_num + head_id) * stride + (new_col << 5)*act_seq_len + new_row) >> 2)] = tmp4; +} + //add_QK_bias_transform for batch int8 cublasLtMatmul & per axis quantization for weight //1.add QK bias //2.transform each Q K CUBLASLT_ORDER_COL32 matrixes into a series of sub-matrix (with CUBLASLT_ORDER_COL32/CUBLASLT_ORDER_COL4_4R2_8C layout) @@ -331,8 +488,57 @@ void add_QK_bias_transform(int8_t *q_buf_, int8_t *k_buf_, const int8_t* Q, cons } buf_ptr4[(((batch_id*head_num + head_id) * stride + (new_col << 5)*seq_len + new_row) >> 2)] = tmp4; +} + + +template +void add_QK_bias_transform_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int8_t* Q, const T* bias_Q, + const int8_t* K, const T* bias_K, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float *q_input_deQFactor_ptr, const float *k_input_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream) +{ + assert(size_per_head % 32 == 0); + if (seq_len % 32 == 0) + { + add_QK_bias_transform_varlen<<>>( + q_buf, k_buf, Q, bias_Q, K, bias_K, + batch_size * seq_len, batch_size, seq_len, head_num, size_per_head, + seq_len, seq_len*size_per_head, seq_len*size_per_head, + q_input_deQFactor_ptr, k_input_deQFactor_ptr, q_output_scale_ptr, k_output_scale_ptr, + use_ORDER_COL32_2R_4R4); + } + else + { + int seq_len_padded = (seq_len + 31)/32*32; + //The padding words will not be considered in softmax, so we don't need memset for k_buf_ + //cudaMemsetAsync(k_buf, 0, batch_size * head_num * seq_len_padded * size_per_head * sizeof(int8_t), stream); + add_QK_bias_transform_varlen<<>>( + q_buf, k_buf, Q, bias_Q, K, bias_K, + batch_size * seq_len, batch_size, seq_len, head_num, size_per_head, + seq_len_padded, seq_len*size_per_head, seq_len_padded*size_per_head, + q_input_deQFactor_ptr, k_input_deQFactor_ptr, q_output_scale_ptr, k_output_scale_ptr, + use_ORDER_COL32_2R_4R4); + } } +template +void add_QK_bias_transform_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int8_t* Q, const float* bias_Q, + const int8_t* K, const float* bias_K, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float *q_input_deQFactor_ptr, const float *k_input_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_QK_bias_transform_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int8_t* Q, const half* bias_Q, + const int8_t* K, const half* bias_K, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float *q_input_deQFactor_ptr, const float *k_input_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + //add_QK_bias_transform & rebuild padding for batch int8 cublasLtMatmul & per axis quantization for weight //1.add QK bias //2.transform each Q K CUBLASLT_ORDER_COL32 matrixes into a series of sub-matrix (with CUBLASLT_ORDER_COL32/CUBLASLT_ORDER_COL4_4R2_8C layout) @@ -449,6 +655,59 @@ void add_QK_bias_transform_rebuild_padding(int8_t *q_buf_, int8_t *k_buf_, const buf_ptr4[(((batch_id*head_num + head_id) * stride + (new_col << 5)*seq_len + new_row) >> 2)] = tmp4; } +template +void add_QK_bias_transform_rebuild_padding_kernelLauncher(int8_t *q_buf, int8_t *k_buf, + const int32_t* Q, const T* bias_Q, + const int32_t* K, const T* bias_K, + const int* sequence_id_offset, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float * q_weight_amax, + const float *q_input_deQFactor_div127_ptr, + const float * k_weight_amax, + const float *k_input_deQFactor_div127_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream) +{ + add_QK_bias_transform_rebuild_padding<<>>( + q_buf, k_buf, Q, bias_Q, K, bias_K, + sequence_id_offset, valid_word_num, + batch_size*seq_len, batch_size, seq_len, + head_num, size_per_head, seq_len*size_per_head, + q_weight_amax, q_input_deQFactor_div127_ptr, + k_weight_amax, k_input_deQFactor_div127_ptr, + q_output_scale_ptr, k_output_scale_ptr, + use_ORDER_COL32_2R_4R4); +} + +template +void add_QK_bias_transform_rebuild_padding_kernelLauncher(int8_t *q_buf, int8_t *k_buf, + const int32_t* Q, const float* bias_Q, + const int32_t* K, const float* bias_K, + const int* sequence_id_offset, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float * q_weight_amax, + const float *q_input_deQFactor_div127_ptr, + const float * k_weight_amax, + const float *k_input_deQFactor_div127_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_QK_bias_transform_rebuild_padding_kernelLauncher(int8_t *q_buf, int8_t *k_buf, + const int32_t* Q, const half* bias_Q, + const int32_t* K, const half* bias_K, + const int* sequence_id_offset, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float * q_weight_amax, + const float *q_input_deQFactor_div127_ptr, + const float * k_weight_amax, + const float *k_input_deQFactor_div127_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + //add_QK_bias_transform & rebuild padding for batch int8 cublasLtMatmul & per tensor quantization for weight //1.add QK bias //2.transform each Q K CUBLASLT_ORDER_COL32 matrixes into a series of sub-matrix (with CUBLASLT_ORDER_COL32/CUBLASLT_ORDER_COL4_4R2_8C layout) @@ -563,6 +822,169 @@ void add_QK_bias_transform_rebuild_padding(int8_t *q_buf_, int8_t *k_buf_, const buf_ptr4[(((batch_id*head_num + head_id) * stride + (new_col << 5)*seq_len + new_row) >> 2)] = tmp4; } + +//add_QK_bias_transform & rebuild padding for batch int8 cublasLtMatmul & per tensor quantization for weight +//1.add QK bias +//2.transform each Q K CUBLASLT_ORDER_COL32 matrixes into a series of sub-matrix (with CUBLASLT_ORDER_COL32/CUBLASLT_ORDER_COL4_4R2_8C layout) +// Q, K are CUBLASLT_ORDER_COL32 matrixes of m = valid_word_num, n = head_num * size_per_head +// q_buf_ is of batchCount = batch_size * head_num, m = seq_len, n = size_per_head, CUBLASLT_ORDER_COL32 +// seq_len_padded = (seq_len + 31)/32*32; +// k_buf_ is of batchCount = batch_size * head_num, m = seq_len_padded, n = size_per_head, CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 +//only for int8 IO +//seq_len, size_per_head must be a multiple of 32 +//grid.x = valid_word_num * 2; +//block.x = head_num * size_per_head / 4; +//using char4 +template +__global__ +void add_QK_bias_transform_rebuild_padding_varlen(int8_t *q_buf_, int8_t *k_buf_, const int8_t* Q, const T* bias_Q, + const int8_t* K, const T* bias_K, const int* sequence_id_offset, + const int valid_word_num, const int m, const int batch_size, + const int seq_len, const int seq_len_padded, const int head_num, + const int size_per_head, int stride_q, int stride_k, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4) +{ + const char4* data_ptr; + char4* buf_ptr4; + const T* bias_ptr; + int qk_id = blockIdx.x / valid_word_num; + + data_ptr = qk_id == 0 ? (const char4*)Q : (const char4*)K; + buf_ptr4 = qk_id == 0 ? (char4*)q_buf_ : (char4*)k_buf_; + bias_ptr = qk_id == 0 ? bias_Q : bias_K; + + int threadIdx4 = threadIdx.x << 2; + int m_full_idx = blockIdx.x % valid_word_num; + m_full_idx = (valid_word_num != m) ? (m_full_idx + __ldg(sequence_id_offset+m_full_idx)) : m_full_idx; + int batch_id = m_full_idx / seq_len; + int head_id = threadIdx4 / size_per_head; + int id_in_head = threadIdx4 % size_per_head; + int word_id = m_full_idx % seq_len; + + const float deQFactor = qk_id == 0 ? __ldg(q_deQFactor_ptr) : __ldg(k_deQFactor_ptr); + const float output_scale = qk_id == 0 ? __ldg(q_output_scale_ptr) : __ldg(k_output_scale_ptr); + + int data_id = (((threadIdx4 >> 5) << 5)*valid_word_num + ((blockIdx.x%valid_word_num) << 5) + (threadIdx4&31)) >> 2; + + float scale; + float tmp; + char4 tmp4; + + tmp4 = __ldg(data_ptr+data_id); + + scale = static_cast(tmp4.x) * deQFactor; + tmp = static_cast(__ldg(bias_ptr+threadIdx4)) + scale; + tmp4.x = float_to_int8_rn(tmp*output_scale); + + threadIdx4 = threadIdx4+1; + scale = static_cast(tmp4.y) * deQFactor; + tmp = static_cast(__ldg(bias_ptr+threadIdx4)) + scale; + tmp4.y = float_to_int8_rn(tmp*output_scale); + + threadIdx4 = threadIdx4+1; + scale = static_cast(tmp4.z) * deQFactor; + tmp = static_cast(__ldg(bias_ptr+threadIdx4)) + scale; + tmp4.z = float_to_int8_rn(tmp*output_scale); + + threadIdx4 = threadIdx4+1; + scale = static_cast(tmp4.w) * deQFactor; + tmp = static_cast(__ldg(bias_ptr+threadIdx4)) + scale; + tmp4.w = float_to_int8_rn(tmp*output_scale); + + //row_id, col_id of sub-matrix (m = seq_len or seq_len_padded, n = size_per_head), column-major + int row_id = word_id; + int col_id = id_in_head; + //new (row, rol) of LtTrans COL32/COL4 sub-matrix, leading dim = (COL32_ * seq_len) or (COL32_ * seq_len_padded) + int new_col = col_id >> 5; + int new_row; + if (use_ORDER_COL32_2R_4R4) + { + int row_in_tile = row_id & 31; + int col_in_tile = col_id & 31; + new_row = (qk_id != 1) ? + //COL32 + ((row_id << 5) + (col_id&31)) + : + //COL32_2R_4R4 + ( + ((row_id >> 5) << 10) + + //(((row%8)/2*4+row/8)*2+row%2)*32+col + (((((((row_in_tile&7)>>1)<<2)+(row_in_tile>>3))<<1)+(row_in_tile&1))<<5)+col_in_tile + ) + ; + } + else + { + new_row = (qk_id != 1) ? + //COL32 + ((row_id << 5) + (col_id&31)) + : + //COL4 + ////row_id/8 is the number of tile of (8 rows 32 columns) -- column-major + ////row_id%2 is even row, otherwise odd row + ////col_id%COL32_/8 is the number tile of (8 rows 8 columns) + ( + ((((row_id >> 3) << 3) + ((row_id&1) << 2) + ((col_id&31) >> 3)) << 5) + + ////col_id%8 >= 4 is the right half of (8 rows 8 columns) tile + ////(row_id%8/2) is (the row id of alternating 4 rows) - 1 + (((((col_id&7) >= 4)?4:0) + ((row_id&7) >> 1)) << 2) + + ////col_id%4 is the id of 4 cols + (col_id&3) + ) + ; + } + + const int stride = (qk_id != 1) ? stride_q : stride_k; + const int len = (qk_id != 1) ? seq_len : seq_len_padded; + buf_ptr4[(((batch_id*head_num + head_id) * stride + (new_col << 5)*len + new_row) >> 2)] = tmp4; +} + +template +void add_QK_bias_transform_rebuild_padding_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int8_t* Q, const T* bias_Q, + const int8_t* K, const T* bias_K, const int* sequence_id_offset, + const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream) +{ + int seq_len_padded = (seq_len + 31)/32*32; + add_QK_bias_transform_rebuild_padding_varlen<<>>( + q_buf, k_buf, Q, bias_Q, K, bias_K, + sequence_id_offset, valid_word_num, + batch_size * seq_len, batch_size, + seq_len, seq_len_padded, head_num, size_per_head, + seq_len*size_per_head, seq_len_padded*size_per_head, + q_deQFactor_ptr, k_deQFactor_ptr, + q_output_scale_ptr, k_output_scale_ptr, + use_ORDER_COL32_2R_4R4); +} + +template +void add_QK_bias_transform_rebuild_padding_kernelLauncher(int8_t *q_buf, int8_t *k_buf, + const int8_t* Q, const float* bias_Q, + const int8_t* K, const float* bias_K, + const int* sequence_id_offset, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_QK_bias_transform_rebuild_padding_kernelLauncher(int8_t *q_buf, int8_t *k_buf, + const int8_t* Q, const half* bias_Q, + const int8_t* K, const half* bias_K, + const int* sequence_id_offset, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + //input matrix a matrix of m = batch_size*seq_len , n = head_num*size_per_head, CUBLASLT_ORDER_COL32 //output matrixes are a series of sub-matrixes with size of m = size_per_head, n = seq_len , CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 //only for int32_t Input int8_t Output @@ -669,78 +1091,68 @@ void add_V_bias_transform(int8_t *v_buf_, const int32_t *V, const T *V_bias, con buf_ptr4[(blockIdx.z*stride + (col << 5)*size_per_head + row) >> 2] = dataTmp; } - - -//input matrix a matrix of m = batch_size*seq_len , n = head_num*size_per_head, CUBLASLT_ORDER_COL32 -//output matrixes are a series of sub-matrixes with size of m = size_per_head, n = seq_len , CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 -//only for int8_t IO -//seq_len, size_per_head must be a multiple of 32 -//grid = (size_per_head/32, seq_len/32, batch_size*head_num) -//block = (8, 32); -//using char4 -//per tensor quantization for weight -template +template <> __global__ -void add_V_bias_transform(int8_t *v_buf_, const int8_t *V, const T *V_bias, const int batch_size, const int seq_len, - const int head_num, const int size_per_head, int stride, - const float *input_deQFactor_ptr, const float *out_scale_ptr, bool use_ORDER_COL32_2R_4R4) +void add_V_bias_transform(int8_t *v_buf_, const int32_t *V, const half *V_bias, const int batch_size, const int seq_len, + const int head_num, const int size_per_head, int stride, const float* weight_amax, + const float *input_deQFactor_div127_ptr, const float *out_scale_ptr, bool use_ORDER_COL32_2R_4R4) { - const float input_deQFactor = __ldg(input_deQFactor_ptr); + const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); const float out_scale = __ldg(out_scale_ptr); __shared__ int8_t shm[32][33]; - const char4* data_ptr = (const char4*)V; + const int32_t* data_ptr = V; char4* buf_ptr4 = (char4*) v_buf_; - const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; //for src of (seq_len, size_per_head) int batch_id = blockIdx.z/head_num; int head_id = blockIdx.z%head_num; - int word_id = (blockIdx.y << 5) + threadIdx.y; - int id_in_size = (blockIdx.x << 5) + threadIdx4; + + int blockIdy32 = (blockIdx.y << 5); + int blockIdx32 = (blockIdx.x << 5); + int word_id = blockIdy32 + threadIdx.y; + int id_in_size = blockIdx32 + threadIdx4; //for V layout (batch_size*seq_len, head_num*size_per_head) int col = head_id*size_per_head + id_in_size; int row = batch_id*seq_len + word_id; - int inIdx = (((col >> 5) << 5)*batch_size*seq_len + ((row << 5) + (col&31))) >> 2; + int inIdx = ((col & 0xffffffe0)*batch_size*seq_len + ((row << 5) + (col&31))); //for shm row-major int sh_col = threadIdx4; int sh_row = threadIdx.y; - float tmp; + int col_2 = col >> 1; float scale; - //const half2* bias_ptr2 = (const half2*)bias_ptr; - //half2 tmp2; + const half2* bias_ptr2 = (const half2*)V_bias; + half2 tmp2; - //tmp2 = __ldg(&bias_ptr2[col >> 1]); - - char4 dataTmp = __ldg(data_ptr + inIdx); + tmp2 = __ldg(bias_ptr2+col_2); - scale = dataTmp.x * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr + col));//(tmp2.x); - shm[sh_row][sh_col] = float_to_int8_rn(tmp*out_scale); + scale = __ldg(data_ptr+inIdx) * __ldg(weight_amax+col) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.x); + shm[sh_row][sh_col] = float_to_int8_rn(scale*out_scale); - scale = dataTmp.y * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr+col+1));//(tmp2.y); - shm[sh_row][sh_col+1] = float_to_int8_rn(tmp*out_scale); + scale = __ldg(data_ptr+inIdx+1) * __ldg(weight_amax+col+1) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.y); + shm[sh_row][sh_col+1] = float_to_int8_rn(scale*out_scale); - //tmp2 = __ldg(&bias_ptr2[(col >> 1) + 1]); + tmp2 = __ldg(bias_ptr2 + col_2 + 1); - scale = dataTmp.z * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr+col+2));//(tmp2.x); - shm[sh_row][sh_col+2] = float_to_int8_rn(tmp*out_scale); + scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.x); + shm[sh_row][sh_col+2] = float_to_int8_rn(scale*out_scale); - scale = dataTmp.w * input_deQFactor; - tmp = scale + static_cast(__ldg(bias_ptr+col+3));//(tmp2.y); - shm[sh_row][sh_col+3] = float_to_int8_rn(tmp*out_scale); + scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; + scale = scale + static_cast(tmp2.y); + shm[sh_row][sh_col+3] = float_to_int8_rn(scale*out_scale); __syncthreads(); //for dst of (size_per_head, seq_len) - word_id = (blockIdx.y << 5) + threadIdx4; - id_in_size = (blockIdx.x << 5) + threadIdx.y; + word_id = blockIdy32 + threadIdx4; + id_in_size = blockIdx32 + threadIdx.y; col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) @@ -761,7 +1173,7 @@ void add_V_bias_transform(int8_t *v_buf_, const int8_t *V, const T *V_bias, cons ////id_in_size/8 is the number of tile of (8 rows 32 columns) -- column-major ////id_in_size%2 is even row, otherwise odd row ////word_id%COL32_/8 is the number tile of (8 rows 8 columns) - ((((id_in_size >> 3) << 3) + ((id_in_size&1) << 2) + ((word_id&31) >> 3)) << 5) + + (((id_in_size & 0xfffffff8) + ((id_in_size&1) << 2) + ((word_id&31) >> 3)) << 5) + ////word_id%8 >= 4 is the right half of (8 rows 8 columns) tile ////(id_in_size%8/2) is (the row id of alternating 4 rows) - 1 (((((word_id&7) >= 4)?4:0) + ((id_in_size&7) >> 1)) << 2) + @@ -769,7 +1181,8 @@ void add_V_bias_transform(int8_t *v_buf_, const int8_t *V, const T *V_bias, cons (word_id&3) ); } - + + char4 dataTmp; dataTmp.x = shm[sh_col][sh_row]; dataTmp.y = shm[sh_col+1][sh_row]; dataTmp.z = shm[sh_col+2][sh_row]; @@ -777,68 +1190,106 @@ void add_V_bias_transform(int8_t *v_buf_, const int8_t *V, const T *V_bias, cons buf_ptr4[(blockIdx.z*stride + (col << 5)*size_per_head + row) >> 2] = dataTmp; } -template <> +template +void add_V_bias_transform_kernelLauncher(int8_t *v_buf, const int32_t *V, const T *V_bias, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float* weight_amax, + const float *input_deQFactor_div127_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream) +{ + add_V_bias_transform<<>>(v_buf, V, V_bias, batch_size, seq_len, head_num, size_per_head, seq_len*size_per_head, weight_amax, input_deQFactor_div127_ptr, out_scale_ptr, use_ORDER_COL32_2R_4R4); +} + +template +void add_V_bias_transform_kernelLauncher(int8_t *v_buf, const int32_t *V, const float *V_bias, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float* weight_amax, + const float *input_deQFactor_div127_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_V_bias_transform_kernelLauncher(int8_t *v_buf, const int32_t *V, const half *V_bias, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float* weight_amax, + const float *input_deQFactor_div127_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +//input matrix a matrix of m = batch_size*seq_len , n = head_num*size_per_head, CUBLASLT_ORDER_COL32 +//seq_len_padded = (seq_len+31)/32*32 +//output matrixes are a series of sub-matrixes with size of m = size_per_head, n = seq_len_padded , CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 +//only for int8_t IO +//size_per_head must be a multiple of 32 +//grid = (size_per_head/32, seq_len_padded/32, batch_size*head_num) +//block = (8, 32); +//using char4 +//per tensor quantization for weight +template __global__ -void add_V_bias_transform(int8_t *v_buf_, const int32_t *V, const half *V_bias, const int batch_size, const int seq_len, - const int head_num, const int size_per_head, int stride, const float* weight_amax, - const float *input_deQFactor_div127_ptr, const float *out_scale_ptr, bool use_ORDER_COL32_2R_4R4) +void add_V_bias_transform_varlen(int8_t *v_buf_, const int8_t *V, const T *V_bias, const int batch_size, const int seq_len, + const int head_num, const int size_per_head, const int seq_len_padded, int stride, + const float *input_deQFactor_ptr, const float *out_scale_ptr, bool use_ORDER_COL32_2R_4R4) { - const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); + const float input_deQFactor = __ldg(input_deQFactor_ptr); const float out_scale = __ldg(out_scale_ptr); __shared__ int8_t shm[32][33]; - const int32_t* data_ptr = V; + const char4* data_ptr = (const char4*)V; char4* buf_ptr4 = (char4*) v_buf_; + const T* bias_ptr = V_bias; int threadIdx4 = threadIdx.x << 2; //for src of (seq_len, size_per_head) int batch_id = blockIdx.z/head_num; int head_id = blockIdx.z%head_num; - - int blockIdy32 = (blockIdx.y << 5); - int blockIdx32 = (blockIdx.x << 5); - int word_id = blockIdy32 + threadIdx.y; - int id_in_size = blockIdx32 + threadIdx4; + int word_id = (blockIdx.y << 5) + threadIdx.y; + int id_in_size = (blockIdx.x << 5) + threadIdx4; - //for V layout (batch_size*seq_len, head_num*size_per_head) - int col = head_id*size_per_head + id_in_size; - int row = batch_id*seq_len + word_id; - int inIdx = ((col & 0xffffffe0)*batch_size*seq_len + ((row << 5) + (col&31))); + int col, row; //for shm row-major int sh_col = threadIdx4; int sh_row = threadIdx.y; + char4 dataTmp; + if (word_id < seq_len) + { + //for V layout (batch_size*seq_len, head_num*size_per_head) + col = head_id*size_per_head + id_in_size; + row = batch_id*seq_len + word_id; + int inIdx = (((col >> 5) << 5)*batch_size*seq_len + ((row << 5) + (col&31))) >> 2; - int col_2 = col >> 1; - float scale; - - const half2* bias_ptr2 = (const half2*)V_bias; - half2 tmp2; - - tmp2 = __ldg(bias_ptr2+col_2); + float tmp; + float scale; - scale = __ldg(data_ptr+inIdx) * __ldg(weight_amax+col) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.x); - shm[sh_row][sh_col] = float_to_int8_rn(scale*out_scale); + dataTmp = __ldg(data_ptr + inIdx); - scale = __ldg(data_ptr+inIdx+1) * __ldg(weight_amax+col+1) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.y); - shm[sh_row][sh_col+1] = float_to_int8_rn(scale*out_scale); + scale = dataTmp.x * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col));//(tmp2.x); + shm[sh_row][sh_col] = float_to_int8_rn(tmp*out_scale); - tmp2 = __ldg(bias_ptr2 + col_2 + 1); + scale = dataTmp.y * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr+col+1));//(tmp2.y); + shm[sh_row][sh_col+1] = float_to_int8_rn(tmp*out_scale); - scale = __ldg(data_ptr + inIdx + 2) * __ldg(weight_amax + col + 2) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.x); - shm[sh_row][sh_col+2] = float_to_int8_rn(scale*out_scale); + scale = dataTmp.z * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr+col+2));//(tmp2.x); + shm[sh_row][sh_col+2] = float_to_int8_rn(tmp*out_scale); - scale = __ldg(data_ptr + inIdx + 3) * __ldg(weight_amax + col + 3) * input_deQFactor_div127; - scale = scale + static_cast(tmp2.y); - shm[sh_row][sh_col+3] = float_to_int8_rn(scale*out_scale); + scale = dataTmp.w * input_deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr+col+3));//(tmp2.y); + shm[sh_row][sh_col+3] = float_to_int8_rn(tmp*out_scale); + } + else + { + shm[sh_row][sh_col] = shm[sh_row][sh_col+1] = shm[sh_row][sh_col+2] = shm[sh_row][sh_col+3] = 0; + } __syncthreads(); - //for dst of (size_per_head, seq_len) - word_id = blockIdy32 + threadIdx4; - id_in_size = blockIdx32 + threadIdx.y; + //for dst of (size_per_head, seq_len_padded) + word_id = (blockIdx.y << 5) + threadIdx4; + id_in_size = (blockIdx.x << 5) + threadIdx.y; col = (word_id >> 5); if (use_ORDER_COL32_2R_4R4) @@ -859,7 +1310,7 @@ void add_V_bias_transform(int8_t *v_buf_, const int32_t *V, const half *V_bias, ////id_in_size/8 is the number of tile of (8 rows 32 columns) -- column-major ////id_in_size%2 is even row, otherwise odd row ////word_id%COL32_/8 is the number tile of (8 rows 8 columns) - (((id_in_size & 0xfffffff8) + ((id_in_size&1) << 2) + ((word_id&31) >> 3)) << 5) + + ((((id_in_size >> 3) << 3) + ((id_in_size&1) << 2) + ((word_id&31) >> 3)) << 5) + ////word_id%8 >= 4 is the right half of (8 rows 8 columns) tile ////(id_in_size%8/2) is (the row id of alternating 4 rows) - 1 (((((word_id&7) >= 4)?4:0) + ((id_in_size&7) >> 1)) << 2) + @@ -867,8 +1318,7 @@ void add_V_bias_transform(int8_t *v_buf_, const int32_t *V, const half *V_bias, (word_id&3) ); } - - char4 dataTmp; + dataTmp.x = shm[sh_col][sh_row]; dataTmp.y = shm[sh_col+1][sh_row]; dataTmp.z = shm[sh_col+2][sh_row]; @@ -876,6 +1326,44 @@ void add_V_bias_transform(int8_t *v_buf_, const int32_t *V, const half *V_bias, buf_ptr4[(blockIdx.z*stride + (col << 5)*size_per_head + row) >> 2] = dataTmp; } +template +void add_V_bias_transform_kernelLauncher(int8_t *v_buf, const int8_t *V, const T *V_bias, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float *input_deQFactor_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream) +{ + assert(size_per_head % 32 == 0); + if (seq_len % 32 == 0) + { + add_V_bias_transform_varlen<<>>( + v_buf, V, V_bias, + batch_size, seq_len, head_num, size_per_head, + seq_len, seq_len*size_per_head, + input_deQFactor_ptr, out_scale_ptr, use_ORDER_COL32_2R_4R4); + } + else + { + const int seq_len_padded = (seq_len + 31)/32*32; + add_V_bias_transform_varlen<<>>( + v_buf, V, V_bias, + batch_size, seq_len, head_num, size_per_head, + seq_len_padded, seq_len_padded*size_per_head, + input_deQFactor_ptr, out_scale_ptr, use_ORDER_COL32_2R_4R4); + } +} + +template +void add_V_bias_transform_kernelLauncher(int8_t *v_buf, const int8_t *V, const float *V_bias, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float *input_deQFactor_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_V_bias_transform_kernelLauncher(int8_t *v_buf, const int8_t *V, const half *V_bias, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float *input_deQFactor_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + //add bias into V & rebuild padding //input matrix a matrix of m = valid_word_num, n = head_num*size_per_head, CUBLASLT_ORDER_COL32 //output matrixes are a series of sub-matrixes with size of m = size_per_head, n = seq_len , CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 @@ -1091,6 +1579,46 @@ void add_V_bias_transform_rebuild_padding(int8_t *v_buf_, const int32_t *V, cons buf_ptr4[(blockIdx.z*stride + (col << 5)*size_per_head + row) >> 2] = dataTmp; } +template +void add_V_bias_transform_rebuild_padding_kernelLauncher(int8_t *v_buf, const int32_t *V, const T *V_bias, + const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float* weight_amax, + const float *input_deQFactor_div127_ptr, + const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream) +{ + add_V_bias_transform_rebuild_padding<<>>( + v_buf, V, V_bias, + sequence_id_map, valid_word_num, + batch_size, seq_len, + head_num, size_per_head, + seq_len*size_per_head, + weight_amax, input_deQFactor_div127_ptr, + out_scale_ptr, use_ORDER_COL32_2R_4R4); +} + +template +void add_V_bias_transform_rebuild_padding_kernelLauncher(int8_t *v_buf, const int32_t *V, const float *V_bias, + const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float* weight_amax, + const float *input_deQFactor_div127_ptr, + const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_V_bias_transform_rebuild_padding_kernelLauncher(int8_t *v_buf, const int32_t *V, const half *V_bias, + const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float* weight_amax, + const float *input_deQFactor_div127_ptr, + const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + //add bias into V & rebuild padding //input matrix a matrix of m = valid_word_num, n = head_num*size_per_head, CUBLASLT_ORDER_COL32 //output matrixes are a series of sub-matrixes with size of m = size_per_head, n = seq_len , CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 @@ -1201,6 +1729,149 @@ void add_V_bias_transform_rebuild_padding(int8_t *v_buf_, const int8_t *V, const buf_ptr4[(blockIdx.z*stride + (col << 5)*size_per_head + row) >> 2] = dataTmp; } +//add bias into V & rebuild padding +//input matrix a matrix of m = valid_word_num, n = head_num*size_per_head, CUBLASLT_ORDER_COL32 +//output matrixes are a series of sub-matrixes with size of m = size_per_head, n = seq_len_padded , CUBLASLT_ORDER_COL4_4R2_8C or CUBLASLT_ORDER_COL32_2R_4R4 +//only for int8_t IO +//seq_len, size_per_head must be a multiple of 32 +//grid = (size_per_head/32, seq_len_padded/32, batch_size*head_num) +//block = (8, 32); +//using char4 +//per tensor quantization for weight +template +__global__ +void add_V_bias_transform_rebuild_padding_varlen(int8_t *v_buf_, const int8_t *V, const T *V_bias, const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, const int seq_len_padded, + const int head_num, const int size_per_head, int stride, + const float *deQFactor_ptr, const float *out_scale_ptr, bool use_ORDER_COL32_2R_4R4) +{ + __shared__ int8_t shm[32][33]; + const char4* data_ptr = (const char4*)V; + char4* buf_ptr4 = (char4*) v_buf_; + const T* bias_ptr = V_bias; + + int threadIdx4 = threadIdx.x << 2; + + //for src of (seq_len, size_per_head) + int batch_id = blockIdx.z/head_num; + int head_id = blockIdx.z%head_num; + int word_id = (blockIdx.y << 5) + threadIdx.y; + int id_in_size = (blockIdx.x << 5) + threadIdx4; + + //for shm row-major + int sh_col = threadIdx4; + int sh_row = threadIdx.y; + + //for V layout (batch_size*seq_len, head_num*size_per_head) + int col; + int row = word_id < seq_len ? __ldg(sequence_id_map + batch_id*seq_len + word_id) : -1; + + if (row != -1){ + col = head_id*size_per_head + id_in_size; + int inIdx = ((col & 0xffffffe0)*valid_word_num + ((row << 5) + (col&31))) >> 2; + + float tmp; + float scale; + + const float deQFactor = __ldg(deQFactor_ptr); + const float out_scale = __ldg(out_scale_ptr); + + char4 dataTmp = __ldg(data_ptr + inIdx); + + scale = dataTmp.x * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr + col)); + shm[sh_row][sh_col] = float_to_int8_rn(tmp*out_scale); + + scale = dataTmp.y * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr+col+1)); + shm[sh_row][sh_col+1] = float_to_int8_rn(tmp*out_scale); + + scale = dataTmp.z * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr+col+2)); + shm[sh_row][sh_col+2] = float_to_int8_rn(tmp*out_scale); + + scale = dataTmp.w * deQFactor; + tmp = scale + static_cast(__ldg(bias_ptr+col+3)); + shm[sh_row][sh_col+3] = float_to_int8_rn(tmp*out_scale); + } + else{ + shm[sh_row][sh_col] = shm[sh_row][sh_col + 1] = shm[sh_row][sh_col + 2] = shm[sh_row][sh_col + 3] = 0; + } + __syncthreads(); + + char4 dataTmp; + dataTmp.x = shm[sh_col][sh_row]; + dataTmp.y = shm[sh_col+1][sh_row]; + dataTmp.z = shm[sh_col+2][sh_row]; + dataTmp.w = shm[sh_col+3][sh_row]; + + //for dst of (size_per_head, seq_len_padded) + word_id = (blockIdx.y << 5) + threadIdx4; + id_in_size = (blockIdx.x << 5) + threadIdx.y; + col = (word_id >> 5); + + if (use_ORDER_COL32_2R_4R4) + { + int row_in_tile = id_in_size & 31; + int col_in_tile = word_id & 31; + row = ( + //COL32_2R_4R4 + ((id_in_size >> 5) << 10) + + //(((row%8)/2*4+row/8)*2+row%2)*32+col + (((((((row_in_tile&7)>>1)<<2)+(row_in_tile>>3))<<1)+(row_in_tile&1))<<5)+col_in_tile + ); + } + else + { + row = ( + //COL4 + ////id_in_size/8 is the number of tile of (8 rows 32 columns) -- column-major + ////id_in_size%2 is even row, otherwise odd row + ////word_id%COL32_/8 is the number tile of (8 rows 8 columns) + (((id_in_size & 0xfffffff8) + ((id_in_size&1) << 2) + ((word_id&31) >> 3)) << 5) + + ////word_id%8 >= 4 is the right half of (8 rows 8 columns) tile + ////(id_in_size%8/2) is (the row id of alternating 4 rows) - 1 + (((((word_id&7) >= 4)?4:0) + ((id_in_size&7) >> 1)) << 2) + + ////word_id%4 is the id of 4 cols + (word_id&3) + ); + } + + buf_ptr4[(blockIdx.z*stride + (col << 5)*size_per_head + row) >> 2] = dataTmp; +} + + +template +void add_V_bias_transform_rebuild_padding_kernelLauncher(int8_t *v_buf, const int8_t *V, const T *V_bias, + const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float *deQFactor_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream) +{ + int seq_len_padded = (seq_len + 31)/32*32; + add_V_bias_transform_rebuild_padding_varlen<<>>( + v_buf, V, V_bias, sequence_id_map, valid_word_num, + batch_size, seq_len, seq_len_padded, head_num, size_per_head, seq_len_padded*size_per_head, + deQFactor_ptr, out_scale_ptr, use_ORDER_COL32_2R_4R4); +} + +template +void add_V_bias_transform_rebuild_padding_kernelLauncher(int8_t *v_buf, const int8_t *V, const float *V_bias, + const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float *deQFactor_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_V_bias_transform_rebuild_padding_kernelLauncher(int8_t *v_buf, const int8_t *V, const half *V_bias, + const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float *deQFactor_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + __global__ void trt_add_QKV_bias(half2* Q, const half2* bias_Q, half2* K, const half2* bias_K, half2* V, const half2* bias_V, half2* q_buf_, half2* k_buf_, half2* v_buf_, @@ -1252,33 +1923,213 @@ void OpenMultiHeadAttention::trt_add_QKV_bias_kernelLauncher( head_num_, size_per_head_ / 2); } -template -void OpenMultiHeadAttention::fused_multiHeadAttr_kernelLauncher() +// add bias and then transform from +// 3 * ([valid_word_num, head*size] + CUBLASLT_ORDER_COL32) -> [valid_word_num, head, 3, size] + row-major +// input is INT32 && per axis quantization for weight +// output is INT8 && per tensor quantization +// grid((head*size + 31)/32, (valid_word_num + 31)/32, 3) +// block(8, 32) +// size should be a multiple of 4 +//using char4 as output, int4 as input +template +__global__ +void trt_add_QKV_bias_COL32_int32IInt8O(char4* output, const int4* QKV, + const T* bias_Q, const T* bias_K, const T* bias_V, + const float *input_deQFactor_div127_ptr, + const float *q_weight_amax, + const float *k_weight_amax, + const float *v_weight_amax, + const float qkv_output_scale, const int valid_word_num, const int head_num, + const int size_per_head, const int head_num_x_size_per_head) { - trt_add_QKV_bias_kernelLauncher(param_.self_attention.query_weight.bias, - param_.self_attention.key_weight.bias, - param_.self_attention.value_weight.bias); - - - const int B = param_.trt_seqlen_size - 1; - const int maxS = from_seq_len_; - int S = 384; - if (maxS <= 64) + const int qkv_id = blockIdx.z; + const int seq_id = (blockIdx.y << 5) + threadIdx.y; + const int threadIdx4 = threadIdx.x << 2; + int hidden_id = (blockIdx.x << 5) + threadIdx4; + const int size_id = hidden_id % size_per_head; + const int head_id = hidden_id / size_per_head; + + const bool qual = (seq_id < valid_word_num) && (hidden_id < head_num_x_size_per_head); + if (qual) { - S = 64; + const float* weight_amax = qkv_id == 0 ? q_weight_amax : (qkv_id == 1 ? k_weight_amax : v_weight_amax); + const float input_deQFactor_div127 = __ldg(input_deQFactor_div127_ptr); + + const T* bias_ptr = (qkv_id == 0) ? bias_Q : ((qkv_id == 1) ? bias_K : bias_V); + + const int input_id = (qkv_id * valid_word_num * head_num_x_size_per_head + ((hidden_id & 0xffffffe0)*valid_word_num + (seq_id << 5) + (hidden_id&31))) >> 2; + + char4 tmp; + const int4 tmp_int4 = __ldg(QKV+input_id); + + tmp.x = float_to_int8_rn((static_cast(tmp_int4.x) * __ldg(weight_amax+hidden_id) * input_deQFactor_div127 + static_cast(__ldg(bias_ptr + hidden_id))) * qkv_output_scale); + + hidden_id += 1; + tmp.y = float_to_int8_rn((static_cast(tmp_int4.y) * __ldg(weight_amax+hidden_id) * input_deQFactor_div127 + static_cast(__ldg(bias_ptr + hidden_id))) * qkv_output_scale); + + hidden_id += 1; + tmp.z = float_to_int8_rn((static_cast(tmp_int4.z) * __ldg(weight_amax+hidden_id) * input_deQFactor_div127 + static_cast(__ldg(bias_ptr + hidden_id))) * qkv_output_scale); + + hidden_id += 1; + tmp.w = float_to_int8_rn((static_cast(tmp_int4.w) * __ldg(weight_amax+hidden_id) * input_deQFactor_div127 + static_cast(__ldg(bias_ptr + hidden_id))) * qkv_output_scale); + + //const int output_id = (seq_id * 3 * head_num_x_size_per_head + head_id * 3 * size_per_head + qkv_id * size_per_head + size_id) >> 2; + const int output_id = ((seq_id * head_num_x_size_per_head + head_id * size_per_head) * 3 + qkv_id * size_per_head + size_id) >> 2; + + output[output_id] = tmp; } - else if (maxS <= 96) +} + +template +void OpenMultiHeadAttention::trt_add_QKV_bias_COL32_int32Iint8O_kernelLauncher( + int8_t* output, + const int32_t* Q, + const DataType_* bias_Q, + const DataType_* bias_K, + const DataType_* bias_V, + const float *input_deQFactor_div127_ptr, + const float * q_weight_amax, + const float * k_weight_amax, + const float * v_weight_amax, + const float qkv_output_scale) +{ + int head_num_x_size_per_head = head_num_*size_per_head_; + dim3 grid((head_num_x_size_per_head + 31)/32, (param_.valid_word_num + 31)/32, 3); + dim3 block(8, 32); + + assert(size_per_head_ % 4 == 0); + + trt_add_QKV_bias_COL32_int32IInt8O<<>>((char4*)output, (const int4*)Q, + bias_Q, bias_K, bias_V, + input_deQFactor_div127_ptr, + q_weight_amax, + k_weight_amax, + v_weight_amax, + qkv_output_scale, param_.valid_word_num, + head_num_, size_per_head_, head_num_x_size_per_head); +} + +// Add bias, and then transform from +// 3 * ([valid_word_num, head*size] + CUBLASLT_ORDER_COL32) -> [valid_word_num, head, 3, size] + row-major +// grid((head*size + 31)/32, (valid_word_num + 31)/32, 3) +// block(8, 32) +// size should be a multiple of 4 +template +__global__ +void trt_add_QKV_bias_COL32_int8IO(char4* output, const char4* QKV, + const T* bias_Q, const T* bias_K, const T* bias_V, + const float *q_input_deQFactor_ptr, const float *k_input_deQFactor_ptr, + const float *v_input_deQFactor_ptr, const float qkv_output_scale, + const int valid_word_num, const int head_num, const int size_per_head, + const int head_num_x_size_per_head) +{ + const int qkv_id = blockIdx.z; + const int seq_id = (blockIdx.y << 5) + threadIdx.y; + const int threadIdx4 = threadIdx.x << 2; + const int hidden_id = (blockIdx.x << 5) + threadIdx4; + const int size_id = hidden_id % size_per_head; + const int head_id = hidden_id / size_per_head; + + const bool qual = (seq_id < valid_word_num) && (hidden_id < head_num_x_size_per_head); + if (qual) { - S = 96; + const float *input_deQFactor_ptr = (qkv_id == 0) ? q_input_deQFactor_ptr : ((qkv_id == 1) ? k_input_deQFactor_ptr : v_input_deQFactor_ptr); + const float input_deQFactor = __ldg(input_deQFactor_ptr); + + const T* bias_ptr = (qkv_id == 0) ? bias_Q : ((qkv_id == 1) ? bias_K : bias_V); + + const int input_id = (qkv_id * valid_word_num * head_num_x_size_per_head + ((hidden_id & 0xffffffe0)*valid_word_num + (seq_id << 5) + (hidden_id&31))) >> 2; + + char4 tmp = __ldg(QKV+input_id); + + tmp.x = float_to_int8_rn((static_cast(tmp.x) * input_deQFactor + static_cast(__ldg(bias_ptr + hidden_id))) * qkv_output_scale); + + tmp.y = float_to_int8_rn((static_cast(tmp.y) * input_deQFactor + static_cast(__ldg(bias_ptr + hidden_id + 1))) * qkv_output_scale); + + tmp.z = float_to_int8_rn((static_cast(tmp.z) * input_deQFactor + static_cast(__ldg(bias_ptr + hidden_id + 2))) * qkv_output_scale); + + tmp.w = float_to_int8_rn((static_cast(tmp.w) * input_deQFactor + static_cast(__ldg(bias_ptr + hidden_id + 3))) * qkv_output_scale); + + //const int output_id = (seq_id * 3 * head_num_x_size_per_head + head_id * 3 * size_per_head + qkv_id * size_per_head + size_id) >> 2; + const int output_id = ((seq_id * head_num_x_size_per_head + head_id * size_per_head) * 3 + qkv_id * size_per_head + size_id) >> 2; + + output[output_id] = tmp; } - else if (maxS <= 128) +} + +template +void OpenMultiHeadAttention::trt_add_QKV_bias_COL32_int8IO_kernelLauncher( + int8_t* output, + const int8_t* Q, + const DataType_* bias_Q, + const DataType_* bias_K, + const DataType_* bias_V, + const float *q_input_deQFactor_ptr, + const float *k_input_deQFactor_ptr, + const float *v_input_deQFactor_ptr, + const float qkv_output_scale) +{ + int head_num_x_size_per_head = head_num_*size_per_head_; + dim3 grid((head_num_x_size_per_head + 31)/32, (param_.valid_word_num + 31)/32, 3); + dim3 block(8, 32); + + assert(size_per_head_ % 4 == 0); + + trt_add_QKV_bias_COL32_int8IO<<>>((char4*)output, (const char4*)Q, + bias_Q, bias_K, bias_V, + q_input_deQFactor_ptr, k_input_deQFactor_ptr, v_input_deQFactor_ptr, + qkv_output_scale, param_.valid_word_num, + head_num_, size_per_head_, head_num_x_size_per_head); +} + +template +void OpenMultiHeadAttention::int8_fused_multiHeadAttr_kernelLauncher(const void* Q, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, const float *v_deQFactor_ptr, + const float mScaleQkv, const int S) +{ + + if (int8_mode_ == 1) { - S = 128; + trt_add_QKV_bias_COL32_int32Iint8O_kernelLauncher((int8_t*)q_buf_, + (const int32_t*)Q, + param_.self_attention.query_weight.bias, + param_.self_attention.key_weight.bias, + param_.self_attention.value_weight.bias, + param_.amaxList+2, query_weight_amax_list, + key_weight_amax_list, value_weight_amax_list, + 1.0f/mScaleQkv); } - else if (maxS <= 256) + else if (int8_mode_ == 2) { - S = 256; + trt_add_QKV_bias_COL32_int8IO_kernelLauncher((int8_t*)q_buf_, + (const int8_t*)Q, + param_.self_attention.query_weight.bias, + param_.self_attention.key_weight.bias, + param_.self_attention.value_weight.bias, + q_deQFactor_ptr, k_deQFactor_ptr, v_deQFactor_ptr, + 1.0f/mScaleQkv + ); } + + const int B = param_.trt_seqlen_size - 1; + dispatcher_int8->setup(S, B); + dispatcher_int8->run((int8_t*)q_buf_, nullptr, param_.trt_seqlen_offset, trt_attn_workspace_, (int8_t*)transpose_dst_int_buf_, param_.stream); + + //transpose_dst_int_buf_ is [batch*seqlen, hidden_dim] row-major + rowMajorToCOL32_kernelLauncher((int8_t*)(param_.attr_out), (const int8_t*)transpose_dst_int_buf_, param_.valid_word_num, head_num_*size_per_head_, param_.stream); +} + + +template +void OpenMultiHeadAttention::fused_multiHeadAttr_kernelLauncher(const int S) +{ + + trt_add_QKV_bias_kernelLauncher(param_.self_attention.query_weight.bias, + param_.self_attention.key_weight.bias, + param_.self_attention.value_weight.bias); + + + const int B = param_.trt_seqlen_size - 1; dispatcher_fp16->setup(S, B); dispatcher_fp16->run(q_buf_, nullptr, param_.trt_seqlen_offset, trt_attn_workspace_, param_.attr_out, param_.stream); } @@ -1394,6 +2245,34 @@ void add_QKV_bias_rebuild_padding(T* Q, const T* bias_Q, T* K, const T* bias_K, v_buf_[tgt_id] = V[src_id] + bias_V[tid]; } +template +void add_QKV_bias_rebuild_padding_kernelLauncher(T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, const T* bias_V, T* q_buf, T* k_buf, T* v_buf, + const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, + const int* mask_offset, cudaStream_t stream) +{ + const int k = head_num*size_per_head; + + if(std::is_same::value) + { + add_QKV_bias_rebuild_padding<<>>(Q, bias_Q, K, bias_K, + V, bias_V, q_buf, k_buf, v_buf, + batch_size, seq_len, head_num, size_per_head, mask_offset); + } + else + { + add_QKV_bias_rebuild_padding<<>>((half2*)Q, (const half2*)bias_Q, + (half2*)K, (const half2*)bias_K, (half2*)V, (const half2*)bias_V, + (half2*)q_buf, (half2*)k_buf, (half2*)v_buf, + batch_size, seq_len, head_num, size_per_head / 2, mask_offset); + } +} + +template +void add_QKV_bias_rebuild_padding_kernelLauncher(float* Q, const float* bias_Q, float* K, const float* bias_K, float* V, const float* bias_V, float* q_buf, float* k_buf, float* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, const int* mask_offset, cudaStream_t stream); + +template +void add_QKV_bias_rebuild_padding_kernelLauncher(half* Q, const half* bias_Q, half* K, const half* bias_K, half* V, const half* bias_V, half* q_buf, half* k_buf, half* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int valid_word_num, const int* mask_offset, cudaStream_t stream); + template __global__ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, @@ -1619,13 +2498,13 @@ void softmax_kernel_v3_LE32(T* qk_buf_, const T* attr_mask, const int batch_size } } -//int_buf are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 +//input are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 //grid = (seq_len, batch_size, head_num) //block.x = max(32, (seq_len/4 + 31)/32*32) //for int32_t I; int8 O; template __global__ -void softmax_COL32(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_mask, const int batch_size, +void softmax_COL32(int8_t* output, const int32_t* input, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const float scalar1a, const float *scalar1b, const float *scalar1c, const float *amax_ptr, const int head_num_x_seq_len, const int seq_len_x_seq_len) { @@ -1634,7 +2513,7 @@ void softmax_COL32(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_mask, int mask_id; int threadIdx4 = threadIdx.x << 2; - char4* buf4Ptr = (char4 *)qk_buf_; + char4* buf4Ptr = (char4 *)output; bool qual = threadIdx4 < seq_len; for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ @@ -1654,10 +2533,10 @@ void softmax_COL32(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_mask, float4 floatTmp4 = {0.0f, 0.0f, 0.0f, 0.0f}; if (qual){ - floatTmp4.x = static_cast(__ldg(int_buf + inIdx)) * scalar1; - floatTmp4.y = static_cast(__ldg(int_buf+inIdx+1)) * scalar1; - floatTmp4.z = static_cast(__ldg(int_buf+inIdx+2)) * scalar1; - floatTmp4.w = static_cast(__ldg(int_buf+inIdx+3)) * scalar1; + floatTmp4.x = static_cast(__ldg(input + inIdx)) * scalar1; + floatTmp4.y = static_cast(__ldg(input+inIdx+1)) * scalar1; + floatTmp4.z = static_cast(__ldg(input+inIdx+2)) * scalar1; + floatTmp4.w = static_cast(__ldg(input+inIdx+3)) * scalar1; } float mask_val, max_val; @@ -1728,42 +2607,45 @@ void softmax_COL32(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_mask, } } -//int_buf are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 +//input are a series of sub-matrixes of m = seq_len, n = seq_len_padded, CUBLASLT_ORDER_COL32 +//seq_len_padded = (seq_len+31)/32*32 //grid = (seq_len, batch_size, head_num) -//block.x = max(32, (seq_len/4 + 31)/32*32) +//block.x = max(32, (seq_len_padded/4 + 31)/32*32) //for int8_t IO; template __global__ -void softmax_COL32(int8_t* qk_buf_, const int8_t* int_buf, const T* attr_mask, const int batch_size, - const int head_num, const int seq_len, const float scalar1a, const float *scalar1b, - const float *amax_ptr, const int head_num_x_seq_len, const int seq_len_x_seq_len) +void softmax_COL32_varlen(int8_t* output, const int8_t* input, const T* attr_mask, const int batch_size, + const int head_num, const int seq_len, const int seq_len_padded, + const float scalar1a, const float *scalar1b, const float *amax_ptr, + const int seq_len_x_seq_len, const int seq_len_x_seq_len_padded) { const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b); int mask_id; int threadIdx4 = threadIdx.x << 2; - char4* buf4Ptr = (char4 *)qk_buf_; - const char4* inBuf4Ptr = (const char4*)int_buf; + char4* buf4Ptr = (char4 *)output; + const char4* inBuf4Ptr = (const char4*)input; - bool qual = threadIdx4 < seq_len; + const bool qual = threadIdx4 < seq_len; + const bool qual_padded = threadIdx4 < seq_len_padded; for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ char4 tmp4 = {0, 0, 0, 0}; - int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len) + + int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + (threadIdx4 & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdx4 & 31)) >> 2; - //set softmax of padding word to 0 + //set softmax of padding word in rows to 0 const float mask_in_seq = static_cast(__ldg(attr_mask+(blockIdx.y*seq_len_x_seq_len + seq_id))); if (mask_in_seq < 0.1f){ - if (qual) + if (qual_padded) buf4Ptr[inIdx] = tmp4; continue; } + //set softmax of padding word in cols to 0 float4 floatTmp4 = {0.0f, 0.0f, 0.0f, 0.0f}; - if (qual){ tmp4 = __ldg(inBuf4Ptr + inIdx); floatTmp4.x = static_cast(tmp4.x) * scalar1; @@ -1828,26 +2710,101 @@ void softmax_COL32(int8_t* qk_buf_, const int8_t* int_buf, const T* attr_mask, c } __syncthreads(); - if (qual){ + if (qual_padded){ - tmp4.x = float_to_int8_rn(floatTmp4.x*s_sum); - tmp4.y = float_to_int8_rn(floatTmp4.y*s_sum); - tmp4.z = float_to_int8_rn(floatTmp4.z*s_sum); - tmp4.w = float_to_int8_rn(floatTmp4.w*s_sum); + tmp4.x = qual ? float_to_int8_rn(floatTmp4.x*s_sum) : static_cast(0); + tmp4.y = qual ? float_to_int8_rn(floatTmp4.y*s_sum) : static_cast(0); + tmp4.z = qual ? float_to_int8_rn(floatTmp4.z*s_sum) : static_cast(0); + tmp4.w = qual ? float_to_int8_rn(floatTmp4.w*s_sum) : static_cast(0); buf4Ptr[inIdx] = tmp4; } } } -//int_buf are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 +//input are a series of sub-matrixes of m = seq_len, n = seq_len_padded, CUBLASLT_ORDER_COL32 +//seq_len_padded = (seq_len+31)/32*32 +//grid = (seq_len, batch_size, head_num) +//block.x = max(32, (seq_len_padded + 31)/32*32) +//for int8_t IO, I/O with int8_t element; +template +__global__ +void softmax_COL32_perElement_varlen(int8_t* output, const int8_t* input, const T* attr_mask, const int batch_size, + const int head_num, const int seq_len, const int seq_len_padded, + const float scalar1a, const float *scalar1b, const float *amax_ptr, + const int seq_len_x_seq_len, const int seq_len_x_seq_len_padded) +{ + const float amax = __ldg(amax_ptr); + const float scalar1 = scalar1a * __ldg(scalar1b); + int mask_id; + const int tidx = threadIdx.x; + + const bool qual = tidx < seq_len; + const bool qual_padded = tidx < seq_len_padded; + for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ + + int8_t tmp = 0; + int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + + (tidx & 0xffffffe0) * seq_len + + (seq_id << 5) + (tidx & 31)); + + //set softmax of padding word in rows to 0 + const float mask_in_seq = static_cast(__ldg(attr_mask+(blockIdx.y*seq_len_x_seq_len + seq_id))); + if (mask_in_seq < 0.1f){ + if (qual_padded) + output[inIdx] = tmp; + continue; + } + + //set softmax of padding word in cols to 0 + float floatTmp = qual ? (static_cast(__ldg(input + inIdx)) * scalar1) : 0.0f; + + float mask_val, max_val; + max_val = -1e20f; + + __shared__ float s_max, s_sum; + + if (qual){ + mask_id = tidx + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len; + mask_val = (1.0f - static_cast(__ldg(attr_mask+mask_id))) * -10000.0f; + floatTmp = floatTmp + mask_val; + } + + max_val = blockDim.x <= 32 ? warpReduceMax(floatTmp) : blockReduceMax(floatTmp); + + if (threadIdx.x == 0){ + s_max = max_val; + } + __syncthreads(); + + float sum_val = 0.0f; + + floatTmp = qual ? __expf(floatTmp - s_max) : floatTmp; + + sum_val = blockDim.x <= 32 ? warpReduceSum(floatTmp) : blockReduceSum(floatTmp); + + if (threadIdx.x == 0){ + s_sum = __fdividef(127.0f, (sum_val + 1e-6f)); + s_sum = __fdividef(s_sum, amax); + } + __syncthreads(); + + if (qual_padded){ + tmp = qual ? float_to_int8_rn(floatTmp*s_sum) : static_cast(0); + output[inIdx] = tmp; + } + } +} + + +//input are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 //grid = (seq_len, batch_size, head_num) //block.x = (seq_len + 31)/32 //for int32_t I; int8 O; //for seq_len <= 32 template __global__ -void softmax_COL32_LE32(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_mask, const int batch_size, +void softmax_COL32_LE32(int8_t* output, const int32_t* input, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const float scalar1a, const float *scalar1b, const float *scalar1c, const float *amax_ptr, const int head_num_x_seq_len, const int seq_len_x_seq_len) { @@ -1865,11 +2822,11 @@ void softmax_COL32_LE32(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_m float mask_in_seq = static_cast(__ldg(attr_mask+(blockIdx.y*seq_len_x_seq_len + seq_id))); if (mask_in_seq < 0.1f){ if (qual) - qk_buf_[inIdx] = 0; + output[inIdx] = 0; continue; } - float floatTmp = qual ? static_cast(__ldg(int_buf + inIdx)) * scalar1 : 0.0f; + float floatTmp = qual ? static_cast(__ldg(input + inIdx)) * scalar1 : 0.0f; float mask_val, max_val; @@ -1899,53 +2856,55 @@ void softmax_COL32_LE32(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_m if (qual){ - qk_buf_[inIdx] = float_to_int8_rn(floatTmp*s_sum); + output[inIdx] = float_to_int8_rn(floatTmp*s_sum); } } } -//int_buf are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 +//input are a series of sub-matrixes of m = seq_len, n = seq_len_padded, CUBLASLT_ORDER_COL32 +//seq_len_padded = (seq_len+31)/32*32 +//attr_mask is [batch_size, seq_len, seq_len] //grid = (seq_len, batch_size, head_num) -//block.x = (seq_len + 31)/32 +//block.x = seq_len_padded //for int8_t IO; -//for seq_len <= 32 +//for seq_len_padded == 32 template __global__ -void softmax_COL32_LE32(int8_t* qk_buf_, const int8_t* int_buf, const T* attr_mask, const int batch_size, - const int head_num, const int seq_len, const float scalar1a, const float *scalar1b, - const float *amax_ptr, const int head_num_x_seq_len, const int seq_len_x_seq_len) +void softmax_COL32_LE32_varlen(int8_t* output, const int8_t* input, const T* attr_mask, const int batch_size, + const int head_num, const int seq_len, const int seq_len_padded, + const float scalar1a, const float *scalar1b, const float *amax_ptr, + const int seq_len_x_seq_len, const int seq_len_x_seq_len_padded) { const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b); int mask_id; int threadIdxx = threadIdx.x; - bool qual = threadIdxx < seq_len; + const bool qual = threadIdxx < seq_len; + const bool qual_padded = threadIdxx < seq_len_padded; for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ - int inIdx = (blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len) + + int inIdx = (blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + (threadIdxx & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdxx & 31); - //set softmax of padding word to 0 + //set softmax of padding word in rows to 0 float mask_in_seq = static_cast(__ldg(attr_mask+(blockIdx.y*seq_len_x_seq_len + seq_id))); if (mask_in_seq < 0.1f){ - if (qual) - qk_buf_[inIdx] = 0; + if (qual_padded) + output[inIdx] = 0; continue; } - - float floatTmp = qual ? static_cast(__ldg(int_buf + inIdx)) * scalar1 : 0.0f; - float mask_val, max_val; - __shared__ float s_max, s_sum; + //set softmax of padding word in cols to 0 + float floatTmp = qual ? static_cast(__ldg(input + inIdx)) * scalar1 : 0.0f; mask_id = qual ? threadIdxx + blockIdx.y * seq_len_x_seq_len + seq_id * seq_len : 0; mask_val = qual ? (1.0f - static_cast(__ldg(attr_mask+mask_id))) * -10000.0f : 0.0f; floatTmp = qual ? floatTmp + mask_val : 0.0f; max_val = qual ? floatTmp : -1e20f; - max_val = blockDim.x <= 32 ? warpReduceMax(max_val) : blockReduceMax(max_val); + max_val = warpReduceMax(max_val); if (threadIdx.x == 0){ s_max = max_val; @@ -1963,20 +2922,20 @@ void softmax_COL32_LE32(int8_t* qk_buf_, const int8_t* int_buf, const T* attr_ma __syncthreads(); - if (qual){ - qk_buf_[inIdx] = float_to_int8_rn(floatTmp*s_sum); + if (qual_padded){ + output[inIdx] = qual ? float_to_int8_rn(floatTmp*s_sum) : static_cast(0); } } } -//int_buf are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 +//input are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 //grid = (seq_len, batch_size, head_num) //block.x = max(32, (seq_len/2 + 31)/32*32) //for int32_t I; int8 O; //for seq_len in (32, 64] template __global__ -void softmax_COL32_LE64(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_mask, const int batch_size, +void softmax_COL32_LE64(int8_t* output, const int32_t* input, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const float scalar1a, const float *scalar1b, const float *scalar1c, const float *amax_ptr, const int head_num_x_seq_len, const int seq_len_x_seq_len) { @@ -1985,7 +2944,7 @@ void softmax_COL32_LE64(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_m int mask_id; int threadIdx2 = threadIdx.x << 1; - char2* buf2Ptr = (char2 *)qk_buf_; + char2* buf2Ptr = (char2 *)output; bool qual = threadIdx2 < seq_len; for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ @@ -2004,8 +2963,8 @@ void softmax_COL32_LE64(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_m float2 floatTmp2 = {0.0f, 0.0f}; if (qual){ - floatTmp2.x = static_cast(__ldg(int_buf + inIdx)) * scalar1; - floatTmp2.y = static_cast(__ldg(int_buf + inIdx + 1)) * scalar1; + floatTmp2.x = static_cast(__ldg(input + inIdx)) * scalar1; + floatTmp2.y = static_cast(__ldg(input + inIdx + 1)) * scalar1; } float mask_val, max_val; @@ -2058,40 +3017,44 @@ void softmax_COL32_LE64(int8_t* qk_buf_, const int32_t* int_buf, const T* attr_m } } -//int_buf are a series of sub-matrixes of m = seq_len, n = seq_len, CUBLASLT_ORDER_COL32 +//input are a series of sub-matrixes of m = seq_len, n = seq_len_padded, CUBLASLT_ORDER_COL32 +//seq_len_padded = (seq_len+31)/32*32 //grid = (seq_len, batch_size, head_num) -//block.x = max(32, (seq_len/2 + 31)/32*32) +//block.x = 32 //for int8_t IO //for seq_len in (32, 64] template __global__ -void softmax_COL32_LE64(int8_t* qk_buf_, const int8_t* int_buf, const T* attr_mask, const int batch_size, - const int head_num, const int seq_len, const float scalar1a, const float *scalar1b, - const float *amax_ptr, const int head_num_x_seq_len, const int seq_len_x_seq_len) +void softmax_COL32_LE64_varlen(int8_t* output, const int8_t* input, const T* attr_mask, const int batch_size, + const int head_num, const int seq_len, const int seq_len_padded, + const float scalar1a, const float *scalar1b, const float *amax_ptr, + const int seq_len_x_seq_len, const int seq_len_x_seq_len_padded) { const float amax = __ldg(amax_ptr); const float scalar1 = scalar1a * __ldg(scalar1b); int mask_id; int threadIdx2 = threadIdx.x << 1; - char2* buf2Ptr = (char2 *)qk_buf_; - const char2* inBuf2Ptr = (const char2 *)int_buf; + char2* buf2Ptr = (char2 *)output; + const char2* inBuf2Ptr = (const char2 *)input; - bool qual = threadIdx2 < seq_len; + const bool qual = threadIdx2 < seq_len; + const bool qual_padded = threadIdx2 < seq_len_padded; for (int seq_id = blockIdx.x ; seq_id < seq_len ; seq_id += gridDim.x){ char2 tmp2 = {0, 0}; - int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len) + + int inIdx = ((blockIdx.y * head_num + blockIdx.z) * (seq_len_x_seq_len_padded) + (threadIdx2 & 0xffffffe0) * seq_len + (seq_id << 5) + (threadIdx2 & 31)) >> 1; - //set softmax of padding word to 0 + //set softmax of padding word in rows to 0 float mask_in_seq = static_cast(__ldg(attr_mask+(blockIdx.y*seq_len_x_seq_len + seq_id))); if (mask_in_seq < 0.1f){ - if (qual) + if (qual_padded) buf2Ptr[inIdx] = tmp2; continue; } + //set softmax of padding word in cols to 0 float2 floatTmp2 = {0.0f, 0.0f}; if (qual){ tmp2 = __ldg(inBuf2Ptr + inIdx); @@ -2117,7 +3080,7 @@ void softmax_COL32_LE64(int8_t* qk_buf_, const int8_t* int_buf, const T* attr_ma max_val = fmaxf(floatTmp2.x, floatTmp2.y); } - max_val = blockDim.x <= 32 ? warpReduceMax(max_val) : blockReduceMax(max_val); + max_val = warpReduceMax(max_val); if (threadIdx.x == 0){ s_max = max_val; @@ -2133,7 +3096,7 @@ void softmax_COL32_LE64(int8_t* qk_buf_, const int8_t* int_buf, const T* attr_ma sum_val += floatTmp2.y; } - sum_val = blockDim.x <= 32 ? warpReduceSum(sum_val) : blockReduceSum(sum_val); + sum_val = warpReduceSum(sum_val); if (threadIdx.x == 0){ s_sum = __fdividef(127.0f, (sum_val + 1e-6f)); @@ -2141,15 +3104,121 @@ void softmax_COL32_LE64(int8_t* qk_buf_, const int8_t* int_buf, const T* attr_ma } __syncthreads(); - if (qual){ - tmp2.x = float_to_int8_rn(floatTmp2.x*s_sum); - tmp2.y = float_to_int8_rn(floatTmp2.y*s_sum); + if (qual_padded){ + tmp2.x = qual ? float_to_int8_rn(floatTmp2.x*s_sum) : static_cast(0); + tmp2.y = qual ? float_to_int8_rn(floatTmp2.y*s_sum) : static_cast(0); buf2Ptr[inIdx] = tmp2; } } } +template +void softmax_COL32_kernelLauncher(int8_t* output, const int32_t* input, const T* attr_mask, + const int batch_size, const int head_num, const int seq_len, + const float scalar1a, const float *scalar1b, const float *scalar1c, + const float *amax_ptr, cudaStream_t stream) +{ + dim3 grid, block; + grid.x = seq_len; + grid.y = batch_size; + grid.z = head_num; + + if (seq_len <= 32){ + if (batch_size * head_num > 960) + grid.x = ceil(float(seq_len)/32.0f); + block.x = (seq_len + 31)/32*32; + softmax_COL32_LE32<<>>(output, input, attr_mask, batch_size, head_num, + seq_len, scalar1a, scalar1b, scalar1c, + amax_ptr, seq_len*head_num, seq_len*seq_len); + } + else if (seq_len <= 64){ + assert(seq_len % 2 == 0); + block.x = (seq_len/2 + 31)/32*32; + if (batch_size * head_num > 960) + grid.x = ceil(float(seq_len)/32.0f); + softmax_COL32_LE64<<>>(output, input, attr_mask, batch_size, head_num, + seq_len, scalar1a, scalar1b, scalar1c, + amax_ptr, seq_len*head_num, seq_len*seq_len); + } + else + { + assert(seq_len % 4 == 0); + block.x = (seq_len/4 + 31)/32*32; + softmax_COL32<<>>(output, input, attr_mask, batch_size, head_num, + seq_len, scalar1a, scalar1b, scalar1c, + amax_ptr, seq_len*head_num, seq_len*seq_len); + } +} + +template +void softmax_COL32_kernelLauncher(int8_t* output, const int32_t* input, const float* attr_mask, + const int batch_size, const int head_num, const int seq_len, + const float scalar1a, const float *scalar1b, const float *scalar1c, + const float *amax_ptr, cudaStream_t stream); + +template +void softmax_COL32_kernelLauncher(int8_t* output, const int32_t* input, const half* attr_mask, + const int batch_size, const int head_num, const int seq_len, + const float scalar1a, const float *scalar1b, const float *scalar1c, + const float *amax_ptr, cudaStream_t stream); + +template +void softmax_COL32_kernelLauncher(int8_t* output, const int8_t* input, const T* attr_mask, + const int batch_size, const int head_num, const int seq_len, + const float scalar1a, const float *scalar1b, const float *amax_ptr, + cudaStream_t stream) +{ + dim3 grid, block; + grid.x = seq_len; + grid.y = batch_size; + grid.z = head_num; + const int seq_len_padded = (seq_len + 31)/32*32; + + if (seq_len <= 32){ + if (batch_size * head_num > 960) + grid.x = ceil(float(seq_len)/32.0f); + block.x = seq_len_padded; + softmax_COL32_LE32_varlen<<>>(output, input, attr_mask, batch_size, head_num, + seq_len, seq_len_padded, scalar1a, scalar1b, amax_ptr, + seq_len*seq_len, seq_len*seq_len_padded); + } + else if (seq_len <= 64 && (seq_len % 2 == 0)){ + block.x = 32; + if (batch_size * head_num > 960) + grid.x = ceil(float(seq_len)/32.0f); + softmax_COL32_LE64_varlen<<>>(output, input, attr_mask, batch_size, head_num, + seq_len, seq_len_padded, scalar1a, scalar1b, amax_ptr, + seq_len*seq_len, seq_len*seq_len_padded); + } + else if (seq_len > 64 && (seq_len % 4 == 0)) + { + block.x = (seq_len_padded/4 + 31)/32*32; + softmax_COL32_varlen<<>>(output, input, attr_mask, batch_size, head_num, + seq_len, seq_len_padded, scalar1a, scalar1b, amax_ptr, + seq_len*seq_len, seq_len*seq_len_padded); + } + else + { + block.x = (seq_len_padded + 31)/32*32; + softmax_COL32_perElement_varlen<<>>(output, input, attr_mask, batch_size, head_num, + seq_len, seq_len_padded, scalar1a, scalar1b, amax_ptr, + seq_len*seq_len, seq_len*seq_len_padded); + } +} + +template +void softmax_COL32_kernelLauncher(int8_t* output, const int8_t* input, const float* attr_mask, + const int batch_size, const int head_num, const int seq_len, + const float scalar1a, const float *scalar1b, + const float *amax_ptr, cudaStream_t stream); + +template +void softmax_COL32_kernelLauncher(int8_t* output, const int8_t* input, const half* attr_mask, + const int batch_size, const int head_num, const int seq_len, + const float scalar1a, const float *scalar1b, + const float *amax_ptr, cudaStream_t stream); + template __global__ @@ -2181,6 +3250,32 @@ void transpose(half* src, half* dst, dst_ptr[target_id] = src_ptr[tid]; } +template +void transpose_kernelLauncher(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head, cudaStream_t stream) +{ + dim3 grid, block; + if (std::is_same::value) + { + const int seq_per_block = 1; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head; + transpose<<>>(src, dst, batch_size, seq_len, head_num, size_per_head); + } + else + { + const int seq_per_block = 4; + grid.x = batch_size * head_num * seq_len / seq_per_block; + block.x = seq_per_block * size_per_head / 2; + assert(grid.x * seq_per_block == batch_size * head_num * seq_len); + transpose<<>>(src, dst, batch_size, seq_len, head_num, size_per_head / 2); + } +} + +template +void transpose_kernelLauncher(float* src, float* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head, cudaStream_t stream); + +template +void transpose_kernelLauncher(half* src, half* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head, cudaStream_t stream); template __global__ @@ -2203,6 +3298,38 @@ void transpose_rebuild_padding(T* src, T* dst, const int batch_size, const int s head_id * seq_len * size_per_head + src_seq_id * size_per_head + hidden_id]; } +template +void transpose_rebuild_padding_kernelLauncher(T* src, T* dst, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int* mask_offset, cudaStream_t stream) +{ + int k = head_num * size_per_head; + if (std::is_same::value) + { + transpose_rebuild_padding<<>>(src, dst, + batch_size, seq_len, head_num, size_per_head, mask_offset); + } + else + { + transpose_rebuild_padding<<>>( + (half2*)src, (half2*)dst, + batch_size, seq_len, head_num, size_per_head / 2, mask_offset); + } +} + +template +void transpose_rebuild_padding_kernelLauncher(float* src, float* dst, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int* mask_offset, cudaStream_t stream); + +template +void transpose_rebuild_padding_kernelLauncher(half* src, half* dst, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int* mask_offset, cudaStream_t stream); + template __global__ void rebuild_sequence_length_padding(const T* src, T* tgt, const int* mask_offset, @@ -2219,405 +3346,6 @@ __global__ void rebuild_sequence_length_padding(const T* src, T* tgt, } } -template -void OpenMultiHeadAttention::multiHeadAttr_nofuse_kernelLauncher( - cudaStream_t stream, - cublasHandle_t cublas_handle, - cublasLtHandle_t cublaslt_handle, - DataType_* Q, - const DataType_* bias_Q, - DataType_* K, - const DataType_* bias_K, - DataType_* V, - const DataType_* bias_V, - const DataType_* attr_mask, - DataType_* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int int8_mode_, - const DataType_ scalar) -{ - const int k = head_num * size_per_head; - - dim3 grid; - dim3 block; - - - if (int8_mode_ != 0) - { - //var for int8 - const float*Qbias_amax_ptr, *Kbias_amax_ptr, *Vbias_amax_ptr, *bmm1_amax_ptr, *Softmax_amax_ptr, *bmm2_amax_ptr, *in_amax_ptr, *Q_aftergemm_amax_ptr, *K_aftergemm_amax_ptr, *V_aftergemm_amax_ptr; - Qbias_amax_ptr = param_.amaxList + 8; - Kbias_amax_ptr = param_.amaxList + 16; - Vbias_amax_ptr = param_.amaxList + 24; - Softmax_amax_ptr = param_.amaxList + 32; - bmm2_amax_ptr = param_.amaxList + 36; - Q_aftergemm_amax_ptr = param_.amaxList + 4; - K_aftergemm_amax_ptr = param_.amaxList + 12; - V_aftergemm_amax_ptr = param_.amaxList + 20; - bmm1_amax_ptr = param_.amaxList + 28; - in_amax_ptr = param_.amaxList; - - assert(seq_len % COL32_ == 0 && size_per_head%COL32_ == 0); - - if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len){ - if (int8_mode_ == 1) - { - add_QK_bias_transform<<>>((int8_t*)q_buf_, (int8_t*)k_buf_, (const int32_t*) Q, bias_Q, (const int32_t*) K, - bias_K, batch_size * seq_len, batch_size, seq_len, head_num, size_per_head, - seq_len*size_per_head, query_weight_amax_list, in_amax_ptr+2, key_weight_amax_list, - in_amax_ptr+2, Qbias_amax_ptr+3, Kbias_amax_ptr+3, use_ORDER_COL32_2R_4R4_); - add_V_bias_transform<<>>((int8_t*)v_buf_, (const int32_t *)V, bias_V, batch_size, seq_len, - head_num, size_per_head, seq_len*size_per_head, value_weight_amax_list, - in_amax_ptr+2, Vbias_amax_ptr+3, use_ORDER_COL32_2R_4R4_); - } - else - { - add_QK_bias_transform<<>>((int8_t*)q_buf_, (int8_t*)k_buf_, (const int8_t*) Q, bias_Q, (const int8_t*) K, - bias_K, batch_size * seq_len, batch_size, seq_len, head_num, size_per_head, - seq_len*size_per_head, Q_aftergemm_amax_ptr+1, K_aftergemm_amax_ptr+1, - Qbias_amax_ptr+3, Kbias_amax_ptr+3, use_ORDER_COL32_2R_4R4_); - add_V_bias_transform<<>>((int8_t*)v_buf_, (const int8_t *)V, bias_V, batch_size, seq_len, - head_num, size_per_head, seq_len*size_per_head, - V_aftergemm_amax_ptr+1, Vbias_amax_ptr+3, use_ORDER_COL32_2R_4R4_); - } - } - else{ - cudaMemset(sequence_id_map_, -1, batch_size * seq_len * sizeof(int)); - mappingRemovePaddingData<<>>(sequence_id_map_, param_.sequence_id_offset, param_.valid_word_num); - if (int8_mode_ == 1) - { - add_QK_bias_transform_rebuild_padding<<>>((int8_t*)q_buf_, (int8_t*)k_buf_, (const int32_t*) Q, bias_Q, - (const int32_t*) K, bias_K, param_.sequence_id_offset, param_.valid_word_num, - batch_size * seq_len, batch_size, seq_len, head_num, size_per_head, seq_len*size_per_head, - query_weight_amax_list, in_amax_ptr+2, key_weight_amax_list, in_amax_ptr+2, - Qbias_amax_ptr+3, Kbias_amax_ptr+3, use_ORDER_COL32_2R_4R4_); - - add_V_bias_transform_rebuild_padding<<>>((int8_t*)v_buf_, (const int32_t *)V, bias_V, sequence_id_map_, - param_.valid_word_num, batch_size, seq_len, head_num, - size_per_head, seq_len*size_per_head, value_weight_amax_list, - in_amax_ptr+2, Vbias_amax_ptr+3, use_ORDER_COL32_2R_4R4_); - } - else - { - add_QK_bias_transform_rebuild_padding<<>>((int8_t*)q_buf_, (int8_t*)k_buf_, (const int8_t*) Q, bias_Q, - (const int8_t*) K, bias_K, param_.sequence_id_offset, param_.valid_word_num, - batch_size * seq_len, batch_size, seq_len, head_num, size_per_head, seq_len*size_per_head, - Q_aftergemm_amax_ptr+1, K_aftergemm_amax_ptr+1, - Qbias_amax_ptr+3, Kbias_amax_ptr+3, use_ORDER_COL32_2R_4R4_); - - add_V_bias_transform_rebuild_padding<<>>((int8_t*)v_buf_, (const int8_t *)V, bias_V, sequence_id_map_, - param_.valid_word_num, batch_size, seq_len, head_num, - size_per_head, seq_len*size_per_head, - V_aftergemm_amax_ptr+1, Vbias_amax_ptr+3, use_ORDER_COL32_2R_4R4_); - } - } - - int batchCount = batch_size * head_num; - grid.x = seq_len; - grid.y = batch_size; - grid.z = head_num; - - if (int8_mode_ == 1) - { - cublasLtMM_withAlgo(qk_int_buf_, batchCount, seq_len, seq_len, size_per_head, - size_per_head*seq_len, size_per_head*seq_len, seq_len*seq_len, - (int8_t*)q_buf_, (int8_t*)k_buf_, cublaslt_handle, stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_, true); - - if (seq_len <= 32){ - if (batch_size * head_num > 960) - grid.x = ceil(float(seq_len)/32.0f); - block.x = (seq_len + 31)/32*32; - softmax_COL32_LE32<<>>((int8_t*)qk_buf_, qk_int_buf_, attr_mask, batch_size, head_num, - seq_len, float(scalar), Qbias_amax_ptr + 1, Kbias_amax_ptr + 1, - Softmax_amax_ptr, seq_len*head_num, seq_len*seq_len); - } - else if (seq_len <= 64){ - assert(seq_len % 2 == 0); - block.x = (seq_len/2 + 31)/32*32; - if (batch_size * head_num > 960) - grid.x = ceil(float(seq_len)/32.0f); - softmax_COL32_LE64<<>>((int8_t*)qk_buf_, qk_int_buf_, attr_mask, batch_size, head_num, - seq_len, float(scalar), Qbias_amax_ptr + 1, Kbias_amax_ptr + 1, - Softmax_amax_ptr, seq_len*head_num, seq_len*seq_len); - } - else - { - assert(seq_len % 4 == 0); - block.x = (seq_len/4 + 31)/32*32; - softmax_COL32<<>>((int8_t*)qk_buf_, qk_int_buf_, attr_mask, batch_size, head_num, - seq_len, float(scalar), Qbias_amax_ptr + 1, Kbias_amax_ptr + 1, - Softmax_amax_ptr, seq_len*head_num, seq_len*seq_len); - } - - cublasLtMM_withAlgo(transpose_dst_int_buf_, batchCount, seq_len, size_per_head, seq_len, - seq_len*seq_len, size_per_head*seq_len, size_per_head*seq_len, (int8_t*)qk_buf_, - (int8_t*)v_buf_, cublaslt_handle, stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_, true); - - if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) - { - transpose_COL32_kernelLauncher((int8_t*)dst, (const int*)transpose_dst_int_buf_, batch_size, seq_len, head_num, - size_per_head, Vbias_amax_ptr+1, Softmax_amax_ptr+1, bmm2_amax_ptr+3, stream); - } - else - { - transpose_COL32_rebuild_padding_kernelLauncher((int8_t*)dst, (const int*)transpose_dst_int_buf_, sequence_id_map_, - param_.valid_word_num, batch_size, seq_len, head_num, size_per_head, - Vbias_amax_ptr+1, Softmax_amax_ptr+1, bmm2_amax_ptr+3, stream); - } - - } - else - { - cublasLtMM_withAlgo_int8IO((int8_t*)qk_int_buf_, batchCount, seq_len, seq_len, size_per_head, - size_per_head*seq_len, size_per_head*seq_len, seq_len*seq_len, - param_.int8O_gemm_deQ_scale_list[3], - (int8_t*)q_buf_, (int8_t*)k_buf_, cublaslt_handle, stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_, true); - - if (seq_len <= 32){ - if (batch_size * head_num > 960) - grid.x = ceil(float(seq_len)/32.0f); - block.x = (seq_len + 31)/32*32; - softmax_COL32_LE32<<>>((int8_t*)qk_buf_, (int8_t*)qk_int_buf_, attr_mask, batch_size, head_num, - seq_len, float(scalar), bmm1_amax_ptr + 1, - Softmax_amax_ptr, seq_len*head_num, seq_len*seq_len); - } - else if (seq_len <= 64){ - assert(seq_len % 2 == 0); - block.x = (seq_len/2 + 31)/32*32; - if (batch_size * head_num > 960) - grid.x = ceil(float(seq_len)/32.0f); - softmax_COL32_LE64<<>>((int8_t*)qk_buf_, (int8_t*)qk_int_buf_, attr_mask, batch_size, head_num, - seq_len, float(scalar), bmm1_amax_ptr + 1, - Softmax_amax_ptr, seq_len*head_num, seq_len*seq_len); - } - else - { - assert(seq_len % 4 == 0); - block.x = (seq_len/4 + 31)/32*32; - softmax_COL32<<>>((int8_t*)qk_buf_, (int8_t*)qk_int_buf_, attr_mask, batch_size, head_num, - seq_len, float(scalar), bmm1_amax_ptr + 1, - Softmax_amax_ptr, seq_len*head_num, seq_len*seq_len); - } - - cublasLtMM_withAlgo_int8IO((int8_t*)transpose_dst_int_buf_, batchCount, seq_len, size_per_head, seq_len, - seq_len*seq_len, size_per_head*seq_len, size_per_head*seq_len, param_.int8O_gemm_deQ_scale_list[4], (int8_t*)qk_buf_, - (int8_t*)v_buf_, cublaslt_handle, stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_, true); - if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) - { - transpose_COL32_kernelLauncher((int8_t*)dst, (const int8_t*)transpose_dst_int_buf_, batch_size, seq_len, head_num, - size_per_head, bmm2_amax_ptr+1, bmm2_amax_ptr+3, stream); - } - else - { - transpose_COL32_rebuild_padding_kernelLauncher((int8_t*)dst, (const int8_t*)transpose_dst_int_buf_, sequence_id_map_, - param_.valid_word_num, batch_size, seq_len, head_num, size_per_head, - bmm2_amax_ptr+1, - bmm2_amax_ptr+3, stream); - } - } - } - //FP32/FP16 - else{ - if(OpType_ == OperationType::FP32) - { - if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) - { - const int m = batch_size * seq_len; - const int word_per_block = 1; - assert(k <= 1024); - assert(m / word_per_block * 3 <= 65536); - - dim3 grid(m / word_per_block * 3); - dim3 block(k); - add_QKV_bias<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, v_buf_, - batch_size, seq_len, head_num, size_per_head, word_per_block); - } - else - { - add_QKV_bias_rebuild_padding<<>>(Q, bias_Q, K, bias_K, - V, bias_V, q_buf_, k_buf_, v_buf_, - batch_size, seq_len, head_num, size_per_head, param_.sequence_id_offset); - } - } - else - { - if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) - { - const int word_per_block = 1; - grid.x = batch_size * seq_len / word_per_block; - block.x = head_num * size_per_head * word_per_block / 2; - - assert(block.x <= 1024); - - add_QKV_bias<<>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, - v_buf_, batch_size, seq_len, head_num, size_per_head / 2, word_per_block); - } - else - { - add_QKV_bias_rebuild_padding<<>>((half2*)Q, (const half2*)bias_Q, - (half2*)K, (const half2*)bias_K, (half2*)V, (const half2*)bias_V, - (half2*)q_buf_, (half2*)k_buf_, (half2*)v_buf_, - batch_size, seq_len, head_num, size_per_head / 2, param_.sequence_id_offset); - } - } - - DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; - - check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle, - CUBLAS_OP_T, CUBLAS_OP_N, - seq_len, seq_len, size_per_head, - &alpha, - k_buf_, AType_, size_per_head, seq_len * size_per_head, - q_buf_, BType_, size_per_head, seq_len * size_per_head, - &beta, - qk_buf_, CType_, seq_len, seq_len * seq_len, - batch_size * head_num, - computeType_, - static_cast(cublasAlgo_[1]))); - - //deal with odd seq_len - if (seq_len % 2 != 0){ - if(seq_len <= 32) - block.x = 32; - else if(seq_len > 32 && seq_len <= 64) - block.x = 64; - else if(seq_len > 64 && seq_len <= 128) - block.x = 128; - else if(seq_len > 128 && seq_len <= 256) - block.x = 256; - else if(seq_len > 256 && seq_len <= 512) - block.x = 512; - else - block.x = 1024; - - if(batch_size * head_num <= 120) - { - grid.x = batch_size * head_num * seq_len; - softmax_kernel_v2<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); - } - else - { - grid.x = batch_size * head_num; - softmax_kernel<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); - } - } - //deal with even seq_len - else{ - grid.x = seq_len; - if (batch_size * head_num > 360) - grid.x = ceil(float(seq_len)/32.0f); - grid.y = batch_size; - grid.z = head_num; - if (seq_len <= 32){ - block.x = 32; - softmax_kernel_v3_LE32<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); - } - else{ - if (OpType_ == OperationType::FP16){ - block.x = (seq_len/2 + 31)/32*32; - softmax_kernel_v3<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); - } - else{ - block.x = (seq_len + 31)/32*32; - softmax_kernel_v3<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scalar); - } - } - grid.x = grid.y = grid.z = 1; - } - - check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - size_per_head, seq_len, seq_len, - &alpha, - v_buf_, AType_, size_per_head, seq_len * size_per_head, - qk_buf_, BType_, seq_len, seq_len * seq_len, - &beta, - transpose_dst_, CType_, size_per_head, seq_len * size_per_head, - batch_size * head_num, - computeType_, - static_cast(cublasAlgo_[2]))); - - /* for half2 only */ - if(OpType_ == OperationType::FP16) - { - if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) - { - const int seq_per_block = 4; - grid.x = batch_size * head_num * seq_len / seq_per_block; - block.x = seq_per_block * size_per_head / 2; - - assert(grid.x * seq_per_block == batch_size * head_num * seq_len); - - transpose<<>>(transpose_dst_, dst, - batch_size, seq_len, head_num, size_per_head / 2); - } - else - { - transpose_rebuild_padding<<>>( - (half2*)transpose_dst_, (half2*)dst, - batch_size, seq_len, head_num, size_per_head / 2, param_.sequence_id_offset); - } - } - else - { - if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) - { - const int seq_per_block = 1; - grid.x = batch_size * head_num * seq_len / seq_per_block; - block.x = seq_per_block * size_per_head; - transpose<<>>(transpose_dst_, dst, - batch_size, seq_len, head_num, size_per_head); - } - else - { - transpose_rebuild_padding<<>>(transpose_dst_, dst, - batch_size, seq_len, head_num, size_per_head, param_.sequence_id_offset); - } - } - } -} - -template void OpenMultiHeadAttention::multiHeadAttr_nofuse_kernelLauncher( - cudaStream_t stream, - cublasHandle_t handle, - cublasLtHandle_t cublaslt_handle, - float* Q, - const float* bias_Q, - float* K, - const float* bias_K, - float* V, - const float* bias_V, - const float* attr_mask, - float* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int int8_mode_, - const float scalar); - -template void OpenMultiHeadAttention::multiHeadAttr_nofuse_kernelLauncher( - cudaStream_t stream, - cublasHandle_t handle, - cublasLtHandle_t cublaslt_handle, - half* Q, - const half* bias_Q, - half* K, - const half* bias_K, - half* V, - const half* bias_V, - const half* attr_mask, - half* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int int8_mode_, - const half scalar); - template void OpenMultiHeadAttention::trt_add_QKV_bias_kernelLauncher( const float* bias_Q, const float* bias_K, @@ -2628,8 +3356,63 @@ template void OpenMultiHeadAttention::trt_add_QKV_bias_kern const half* bias_K, const half* bias_V); -template void OpenMultiHeadAttention::fused_multiHeadAttr_kernelLauncher(); -template void OpenMultiHeadAttention::fused_multiHeadAttr_kernelLauncher(); +template void OpenMultiHeadAttention::trt_add_QKV_bias_COL32_int8IO_kernelLauncher( + int8_t* output, + const int8_t* Q, + const float* bias_Q, + const float* bias_K, + const float* bias_V, + const float *q_input_deQFactor_ptr, + const float *k_input_deQFactor_ptr, + const float *v_input_deQFactor_ptr, + const float qkv_output_scale); + +template void OpenMultiHeadAttention::trt_add_QKV_bias_COL32_int8IO_kernelLauncher( + int8_t* output, + const int8_t* Q, + const half* bias_Q, + const half* bias_K, + const half* bias_V, + const float *q_input_deQFactor_ptr, + const float *k_input_deQFactor_ptr, + const float *v_input_deQFactor_ptr, + const float qkv_output_scale); + +template void OpenMultiHeadAttention::trt_add_QKV_bias_COL32_int32Iint8O_kernelLauncher( + int8_t* output, + const int32_t* Q, + const float* bias_Q, + const float* bias_K, + const float* bias_V, + const float *input_deQFactor_div127_ptr, + const float * q_weight_amax, + const float * k_weight_amax, + const float * v_weight_amax, + const float qkv_output_scale); + +template void OpenMultiHeadAttention::trt_add_QKV_bias_COL32_int32Iint8O_kernelLauncher( + int8_t* output, + const int32_t* Q, + const half* bias_Q, + const half* bias_K, + const half* bias_V, + const float *input_deQFactor_div127_ptr, + const float * q_weight_amax, + const float * k_weight_amax, + const float * v_weight_amax, + const float qkv_output_scale); + +template void OpenMultiHeadAttention::fused_multiHeadAttr_kernelLauncher(const int S); +template void OpenMultiHeadAttention::fused_multiHeadAttr_kernelLauncher(const int S); + +template void OpenMultiHeadAttention::int8_fused_multiHeadAttr_kernelLauncher( + const void* Q, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, const float *v_deQFactor_ptr, + const float mScaleQkv, const int S); +template void OpenMultiHeadAttention::int8_fused_multiHeadAttr_kernelLauncher( + const void* Q, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, const float *v_deQFactor_ptr, + const float mScaleQkv, const int S); __global__ void trt_add_QKV_bias_2(const half2* Q, const half2* bias_Q, @@ -2661,7 +3444,7 @@ void trt_add_QKV_bias_2(const half2* Q, const half2* bias_Q, size_id] = V[ seq_id * blockDim.x + threadIdx.x] + bias_V[threadIdx.x]; } -void add_QKV_bias_transpose_kernelLauncher( +void trt_add_QKV_bias_transpose_debug_kernelLauncher( const half* query_buf, const half* bias_Q, const half* key_buf, const half* bias_K, const half* value_buf, const half* bias_V, diff --git a/fastertransformer/cuda/open_attention.h b/fastertransformer/cuda/open_attention.h index 9b3e356bd..89a393e50 100644 --- a/fastertransformer/cuda/open_attention.h +++ b/fastertransformer/cuda/open_attention.h @@ -19,12 +19,15 @@ #pragma once -#include "fastertransformer/allocator.h" +#include "fastertransformer/utils/allocator.h" #include "fastertransformer/cuda/multi_head_attention.h" +#include "fastertransformer/cuda/attention_kernels.cuh" +#include "fastertransformer/cuda/transformer_kernels.cuh" #include "fastertransformer/cuda/cuda_kernels.h" #include "fastertransformer/cuda/cuda_int8_kernels.h" #include "fastertransformer/gemm_test/encoder_gemm_func.h" #include "fastertransformer/gemm_test/encoder_igemm_func.h" +#include "fastertransformer/utils/functions.h" #include #include #include @@ -33,6 +36,120 @@ namespace fastertransformer{ namespace cuda{ +void trt_add_QKV_bias_transpose_debug_kernelLauncher( + const half* query_buf, const half* bias_Q, + const half* key_buf, const half* bias_K, + const half* value_buf, const half* bias_V, + half* context_buf, + const int valid_word_num, + const int head_num, const int size_per_head, + cudaStream_t stream); // Used to debug the trt_add_QKV_bias kernel + +template +void add_QK_bias_transform_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int32_t* Q, const T* bias_Q, + const int32_t* K, const T* bias_K, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float * q_weight_amax, const float *q_input_deQFactor_div127_ptr, + const float * k_weight_amax, const float *k_input_deQFactor_div127_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_QK_bias_transform_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int8_t* Q, const T* bias_Q, + const int8_t* K, const T* bias_K, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float *q_input_deQFactor_ptr, const float *k_input_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_V_bias_transform_kernelLauncher(int8_t *v_buf, const int32_t *V, const T *V_bias, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float* weight_amax, + const float *input_deQFactor_div127_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_V_bias_transform_kernelLauncher(int8_t *v_buf, const int8_t *V, const T *V_bias, const int batch_size, + const int seq_len, const int head_num, const int size_per_head, + const float *input_deQFactor_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +void mappingRemovePaddingData_kernelLauncher(const int batch_size, const int seq_len, + const int valid_word_num, int *mapping, + const int* sequence_id_offset, cudaStream_t stream); + +template +void add_QK_bias_transform_rebuild_padding_kernelLauncher(int8_t *q_buf, int8_t *k_buf, + const int32_t* Q, const T* bias_Q, + const int32_t* K, const T* bias_K, + const int* sequence_id_offset, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float * q_weight_amax, + const float *q_input_deQFactor_div127_ptr, + const float * k_weight_amax, + const float *k_input_deQFactor_div127_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_QK_bias_transform_rebuild_padding_kernelLauncher(int8_t *q_buf, int8_t *k_buf, const int8_t* Q, const T* bias_Q, + const int8_t* K, const T* bias_K, const int* sequence_id_offset, + const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, + const float *q_output_scale_ptr, const float *k_output_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_V_bias_transform_rebuild_padding_kernelLauncher(int8_t *v_buf, const int32_t *V, const T *V_bias, + const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float* weight_amax, + const float *input_deQFactor_div127_ptr, + const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void add_V_bias_transform_rebuild_padding_kernelLauncher(int8_t *v_buf, const int8_t *V, const T *V_bias, + const int* sequence_id_map, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const float *deQFactor_ptr, const float *out_scale_ptr, + bool use_ORDER_COL32_2R_4R4, cudaStream_t stream); + +template +void softmax_COL32_kernelLauncher(int8_t* qk_buf, const int32_t* qk_int_buf, const T* attr_mask, + const int batch_size, const int head_num, const int seq_len, + const float scalar1a, const float *scalar1b, const float *scalar1c, + const float *amax_ptr, cudaStream_t stream); + +template +void softmax_COL32_kernelLauncher(int8_t* qk_buf, const int8_t* qk_int_buf, const T* attr_mask, + const int batch_size, const int head_num, const int seq_len, + const float scalar1a, const float *scalar1b, const float *amax_ptr, + cudaStream_t stream); + +template +void add_QKV_bias_rebuild_padding_kernelLauncher(T* Q, const T* bias_Q, T* K, const T* bias_K, + T* V, const T* bias_V, T* q_buf, T* k_buf, T* v_buf, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, const int valid_word_num, + const int* mask_offset, cudaStream_t stream); + +template +void transpose_kernelLauncher(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head, cudaStream_t stream); + +template +void transpose_rebuild_padding_kernelLauncher(T* src, T* dst, const int valid_word_num, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int* mask_offset, cudaStream_t stream); + template class OpenMultiHeadAttentionTraits; @@ -42,6 +159,7 @@ class OpenMultiHeadAttentionTraits public: typedef float DataType; static cudaDataType_t const computeType = CUDA_R_32F; + static cudaDataType_t const scaleType = CUDA_R_32F; static cudaDataType_t const AType = CUDA_R_32F; static cudaDataType_t const BType = CUDA_R_32F; static cudaDataType_t const CType = CUDA_R_32F; @@ -54,6 +172,7 @@ class OpenMultiHeadAttentionTraits public: typedef half DataType; static cudaDataType_t const computeType = CUDA_R_16F; + static cudaDataType_t const scaleType = CUDA_R_16F; static cudaDataType_t const AType = CUDA_R_16F; static cudaDataType_t const BType = CUDA_R_16F; static cudaDataType_t const CType = CUDA_R_16F; @@ -76,10 +195,10 @@ class OpenMultiHeadAttention: IMultiHeadAttention IAllocator* allocator_ = NULL; MultiHeadInitParam param_; - int cublasAlgo_[4]; - std::map cublasLtAlgoMap_; - std::map cublasAlgoMap_; - std::map isFuseQKVMap_; + //algo for batch matrix multiplication in unfused mha + int cublasBmmAlgo_[2]; + std::map cublasAlgoMap_; + std::map parameterMap_; bool is_fuse_QKV_; DataType_* buf_ = NULL; @@ -96,19 +215,22 @@ class OpenMultiHeadAttention: IMultiHeadAttention DataType_** qkv_input_; DataType_** qkv_buf_; + void* cublas_workspace_; + void* trt_attn_workspace_; const float *query_weight_amax_list, *key_weight_amax_list, *value_weight_amax_list; + int sm_; int batch_size_; int from_seq_len_; int to_seq_len_; int head_num_; int size_per_head_; - int mSM_; //int8_mode == 0 -- not use int8 - //int8_mode == 1 -- use int8 without quantized residual - //int8_mode == 2 -- use int8 with quantized residual + //int8_mode == 1 -- use int8; without quantized residual; when (batch*seqLen >= 512) or (seqLen % 32 !=0 ), using trt fused mha + //int8_mode == 2 -- use int8; with quantized residual; with trt fused mha + //int8_mode == 3 -- use int8; with quantized residual; without trt fused mha int int8_mode_ = 0; int* sequence_id_map_; int* Q_int_buf_; @@ -119,140 +241,70 @@ class OpenMultiHeadAttention: IMultiHeadAttention bool allow_gemm_test_ = false; bool use_ORDER_COL32_2R_4R4_ = false; - std::unique_ptr dispatcher_fp16; + std::unique_ptr dispatcher_fp16, dispatcher_int8; public: - void readAlgoFromConfig(int int8_mode, int batch_size, int seq_len, int head_num, int size_per_head) + void getCublasBmmAlgoFromMap() { - if (int8_mode != 0) - { - cublasLtAlgoMap_.clear(); - FILE* fd = fopen(IGEMM_CONFIG, "r"); - if (fd == NULL) - return; - int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val, swizzle, reductionScheme, workspaceSize, stages; - int batch_size, seq_len, head_num, size_per_head; - while(fscanf(fd,"%d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d\n", &batch_size, &seq_len, &head_num, &size_per_head, &batchCount2, &m2, &n2, &k2, &algoId, &customOption, &tile, &splitK_val, &swizzle, &reductionScheme, &workspaceSize, &stages)!=EOF) - { - char mark[256]; - sprintf(mark, "%d_%d_%d_%d", batchCount2, m2, n2, k2); - std::string markStr(mark); - //workspaceSize should be zero - if (cublasLtAlgoMap_.find(markStr) == cublasLtAlgoMap_.end() && workspaceSize == 0) - { - cublasLtAlgoMap_[markStr].algoId = algoId; - cublasLtAlgoMap_[markStr].customOption = customOption; - cublasLtAlgoMap_[markStr].tile = tile; - cublasLtAlgoMap_[markStr].splitK_val = splitK_val; - cublasLtAlgoMap_[markStr].swizzle = swizzle; - cublasLtAlgoMap_[markStr].reductionScheme = reductionScheme; - cublasLtAlgoMap_[markStr].workspaceSize = workspaceSize; - cublasLtAlgoMap_[markStr].stages = stages; - } - } - fclose(fd); - } + int batchCount, m, n, k, dataType; + if (std::is_same::value) + dataType = HALF_DATATYPE; else + dataType = FLOAT_DATATYPE; + //bmm1 + batchCount = batch_size_*head_num_; + m = from_seq_len_; + n = from_seq_len_; + k = size_per_head_; + char mark[256]; + sprintf(mark, "%d_%d_%d_%d_%d", batchCount, n, m, k, dataType); + if (cublasAlgoMap_.find(mark) != cublasAlgoMap_.end()) { - cublasAlgoMap_.clear(); - isFuseQKVMap_.clear(); - FILE* fd = fopen(GEMM_CONFIG, "r"); - if (fd == NULL) - return; - int batchCount2, m2, n2, k2, is_fp16, algoId; - int batch_size, seq_len, head_num, size_per_head; - float runtime; - while(fscanf(fd,"%d %d %d %d ### %d %d %d %d %d %d %f\n", &batch_size, &seq_len, &head_num, &size_per_head, &batchCount2, &m2, &n2, &k2, &is_fp16, &algoId, &runtime)!=EOF){ - char mark[256]; - sprintf(mark, "%d_%d_%d_%d_%d", batchCount2, m2, n2, k2, is_fp16); - std::string markStr(mark); - cublasAlgoMap_[markStr] = algoId; - if (batchCount2 == 1 || batchCount2 == 3) - isFuseQKVMap_[markStr] = runtime; - } - fclose(fd); + cublasBmmAlgo_[0] = cublasAlgoMap_[mark].algoId; } - } - - void getBestAlgoFromMap(int batch_size, int seq_len, int head_num, int size_per_head, int is_fp16) - { - int m = batch_size * seq_len; - int n = head_num * size_per_head; - int k = n; - char mark[256]; - int foundAlgo = 0; - float split_time = -1.0, fuse_time = -1.0; - sprintf(mark, "1_%d_%d_%d_%d", m, n, k, is_fp16); - std::string markStr(mark); - if (cublasAlgoMap_.find(markStr) != cublasAlgoMap_.end()) + else { - cublasAlgo_[0] = cublasAlgoMap_[markStr]; - foundAlgo += 1; - if (isFuseQKVMap_.find(markStr) != isFuseQKVMap_.end()) - split_time = isFuseQKVMap_[markStr]; + cublasBmmAlgo_[0] = dataType == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP; } - if (foundAlgo == 1) + //bmm2 + batchCount = batch_size_*head_num_; + m = from_seq_len_; + n = size_per_head_; + k = from_seq_len_; + sprintf(mark, "%d_%d_%d_%d_%d", batchCount, n, m, k, dataType); + if (cublasAlgoMap_.find(mark) != cublasAlgoMap_.end()) { - sprintf(mark, "%d_%d_%d_%d_%d", batch_size_*head_num_, from_seq_len_, from_seq_len_, size_per_head_, is_fp16); - std::string markStr(mark); - if (cublasAlgoMap_.find(markStr) != cublasAlgoMap_.end()) - { - cublasAlgo_[1] = cublasAlgoMap_[markStr]; - foundAlgo += 1; - } - if (foundAlgo == 2) - { - sprintf(mark, "%d_%d_%d_%d_%d", batch_size_*head_num_, from_seq_len_, size_per_head_, from_seq_len_, is_fp16); - std::string markStr(mark); - if (cublasAlgoMap_.find(markStr) != cublasAlgoMap_.end()) - { - cublasAlgo_[2] = cublasAlgoMap_[markStr]; - foundAlgo += 1; - } - if (foundAlgo == 3) - { - sprintf(mark, "3_%d_%d_%d_%d", m, n, k, is_fp16); - std::string markStr(mark); - if (cublasAlgoMap_.find(markStr) != cublasAlgoMap_.end()) - { - cublasAlgo_[3] = cublasAlgoMap_[markStr]; - foundAlgo += 1; - if (isFuseQKVMap_.find(markStr) != isFuseQKVMap_.end()) - fuse_time = isFuseQKVMap_[markStr]; - } - } - } + cublasBmmAlgo_[1] = cublasAlgoMap_[mark].algoId; } - if(foundAlgo != 4) + else { - printf("[WARNING][OpenMultiHeadAttention] Loading GEMM algorithms error, using default GEMM algorithms!\n"); - if(is_fp16 == 0) - { - cublasAlgo_[0] = -1; - cublasAlgo_[1] = -1; - cublasAlgo_[2] = -1; - cublasAlgo_[3] = -1; - } - else - { - cublasAlgo_[0] = 99; - cublasAlgo_[1] = 99; - cublasAlgo_[2] = 99; - cublasAlgo_[3] = 99; - } - is_fuse_QKV_ = false; + cublasBmmAlgo_[1] = dataType == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP; } + } + + void judgeFusedQKV() + { + is_fuse_QKV_ = false; + int m, n, k, dataType; + if (std::is_same::value) + dataType = HALF_DATATYPE; else + dataType = FLOAT_DATATYPE; + + m = batch_size_*from_seq_len_; + n = head_num_*size_per_head_; + k = head_num_*size_per_head_; + char mark[256], mark2[256]; + sprintf(mark, "1_%d_%d_%d_%d", n, m, k, dataType); + sprintf(mark2, "3_%d_%d_%d_%d", n, m, k, dataType); + if ( + cublasAlgoMap_.find(mark) != cublasAlgoMap_.end() && + cublasAlgoMap_.find(mark2) != cublasAlgoMap_.end() && + 3*cublasAlgoMap_[mark].exec_time > cublasAlgoMap_[mark2].exec_time + ) { - is_fuse_QKV_ = false; - if ((split_time > 0) && - (fuse_time > 0) && - (3*split_time > fuse_time) - ) - { is_fuse_QKV_ = true; - } } } @@ -274,9 +326,39 @@ class OpenMultiHeadAttention: IMultiHeadAttention } } + size_t get_workspace_size() + { + size_t size = 0; + + const int buf_size = batch_size_ * head_num_ * from_seq_len_ * size_per_head_; + const int qk_buf_size = batch_size_ * head_num_ * from_seq_len_ * from_seq_len_; + const int seq_len_padded = (from_seq_len_ + 31)/32*32; + const int padded_buf_size = batch_size_ * head_num_ * seq_len_padded * size_per_head_; + const int padded_qk_buf_size = batch_size_ * head_num_ * seq_len_padded * seq_len_padded; + + if(int8_mode_ != 0) + { + //query_buf_(Q_int_buf_) key_buf_(K_int_buf_) value_buf_(V_int_buf_) qk_int_buf_ transpose_dst_(transpose_dst_int_buf_) + size = sizeof(int) * (4*buf_size + padded_qk_buf_size) + + //int8 q_buf_ k_buf_ v_buf_ qk_buf_ + sizeof(int8_t) * (3*padded_buf_size + padded_qk_buf_size) + + //sequence_id_map + (batch_size_*from_seq_len_)*sizeof(int) + + //trt_attn_workspace_ + (dispatcher_int8.get() ? dispatcher_int8->getWorkspaceSize() : 0); + + } + else + { + size = sizeof(DataType_) * (buf_size * 7 + qk_buf_size) + sizeof(DataType_*) * 9 + + (dispatcher_fp16.get() ? dispatcher_fp16->getWorkspaceSize() : 0); + } + return size; + } + //allocate buffer for OpenMultiHeadAttention //read config again if hasChangedConfig == true - void allocateBuffer(IAllocator* allocator, int batch_size, int from_seq_len, int to_seq_len, + void allocateBuffer(IAllocator* allocator, void* cublas_workspace, int batch_size, int from_seq_len, int to_seq_len, int head_num, int size_per_head, bool hasChangedConfig, bool use_trt_kernel) { #ifndef NDEBUG @@ -287,9 +369,7 @@ class OpenMultiHeadAttention: IMultiHeadAttention printf("[ERROR][OpenMultiHeadAttention][allocateBuffer] allocator == NULL!\n"); exit(-1); } - - int buf_size = batch_size_ * head_num_ * from_seq_len_ * size_per_head_; - int qk_buf_size = batch_size_ * head_num_ * from_seq_len_ * from_seq_len_; + try { //only allocate new buffer when buf_ is empty @@ -307,20 +387,28 @@ class OpenMultiHeadAttention: IMultiHeadAttention to_seq_len_ = to_seq_len; head_num_ = head_num; size_per_head_ = size_per_head; + cublas_workspace_ = cublas_workspace; + + const int buf_size = batch_size_ * head_num_ * from_seq_len_ * size_per_head_; + const int qk_buf_size = batch_size_ * head_num_ * from_seq_len_ * from_seq_len_; - int buf_size = batch_size_ * head_num_ * from_seq_len_ * size_per_head_; - int qk_buf_size = batch_size_ * head_num_ * from_seq_len_ * from_seq_len_; if (int8_mode_ != 0) { - buf_ = (DataType_*) allocator_->malloc( - //query_buf_(Q_int_buf_) key_buf_(K_int_buf_) value_buf_(V_int_buf_) qk_int_buf_ transpose_dst_(transpose_dst_int_buf_) - sizeof(int) * (4*buf_size + qk_buf_size) + - //q_buf_ k_buf_ v_buf_ - sizeof(DataType_) * (3*buf_size + qk_buf_size) + - //for fused qkv pointer - sizeof(DataType_*) * 9 + - //sequence_id_map - (batch_size_*from_seq_len_)*sizeof(int), false); + if ((int8_mode_ == 1 && (batch_size_*from_seq_len_ >= 512 || (from_seq_len_ % 32 != 0))) || int8_mode_ == 2) + { + if (use_trt_kernel && (sm_ == kSM_86 || sm_ == kSM_80 || sm_ == kSM_75 || sm_ == kSM_72) && size_per_head_ == 64) + { + //try + { + dispatcher_int8.reset(new FusedMHARunnerInt8v2(head_num_, size_per_head_, sm_)); + } + } + } + const int seq_len_padded = (from_seq_len_ + 31)/32*32; + const int padded_buf_size = batch_size_ * head_num_ * seq_len_padded * size_per_head_; + const int padded_qk_buf_size = batch_size_ * head_num_ * seq_len_padded * seq_len_padded; + + buf_ = (DataType_*) allocator_->malloc(get_workspace_size(), false); if (buf_ == NULL) throw std::runtime_error(std::string("Allocator failed to allocate internal buffer.")); Q_int_buf_ = (int *)(buf_); @@ -328,26 +416,19 @@ class OpenMultiHeadAttention: IMultiHeadAttention V_int_buf_ = K_int_buf_ + buf_size; transpose_dst_int_buf_ = V_int_buf_ + buf_size; qk_int_buf_ = transpose_dst_int_buf_ + buf_size; - q_buf_ = (DataType_*)(qk_int_buf_ + qk_buf_size); - k_buf_ = q_buf_ + buf_size; - v_buf_ = k_buf_ + buf_size; - qk_buf_ = v_buf_ + buf_size; - qkv_kernel_ = (DataType_**)(qk_buf_ + qk_buf_size); - qkv_input_ = qkv_kernel_ + 3; - qkv_buf_ = qkv_input_ + 3; - sequence_id_map_ = (int*)(qkv_buf_ + 3); - if (int8_mode_ == 2) { - K_int_buf_ = (int*)((int8_t*)Q_int_buf_ + buf_size); - V_int_buf_ = (int*)((int8_t*)K_int_buf_ + buf_size); - } + q_buf_ = (DataType_*)(qk_int_buf_ + padded_qk_buf_size); + //the actual size is calculated with int8_t datatype + k_buf_ = (DataType_*)((int8_t*)q_buf_ + padded_buf_size); + v_buf_ = (DataType_*)((int8_t*)k_buf_ + padded_buf_size); + qk_buf_ = (DataType_*)((int8_t*)v_buf_ + padded_buf_size); + sequence_id_map_ = (int*)((int8_t*)qk_buf_ + padded_qk_buf_size); + trt_attn_workspace_ = (void*)(sequence_id_map_ + (batch_size_*from_seq_len_)); } else { - if (use_trt_kernel && (mSM_ == kSM_86 || mSM_ == kSM_80 || mSM_ == kSM_75 || mSM_ == kSM_72) && size_per_head_ == 64) - dispatcher_fp16.reset(new FusedMHARunnerFP16v2(head_num_, size_per_head_, mSM_)); - buf_ = (DataType_*) allocator_->malloc(sizeof(DataType_) * (buf_size * 7 + qk_buf_size) + - sizeof(DataType_*) * 9 + - (dispatcher_fp16.get() ? dispatcher_fp16->getWorkspaceSize() : 0), false); + if (use_trt_kernel && (sm_ == kSM_70 || sm_ == kSM_86 || sm_ == kSM_80 || sm_ == kSM_75 || sm_ == kSM_72) && size_per_head_ == 64) + dispatcher_fp16.reset(new FusedMHARunnerFP16v2(head_num_, size_per_head_, sm_)); + buf_ = (DataType_*) allocator_->malloc(get_workspace_size(), false); if (buf_ == NULL) throw std::runtime_error(std::string("Allocator failed to allocate internal buffer.")); query_buf_ = buf_; @@ -369,32 +450,27 @@ class OpenMultiHeadAttention: IMultiHeadAttention //if config changes, read config again if (hasChangedConfig) { + int isConfigExist = -1; if (int8_mode_ != 0) { - int isConfigExist = access(IGEMM_CONFIG, 0); - if (isConfigExist == -1) - printf("[WARNING][OpenMultiHeadAttention] %s is not found; using default GEMM algo\n", IGEMM_CONFIG); - else - { - readAlgoFromConfig(int8_mode_, batch_size_, from_seq_len_, head_num_, size_per_head_); - } + isConfigExist = access(IGEMM_CONFIG, 0); } else { - int isConfigExist = access(GEMM_CONFIG, 0); - if (isConfigExist == -1) - printf("[WARNING][OpenMultiHeadAttention] %s is not found; using default GEMM algo\n", GEMM_CONFIG); - else - { - readAlgoFromConfig(int8_mode_, batch_size_, from_seq_len_, head_num_, size_per_head_); - } + isConfigExist = access(GEMM_CONFIG, 0); + } + if (isConfigExist == -1) + printf("[WARNING][OpenMultiHeadAttention] %s is not found; using default GEMM algo\n", int8_mode_ != 0 ? IGEMM_CONFIG : GEMM_CONFIG); + else + { + readAlgoFromConfig(int8_mode_, cublasAlgoMap_, parameterMap_, false); } } if (int8_mode_ == 0) { - int is_fp16 = (sizeof(DataType_) == sizeof(half) ? 1 : 0); - getBestAlgoFromMap(batch_size_, from_seq_len_, head_num_, size_per_head_, is_fp16); + getCublasBmmAlgoFromMap(); + judgeFusedQKV(); } } catch(std::runtime_error& error) @@ -404,49 +480,308 @@ class OpenMultiHeadAttention: IMultiHeadAttention } //Ctor - OpenMultiHeadAttention(int int8_mode=0, bool allow_gemm_test=false, bool use_ORDER_COL32_2R_4R4=false) : - int8_mode_(int8_mode), allow_gemm_test_(allow_gemm_test), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) + OpenMultiHeadAttention(int int8_mode=0, bool allow_gemm_test=false, bool use_ORDER_COL32_2R_4R4=false, int sm = 75) : + int8_mode_(int8_mode), allow_gemm_test_(allow_gemm_test), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4), sm_(sm) { #ifndef NDEBUG PRINT_FUNC_NAME_(); #endif - mSM_ = getSMVersion(); + //sm_ = getSMVersion(); - int is_fp16 = (sizeof(DataType_) == sizeof(half) ? 1 : 0); try { - if (int8_mode_ != 0){ - int isConfigExist = access(IGEMM_CONFIG, 0); + int isConfigExist = -1; + if (int8_mode_ != 0) + isConfigExist = access(IGEMM_CONFIG, 0); + else + isConfigExist = access(GEMM_CONFIG, 0); - if (isConfigExist == -1) + if (isConfigExist == -1) + { + if (!allow_gemm_test_) { - if (!allow_gemm_test_) - { - printf("[WARNING][OpenMultiHeadAttention] %s is not found; using default GEMM algo\n", IGEMM_CONFIG); - } + printf("[WARNING][OpenMultiHeadAttention] %s is not found; using default GEMM algo\n", int8_mode_ != 0 ? IGEMM_CONFIG : GEMM_CONFIG); } - else + } + else + { + readAlgoFromConfig(int8_mode_, cublasAlgoMap_, parameterMap_, false); + } + } + catch(std::runtime_error& error) + { + throw error; + } + } + + OpenMultiHeadAttention(const OpenMultiHeadAttention *attention) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + sm_ = attention->sm_; + int8_mode_ = attention->int8_mode_; + allow_gemm_test_ = attention->allow_gemm_test_; + + for(int i = 0; i < 2; i++) cublasBmmAlgo_[i] = attention->cublasBmmAlgo_[i]; + cublasAlgoMap_ = attention->cublasAlgoMap_; + parameterMap_ = attention->parameterMap_; + is_fuse_QKV_ = attention->is_fuse_QKV_; + use_ORDER_COL32_2R_4R4_ = attention->use_ORDER_COL32_2R_4R4_; + } + + void multiHeadAttr_nofuse_kernelLauncher( + cudaStream_t stream, + cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, + DataType_* Q, + const DataType_* bias_Q, + DataType_* K, + const DataType_* bias_K, + DataType_* V, + const DataType_* bias_V, + const DataType_* attr_mask, + DataType_* dst, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int int8_mode_, + const DataType_ scalar) + { + const int k = head_num * size_per_head; + + if (int8_mode_ != 0) + { + //var for int8 + const float*Qbias_amax_ptr, *Kbias_amax_ptr, *Vbias_amax_ptr, *bmm1_amax_ptr, *Softmax_amax_ptr, *bmm2_amax_ptr, *in_amax_ptr, *Q_aftergemm_amax_ptr, *K_aftergemm_amax_ptr, *V_aftergemm_amax_ptr; + Qbias_amax_ptr = param_.amaxList + 8; + Kbias_amax_ptr = param_.amaxList + 16; + Vbias_amax_ptr = param_.amaxList + 24; + Softmax_amax_ptr = param_.amaxList + 32; + bmm2_amax_ptr = param_.amaxList + 36; + Q_aftergemm_amax_ptr = param_.amaxList + 4; + K_aftergemm_amax_ptr = param_.amaxList + 12; + V_aftergemm_amax_ptr = param_.amaxList + 20; + bmm1_amax_ptr = param_.amaxList + 28; + in_amax_ptr = param_.amaxList; + + if (size_per_head % 32 != 0) + { + printf("[ERROR][FT][multiHeadAttr_nofuse_kernelLauncher] int8 unfused mha kernel only works when size_per_head %% 32 == 0.\n"); + exit(-1); + } + if ((seq_len % 32 != 0) && int8_mode_ == 1) + { + printf("[ERROR][FT][multiHeadAttr_nofuse_kernelLauncher] int8 mode1 unfused mha kernel only works when seq_len %% 32 == 0.\n"); + exit(-1); + } + const int seq_len_padded = (seq_len + 31)/32*32; + + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + if (int8_mode_ == 1) + { + add_QK_bias_transform_kernelLauncher((int8_t*)q_buf_, (int8_t*)k_buf_, + (const int32_t*) Q, bias_Q, (const int32_t*) K, bias_K, + batch_size, seq_len, head_num, size_per_head, + query_weight_amax_list, in_amax_ptr+2, + key_weight_amax_list, in_amax_ptr+2, + Qbias_amax_ptr+3, Kbias_amax_ptr+3, + use_ORDER_COL32_2R_4R4_, stream); + add_V_bias_transform_kernelLauncher((int8_t*)v_buf_, + (const int32_t *)V, bias_V, + batch_size, seq_len, head_num, size_per_head, + value_weight_amax_list, in_amax_ptr+2, Vbias_amax_ptr+3, + use_ORDER_COL32_2R_4R4_, stream); + } + else if (int8_mode_ == 2 || int8_mode_ == 3) { - readAlgoFromConfig(int8_mode_, batch_size_, from_seq_len_, head_num_, size_per_head_); + add_QK_bias_transform_kernelLauncher((int8_t*)q_buf_, (int8_t*)k_buf_, + (const int8_t*) Q, bias_Q, (const int8_t*)K, bias_K, + batch_size, seq_len, head_num, size_per_head, + Q_aftergemm_amax_ptr+1, K_aftergemm_amax_ptr+1, + Qbias_amax_ptr+3, Kbias_amax_ptr+3, + use_ORDER_COL32_2R_4R4_, stream); + add_V_bias_transform_kernelLauncher((int8_t*)v_buf_, (const int8_t *)V, bias_V, + batch_size, seq_len, head_num, size_per_head, + V_aftergemm_amax_ptr+1, Vbias_amax_ptr+3, + use_ORDER_COL32_2R_4R4_, stream); } } else{ - int isConfigExist = access(GEMM_CONFIG, 0); + mappingRemovePaddingData_kernelLauncher(batch_size, seq_len, + param_.valid_word_num, sequence_id_map_, + param_.sequence_id_offset, stream); + // if we use remove padding, then initialize the q_buf_, k_buf_ and v_buf_ to prevent bugs. v_buf_ will be properly initiaized in add_V_bias_transform_rebuild_padding_kernelLauncher() + cudaMemsetAsync((int8_t*)q_buf_, 0, 2 * batch_size_ * seq_len_padded * head_num * size_per_head * sizeof(int8_t), param_.stream); + if (int8_mode_ == 1) + { + + add_QK_bias_transform_rebuild_padding_kernelLauncher((int8_t*)q_buf_, (int8_t*)k_buf_, + (const int32_t*)Q, bias_Q, + (const int32_t*)K, bias_K, + param_.sequence_id_offset, param_.valid_word_num, + batch_size, seq_len, + head_num, size_per_head, + query_weight_amax_list, in_amax_ptr+2, + key_weight_amax_list, in_amax_ptr+2, + Qbias_amax_ptr+3, Kbias_amax_ptr+3, + use_ORDER_COL32_2R_4R4_, stream); - if (isConfigExist == -1) + add_V_bias_transform_rebuild_padding_kernelLauncher((int8_t*)v_buf_, (const int32_t *)V, bias_V, + sequence_id_map_, param_.valid_word_num, + batch_size, seq_len, head_num, size_per_head, + value_weight_amax_list, in_amax_ptr+2, Vbias_amax_ptr+3, + use_ORDER_COL32_2R_4R4_, stream); + } + else if (int8_mode_ == 2 || int8_mode_ == 3) { - if (!allow_gemm_test_) - { - printf("[WARNING][OpenMultiHeadAttention] %s is not found; using default GEMM algo\n", GEMM_CONFIG); - } + add_QK_bias_transform_rebuild_padding_kernelLauncher((int8_t*)q_buf_, (int8_t*)k_buf_, + (const int8_t*)Q, bias_Q, + (const int8_t*)K, bias_K, + param_.sequence_id_offset, param_.valid_word_num, + batch_size, seq_len, head_num, size_per_head, + Q_aftergemm_amax_ptr+1, K_aftergemm_amax_ptr+1, + Qbias_amax_ptr+3, Kbias_amax_ptr+3, + use_ORDER_COL32_2R_4R4_, stream); + + add_V_bias_transform_rebuild_padding_kernelLauncher((int8_t*)v_buf_, (const int8_t *)V, bias_V, + sequence_id_map_, param_.valid_word_num, + batch_size, seq_len, head_num, size_per_head, + V_aftergemm_amax_ptr+1, Vbias_amax_ptr+3, + use_ORDER_COL32_2R_4R4_, stream); + } + } + + int batchCount = batch_size * head_num; + + if (int8_mode_ == 1) + { + cublasLtMM_withAlgo(qk_int_buf_, batchCount, seq_len, seq_len, size_per_head, + size_per_head*seq_len, size_per_head*seq_len, seq_len*seq_len, + (int8_t*)q_buf_, (int8_t*)k_buf_, cublaslt_handle, stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); + + softmax_COL32_kernelLauncher((int8_t*)qk_buf_, qk_int_buf_, attr_mask, + batch_size, head_num, seq_len, + float(scalar), Qbias_amax_ptr + 1, Kbias_amax_ptr + 1, + Softmax_amax_ptr, stream); + + cublasLtMM_withAlgo(transpose_dst_int_buf_, batchCount, seq_len, size_per_head, seq_len, + seq_len*seq_len, size_per_head*seq_len, size_per_head*seq_len, (int8_t*)qk_buf_, + (int8_t*)v_buf_, cublaslt_handle, stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); + + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + transpose_COL32_kernelLauncher((int8_t*)dst, (const int*)transpose_dst_int_buf_, batch_size, seq_len, head_num, + size_per_head, Vbias_amax_ptr+1, Softmax_amax_ptr+1, bmm2_amax_ptr+3, stream); + } + else + { + transpose_COL32_rebuild_padding_kernelLauncher((int8_t*)dst, (const int*)transpose_dst_int_buf_, sequence_id_map_, + param_.valid_word_num, batch_size, seq_len, head_num, size_per_head, + Vbias_amax_ptr+1, Softmax_amax_ptr+1, bmm2_amax_ptr+3, stream); + } + } + else if (int8_mode_ == 2 || int8_mode_ == 3) + { + cublasLtMM_withAlgo_int8IO((int8_t*)qk_int_buf_, batchCount, seq_len, seq_len_padded, size_per_head, + size_per_head*seq_len, size_per_head*seq_len_padded, seq_len*seq_len_padded, + param_.int8O_gemm_deQ_scale_list[3], + (int8_t*)q_buf_, (int8_t*)k_buf_, cublaslt_handle, stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); + + softmax_COL32_kernelLauncher((int8_t*)qk_buf_, (int8_t*)qk_int_buf_, attr_mask, + batch_size, head_num, seq_len, + float(scalar), bmm1_amax_ptr + 1, Softmax_amax_ptr, + stream); + + cublasLtMM_withAlgo_int8IO((int8_t*)transpose_dst_int_buf_, batchCount, seq_len, size_per_head, seq_len_padded, + seq_len*seq_len_padded, size_per_head*seq_len_padded, size_per_head*seq_len, param_.int8O_gemm_deQ_scale_list[4], (int8_t*)qk_buf_, + (int8_t*)v_buf_, cublaslt_handle, stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + transpose_COL32_kernelLauncher((int8_t*)dst, (const int8_t*)transpose_dst_int_buf_, batch_size, seq_len, head_num, + size_per_head, bmm2_amax_ptr+1, bmm2_amax_ptr+3, stream); } else { - readAlgoFromConfig(int8_mode_, batch_size_, from_seq_len_, head_num_, size_per_head_); + transpose_COL32_rebuild_padding_kernelLauncher((int8_t*)dst, (const int8_t*)transpose_dst_int_buf_, sequence_id_map_, + param_.valid_word_num, batch_size, seq_len, head_num, size_per_head, + bmm2_amax_ptr+1, bmm2_amax_ptr+3, stream); } + } + } + //FP32/FP16 + else + { + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + add_QKV_bias_transpose_kernelLauncher(q_buf_, k_buf_, v_buf_, + Q, bias_Q, + K, bias_K, + V, bias_V, + batch_size_, seq_len, + head_num, + size_per_head, stream); + } + else + { + // if we use remove padding, then initialize the q_buf_, k_buf_ and v_buf_ to prevent bugs. + cudaMemsetAsync(q_buf_, 0, 3 * batch_size_ * seq_len * head_num * size_per_head * sizeof(DataType_), param_.stream); - + add_QKV_bias_rebuild_padding_kernelLauncher(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, v_buf_, + batch_size, seq_len, head_num, size_per_head, param_.valid_word_num, param_.sequence_id_offset, stream); + } + + DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; + + check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + seq_len, seq_len, size_per_head, + &alpha, + k_buf_, AType_, size_per_head, seq_len * size_per_head, + q_buf_, BType_, size_per_head, seq_len * size_per_head, + &beta, + qk_buf_, CType_, seq_len, seq_len * seq_len, + batch_size * head_num, + computeType_, + static_cast(cublasBmmAlgo_[0]))); + + attn_softmax_kernelLauncher(qk_buf_, attr_mask, batch_size, seq_len, head_num, scalar, stream); + + check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + size_per_head, seq_len, seq_len, + &alpha, + v_buf_, AType_, size_per_head, seq_len * size_per_head, + qk_buf_, BType_, seq_len, seq_len * seq_len, + &beta, + transpose_dst_, CType_, size_per_head, seq_len * size_per_head, + batch_size * head_num, + computeType_, + static_cast(cublasBmmAlgo_[1]))); + + if(param_.sequence_id_offset == nullptr || param_.valid_word_num == batch_size * seq_len) + { + transpose_kernelLauncher(transpose_dst_, dst, batch_size, seq_len, head_num, size_per_head, stream); + } + else + { + transpose_rebuild_padding_kernelLauncher(transpose_dst_, dst, param_.valid_word_num, + batch_size, seq_len, head_num, size_per_head, + param_.sequence_id_offset, stream); } + } + } + + void forward() + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + try + { + forward(param_.from_tensor, param_.to_tensor); } catch(std::runtime_error& error) { @@ -454,8 +789,24 @@ class OpenMultiHeadAttention: IMultiHeadAttention } } - void forward() + void forward(const DataType_* from_tensor, const DataType_* to_tensor) { + if(param_.sequence_id_offset != nullptr && param_.valid_word_num != batch_size_ * from_seq_len_) + { + is_fuse_QKV_ = false; + } + + if(is_fuse_QKV_ == true && int8_mode_ == 0) + { + // For tensorrt, we cannot get the pointer of from tensor until enqueue + const DataType_* hA[] {param_.self_attention.query_weight.kernel, + param_.self_attention.key_weight.kernel, + param_.self_attention.value_weight.kernel, + from_tensor, to_tensor, to_tensor, + query_buf_, key_buf_, value_buf_}; + cudaMemcpyAsync((void*)qkv_kernel_, hA, sizeof(DataType_*) * 9, cudaMemcpyHostToDevice, param_.stream); + } + #ifndef NDEBUG PRINT_FUNC_NAME_(); #endif @@ -467,6 +818,15 @@ class OpenMultiHeadAttention: IMultiHeadAttention try { if (int8_mode_ != 0){ + //K_int_buf_ V_int_buf_ should point to correct buffer according to param_.valid_word_num + if (int8_mode_ == 1) { + K_int_buf_ = (int*)Q_int_buf_ + param_.valid_word_num * head_num_ * size_per_head_; + V_int_buf_ = (int*)K_int_buf_ + param_.valid_word_num * head_num_ * size_per_head_; + } else if (int8_mode_ == 2 || int8_mode_ == 3){ + K_int_buf_ = (int*)((int8_t*)Q_int_buf_ + param_.valid_word_num * head_num_ * size_per_head_); + V_int_buf_ = (int*)((int8_t*)K_int_buf_ + param_.valid_word_num * head_num_ * size_per_head_); + } + int fusedINT8QKV = 0; const int8_t* Q_weight = (const int8_t*)(param_.self_attention.query_weight.kernel); const int8_t* K_weight = (const int8_t*)(param_.self_attention.key_weight.kernel); @@ -485,63 +845,72 @@ class OpenMultiHeadAttention: IMultiHeadAttention cublasLtMM_withAlgo(Q_int_buf_, 1, m, n, k, 0, 0, 0, param_.int8_from_tensor, Q_weight, param_.cublaslt_handle, param_.stream, - cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); cublasLtMM_withAlgo(K_int_buf_, 1, m, n, k, 0, 0, 0, param_.int8_from_tensor, K_weight, param_.cublaslt_handle, param_.stream, - cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); cublasLtMM_withAlgo(V_int_buf_, 1, m, n, k, 0, 0, 0, param_.int8_from_tensor, V_weight, param_.cublaslt_handle, param_.stream, - cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); } else{ int strideFactor = (fusedINT8QKV == 1) ? (sizeof(DataType_)/sizeof(int8_t)) : 1; cublasLtMM_withAlgo(Q_int_buf_, 3, m, n, k, 0, n*k*strideFactor, n*m, param_.int8_from_tensor, Q_weight, - param_.cublaslt_handle, param_.stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + param_.cublaslt_handle, param_.stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); } } - else + else if (int8_mode_ == 2 || int8_mode_ == 3) { if (fusedINT8QKV == 0){ cublasLtMM_withAlgo_int8IO((int8_t*)Q_int_buf_, 1, m, n, k, 0, 0, 0, param_.int8O_gemm_deQ_scale_list[0], param_.int8_from_tensor, Q_weight, param_.cublaslt_handle, param_.stream, - cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); cublasLtMM_withAlgo_int8IO((int8_t*)K_int_buf_, 1, m, n, k, 0, 0, 0, param_.int8O_gemm_deQ_scale_list[1], param_.int8_from_tensor, K_weight, param_.cublaslt_handle, param_.stream, - cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); cublasLtMM_withAlgo_int8IO((int8_t*)V_int_buf_, 1, m, n, k, 0, 0, 0, param_.int8O_gemm_deQ_scale_list[2], param_.int8_from_tensor, V_weight, param_.cublaslt_handle, param_.stream, - cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); } else{ int strideFactor = (fusedINT8QKV == 1) ? (sizeof(DataType_)/sizeof(int8_t)) : 1; cublasLtMM_withAlgo_int8IO((int8_t*)Q_int_buf_, 3, m, n, k, 0, n*k*strideFactor, n*m, param_.int8O_gemm_deQ_scale_list[0], param_.int8_from_tensor, Q_weight, - param_.cublaslt_handle, param_.stream, cublasLtAlgoMap_, use_ORDER_COL32_2R_4R4_); + param_.cublaslt_handle, param_.stream, cublasAlgoMap_, use_ORDER_COL32_2R_4R4_); } } - if (fusedINT8QKV != 0) { - if (int8_mode_ == 1) { - K_int_buf_ = (int*)Q_int_buf_ + param_.valid_word_num * head_num_ * size_per_head_; - V_int_buf_ = (int*)K_int_buf_ + param_.valid_word_num * head_num_ * size_per_head_; - } else { - K_int_buf_ = (int*)((int8_t*)Q_int_buf_ + param_.valid_word_num * head_num_ * size_per_head_); - V_int_buf_ = (int*)((int8_t*)K_int_buf_ + param_.valid_word_num * head_num_ * size_per_head_); - } + int S; + if(dispatcher_int8.get()) + S = dispatcher_int8->getSFromMaxSeqLen(from_seq_len_); + if(dispatcher_int8.get() && dispatcher_int8->isValid(S) && param_.trt_seqlen_offset != nullptr) + { + // This function is only used when we satisfy the following conditions: + // 1. INT8 + // 2. GPU SM >= 75 + int8_fused_multiHeadAttr_kernelLauncher((const void*)Q_int_buf_, + param_.amaxList + 4 + 1, + param_.amaxList + 12 + 1, + param_.amaxList + 20 + 1, + param_.trt_fused_mha_amax_list[0]/127.0f, + S + ); } + else + { - DataType_ scalar = 1 / sqrtf(size_per_head_ * 1.0f); - multiHeadAttr_nofuse_kernelLauncher( + DataType_ scalar = 1 / sqrtf(size_per_head_ * 1.0f); + multiHeadAttr_nofuse_kernelLauncher( param_.stream, param_.cublas_handle, param_.cublaslt_handle, @@ -559,10 +928,12 @@ class OpenMultiHeadAttention: IMultiHeadAttention size_per_head_, int8_mode_, scalar); + } } else{ if(is_fuse_QKV_ == true) { + int algoId = getAlgoIdFromMap(cublasAlgoMap_, 3, n, m, k, AType_ == CUDA_R_16F ? HALF_DATATYPE : FLOAT_DATATYPE); check_cuda_error(cublasGemmBatchedEx(param_.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, @@ -573,87 +944,78 @@ class OpenMultiHeadAttention: IMultiHeadAttention (void* const*)qkv_buf_, CType_, n, 3, computeType_, - static_cast(cublasAlgo_[3]))); + static_cast(algoId))); } else { - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.query_weight.kernel, AType_, n, - param_.from_tensor, BType_, k, - &beta, - query_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); + cublasMM_cublasLtMM_wrapper(param_.cublaslt_handle, param_.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, &alpha, + param_.self_attention.query_weight.kernel, AType_, n, + from_tensor, BType_, k, + &beta, (DataType_ *)query_buf_, CType_, n, + param_.stream, cublasAlgoMap_, sm_, cublas_workspace_); #ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); #endif - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.key_weight.kernel, AType_, n, - param_.to_tensor, BType_, k, - &beta, - key_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); + cublasMM_cublasLtMM_wrapper(param_.cublaslt_handle, param_.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, &alpha, + param_.self_attention.key_weight.kernel, AType_, n, + to_tensor, BType_, k, + &beta, (DataType_ *)key_buf_, CType_, n, + param_.stream, cublasAlgoMap_, sm_, cublas_workspace_); #ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); #endif - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.value_weight.kernel, AType_, n, - param_.to_tensor, BType_, k, - &beta, - value_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); + cublasMM_cublasLtMM_wrapper(param_.cublaslt_handle, param_.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, &alpha, + param_.self_attention.value_weight.kernel, AType_, n, + to_tensor, BType_, k, + &beta, (DataType_ *)value_buf_, CType_, n, + param_.stream, cublasAlgoMap_, sm_, cublas_workspace_); } #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif - if(dispatcher_fp16.get() && OpType_==OperationType::FP16) + int S; + if(dispatcher_fp16.get()) + S = dispatcher_fp16->getSFromMaxSeqLen(from_seq_len_); + if(dispatcher_fp16.get() && OpType_==OperationType::FP16 && dispatcher_fp16->isValid(S) && param_.trt_seqlen_offset != nullptr) { // This function is only used when we satisfy the following conditions: // 1. FP16 // 2. GPU SM >= 72 - fused_multiHeadAttr_kernelLauncher(); + // 3. Temporally add seqlen <= 384 limitation because the current fused mha cannot handle seqlen > 384. + fused_multiHeadAttr_kernelLauncher(S); } else { DataType_ scalar = 1 / sqrtf(size_per_head_ * 1.0f); - multiHeadAttr_nofuse_kernelLauncher( - param_.stream, - param_.cublas_handle, - param_.cublaslt_handle, - query_buf_, - param_.self_attention.query_weight.bias, - key_buf_, - param_.self_attention.key_weight.bias, - value_buf_, - param_.self_attention.value_weight.bias, - param_.attr_mask, - param_.attr_out, - batch_size_, - from_seq_len_, - head_num_, - size_per_head_, - int8_mode_, - scalar); + multiHeadAttr_nofuse_kernelLauncher( + param_.stream, + param_.cublas_handle, + param_.cublaslt_handle, + query_buf_, + param_.self_attention.query_weight.bias, + key_buf_, + param_.self_attention.key_weight.bias, + value_buf_, + param_.self_attention.value_weight.bias, + param_.attr_mask, + param_.attr_out, + batch_size_, + from_seq_len_, + head_num_, + size_per_head_, + int8_mode_, + scalar); } } } @@ -663,31 +1025,39 @@ class OpenMultiHeadAttention: IMultiHeadAttention } } - void fused_multiHeadAttr_kernelLauncher(); + void fused_multiHeadAttr_kernelLauncher(const int S); + + void int8_fused_multiHeadAttr_kernelLauncher(const void* Q, + const float *q_deQFactor_ptr, const float *k_deQFactor_ptr, const float *v_deQFactor_ptr, + const float mScaleQkv, const int S); - void multiHeadAttr_nofuse_kernelLauncher( - cudaStream_t stream, - cublasHandle_t handle, - cublasLtHandle_t cublaslt_handle, - DataType_* Q, + void trt_add_QKV_bias_kernelLauncher( + const DataType_* bias_Q, + const DataType_* bias_K, + const DataType_* bias_V); + + void trt_add_QKV_bias_COL32_int8IO_kernelLauncher( + int8_t* output, + const int8_t* Q, const DataType_* bias_Q, - DataType_* K, const DataType_* bias_K, - DataType_* V, const DataType_* bias_V, - const DataType_* attr_mask, - DataType_* dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head, - const int int8_mode_, - const DataType_ scalar); + const float *q_input_deQFactor_ptr, + const float *k_input_deQFactor_ptr, + const float *v_input_deQFactor_ptr, + const float qkv_output_scale); - void trt_add_QKV_bias_kernelLauncher( + void trt_add_QKV_bias_COL32_int32Iint8O_kernelLauncher( + int8_t* output, + const int32_t* Q, const DataType_* bias_Q, const DataType_* bias_K, - const DataType_* bias_V); + const DataType_* bias_V, + const float *input_deQFactor_div127_ptr, + const float * q_weight_amax, + const float * k_weight_amax, + const float * v_weight_amax, + const float qkv_output_scale); void initialize(MultiHeadInitParam param) { @@ -700,35 +1070,11 @@ class OpenMultiHeadAttention: IMultiHeadAttention query_weight_amax_list = param_.amaxList + ACTIVATION_AMAX_NUM; key_weight_amax_list = query_weight_amax_list + hidden_dim; value_weight_amax_list = key_weight_amax_list + hidden_dim; + if (dispatcher_int8.get()) + { + dispatcher_int8.get()->setScaleList(param_.trt_fused_mha_amax_list[0]/127.0f, param_.trt_fused_mha_amax_list[1]/127.0f, param_.trt_fused_mha_amax_list[2]/127.0f); + } } - if(is_fuse_QKV_ == true && param_.from_tensor != nullptr) - { - // For tensorrt, we cannot get the pointer of from tensor until enqueue - const DataType_* hA[] {param_.self_attention.query_weight.kernel, - param_.self_attention.key_weight.kernel, - param_.self_attention.value_weight.kernel, - param_.from_tensor, param_.from_tensor, param_.from_tensor, - query_buf_, key_buf_, value_buf_}; - cudaMemcpyAsync((void*)qkv_kernel_, hA, sizeof(DataType_*) * 9, cudaMemcpyHostToDevice, param_.stream); - } - } - void trt_initialize(DataType_* from_tensor, DataType_* to_tensor, DataType_* attr_mask, cudaStream_t stream, - cublasHandle_t cublas_handle) - { - param_.from_tensor = from_tensor; - param_.to_tensor = to_tensor; - param_.attr_mask = attr_mask; - param_.stream = stream; - param_.cublas_handle = cublas_handle; - if(is_fuse_QKV_ == true) - { - const DataType_* hA[] {param_.self_attention.query_weight.kernel, - param_.self_attention.key_weight.kernel, - param_.self_attention.value_weight.kernel, - param_.from_tensor, param_.from_tensor, param_.from_tensor, - query_buf_, key_buf_, value_buf_}; - cudaMemcpyAsync((void*)qkv_kernel_, hA, sizeof(DataType_*) * 9, cudaMemcpyHostToDevice, param_.stream); - } } ~OpenMultiHeadAttention() override @@ -745,15 +1091,7 @@ class OpenMultiHeadAttention: IMultiHeadAttention } } }; - -void add_QKV_bias_transpose_kernelLauncher( - const half* query_buf, const half* bias_Q, - const half* key_buf, const half* bias_K, - const half* value_buf, const half* bias_V, - half* context_buf, - const int valid_word_num, - const int head_num, const int size_per_head, - cudaStream_t stream); // Used to debug the trt_add_QKV_bias kernel - + }//namespace cuda }//namespace fastertransformer + diff --git a/fastertransformer/cuda/open_decoder.cu b/fastertransformer/cuda/open_decoder.cu index f63a09c1a..b277daf8e 100644 --- a/fastertransformer/cuda/open_decoder.cu +++ b/fastertransformer/cuda/open_decoder.cu @@ -1,1843 +1,1505 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * Open sourced multi-head attention - **/ - -#include "fastertransformer/open_decoder.h" -#include "cub/cub.cuh" - -namespace fastertransformer{ - -const int WARP_SIZE = 32; -const bool ATTENION_OPT = true; -const int ATTENTION_BLOCK_SIZE = 256; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -using Copy_half_t = - typename std::conditional::type - >::type - >::type; - -template -using Copy_t = Copy_half_t; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/** - masked multi-head attention - */ -#define FINAL_MASK 0xffffffff -template -__inline__ __device__ -T warpReduceSum(T val) -{ - for(int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); - return val; -} -/* Calculate the sum of all elements in a block */ -template - __inline__ __device__ -T blockReduceSum(T val) -{ - static __shared__ T shared[32]; - // __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if(lane == 0) - shared[wid] = val; - - __syncthreads(); - - val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); - - return val; -} - -/* gelu activation */ -template -__inline__ __device__ -T gelu(T x) -{ - float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -/* gelu activation for half2 */ -template <> -__inline__ __device__ -half2 gelu(half2 val) -{ - half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - - -template -__global__ -void add_bias_gelu(T* out, const T* bias, int m, int n) -{ - for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) - { - T reg_bias = __ldg(&bias[id % n]); - T val = out[id] + reg_bias; - out[id] = (T)(gelu(val)); - } -} - -template <> - __global__ -void add_bias_gelu(half* out, const half* bias, int m, int n) -{ - half2* out_ptr = (half2*) out; - const half2* bias_ptr = (half2*) bias; - - for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) - { - half2 reg_bias = __ldg(&bias_ptr[id % n]); - half2 val = out_ptr[id] + reg_bias; - out_ptr[id] = gelu(val); - } -} - -template -__global__ -void add_bias_relu(T* out, const T* bias, int m, int n) -{ - for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) - { - T reg_bias = __ldg(&bias[id % n]); - T val = out[id] + reg_bias; - out[id] = (T)(val > 0.0f ? val : 0.0f); - } -} - -template <> - __global__ -void add_bias_relu(half* out, const half* bias, int m, int n) -{ - half2* out_ptr = (half2*) out; - const half2* bias_ptr = (half2*) bias; - - for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) - { - half2 reg_bias = __ldg(&bias_ptr[id % n]); - half2 val = out_ptr[id] + reg_bias; - val.x = val.x > (half)0.0f ? val.x : (half)0.0f; - val.y = val.y > (half)0.0f ? val.y : (half)0.0f; - out_ptr[id] = val; - } -} - -template - __inline__ __device__ -T warpReduceMax(T val) -{ - for(int mask = 16; mask > 0; mask >>= 1) - val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); - return val; -} -/* Calculate the maximum of all elements in a block */ -template - __inline__ __device__ -T blockReduceMax(T val) -{ - static __shared__ T shared[32]; -// __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - val = warpReduceMax(val); // get maxx in each warp - - if(lane == 0) // record in-warp maxx by warp Idx - shared[wid] = val; - - __syncthreads(); - - - val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)-1e20f; - val = warpReduceMax(val); - - return val; -} - -template -__global__ -void masked_attention_kernel_opt( - T* __restrict key_buf, T* __restrict value_buf, - T* __restrict query_buf, const T* __restrict self_Q_bias, - T* __restrict key_cache, const T* __restrict self_K_bias, - T* __restrict value_cache, const T* __restrict self_V_bias, - T* __restrict context_buf, int batch_size, int head_num, const int step, const T scalar) -{ - typedef Copy_t copy_t; - const int elems_per_thread = size_per_head / WARP_SIZE; - - union Access_t - { - copy_t v; - T x[elems_per_thread]; // supported size 1,2,4 - }; - typedef struct Float_n_t - { - T x[elems_per_thread]; // supported size 1,2,4 - } float_n_t; - - __shared__ float_n_t sq[block_sz]; - - extern __shared__ float logits[]; // use to store the logits from [0~step] - - const int tid = threadIdx.x; - const int warp_num = block_sz / WARP_SIZE; - const int bid = blockIdx.x; - const int head_id = blockIdx.x % head_num; - const int warp_id = tid / WARP_SIZE; // warp_id in block - const int lane_id = tid % WARP_SIZE; // lane_id in warp - - typedef cub::BlockReduce MaxValBlockReduce; - typedef cub::BlockReduce BlockReduce; - __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; - __shared__ typename BlockReduce::TempStorage block_temp_storage; - __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; - - int qkv_id = bid * size_per_head; - int qkv_bias_id = head_id * size_per_head; - - query_buf = &query_buf[qkv_id]; - key_buf = &key_buf[qkv_id]; - value_buf = &value_buf[qkv_id]; - self_K_bias = &self_K_bias[qkv_bias_id]; - key_cache = &key_cache[qkv_id]; - self_Q_bias = &self_Q_bias[qkv_bias_id]; - self_V_bias = &self_V_bias[qkv_bias_id]; - value_cache = &value_cache[qkv_id]; - context_buf = &context_buf[qkv_id]; - - Access_t bias_r, query_buf_r; - Access_t key_val_r, key_buf_r; - Access_t value_val_r, value_buf_r; - - // each warp will have its own copy of sq - query_buf_r.v = *((copy_t *)query_buf + lane_id); - key_buf_r.v = *((copy_t *)key_buf + lane_id); - bias_r.v = *((copy_t *)self_Q_bias + lane_id); - float qb_r[elems_per_thread]; - for (int i = 0; i < elems_per_thread; ++i) - { - qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; - } - - //offset for each step - int offset = batch_size * head_num * size_per_head; - bias_r.v = *((copy_t *) self_K_bias + lane_id); - for(int ite = warp_id; ite < step; ite += warp_num) - { - key_val_r.v = *((copy_t *)&key_cache[ite * offset] + lane_id); - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1) - { - for (int i = 0; i < elems_per_thread; i++) - { - key_val_r.x[i] = (float)key_buf_r.x[i] + (float)bias_r.x[i]; - } - *((copy_t *)&key_cache[ite * offset] + lane_id) = key_val_r.v; - } - float val = 0.f; - for (int i = 0; i < elems_per_thread; i++) - { - val = val + (float)key_val_r.x[i] * qb_r[i] * (float)scalar; - } - float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); - if (lane_id == 0) - { - logits[ite] = qk; - } - } - __syncthreads(); - - __shared__ float s_max_val, s_sum; - - float local_i = -1e20f; - for(int i = tid; i < step; i += blockDim.x) - local_i = max(local_i, logits[i]); - - float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - - float local_o = 0.0f; - for(int i = tid; i < step; i += blockDim.x) - { - logits[i] = __expf(logits[i] - s_max_val); - local_o += logits[i]; - } - float val = BlockReduce(block_temp_storage).Sum(local_o); - - if(tid == 0) - s_sum = val + 1e-6; - __syncthreads(); - - float s_sum_inverse = __fdividef(1.0f, s_sum); - for(int i = tid; i < step; i += blockDim.x) - { - logits[i] = logits[i] * s_sum_inverse; - } - __syncthreads(); - - // This optimization introduces discrepancy because of different order in FP32 summation - float sum_r[elems_per_thread] = {0.f}; - bias_r.v = *((copy_t *) self_V_bias + lane_id); - value_buf_r.v = *((copy_t *)value_buf + lane_id); - - for(int ite = warp_id; ite < step; ite += warp_num) - { - value_val_r.v = *((copy_t *)&value_cache[ite * offset] + lane_id); - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1) - { - for (int i = 0; i < elems_per_thread; i++) - { - value_val_r.x[i] = (float)value_buf_r.x[i] + (float)bias_r.x[i]; - } - *((copy_t *)&value_cache[ite * offset] + lane_id) = value_val_r.v; - } - for (int i = 0; i < elems_per_thread; ++i) - { - sum_r[i] += (float)value_val_r.x[i] * logits[ite]; - } - } - for (int i = 0; i < elems_per_thread; i++) - { - sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; - } - __syncthreads(); - if (warp_id == 0) - { - #pragma unroll - for (int j = 1; j < warp_num; j++) - { - for (int i = 0; i < elems_per_thread; ++i) - { - sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + tid].x[i]; - } - } - } - __syncthreads(); - #pragma unroll - for (int i = 0; i < elems_per_thread; i++) - { - value_val_r.x[i] = sum_r[i]; - } - if (warp_id == 0) - { - *((copy_t *)context_buf + lane_id) = value_val_r.v; - } -} - -// only use for compile -template -__global__ -void masked_attention_kernel_opt_half2( - float* __restrict key_buf, float* __restrict value_buf, - float* __restrict query_buf, const float* __restrict self_Q_bias, - float* __restrict key_cache, const float* __restrict self_K_bias, - float* __restrict value_cache, const float* __restrict self_V_bias, - float* __restrict context_buf, int batch_size, int head_num, const int step, const float scalar) {} - -template -__global__ -void masked_attention_kernel_opt_half2( - half* __restrict key_buf, half* __restrict value_buf, - half* __restrict query_buf, const half* __restrict self_Q_bias, - half* __restrict key_cache, const half* __restrict self_K_bias, - half* __restrict value_cache, const half* __restrict self_V_bias, - half* __restrict context_buf, int batch_size, int head_num, const int step, const half scalar) -{ - half2* key_buf_ptr = (half2*)key_buf; - half2* value_buf_ptr = (half2*)value_buf; - half2* query_buf_ptr = (half2*)query_buf; - half2* key_cache_ptr = (half2*)key_cache; - half2* value_cache_ptr = (half2*)value_cache; - const half2* self_Q_bias_ptr = (const half2*)self_Q_bias; - const half2* self_K_bias_ptr = (const half2*)self_K_bias; - const half2* self_V_bias_ptr = (const half2*)self_V_bias; - half2* context_buf_ptr = (half2*)context_buf; - - typedef Copy_t copy_t; - const int elems_per_thread = size_per_head / 2 / WARP_SIZE; - - union Access_t - { - copy_t v; - half2 x[elems_per_thread]; // supported size 1,2,4 - }; - typedef struct Half_n_t - { - half2 x[elems_per_thread]; // supported size 1,2,4 - } half_n_t; - - __shared__ half_n_t sq[block_sz]; - - extern __shared__ float logits[]; // use to store the logits from [0~step] - - const int tid = threadIdx.x; - const int warp_num = block_sz / WARP_SIZE; - const int bid = blockIdx.x; - const int head_id = blockIdx.x % head_num; - const int warp_id = tid / WARP_SIZE; // warp_id in block - const int lane_id = tid % WARP_SIZE; // lane_id in warp - - typedef cub::BlockReduce MaxValBlockReduce; - typedef cub::BlockReduce BlockReduce; - __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; - __shared__ typename BlockReduce::TempStorage block_temp_storage; - __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; - - int qkv_id = bid * size_per_head / 2; - int qkv_bias_id = head_id * size_per_head / 2; - - query_buf_ptr = &query_buf_ptr[qkv_id]; - key_buf_ptr = &key_buf_ptr[qkv_id]; - value_buf_ptr = &value_buf_ptr[qkv_id]; - self_K_bias_ptr = &self_K_bias_ptr[qkv_bias_id]; - key_cache_ptr = &key_cache_ptr[qkv_id]; - self_Q_bias_ptr = &self_Q_bias_ptr[qkv_bias_id]; - self_V_bias_ptr = &self_V_bias_ptr[qkv_bias_id]; - value_cache_ptr = &value_cache_ptr[qkv_id]; - context_buf_ptr = &context_buf_ptr[qkv_id]; - - Access_t bias_r, query_buf_r; - Access_t key_val_r, key_buf_r; - Access_t value_val_r, value_buf_r; - - // each warp will have its own copy of sq - query_buf_r.v = *((copy_t *)query_buf_ptr + lane_id); - key_buf_r.v = *((copy_t *)key_buf_ptr + lane_id); - bias_r.v = *((copy_t *)self_Q_bias_ptr + lane_id); - half2 qb_r[elems_per_thread]; - for (int i = 0; i < elems_per_thread; ++i) - { - qb_r[i] = __hadd2(query_buf_r.x[i], bias_r.x[i]); - } - - //offset for each step - int offset = batch_size * head_num * size_per_head / 2; - bias_r.v = *((copy_t *) self_K_bias + lane_id); - for(int ite = warp_id; ite < step; ite += warp_num) - { - key_val_r.v = *((copy_t *)&key_cache_ptr[ite * offset] + lane_id); - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1) - { - for (int i = 0; i < elems_per_thread; i++) - { - key_val_r.x[i] = __hadd2(key_buf_r.x[i], bias_r.x[i]); - } - *((copy_t *)&key_cache_ptr[ite * offset] + lane_id) = key_val_r.v; - } - float val = 0.f; - for (int i = 0; i < elems_per_thread; i++) - { - half2 val2 = __hmul2(key_val_r.x[i], qb_r[i]); - val = val + (float)((val2.x + val2.y) * scalar); - } - float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); - if (lane_id == 0) - { - logits[ite] = qk; - } - } - __syncthreads(); - - __shared__ float s_max_val, s_sum; - float local_i = -1e20f; - for(int i = tid; i < step; i += blockDim.x) - local_i = max(local_i, logits[i]); - - float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - float local_o = 0.0f; - for(int i = tid; i < step; i += blockDim.x) - { - logits[i] = __expf(logits[i] - s_max_val); - local_o += logits[i]; - } - float val = BlockReduce(block_temp_storage).Sum(local_o); - - if(tid == 0) - s_sum = val + 1e-6; - __syncthreads(); - - float s_sum_inverse = __fdividef(1.0f, s_sum); - for(int i = tid; i < step; i += blockDim.x) - { - logits[i] = logits[i] * s_sum_inverse; - } - __syncthreads(); - - // This optimization introduces discrepancy because of different order in FP32 summation - half2 sum_r[elems_per_thread]; - for(int i = 0; i < elems_per_thread; i++) - { - sum_r[i].x = (half)0.f; - sum_r[i].y = (half)0.f; - } - bias_r.v = *((copy_t *) self_V_bias_ptr + lane_id); - value_buf_r.v = *((copy_t *)value_buf_ptr + lane_id); - - for(int ite = warp_id; ite < step; ite += warp_num) - { - value_val_r.v = *((copy_t *)&value_cache_ptr[ite * offset] + lane_id); - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1) - { - for (int i = 0; i < elems_per_thread; i++) - { - value_val_r.x[i] = __hadd2(value_buf_r.x[i], bias_r.x[i]); - } - *((copy_t *)&value_cache_ptr[ite * offset] + lane_id) = value_val_r.v; - } - for (int i = 0; i < elems_per_thread; ++i) - { - half2 logit2_val; - logit2_val.x = (half)logits[ite]; - logit2_val.y = (half)logits[ite]; - sum_r[i] = __hadd2(sum_r[i], __hmul2(value_val_r.x[i], logit2_val)); - } - } - for (int i = 0; i < elems_per_thread; i++) - { - sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; - } - __syncthreads(); - if (warp_id == 0) - { - #pragma unroll - for (int j = 1; j < warp_num; j++) - { - for (int i = 0; i < elems_per_thread; ++i) - { - sum_r[i] = __hadd2(sum_r[i], sq[j * WARP_SIZE + tid].x[i]); - } - } - } - __syncthreads(); - #pragma unroll - for (int i = 0; i < elems_per_thread; i++) - { - value_val_r.x[i] = sum_r[i]; - } - if (warp_id == 0) - { - *((copy_t *)context_buf_ptr + lane_id) = value_val_r.v; - } -} - -template -__global__ -void masked_attention_kernel( - T* key_buf, T* value_buf, - T* query_buf, const T* self_Q_bias, - T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, - T* context_buf, int batch_size, int head_num, int size_per_head, const int step, const T scalar) -{ - extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; - T* sq = reinterpret_cast(s_buf); - T* logits = reinterpret_cast(&sq[size_per_head]); - - int tid = threadIdx.x; - int bid = blockIdx.x / head_num; - int head_id = blockIdx.x % head_num; - - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; - int qkv_bias_id = head_id * size_per_head + tid; - - if(tid < size_per_head) - sq[tid] = query_buf[qkv_id] + self_Q_bias[qkv_bias_id]; - __syncthreads(); - - //offset for each step - int offset = batch_size * head_num * size_per_head; - for(int ite = 0; ite < step; ++ite) - { - T key = tid < size_per_head ? key_cache[ite * offset + qkv_id] : (T)0.0f; - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1 && tid < size_per_head) - { - key = key_buf[qkv_id] + self_K_bias[qkv_bias_id]; - key_cache[ite * offset + qkv_id] = key; - } - - T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); - T qk = blockReduceSum(val); - if(threadIdx.x == 0) - logits[ite] = qk; - __syncthreads(); //try to remove - } - __syncthreads(); //try to remove - - __shared__ float s_max_val, s_sum; - float local_i = tid < step ? (float)logits[tid] : -1e20f; - float max_val = blockReduceMax(local_i); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - local_i -= s_max_val; - float local_o = tid < step ? __expf(local_i) : 0.0f; - float val = blockReduceSum(local_o); - - if(tid == 0) - s_sum = val + 1e-6; - __syncthreads(); - - if(tid < step) - logits[tid] = local_o / s_sum; - __syncthreads(); - - if(tid < size_per_head) - { - T sum = (T)0.0f; - for(int ite = 0; ite < step; ++ite) - { - T value = value_cache[ite * offset + qkv_id]; - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1) - { - value = value_buf[qkv_id] + self_V_bias[qkv_bias_id]; - value_cache[ite * offset + qkv_id] = value; - } - sum += value * logits[ite]; - } - context_buf[qkv_id] = sum; - } -} - -template -__global__ -void masked_attention_kernel_v2(T* query_buf, const T* self_Q_bias, - T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, - T* context_buf, int batch_size, int head_num, int size_per_head, const int step, const T scalar) -{ - extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; - T* sq = reinterpret_cast(s_buf); - T* logits = reinterpret_cast(&sq[size_per_head]); - - int tid = threadIdx.x; - int bid = blockIdx.x / head_num; - int head_id = blockIdx.x % head_num; - - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; - int qkv_bias_id = head_id * size_per_head + tid; - - if(tid < size_per_head) - sq[tid] = query_buf[qkv_id] + self_Q_bias[qkv_bias_id]; - __syncthreads(); - - int warp_size = 32; - int offset = batch_size * head_num * size_per_head; - int warp_ite = size_per_head / warp_size; - - T qk = (T)0.0f; - - //each warp process one step - int step_id = threadIdx.x >> 5; - if(step_id < step) - { - for(int wite = 0; wite < warp_ite; ++wite) - { - T key = key_cache[step_id * offset + bid * head_num * size_per_head + head_id * size_per_head - + tid % warp_size + wite * warp_size]; - //for the last step, we should update K + bias_K to the cache - if(step_id == step - 1) - { - key += self_K_bias[bid * head_num * size_per_head + head_id * size_per_head + - tid % warp_size + wite * warp_size]; - key_cache[step_id * offset + bid * head_num * size_per_head + head_id * size_per_head - + tid % warp_size + wite * warp_size] = key; - } - qk += key * sq[tid % warp_size + wite * warp_size]; - } - - qk = warpReduceSum(qk * scalar); - if(threadIdx.x % warp_size == 0) - { - logits[step_id] = qk; - printf("step_id %d %f\n", step_id, qk); - } - - } - __syncthreads(); - - __shared__ float s_max_val, s_sum; - float local_i = tid < step ? (float)logits[tid] : -1e20f; - float max_val = blockReduceMax(local_i); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - local_i -= s_max_val; - float local_o = tid < step ? __expf(local_i) : 0.0f; - float val = blockReduceSum(local_o); - - if(tid == 0) - s_sum = val; - __syncthreads(); - if(tid < step) - logits[tid] = local_o / s_sum; - __syncthreads(); - - - if(tid < size_per_head) - { - T sum = (T)0.0f; - for(int ite = 0; ite < step; ++ite) - { - T value = value_cache[ite * offset + qkv_id]; - //for the last step, we should update K + bias_K to the cache - if(ite == step - 1) - { - value += self_V_bias[qkv_bias_id]; - value_cache[ite * offset + qkv_id] = value; - } - sum += value * logits[ite]; - } - context_buf[qkv_id] = sum; - } -} - -template -void masked_attention_dispatch( - T* key_buf, T* value_buf, - T* query_buf, const T* self_Q_bias, - T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, - T* context_buf, int batch_size, int head_num, int size_per_head, const int step, cudaStream_t stream) - { - const int block_sz = ATTENTION_BLOCK_SIZE; - T scalar = (T)(1.f / sqrtf(size_per_head * 1.0f)); - - dim3 grid(batch_size * head_num); - - int cond = size_per_head * ((ATTENION_OPT)? 1:0); - switch (cond) - { - case 32: - masked_attention_kernel_opt<32, block_sz, T><<>>( - key_buf, value_buf, - query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, - batch_size, head_num, step, scalar); - break; - case 64: - if(sizeof(T) == 2) - masked_attention_kernel_opt_half2<64, block_sz><<>>( - key_buf, value_buf, - query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, - batch_size, head_num, step, scalar); - else - masked_attention_kernel_opt<64, block_sz, T><<>>( - key_buf, value_buf, - query_buf, self_Q_bias, - key_cache, self_K_bias, - value_cache, self_V_bias, - context_buf, - batch_size, head_num, step, scalar); - break; - case 128: - if(sizeof(T) == 2) - masked_attention_kernel_opt_half2<128, block_sz><<>>( - key_buf, value_buf, - query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, - batch_size, head_num, step, scalar); - else - masked_attention_kernel_opt<128, block_sz, T><<>>( - key_buf, value_buf, - query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, - batch_size, head_num, step, scalar); - break; - default: - // default path - int block_size = 128; - - //suppose size_per_head <= 128 - if(step <= 64) - block_size = 64; - else if(step <= 128 && step > size_per_head) - block_size = 128; - else if(step > 128 && step <= 256) - block_size = 256; - else if(step > 256 && step <= 512) - block_size = 512; - else - block_size = 1024; - - if((int)block_size < size_per_head) - block_size = size_per_head; - - assert(block_size <= 1024); - dim3 block(block_size); - T scalar = 1 / sqrtf(size_per_head * 1.0f); - - - int shared_size = sizeof(T) * (size_per_head + step); - masked_attention_kernel<<>>( - key_buf, value_buf, - query_buf, self_Q_bias, - key_cache, self_K_bias, - value_cache, self_V_bias, - context_buf, batch_size, - head_num, size_per_head, step, scalar); - } - } - -template -void OpenDecoder::masked_multi_head_attention( - const DataType_* from_tensor, - DataType_* key_cache_, - DataType_* value_cache_, - DataType_* decoder_output, - const int step) -{ - int m = batch_size_; - int n = hidden_units_; - int k = hidden_units_; - - DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; - - if(is_fuse_QKV == true) - { - check_cuda_error(cublasGemmBatchedEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - (const void* const*) qkv_kernel_, AType_, n, - (const void* const*) qkv_input_, BType_, k, - &beta, - (void* const*)qkv_buf_, CType_, n, - 3, - computeType_, - static_cast(cublasAlgo_[4]))); - } - else - { - key_buf_ = key_cache_ + (step - 1) * m * n; - value_buf_ = value_cache_ + (step - 1) * m * n; - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.query_weight.kernel , AType_, n, - from_tensor, BType_, k, - &beta, - query_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.key_weight.kernel, AType_, n, - from_tensor, BType_, k, - &beta, - key_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.value_weight.kernel, AType_, n, - from_tensor, BType_, k, - &beta, - value_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - } - - masked_attention_dispatch( - key_buf_, value_buf_, - query_buf_, param_.self_attention.query_weight.bias, - key_cache_, param_.self_attention.key_weight.bias, - value_cache_, param_.self_attention.value_weight.bias, - context_buf_, batch_size_, - head_num_, size_per_head_, step, param_.stream); - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.self_attention.attention_output_weight.kernel, AType_, n, - context_buf_, BType_, k, - &beta, - decoder_output, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); -} - -template -__global__ -void cross_attention_kernel_opt( - T* __restrict query_buf, const T* __restrict Q_bias, - T* __restrict key_cache, const T* __restrict K_bias, - T* __restrict value_cache, const T* __restrict V_bias, - const int* length_per_sample, T* __restrict context_buf, - int batch_size, int head_num, const int step, const int seq_len, const float scalar) -{ - typedef Copy_t copy_t; - const int elems_per_thread = size_per_head / WARP_SIZE; - union Access_t - { - copy_t v; - T x[elems_per_thread]; // supported size 1,2,4 - }; - typedef struct Float_n_t - { - float x[elems_per_thread]; // supported size 1,2,4 - } float_n_t; - - __shared__ float_n_t sq[block_sz]; - extern __shared__ float logits[]; // use to store the logits from [0~step] - - const int warp_id = threadIdx.x / WARP_SIZE; - const int warp_num = block_sz / WARP_SIZE; - - typedef cub::BlockReduce MaxValBlockReduce; - typedef cub::BlockReduce BlockReduce; - __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; - __shared__ typename BlockReduce::TempStorage block_temp_storage; - - __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; - - const int tid = threadIdx.x; - const int bid = blockIdx.x / head_num; - const int head_id = blockIdx.x % head_num; - - int length = __ldg(&length_per_sample[bid]); - - const int lane_id = tid % WARP_SIZE; - - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head; - int qkv_bias_id = head_id * size_per_head; - - int key_value_id = bid * (seq_len * head_num * size_per_head) + - + head_id * size_per_head; - - query_buf = &query_buf[qkv_id]; - K_bias = &K_bias[qkv_bias_id]; - key_cache = &key_cache[key_value_id]; - Q_bias = &Q_bias[qkv_bias_id]; - V_bias = &V_bias[qkv_bias_id]; - value_cache = &value_cache[key_value_id]; - context_buf = &context_buf[qkv_id]; - - Access_t bias_r, key_val_r, query_buf_r; - - // each warp will have its own copy of sq - query_buf_r.v = *((copy_t *)query_buf + lane_id); - bias_r.v = *((copy_t *)Q_bias + lane_id); - float qb_r[elems_per_thread]; - for (int i = 0; i < elems_per_thread; ++i) - { - qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; - } - - //offset for each step - int offset = head_num * size_per_head; - - bias_r.v = *((copy_t *) K_bias + lane_id); - for(int ite = warp_id; ite < length; ite += warp_num) - { - key_val_r.v = *((copy_t *)&key_cache[ite * offset] + lane_id); - - //For the first step, we should add bias to key memory cache. - //The KV memory cache only need to be updated at the first step. - if (step == 1) - { - for (int i = 0; i < elems_per_thread; i++) - { - key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; - } - *((copy_t *)&key_cache[ite * offset] + lane_id) = key_val_r.v; - } - float val = 0.f; - for (int i = 0; i < elems_per_thread; i++) - { - val = val + (float)key_val_r.x[i] * qb_r[i] * scalar; - } - float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); - if (lane_id == 0) - { - logits[ite] = qk; - } - } - __syncthreads(); - - __shared__ float s_max_val, s_sum; - float local_i = -1e20f; - for(int i = tid; i < length; i += blockDim.x) - local_i = max(local_i, logits[i]); - - float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - float local_o = 0.0f; - for(int i = tid; i < length; i += blockDim.x) - { - logits[i] = __expf(logits[i] - s_max_val); - local_o += logits[i]; - } - float val = BlockReduce(block_temp_storage).Sum(local_o); - - if(tid == 0) - s_sum = val + 1e-6; - __syncthreads(); - - float s_sum_inverse = __fdividef(1.0f, s_sum); - for(int i = tid; i < length; i += blockDim.x) - { - logits[i] = logits[i] * s_sum_inverse; - } - __syncthreads(); - - // This optimization introduces discrepancy because of different order in FP32 summation - float sum_r[elems_per_thread] = {0.f}; - bias_r.v = *((copy_t *) V_bias + lane_id); - for(int ite = warp_id; ite < length; ite += warp_num) - { - key_val_r.v = *((copy_t *)&value_cache[ite * offset] + lane_id); - - //For the first step, we should add bias to key memory cache. - if(step == 1) - { - for (int i = 0; i < elems_per_thread; i++) - { - key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; - } - *((copy_t *)&value_cache[ite * offset] + lane_id) = key_val_r.v; - } - for (int i = 0; i < elems_per_thread; ++i) - { - sum_r[i] += (float)key_val_r.x[i] * logits[ite]; - } - } - for (int i = 0; i < elems_per_thread; i++) - { - sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; - } - __syncthreads(); - if (threadIdx.x < WARP_SIZE) - { - #pragma unroll - for (int j = 1; j < warp_num; j++) - { - for (int i = 0; i < elems_per_thread; ++i) - { - sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + threadIdx.x].x[i]; - } - } - } - __syncthreads(); - #pragma unroll - for (int i = 0; i < elems_per_thread; i++) - { - key_val_r.x[i] = sum_r[i]; - } - if (threadIdx.x < WARP_SIZE) - { - *((copy_t *)context_buf + lane_id) = key_val_r.v; - } -} - -template -__global__ -void cross_attention_kernel( - T* query_buf, const T* Q_bias, - T* key_cache, const T* K_bias, - T* value_cache, const T* V_bias, - const int* length_per_sample, T* context_buf, - int batch_size, int head_num, int size_per_head, int step, const int seq_len, const T scalar) -{ - int tid = threadIdx.x; - int bid = blockIdx.x / head_num; - int head_id = blockIdx.x % head_num; - - extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; - T* sq = reinterpret_cast(s_buf); - T* logits = reinterpret_cast(&sq[size_per_head]); - - int length = __ldg(&length_per_sample[bid]); - - int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; - int qkv_bias_id = head_id * size_per_head + tid; - - if(tid < size_per_head) - sq[tid] = query_buf[qkv_id] + Q_bias[qkv_bias_id]; - __syncthreads(); - - for(int ite = 0; ite < length; ++ite) - { - int key_id = bid * (seq_len * head_num * size_per_head) + ite * (head_num * size_per_head) - + head_id * size_per_head + tid; - - T key = tid < size_per_head ? key_cache[key_id] : (T)(0.0f); - - //For the first step, we should add bias to key memory cache. - //The KV memory cache only need to be updated at the first step. - if(step == 1 && tid < size_per_head) - { - key += K_bias[head_id * size_per_head + tid]; - key_cache[key_id] = key; - } - - T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); - T qk = blockReduceSum(val); - if(threadIdx.x == 0) - logits[ite] = qk; - __syncthreads(); //try to remove - } - __syncthreads(); - - __shared__ float s_max_val, s_sum; - - float local_i = tid < length ? (float)logits[tid] : -1e20f; - float max_val = blockReduceMax(local_i); - if(tid == 0) - s_max_val = max_val; - __syncthreads(); - - local_i -= s_max_val; - float local_o = tid < length ? __expf(local_i) : 0.0f; - float val = blockReduceSum(local_o); - - if(tid == 0) - s_sum = val + 1e-6; - __syncthreads(); - if(tid < length) - logits[tid] = local_o / s_sum; - __syncthreads(); - - if(tid < size_per_head) - { - T sum = (T)0.0f; - for(int ite = 0; ite < length; ++ite) - { - int value_id = bid * seq_len * head_num * size_per_head + ite * head_num * size_per_head - + head_id * size_per_head + tid; - - T value = value_cache[value_id]; - - //for the first step, we should add bias to key memory cache - if(step == 1) - { - value += V_bias[head_id * size_per_head + tid]; - value_cache[value_id] = value; - } - sum += value * logits[ite]; - } - context_buf[bid * head_num * size_per_head + head_id * size_per_head + tid] = sum; - } -} - -template -void cross_attention_dispatch(T* query_buf, const T* Q_bias, - T* key_cache, const T* K_bias, T* value_cache, const T* V_bias, const int* length, - T* context_buf, int batch_size, int head_num, int size_per_head, int step, int seq_len, cudaStream_t stream) - { - const int block_sz = ATTENTION_BLOCK_SIZE; - float scalar = 1.f / sqrtf(size_per_head * 1.0f); - - dim3 grid(batch_size * head_num); - - int cond = size_per_head * ((ATTENION_OPT)? 1:0); - switch (cond) - { - case 32: - cross_attention_kernel_opt<<>>( - query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, - batch_size, head_num, step, seq_len, scalar); - break; - case 64: - cross_attention_kernel_opt<<>>( - query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, - batch_size, head_num, step, seq_len, scalar); - break; - case 128: - cross_attention_kernel_opt<<>>( - query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, - batch_size, head_num, step, seq_len, scalar); - break; - default: - // default path - - int block_size = 128; - - if(seq_len <= 64) - block_size = 64; - else if(seq_len <= 128 && seq_len > size_per_head) - block_size = 128; - else if(seq_len > 128 && seq_len <= 256) - block_size = 256; - else if(seq_len > 256 && seq_len <= 512) - block_size = 512; - else - block_size = 1024; - - if(block_size < size_per_head) - block_size = size_per_head; - - assert(block_size <= 1024); - dim3 block(block_size); - - int shared_size = sizeof(T) * (size_per_head + seq_len); - cross_attention_kernel<<>>( - query_buf, Q_bias, - key_cache, K_bias, - value_cache, V_bias, - length, context_buf, - batch_size, - head_num, size_per_head, step, seq_len, scalar); - } - } - -/* attention with source sentence */ -template -void OpenDecoder::cross_multi_head_attention( - const DataType_* from_tensor, - const DataType_* memory_tensor, - DataType_* key_mem_cache, - DataType_* value_mem_cache, - DataType_* decoder_output, - const int* length, - const int seq_len, - const int step) -{ - int m = batch_size_; - int n = hidden_units_; - int k = hidden_units_; - - DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; - - //reuse the query_buf - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.cross_attention.query_weight.kernel, AType_, n, - from_tensor, BType_, k, - &beta, - query_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); - - if(step == 1) - { - m *= seq_len; - k = memory_hidden_units_; - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.cross_attention.key_weight.kernel, AType_, n, - memory_tensor, BType_, k, - &beta, - key_mem_cache, CType_, n, - computeType_, - static_cast(cublasAlgo_[1]))); - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.cross_attention.value_weight.kernel, AType_, n, - memory_tensor, BType_, k, - &beta, - value_mem_cache, CType_, n, - computeType_, - static_cast(cublasAlgo_[1]))); - k = hidden_units_; - } - - cross_attention_dispatch( - query_buf_, param_.cross_attention.query_weight.bias, - key_mem_cache, param_.cross_attention.key_weight.bias, - value_mem_cache, param_.cross_attention.value_weight.bias, - length, context_buf_, batch_size_, - head_num_, size_per_head_, step, seq_len, param_.stream); - - m = batch_size_; - n = head_num_ * size_per_head_; - k = n; - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - param_.cross_attention.attention_output_weight.kernel, AType_, n, - context_buf_, BType_, k, - &beta, - decoder_output, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); -} - -template -__global__ -void decoder_norm1_kernel_generalize(const T* __restrict input, - const T* __restrict gamma, - const T* __restrict beta, - T* output, - int m, int n) -{ - const int tid = threadIdx.x; - - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - - float local_sum = 0.0f; - for(int i = tid; i < n; i+= blockDim.x) - { - local_sum += (float)(__ldg(&input[blockIdx.x * n + i])); - } - - mean = blockReduceSum(local_sum); - - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - float local_var_sum = 0.0f; - for(int i = tid; i < n; i+= blockDim.x) - { - float diff = (float)(__ldg(&input[blockIdx.x * n + i])) - s_mean; - local_var_sum += diff * diff; - } - variance = blockReduceSum(local_var_sum); - - if(threadIdx.x == 0) - s_variance = rsqrtf(variance / n + 1e-6); - - __syncthreads(); - - for(int i = tid; i < n; i+= blockDim.x) - { - output[blockIdx.x * n + i] = - (T)((( (float)input[blockIdx.x * n + i] - s_mean) * s_variance) * (float)(__ldg(&gamma[i])) + (float)(__ldg(&beta[i]))); - } -} - -template -__global__ -void decoder_norm1_kernel(const T* __restrict input, - const T* __restrict gamma, - const T* __restrict beta, - T* output, - int m, int n) -{ - int tid = threadIdx.x; - - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - - float local_out = tid < n ? (float)(__ldg(&input[blockIdx.x * n + tid])) : 0.0f; - - mean = blockReduceSum(local_out); - - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - variance = blockReduceSum(tid < n ? (local_out - s_mean) * (local_out - s_mean) : 0.0f); - - if(threadIdx.x == 0) - s_variance = rsqrtf(variance / n + 1e-6); - - __syncthreads(); - - if(tid < n) - output[blockIdx.x * n + tid] = - (T)(((local_out - s_mean) * s_variance) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); -} - -template <> -__global__ -void decoder_norm1_kernel(const half* __restrict input, - const half* __restrict gamma, - const half* __restrict beta, - half* output, - int m, int n) -{ - const int tid = threadIdx.x; - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float2 local_out_fp2; - - const half2* input_ptr = (const half2*)input; - const half2* gamma_ptr = (const half2*)gamma; - const half2* beta_ptr = (const half2*)beta; - half2* output_ptr = (half2*)output; - - float local_out = 0.0f; - int id = blockIdx.x * blockDim.x + tid; - if(tid < blockDim.x) - { - local_out_fp2 = __half22float2(__ldg(&input_ptr[id])); - local_out += local_out_fp2.x; - local_out += local_out_fp2.y; - } - - mean = blockReduceSum(local_out); - if(tid == 0) - s_mean = mean / n; - __syncthreads(); - - variance = blockReduceSum(tid < blockDim.x ? - (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean) + (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean) - : 0.0f); - if(tid == 0) - s_variance = rsqrtf(variance / n + 1e-6); - __syncthreads(); - - if(tid < blockDim.x) - { - float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); - float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); - local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; - local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; - output_ptr[id] = __float22half2_rn(local_out_fp2); - } -} - -template -__global__ -void decoder_norm2_kernel_generalize(const T* __restrict input, - const T* __restrict gamma, - const T* __restrict beta, - const T* __restrict bias, - T* output, T* norm_output, - int m, int n) -{ - int tid = threadIdx.x; - - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - - float local_sum = 0.0f; - for(int i = tid; i < n; i+= blockDim.x) - { - float local_out = (float)(__ldg(&input[blockIdx.x * n + i])); - local_out += (float)(output[blockIdx.x * n + i]); - local_out += (float)(__ldg(&bias[i])); - output[blockIdx.x * n + i] = (T)local_out; - local_sum += local_out; - } - - mean = blockReduceSum(local_sum); - - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - float local_var_sum = 0.0f; - for(int i = tid; i < n; i+= blockDim.x) - { - float diff = (float)(__ldg(&output[blockIdx.x * n + i])) - s_mean; - local_var_sum += diff * diff; - } - variance = blockReduceSum(local_var_sum); - - if(threadIdx.x == 0) - s_variance = rsqrtf(variance / n + 1e-6); - __syncthreads(); - - for(int i = tid; i < n; i+= blockDim.x) - { - norm_output[blockIdx.x * n + i] = - (T)((( (float)output[blockIdx.x * n + i] - s_mean) * s_variance) * (float)(__ldg(&gamma[i])) + (float)(__ldg(&beta[i]))); - } -} - -template -__global__ -void decoder_norm2_kernel(const T* __restrict input, - const T* __restrict gamma, - const T* __restrict beta, - const T* __restrict bias, - T* output, T* norm_output, - int m, int n) -{ - int tid = threadIdx.x; - - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - - float local_out = 0.0f; - if(tid < n) - { - local_out = (float)(__ldg(&input[blockIdx.x * n + tid])); - local_out += (float)(output[blockIdx.x * n + tid]); - local_out += (float)(__ldg(&bias[tid])); - output[blockIdx.x * n + tid] = (T)local_out; - } - - mean = blockReduceSum(local_out); - if(threadIdx.x == 0) - s_mean = mean / n; - __syncthreads(); - - variance = blockReduceSum(tid < n ? (local_out - s_mean) * (local_out - s_mean) : 0.0f); - if(threadIdx.x == 0) - s_variance = rsqrtf(variance / n + 1e-6); - __syncthreads(); - - if(tid < n) - norm_output[blockIdx.x * n + tid] = - (T)((local_out - s_mean) * s_variance * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); -} - -template <> -__global__ -void decoder_norm2_kernel(const half* __restrict input, - const half* __restrict gamma, - const half* __restrict beta, - const half* __restrict bias, - half* output, half* norm_output, - int m, int n) -{ - const int tid = threadIdx.x; - __shared__ float s_mean; - __shared__ float s_variance; - float mean = 0.0f; - float variance = 0.0f; - float2 local_out_fp2; - - const half2* input_ptr = (const half2*)input; - const half2* gamma_ptr = (const half2*)gamma; - const half2* beta_ptr = (const half2*)beta; - const half2* bias_ptr = (const half2*)bias; - half2* output_ptr = (half2*)output; - half2* norm_output_ptr = (half2*)norm_output; - - float local_out = 0.0f; - int id = blockIdx.x * blockDim.x + tid; - if(tid < blockDim.x) - { - output_ptr[id] = __hadd2(__hadd2(output_ptr[id], __ldg(&input_ptr[id])), __ldg(&bias_ptr[tid])); - local_out_fp2 = __half22float2(output_ptr[id]); - local_out += local_out_fp2.x; - local_out += local_out_fp2.y; - } - - mean = blockReduceSum(local_out); - if(tid == 0) - s_mean = mean / n; - __syncthreads(); - - variance = blockReduceSum(tid < blockDim.x ? - (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean) + (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean) - : 0.0f); - if(tid == 0) - s_variance = rsqrtf(variance / n + 1e-6); - __syncthreads(); - - if(tid < blockDim.x) - { - float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); - float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); - local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; - local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; - norm_output_ptr[id] = __float22half2_rn(local_out_fp2); - } -} - -template -void OpenDecoder::decoder_norm1( - const DataType_* input, - const DataType_* gamma, - const DataType_* beta, - DataType_* output, - int m, int n) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - - /* For general cases, n is equal to hidden_units, e.g., 512/1024. - Since we have warp shuffle inside the code, block.x % 32 should be 0. - */ - if(n % 32 != 0) - block.x = 1024; - - block.x = block.x / (4 / sizeof(DataType_)); // if using half, only need half of block.x - - /* should pay attention to the rsqrt precision*/ - // assert(block.x <= 1024); - // decoder_norm1_kernel<<>>(input, gamma, beta, output, m, n); - decoder_norm1_kernel_generalize<<>>(input, gamma, beta, output, m, n); // For gpt-3 -} - -template -void OpenDecoder::decoder_norm2( - const DataType_* input, - const DataType_* gamma, - const DataType_* beta, - const DataType_* bias, - DataType_* output, - DataType_* norm_output, - int m, int n) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - - - /* For general cases, n is equal to hidden_units, e.g., 512/1024. - Since we have warp shuffle inside the code, block.x % 32 should be 0. - */ - - if(n % 32 != 0) - block.x = 1024; - - block.x = block.x / (4 / sizeof(DataType_)); // if using half, only need half of block.x - - /* should pay attention to the rsqrt precision*/ - // assert(block.x <= 1024); - // decoder_norm2_kernel<<>>(input, gamma, beta, bias, output, norm_output, m, n); - decoder_norm2_kernel_generalize<<>>(input, gamma, beta, bias, output, norm_output, m, n); // For gpt-3 -} - -template -void OpenDecoder::ffn( - const DataType_* input, - DataType_* ffn_inner, - DataType_* output, - const int m, - const int inner_size, - const int n, - ActivationType activation_type) -{ - int m1 = m, k1 = n, n1 = inner_size; - DataType_ alpha = (DataType_)1.0f; - DataType_ beta = (DataType_)0.0f; - - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n1, m1, k1, - &alpha, - param_.ffn.intermediate_weight.kernel, AType_, n1, - input, BType_, k1, - &beta, - ffn_inner, CType_, n1, - computeType_, - static_cast(cublasAlgo_[2]))); - - // dim3 grid(min(m1, 65536)); - // dim3 block(min(n1 / 4, 1024)); - - // // TODO remove this limitation - // // assert(block.x <= 1024); - - // if(activation_type == ActivationType::RELU) - // add_bias_relu<<>>(ffn_inner, param_.ffn.intermediate_weight.bias, m1, n1); - // else if(activation_type == ActivationType::GELU) - // add_bias_gelu<<>>(ffn_inner, param_.ffn.intermediate_weight.bias, m1, n1); - - dim3 block(min((int)(n1 / 4 / (4 / sizeof(DataType_))), 1024)); - dim3 grid(min(m1 * n1 / block.x, 65536)); - - if(activation_type == ActivationType::RELU) - add_bias_relu<<>>(ffn_inner, param_.ffn.intermediate_weight.bias, m1, n1 / (4 / sizeof(DataType_))); - else if(activation_type == ActivationType::GELU) - add_bias_gelu<<>>(ffn_inner, param_.ffn.intermediate_weight.bias, m1, n1 / (4 / sizeof(DataType_))); - - - int m2 = m, n2 = n, k2 = inner_size; - check_cuda_error(cublasGemmEx(param_.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n2, m2, k2, - &alpha, - param_.ffn.output_weight.kernel, AType_, n2, - ffn_inner, BType_, k2, - &beta, - output, CType_, n2, - computeType_, - static_cast(cublasAlgo_[3]))); -} - -template -__global__ -void add_bias_input_kernel(T* output, const T* input, const T* bias, const int m, const int n) -{ - // original kernel, which only supports cases of n <= 1024. - int id = blockIdx.x * n + threadIdx.x; - output[id] = output[id] + input[id] + __ldg(&bias[threadIdx.x]); -} - - -template -__global__ -void add_bias_input_kernel_generalize(T* output, const T* input, const T* bias, const int m, const int n) -{ - // TODO For GPT-3 - // This kernel can run with any block size and grid size - // Since the hidden dimension of GPT-3 would be larger than 1024 - const int bid = blockIdx.x; - const int blocks_per_row = n / blockDim.x; - const int col_index = (bid % blocks_per_row) * blockDim.x + threadIdx.x; - T bias_val = __ldg(&bias[col_index]); - for(int index = bid * blockDim.x + threadIdx.x; index < m * n; index += blockDim.x * gridDim.x) - { - output[index] = output[index] + input[index] + bias_val; - } -} - -template -void OpenDecoder::add_bias_input(DataType_* output, const DataType_* input, const int m, const int n) -{ - dim3 grid(min(m, 65536)); - dim3 block(min(n, 1024)); - - add_bias_input_kernel_generalize<<>>(output, input, param_.ffn.output_weight.bias, m, n); -} - -template void OpenDecoder::masked_multi_head_attention( - const float* from_tensor, - float* key_cache, - float* value_cache, - float* decoder_output, - const int step); - -template void OpenDecoder::masked_multi_head_attention( - const half* from_tensor, - half* key_cache, - half* value_cache, - half* decoder_output, - const int step); - -template void OpenDecoder::cross_multi_head_attention( - const float* from_tensor, - const float* memory_tensor, - float* key_mem_cache, - float* value_mem_cache, - float* decoder_output, - const int* length, - const int max_seq_len, - const int step); - -template void OpenDecoder::cross_multi_head_attention( - const half* from_tensor, - const half* memory_tensor, - half* key_mem_cache, - half* value_mem_cache, - half* decoder_output, - const int* length, - const int max_seq_len, - const int step); - -template void OpenDecoder::ffn( - const float* input, - float* ffn_inner, - float* otuput, - const int m, - const int inner_size, - const int n, - ActivationType activation_type); - -template void OpenDecoder::ffn( - const half* input, - half* ffn_inner, - half* otuput, - const int m, - const int inner_size, - const int n, - ActivationType activation_type); - -template void OpenDecoder::decoder_norm1( - const float* input, - const float* gamma, - const float* beta, - float* output, - int m, int n); - -template void OpenDecoder::decoder_norm1( - const half* input, - const half* gamma, - const half* beta, - half* output, - int m, int n); - -template void OpenDecoder::decoder_norm2( - const float* input, - const float* gamma, - const float* beta, - const float* bias, - float* output, - float* norm_output, - int m, int n); - -template void OpenDecoder::decoder_norm2( - const half* input, - const half* gamma, - const half* beta, - const half* bias, - half* output, - half* norm_output, - int m, int n); - -template void OpenDecoder::add_bias_input( - float* output, - const float* input, - const int m, - const int n); - -template void OpenDecoder::add_bias_input( - half* output, - const half* input, - const int m, - const int n); - -}//namespace FasterTransformer +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Open sourced multi-head attention + **/ +#include +#include + +#include "fastertransformer/open_decoder.h" +#include "cub/cub.cuh" +#include "fastertransformer/utils/nvtx_utils.h" +#include "masked_multihead_attention.h" + +namespace fastertransformer{ + +const int WARP_SIZE = 32; +const bool ATTENION_OPT = true; +const int ATTENTION_BLOCK_SIZE = 256; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Copy_half_t = + typename std::conditional::type + >::type + >::type; + +template +using Copy_t = Copy_half_t; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + masked multi-head attention +*/ +#define FINAL_MASK 0xffffffff +template +__inline__ __device__ +T warpReduceSum(T val) +{ + for(int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} +/* Calculate the sum of all elements in a block */ +template + __inline__ __device__ +T blockReduceSum(T val) +{ + static __shared__ T shared[32]; + // __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if(lane == 0) + shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + + return val; +} + +template + __inline__ __device__ +T warpReduceMax(T val) +{ + for(int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} +/* Calculate the maximum of all elements in a block */ +template + __inline__ __device__ +T blockReduceMax(T val) +{ + static __shared__ T shared[32]; +// __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if(lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + + val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)-1e20f; + val = warpReduceMax(val); + + return val; +} + +template +__global__ +void masked_attention_kernel_opt( + T* __restrict key_buf, T* __restrict value_buf, + T* __restrict query_buf, const T* __restrict self_Q_bias, + T* __restrict key_cache, const T* __restrict self_K_bias, + T* __restrict value_cache, const T* __restrict self_V_bias, + T* __restrict context_buf, const bool* finished, + int batch_size, int head_num, const int step, const T scalar) +{ + if(finished != nullptr && finished[blockIdx.x / head_num] == true) return; + typedef Copy_t copy_t; + const int elems_per_thread = size_per_head / WARP_SIZE; + + union Access_t + { + copy_t v; + T x[elems_per_thread]; // supported size 1,2,4 + }; + typedef struct Float_n_t + { + T x[elems_per_thread]; // supported size 1,2,4 + } float_n_t; + + __shared__ float_n_t sq[block_sz]; + + extern __shared__ float logits[]; // use to store the logits from [0~step] + + const int tid = threadIdx.x; + const int warp_num = block_sz / WARP_SIZE; + const int bid = blockIdx.x; + const int head_id = blockIdx.x % head_num; + const int warp_id = tid / WARP_SIZE; // warp_id in block + const int lane_id = tid % WARP_SIZE; // lane_id in warp + + typedef cub::BlockReduce MaxValBlockReduce; + typedef cub::BlockReduce BlockReduce; + __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; + __shared__ typename BlockReduce::TempStorage block_temp_storage; + __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; + + int qkv_id = bid * size_per_head; + int qkv_bias_id = head_id * size_per_head; + + query_buf = &query_buf[qkv_id]; + key_buf = &key_buf[qkv_id]; + value_buf = &value_buf[qkv_id]; + self_K_bias = &self_K_bias[qkv_bias_id]; + key_cache = &key_cache[qkv_id]; + self_Q_bias = &self_Q_bias[qkv_bias_id]; + self_V_bias = &self_V_bias[qkv_bias_id]; + value_cache = &value_cache[qkv_id]; + context_buf = &context_buf[qkv_id]; + + Access_t bias_r, query_buf_r; + Access_t key_val_r, key_buf_r; + Access_t value_val_r, value_buf_r; + + // each warp will have its own copy of sq + query_buf_r.v = *((copy_t *)query_buf + lane_id); + key_buf_r.v = *((copy_t *)key_buf + lane_id); + bias_r.v = *((copy_t *)self_Q_bias + lane_id); + float qb_r[elems_per_thread]; + for (int i = 0; i < elems_per_thread; ++i) + { + qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; + } + + //offset for each step + int offset = batch_size * head_num * size_per_head; + bias_r.v = *((copy_t *) self_K_bias + lane_id); + for(int ite = warp_id; ite < step; ite += warp_num) + { + key_val_r.v = *((copy_t *)&key_cache[ite * offset] + lane_id); + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = (float)key_buf_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&key_cache[ite * offset] + lane_id) = key_val_r.v; + } + float val = 0.f; + for (int i = 0; i < elems_per_thread; i++) + { + val = val + (float)key_val_r.x[i] * qb_r[i] * (float)scalar; + } + float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); + if (lane_id == 0) + { + logits[ite] = qk; + } + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + + float local_i = -1e20f; + for(int i = tid; i < step; i += blockDim.x) + local_i = max(local_i, logits[i]); + + float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + + float local_o = 0.0f; + for(int i = tid; i < step; i += blockDim.x) + { + logits[i] = __expf(logits[i] - s_max_val); + local_o += logits[i]; + } + float val = BlockReduce(block_temp_storage).Sum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + + float s_sum_inverse = __fdividef(1.0f, s_sum); + for(int i = tid; i < step; i += blockDim.x) + { + logits[i] = logits[i] * s_sum_inverse; + } + __syncthreads(); + + // This optimization introduces discrepancy because of different order in FP32 summation + float sum_r[elems_per_thread] = {0.f}; + bias_r.v = *((copy_t *) self_V_bias + lane_id); + value_buf_r.v = *((copy_t *)value_buf + lane_id); + + for(int ite = warp_id; ite < step; ite += warp_num) + { + value_val_r.v = *((copy_t *)&value_cache[ite * offset] + lane_id); + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + value_val_r.x[i] = (float)value_buf_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&value_cache[ite * offset] + lane_id) = value_val_r.v; + } + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] += (float)value_val_r.x[i] * logits[ite]; + } + } + for (int i = 0; i < elems_per_thread; i++) + { + sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; + } + __syncthreads(); + if (warp_id == 0) + { + #pragma unroll + for (int j = 1; j < warp_num; j++) + { + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + tid].x[i]; + } + } + } + __syncthreads(); + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) + { + value_val_r.x[i] = sum_r[i]; + } + if (warp_id == 0) + { + *((copy_t *)context_buf + lane_id) = value_val_r.v; + } +} + +template +__global__ +void masked_attention_kernel( + T* key_buf, T* value_buf, + T* query_buf, const T* self_Q_bias, + T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, + T* context_buf, const bool* finished, + int batch_size, int head_num, int size_per_head, const int step, const T scalar) +{ + if(finished != nullptr && finished[blockIdx.x / head_num] == true) return; + extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; + T* sq = reinterpret_cast(s_buf); + T* logits = reinterpret_cast(&sq[size_per_head]); + + int tid = threadIdx.x; + int bid = blockIdx.x / head_num; + int head_id = blockIdx.x % head_num; + + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; + int qkv_bias_id = head_id * size_per_head + tid; + + if(tid < size_per_head) + sq[tid] = query_buf[qkv_id] + self_Q_bias[qkv_bias_id]; + __syncthreads(); + + //offset for each step + int offset = batch_size * head_num * size_per_head; + for(int ite = 0; ite < step; ++ite) + { + T key = tid < size_per_head ? key_cache[ite * offset + qkv_id] : (T)0.0f; + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1 && tid < size_per_head) + { + key = key_buf[qkv_id] + self_K_bias[qkv_bias_id]; + key_cache[ite * offset + qkv_id] = key; + } + + T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); + T qk = blockReduceSum(val); + if(threadIdx.x == 0) + logits[ite] = qk; + __syncthreads(); //try to remove + } + __syncthreads(); //try to remove + + __shared__ float s_max_val, s_sum; + float local_i = tid < step ? (float)logits[tid] : -1e20f; + float max_val = blockReduceMax(local_i); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + local_i -= s_max_val; + float local_o = tid < step ? __expf(local_i) : 0.0f; + float val = blockReduceSum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + + if(tid < step) + logits[tid] = local_o / s_sum; + __syncthreads(); + + if(tid < size_per_head) + { + T sum = (T)0.0f; + for(int ite = 0; ite < step; ++ite) + { + T value = value_cache[ite * offset + qkv_id]; + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + value = value_buf[qkv_id] + self_V_bias[qkv_bias_id]; + value_cache[ite * offset + qkv_id] = value; + } + sum += value * logits[ite]; + } + context_buf[qkv_id] = sum; + } +} + +template +void masked_attention_dispatch( + T* key_buf, T* value_buf, + T* query_buf, const T* self_Q_bias, + T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, + T* context_buf, const bool* finished, int max_batch_size, int inference_batch_size, + int head_num, int size_per_head, const int step, const int max_seq_len, cudaStream_t stream) +{ + if (max_seq_len < 0) { + const int block_sz = ATTENTION_BLOCK_SIZE; + T scalar = (T)(1.f / sqrtf(size_per_head * 1.0f)); + + dim3 grid(inference_batch_size * head_num); + + int cond = size_per_head * ((ATTENION_OPT)? 1:0); + switch (cond) + { + case 32: + masked_attention_kernel_opt<32, block_sz, T><<>>( + key_buf, value_buf, + query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, finished, + max_batch_size, head_num, step, scalar); + break; + case 64: + masked_attention_kernel_opt<64, block_sz, T><<>>( + key_buf, value_buf, + query_buf, self_Q_bias, + key_cache, self_K_bias, + value_cache, self_V_bias, + context_buf, + finished, + max_batch_size, head_num, step, scalar); + break; + case 128: + masked_attention_kernel_opt<128, block_sz, T><<>>( + key_buf, value_buf, + query_buf, self_Q_bias, key_cache, self_K_bias, value_cache, self_V_bias, context_buf, finished, + max_batch_size, head_num, step, scalar); + break; + default: + // default path + int block_size = 128; + + //suppose size_per_head <= 128 + if(step <= 64) + block_size = 64; + else if(step <= 128 && step > size_per_head) + block_size = 128; + else if(step > 128 && step <= 256) + block_size = 256; + else if(step > 256 && step <= 512) + block_size = 512; + else + block_size = 1024; + + if((int)block_size < size_per_head) + block_size = size_per_head; + + assert(block_size <= 1024); + dim3 block(block_size); + T scalar = 1 / sqrtf(size_per_head * 1.0f); + + + int shared_size = sizeof(T) * (size_per_head + step); + masked_attention_kernel<<>>( + key_buf, value_buf, + query_buf, self_Q_bias, + key_cache, self_K_bias, + value_cache, self_V_bias, + context_buf, finished, max_batch_size, + head_num, size_per_head, step, scalar); + } + } + else { + assert(step > 0); + assert(size_per_head == 32 || size_per_head == 64 || size_per_head == 128); + using DataType = typename std::conditional::type; + // Prepare the parameters. + Masked_multihead_attention_params params; + memset(¶ms, 0, sizeof(params)); + params.q_bias = reinterpret_cast(self_Q_bias); + params.k_bias = reinterpret_cast(self_K_bias); + params.v_bias = reinterpret_cast(self_V_bias); + + // Set the output buffer. + params.out = reinterpret_cast(context_buf); + + // Set the input buffers. + params.q = reinterpret_cast(query_buf); + params.k = reinterpret_cast(key_buf); + params.v = reinterpret_cast(value_buf); + params.stride = 0; + params.finished = const_cast(finished); + + params.k_cache = reinterpret_cast(key_cache); + params.v_cache = reinterpret_cast(value_cache); + params.batch_size = inference_batch_size; + params.seq_length = max_seq_len; + params.timestep = step-1; + params.num_heads = head_num; + params.hidden_size_per_head = size_per_head; + params.inv_sqrt_dh = 1.F / sqrtf((float) params.hidden_size_per_head); + + masked_multihead_attention(params, stream); + } +} + +template void masked_attention_dispatch( + float* key_buf, + float* value_buf, + float* query_buf, + const float* self_Q_bias, + float* key_cache, + const float* self_K_bias, + float* value_cache, + const float* self_V_bias, + float* context_buf, + const bool* finished, + int max_batch_size, + int inference_batch_size, + int head_num, + int size_per_head, + const int step, + const int max_seq_size, + cudaStream_t stream); + +template void masked_attention_dispatch( + half* key_buf, + half* value_buf, + half* query_buf, + const half* self_Q_bias, + half* key_cache, + const half* self_K_bias, + half* value_cache, + const half* self_V_bias, + half* context_buf, + const bool* finished, + int max_batch_size, + int inference_batch_size, + int head_num, + int size_per_head, + const int step, + const int max_seq_size, + cudaStream_t stream); + +template +__global__ +void fusedQKV_masked_attention_kernel_opt( + const T* __restrict qkv_buf, const T* __restrict qkv_bias, + T* __restrict key_cache, + T* __restrict value_cache, + T* __restrict context_buf, const bool* finished, int batch_size, int head_num, const int step, const T scalar) +{ + if(finished != nullptr && finished[blockIdx.x / head_num] == true) return; + typedef Copy_t copy_t; + const int elems_per_thread = size_per_head / WARP_SIZE; + + union Access_t + { + copy_t v; + T x[elems_per_thread]; // supported size 1,2,4 + }; + typedef struct Float_n_t + { + T x[elems_per_thread]; // supported size 1,2,4 + } float_n_t; + + __shared__ float_n_t sq[block_sz]; + + extern __shared__ float logits[]; // use to store the logits from [0~step] + + const int tid = threadIdx.x; + const int warp_num = block_sz / WARP_SIZE; + const int bid = blockIdx.x; + const int head_id = blockIdx.x % head_num; + const int warp_id = tid / WARP_SIZE; // warp_id in block + const int lane_id = tid % WARP_SIZE; // lane_id in warp + const int batch_id = bid / head_num; + const int hidden_units = head_num * size_per_head; + + typedef cub::BlockReduce MaxValBlockReduce; + typedef cub::BlockReduce BlockReduce; + __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; + __shared__ typename BlockReduce::TempStorage block_temp_storage; + __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; + + int qkv_id = batch_id * 3 * hidden_units + head_id * size_per_head; + int qkv_bias_id = head_id * size_per_head; + int cache_qkv_id = bid * size_per_head; + + const T* query_buf = qkv_buf + qkv_id; + const T* key_buf = qkv_buf + hidden_units + qkv_id; + const T* value_buf = qkv_buf + 2 * hidden_units + qkv_id; + const T* self_Q_bias = qkv_bias + qkv_bias_id; + const T* self_K_bias = qkv_bias + hidden_units + qkv_bias_id; + const T* self_V_bias = qkv_bias + 2 * hidden_units + qkv_bias_id; + value_cache = value_cache + cache_qkv_id; + key_cache = key_cache + cache_qkv_id; + context_buf = context_buf + cache_qkv_id; + + Access_t bias_r, query_buf_r; + Access_t key_val_r, key_buf_r; + Access_t value_val_r, value_buf_r; + + // each warp will have its own copy of sq + query_buf_r.v = *((copy_t *)query_buf + lane_id); + key_buf_r.v = *((copy_t *)key_buf + lane_id); + bias_r.v = *((copy_t *)self_Q_bias + lane_id); + float qb_r[elems_per_thread]; + for (int i = 0; i < elems_per_thread; ++i) + { + qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; + } + + //offset for each step + int offset = batch_size * hidden_units; + bias_r.v = *((copy_t *) self_K_bias + lane_id); + for(int ite = warp_id; ite < step; ite += warp_num) + { + key_val_r.v = *((copy_t *)&key_cache[ite * offset] + lane_id); + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = (float)key_buf_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&key_cache[ite * offset] + lane_id) = key_val_r.v; + } + float val = 0.f; + for (int i = 0; i < elems_per_thread; i++) + { + val = val + (float)key_val_r.x[i] * qb_r[i] * (float)scalar; + } + float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); + if (lane_id == 0) + { + logits[ite] = qk; + } + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + + float local_i = -1e20f; + for(int i = tid; i < step; i += blockDim.x) + local_i = max(local_i, logits[i]); + + float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + + float local_o = 0.0f; + for(int i = tid; i < step; i += blockDim.x) + { + logits[i] = __expf(logits[i] - s_max_val); + local_o += logits[i]; + } + float val = BlockReduce(block_temp_storage).Sum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + + float s_sum_inverse = __fdividef(1.0f, s_sum); + for(int i = tid; i < step; i += blockDim.x) + { + logits[i] = logits[i] * s_sum_inverse; + } + __syncthreads(); + + // This optimization introduces discrepancy because of different order in FP32 summation + float sum_r[elems_per_thread] = {0.f}; + bias_r.v = *((copy_t *) self_V_bias + lane_id); + value_buf_r.v = *((copy_t *)value_buf + lane_id); + + for(int ite = warp_id; ite < step; ite += warp_num) + { + value_val_r.v = *((copy_t *)&value_cache[ite * offset] + lane_id); + //for the last step, we should update K + bias_K to the cache + if(ite == step - 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + value_val_r.x[i] = (float)value_buf_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&value_cache[ite * offset] + lane_id) = value_val_r.v; + } + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] += (float)value_val_r.x[i] * logits[ite]; + } + } + for (int i = 0; i < elems_per_thread; i++) + { + sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; + } + __syncthreads(); + if (warp_id == 0) + { + #pragma unroll + for (int j = 1; j < warp_num; j++) + { + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + tid].x[i]; + } + } + } + __syncthreads(); + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) + { + value_val_r.x[i] = sum_r[i]; + } + if (warp_id == 0) + { + *((copy_t *)context_buf + lane_id) = value_val_r.v; + } +} + +template +void fusedQKV_masked_attention_dispatch( + const T* qkv_buf, const T* qkv_bias, + T* key_cache, T* value_cache, + T* context_buf, const bool* finished, int max_batch_size, int inference_batch_size, + int head_num, int size_per_head, const int step, const int max_seq_len, cudaStream_t stream) +{ + if (max_seq_len < 0) { + const int block_sz = ATTENTION_BLOCK_SIZE; + T scalar = (T)(1.f / sqrtf(size_per_head * 1.0f)); + + dim3 grid(inference_batch_size * head_num); + + int cond = size_per_head * ((ATTENION_OPT)? 1:0); + switch (cond) + { + case 32: + fusedQKV_masked_attention_kernel_opt<32, block_sz, T><<>>( + qkv_buf, qkv_bias, + key_cache, value_cache, + context_buf, + finished, + max_batch_size, head_num, step, scalar); + break; + case 64: + fusedQKV_masked_attention_kernel_opt<64, block_sz, T><<>>( + qkv_buf, qkv_bias, + key_cache, + value_cache, + context_buf, + finished, + max_batch_size, head_num, step, scalar); + break; + case 128: + fusedQKV_masked_attention_kernel_opt<128, block_sz, T><<>>( + qkv_buf, qkv_bias, + key_cache, + value_cache, + context_buf, + finished, + max_batch_size, head_num, step, scalar); + break; + default: + assert(false); + } + } + else { + using DataType = typename std::conditional::type; + // Prepare the parameters. + Masked_multihead_attention_params params; + memset(¶ms, 0, sizeof(params)); + int hidden_units = head_num * size_per_head; + params.q_bias = reinterpret_cast(qkv_bias); + params.k_bias = reinterpret_cast(qkv_bias) + hidden_units; + params.v_bias = reinterpret_cast(qkv_bias) + 2 * hidden_units; + + // Set the output buffer. + params.out = reinterpret_cast(context_buf); + + // Set the input buffers. + params.q = reinterpret_cast(qkv_buf); + params.k = reinterpret_cast(qkv_buf) + hidden_units; + params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; + params.stride = 3 * hidden_units; + params.finished = const_cast(finished); + + params.k_cache = reinterpret_cast(key_cache); + params.v_cache = reinterpret_cast(value_cache); + params.batch_size = inference_batch_size; + params.seq_length = max_seq_len; + params.timestep = step-1; + params.num_heads = head_num; + params.hidden_size_per_head = size_per_head; + params.inv_sqrt_dh = 1.F / sqrtf((float) params.hidden_size_per_head); + + masked_multihead_attention(params, stream); + } +} + +template void fusedQKV_masked_attention_dispatch( + const float* qkv_buf, + const float* qkv_bias, + float* key_cache, + float* value_cache, + float* context_buf, + const bool* finished, + int max_batch_size, + int inference_batch_size, + int head_num, + int size_per_head, + const int step, + const int max_seq_len, + cudaStream_t stream); + +template void fusedQKV_masked_attention_dispatch( + const half* qkv_buf, + const half* qkv_bias, + half* key_cache, + half* value_cache, + half* context_buf, + const bool* finished, + int max_batch_size, + int inference_batch_size, + int head_num, + int size_per_head, + const int step, + const int max_seq_len, + cudaStream_t stream); + +template +void fusedQKV_masked_attention_kernelLauncher( + const T* qkv_buf, + const T* qkv_bias, + T* k_cache, + T* v_cache, + T* output, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int max_seq_len, + cudaStream_t stream) +{ + fusedQKV_masked_attention_dispatch(qkv_buf, + qkv_bias, + k_cache, + v_cache, + output, + nullptr, + batch_size, + batch_size, + head_num, + size_per_head, + seq_len, + max_seq_len, + stream); +} + +template +__global__ void transpose_4d(T* dst, T* src, + const int dim0, + const int dim1, + const int dim2, + const int dim3, + const int dim0_leading_dim, + const int ite) +{ + // transpose from [dim0, dim1, dim2, dim3] to [dim2, X, dim1, dim3] + // where the dimension of X is dim0_leading_dim, and offset is ite * dim0 + for(int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * dim3; i+= blockDim.x * gridDim.x) + { + int index = i; + const int d3 = index % dim3; + index = (index - d3) / dim3; + const int d2 = index % dim2; + index = (index - d2) / dim2; + const int d1 = index % dim1; + index = (index - d1) / dim1; + const int d0 = index % dim0; + index = (index - d0) / dim0; + dst[d2 * dim0_leading_dim * dim1 * dim3 + (d0 + dim0 * ite) * dim1 * dim3 + d1 * dim3 + d3] = src[i]; + } +} + +template<> +__global__ void transpose_4d(half* dst, half* src, + const int dim0, + const int dim1, + const int dim2, + const int dim3, + const int dim0_leading_dim, + const int ite) +{ + half2 *dst_ptr = (half2 *) dst; + half2 *src_ptr = (half2 *) src; + const int half_dim3 = dim3 / 2; + // transpose from [dim0, dim1, dim2, half_dim3] to [dim2, dim0, dim1, half_dim3] + // where the dimension of X is dim0_leading_dim, and offset is ite * dim0 + for(int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * half_dim3; i+= blockDim.x * gridDim.x) + { + int index = i; + const int d3 = index % half_dim3; + index = (index - d3) / half_dim3; + const int d2 = index % dim2; + index = (index - d2) / dim2; + const int d1 = index % dim1; + index = (index - d1) / dim1; + const int d0 = index % dim0; + index = (index - d0) / dim0; + dst_ptr[d2 * dim0_leading_dim * dim1 * half_dim3 + (d0 + dim0 * ite) * dim1 * half_dim3 + d1 * half_dim3 + d3] = src_ptr[i]; + } +} + +template +void transpose_4d_kernelLauncher(T* dst, T* src, + const int local_batch_size, + const int seq_len, + const int size_per_head, + const int local_hidden_units, + const int local_head_num, + const int batch_size, + const int ite, + cudaStream_t stream) +{ + transpose_4d<<>>( + dst, src, + local_batch_size, local_head_num, + seq_len, size_per_head, batch_size, ite); +} + +template void transpose_4d_kernelLauncher( + float* dst, + float* src, + const int local_batch_size, + const int seq_len, + const int size_per_head, + const int local_hidden_units, + const int local_head_num, + const int batch_size, + const int ite, + cudaStream_t stream); + +template void transpose_4d_kernelLauncher( + half* dst, + half* src, + const int local_batch_size, + const int seq_len, + const int size_per_head, + const int local_hidden_units, + const int local_head_num, + const int batch_size, + const int ite, + cudaStream_t stream); + +#define NEW_TRANSPOSE_BATCH_MAJOR 1 + +template +__global__ void transpose_4d_batch_major_k_cache(T* k_dst, const T* k_src, + const int head_num, + const int size_per_head, + const int seq_len, + const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4)? 4 : 8; + + auto key_src = reinterpret_cast(k_src + batch_id * head_num * size_per_head * seq_len + head_id * size_per_head * seq_len); + auto key_dst = reinterpret_cast(k_dst + batch_id * head_num * size_per_head * max_seq_len + head_id * size_per_head * max_seq_len); + + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + if (out_idx >= head_num * size_per_head_div_x * max_seq_len) return; + + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + idx = (idx - k_seq_len_id) / max_seq_len; + const int k_head_size_id = idx % size_per_head_div_x; + + if (k_seq_len_id < seq_len) + key_dst[out_idx] = key_src[k_seq_len_id * size_per_head_div_x + k_head_size_id]; +} + +template +__global__ void transpose_4d_batch_major_v_cache(T* v_dst, const T* v_src, + const int head_num, + const int size_per_head, + const int seq_len, + const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + + // 16 byte loads will handle "x" dimension + auto val_src = reinterpret_cast(v_src + batch_id * head_num * size_per_head * seq_len + head_id * size_per_head * seq_len); + auto val_dst = reinterpret_cast(v_dst + batch_id * head_num * size_per_head * max_seq_len + head_id * size_per_head * max_seq_len); + + // idx is over output dimension L * size_per_head / x for values + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + constexpr int X_ELEMS = (sizeof(T) == 4)? 4 : 8; + const int size_per_head_div_x = size_per_head / X_ELEMS; + + if (idx >= size_per_head_div_x * seq_len) return; + + val_dst[idx] = val_src[idx]; +} + +template +__global__ void transpose_4d_batch_major(T* k_dst, T* v_dst, + const T* k_src, const T* v_src, + const int head_num, + const int size_per_head, + const int seq_len, + const int max_seq_len) +{ + const int hidden_dim = head_num * size_per_head; + const int x = (sizeof(T) == 4)? 4 : 8; + const int size_per_head_split = size_per_head / x; + const int batch_id = blockIdx.x; + const int seq_id = blockIdx.y; + + for(int id = threadIdx.x; id < head_num * size_per_head_split * x; id += blockDim.x) + { + int tmp_id = id; + int x_id = tmp_id % x; + tmp_id = (tmp_id - x_id) / x; + int size_id = tmp_id % size_per_head_split; + tmp_id = (tmp_id - size_id) / size_per_head_split; + int head_id = tmp_id % head_num; + + // key: [B, head_num, L, size_per_head / x, x] -> [B, head_num, size_per_head / x, L, x] + k_dst[batch_id * hidden_dim * max_seq_len + head_id * size_per_head * max_seq_len + size_id * max_seq_len * x + seq_id * x + x_id] = + k_src[batch_id * hidden_dim * seq_len + head_id * size_per_head * seq_len + seq_id * size_per_head + size_id * x + x_id]; + + // value: [B, head_num, L, size_per_head / x, x] -> [B, head_num, L, size_per_head/x, x] + v_dst[batch_id * hidden_dim * max_seq_len + head_id * size_per_head * max_seq_len + seq_id * size_per_head + size_id * x + x_id] = + v_src[batch_id * hidden_dim * seq_len + head_id * size_per_head * seq_len + seq_id * size_per_head + size_id * x + x_id]; + } +} + +template +void transpose_4d_batch_major_kernelLauncher(T* k_dst, T* v_dst, + const T* k_src, const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + cudaStream_t stream) +{ + constexpr int block_sz = 128; +#if NEW_TRANSPOSE_BATCH_MAJOR == 1 + constexpr int x = (sizeof(T) == 4)? 4 : 8; + int size = max_seq_len * size_per_head / x; + dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num); + dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); + + transpose_4d_batch_major_k_cache<<>>( + k_dst, k_src, + local_head_num, + size_per_head, + seq_len, + max_seq_len + ); + + transpose_4d_batch_major_v_cache<<>>( + v_dst, v_src, + local_head_num, + size_per_head, + seq_len, + max_seq_len + ); +#else + dim3 grid(local_batch_size, seq_len); + + transpose_4d_batch_major<<>>( + k_dst, v_dst, + k_src, v_src, + local_head_num, + size_per_head, + seq_len, + max_seq_len + ); +#endif +} + +template void transpose_4d_batch_major_kernelLauncher(float* k_dst, float* v_dst, + const float* k_src, const float* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + cudaStream_t stream); + +template void transpose_4d_batch_major_kernelLauncher(half* k_dst, half* v_dst, + const half* k_src, const half* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + cudaStream_t stream); + +template +__global__ +void add_QKV_bias_generalized_2(const T* __restrict QKV, + const T* __restrict bias, + T* q_buf_, T* k_buf_, T* v_buf_, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + const int word_per_block) +{ + // QKV: [batch x sequence length, hidden * 3] + const T* data_ptr; + T* buf_ptr; + + int n = head_num * size_per_head; + const int blocks_per_word = n / blockDim.x; + const int blocks_per_buffer = gridDim.x / 3; + const int qkv_id = blockIdx.x / blocks_per_buffer; + const int block_id_in_buffer = blockIdx.x % blocks_per_buffer; + const int offset = block_id_in_buffer * blockDim.x + threadIdx.x; + const int bias_id = offset % n; + T* buf_ptrs[3] = {q_buf_, k_buf_, v_buf_}; + + const int bid = blockIdx.x; + + for(int index = threadIdx.x; index < n; index += blockDim.x) + { + buf_ptrs[index / n][bid * n + index % n] = QKV[bid * 3 * n + index] + __ldg(&bias[index]); + } +} + +template +__global__ +void cross_attention_kernel_opt( + T* __restrict query_buf, const T* __restrict Q_bias, + T* __restrict key_cache, const T* __restrict K_bias, + T* __restrict value_cache, const T* __restrict V_bias, + const int* length_per_sample, T* __restrict context_buf, + const bool* finished, + int batch_size, int head_num, const int step, const int seq_len, const float scalar) +{ + if(finished != nullptr && finished[blockIdx.x / head_num] == true) return; + typedef Copy_t copy_t; + const int elems_per_thread = size_per_head / WARP_SIZE; + union Access_t + { + copy_t v; + T x[elems_per_thread]; // supported size 1,2,4 + }; + typedef struct Float_n_t + { + float x[elems_per_thread]; // supported size 1,2,4 + } float_n_t; + + __shared__ float_n_t sq[block_sz]; + extern __shared__ float logits[]; // use to store the logits from [0~step] + + const int warp_id = threadIdx.x / WARP_SIZE; + const int warp_num = block_sz / WARP_SIZE; + + typedef cub::BlockReduce MaxValBlockReduce; + typedef cub::BlockReduce BlockReduce; + __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; + __shared__ typename BlockReduce::TempStorage block_temp_storage; + + __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; + + const int tid = threadIdx.x; + const int bid = blockIdx.x / head_num; + const int head_id = blockIdx.x % head_num; + + int length = __ldg(&length_per_sample[bid]); + + const int lane_id = tid % WARP_SIZE; + + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head; + int qkv_bias_id = head_id * size_per_head; + + int key_value_id = bid * (seq_len * head_num * size_per_head) + + + head_id * size_per_head; + + query_buf = &query_buf[qkv_id]; + K_bias = &K_bias[qkv_bias_id]; + key_cache = &key_cache[key_value_id]; + Q_bias = &Q_bias[qkv_bias_id]; + V_bias = &V_bias[qkv_bias_id]; + value_cache = &value_cache[key_value_id]; + context_buf = &context_buf[qkv_id]; + + Access_t bias_r, key_val_r, query_buf_r; + + // each warp will have its own copy of sq + query_buf_r.v = *((copy_t *)query_buf + lane_id); + bias_r.v = *((copy_t *)Q_bias + lane_id); + float qb_r[elems_per_thread]; + for (int i = 0; i < elems_per_thread; ++i) + { + qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; + } + + //offset for each step + int offset = head_num * size_per_head; + + bias_r.v = *((copy_t *) K_bias + lane_id); + for(int ite = warp_id; ite < length; ite += warp_num) + { + key_val_r.v = *((copy_t *)&key_cache[ite * offset] + lane_id); + + //For the first step, we should add bias to key memory cache. + //The KV memory cache only need to be updated at the first step. + if (step == 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&key_cache[ite * offset] + lane_id) = key_val_r.v; + } + float val = 0.f; + for (int i = 0; i < elems_per_thread; i++) + { + val = val + (float)key_val_r.x[i] * qb_r[i] * scalar; + } + float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); + if (lane_id == 0) + { + logits[ite] = qk; + } + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + float local_i = -1e20f; + for(int i = tid; i < length; i += blockDim.x) + local_i = max(local_i, logits[i]); + + float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + float local_o = 0.0f; + for(int i = tid; i < length; i += blockDim.x) + { + logits[i] = __expf(logits[i] - s_max_val); + local_o += logits[i]; + } + float val = BlockReduce(block_temp_storage).Sum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + + float s_sum_inverse = __fdividef(1.0f, s_sum); + for(int i = tid; i < length; i += blockDim.x) + { + logits[i] = logits[i] * s_sum_inverse; + } + __syncthreads(); + + // This optimization introduces discrepancy because of different order in FP32 summation + float sum_r[elems_per_thread] = {0.f}; + bias_r.v = *((copy_t *) V_bias + lane_id); + for(int ite = warp_id; ite < length; ite += warp_num) + { + key_val_r.v = *((copy_t *)&value_cache[ite * offset] + lane_id); + + //For the first step, we should add bias to key memory cache. + if(step == 1) + { + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; + } + *((copy_t *)&value_cache[ite * offset] + lane_id) = key_val_r.v; + } + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] += (float)key_val_r.x[i] * logits[ite]; + } + } + for (int i = 0; i < elems_per_thread; i++) + { + sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; + } + __syncthreads(); + if (threadIdx.x < WARP_SIZE) + { + #pragma unroll + for (int j = 1; j < warp_num; j++) + { + for (int i = 0; i < elems_per_thread; ++i) + { + sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + threadIdx.x].x[i]; + } + } + } + __syncthreads(); + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) + { + key_val_r.x[i] = sum_r[i]; + } + if (threadIdx.x < WARP_SIZE) + { + *((copy_t *)context_buf + lane_id) = key_val_r.v; + } +} + +template +__global__ +void cross_attention_kernel( + T* query_buf, const T* Q_bias, + T* key_cache, const T* K_bias, + T* value_cache, const T* V_bias, + const int* length_per_sample, T* context_buf, + const bool* finished, + int batch_size, int head_num, int size_per_head, int step, const int seq_len, const T scalar) +{ + if(finished != nullptr && finished[blockIdx.x / head_num] == true) return; + int tid = threadIdx.x; + int bid = blockIdx.x / head_num; + int head_id = blockIdx.x % head_num; + + extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; + T* sq = reinterpret_cast(s_buf); + T* logits = reinterpret_cast(&sq[size_per_head]); + + int length = __ldg(&length_per_sample[bid]); + + int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; + int qkv_bias_id = head_id * size_per_head + tid; + + if(tid < size_per_head) + sq[tid] = query_buf[qkv_id] + Q_bias[qkv_bias_id]; + __syncthreads(); + + for(int ite = 0; ite < length; ++ite) + { + int key_id = bid * (seq_len * head_num * size_per_head) + ite * (head_num * size_per_head) + + head_id * size_per_head + tid; + + T key = tid < size_per_head ? key_cache[key_id] : (T)(0.0f); + + //For the first step, we should add bias to key memory cache. + //The KV memory cache only need to be updated at the first step. + if(step == 1 && tid < size_per_head) + { + key += K_bias[head_id * size_per_head + tid]; + key_cache[key_id] = key; + } + + T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); + T qk = blockReduceSum(val); + if(threadIdx.x == 0) + logits[ite] = qk; + __syncthreads(); //try to remove + } + __syncthreads(); + + __shared__ float s_max_val, s_sum; + + float local_i = tid < length ? (float)logits[tid] : -1e20f; + float max_val = blockReduceMax(local_i); + if(tid == 0) + s_max_val = max_val; + __syncthreads(); + + local_i -= s_max_val; + float local_o = tid < length ? __expf(local_i) : 0.0f; + float val = blockReduceSum(local_o); + + if(tid == 0) + s_sum = val + 1e-6; + __syncthreads(); + if(tid < length) + logits[tid] = local_o / s_sum; + __syncthreads(); + + if(tid < size_per_head) + { + T sum = (T)0.0f; + for(int ite = 0; ite < length; ++ite) + { + int value_id = bid * seq_len * head_num * size_per_head + ite * head_num * size_per_head + + head_id * size_per_head + tid; + + T value = value_cache[value_id]; + + //for the first step, we should add bias to key memory cache + if(step == 1) + { + value += V_bias[head_id * size_per_head + tid]; + value_cache[value_id] = value; + } + sum += value * logits[ite]; + } + context_buf[bid * head_num * size_per_head + head_id * size_per_head + tid] = sum; + } +} + +template +void cross_attention_dispatch(T* query_buf, const T* Q_bias, + T* key_cache, const T* K_bias, T* value_cache, const T* V_bias, const int* length, + T* context_buf, const bool* finished, + int batch_size, int head_num, int size_per_head, int step, int seq_len, cudaStream_t stream) + { + const int block_sz = ATTENTION_BLOCK_SIZE; + float scalar = 1.f / sqrtf(size_per_head * 1.0f); + + dim3 grid(batch_size * head_num); + + int cond = size_per_head * ((ATTENION_OPT)? 1:0); + switch (cond) + { + case 32: + cross_attention_kernel_opt<<>>( + query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, finished, + batch_size, head_num, step, seq_len, scalar); + break; + case 64: + cross_attention_kernel_opt<<>>( + query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, finished, + batch_size, head_num, step, seq_len, scalar); + break; + case 128: + cross_attention_kernel_opt<<>>( + query_buf, Q_bias, key_cache, K_bias, value_cache, V_bias, length, context_buf, finished, + batch_size, head_num, step, seq_len, scalar); + break; + default: + // default path + + int block_size = 128; + + if(seq_len <= 64) + block_size = 64; + else if(seq_len <= 128 && seq_len > size_per_head) + block_size = 128; + else if(seq_len > 128 && seq_len <= 256) + block_size = 256; + else if(seq_len > 256 && seq_len <= 512) + block_size = 512; + else + block_size = 1024; + + if(block_size < size_per_head) + block_size = size_per_head; + + assert(block_size <= 1024); + dim3 block(block_size); + + int shared_size = sizeof(T) * (size_per_head + seq_len); + cross_attention_kernel<<>>( + query_buf, Q_bias, + key_cache, K_bias, + value_cache, V_bias, + length, context_buf, finished, + batch_size, + head_num, size_per_head, step, seq_len, scalar); + } + } + +template void cross_attention_dispatch( + float* query_buf, + const float* Q_bias, + float* key_cache, + const float* K_bias, + float* value_cache, + const float* V_bias, + const int* length, + float* context_buf, + const bool* finished, + int batch_size, + int head_num, + int size_per_head, + int step, + int seq_len, + cudaStream_t stream); + +template void cross_attention_dispatch( + half* query_buf, + const half* Q_bias, + half* key_cache, + const half* K_bias, + half* value_cache, + const half* V_bias, + const int* length, + half* context_buf, + const bool* finished, + int batch_size, + int head_num, + int size_per_head, + int step, + int seq_len, + cudaStream_t stream); + +template void fusedQKV_masked_attention_kernelLauncher( + const float* qkv_buf, + const float* qkv_bias, + float* k_cache, + float* v_cache, + float* output, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int max_seq_len, + cudaStream_t stream); + +template void fusedQKV_masked_attention_kernelLauncher( + const half* qkv_buf, + const half* qkv_bias, + half* k_cache, + half* v_cache, + half* output, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int max_seq_len, + cudaStream_t stream); + +}//namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/cuda/open_decoder.cuh b/fastertransformer/cuda/open_decoder.cuh new file mode 100644 index 000000000..ccdde89b8 --- /dev/null +++ b/fastertransformer/cuda/open_decoder.cuh @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace fastertransformer{ + +template +void fusedQKV_masked_attention_dispatch( + const T* qkv_buf, const T* qkv_bias, + T* key_cache, T* value_cache, + T* context_buf, const bool* finished, int max_batch_size, int inference_batch_size, + int head_num, int size_per_head, const int step, const int max_seq_len, cudaStream_t stream); + +template +void masked_attention_dispatch( + T* key_buf, T* value_buf, + T* query_buf, const T* self_Q_bias, + T* key_cache, const T* self_K_bias, T* value_cache, const T* self_V_bias, + T* context_buf, const bool* finished, int max_batch_size, int inference_batch_size, + int head_num, int size_per_head, const int step, const int max_seq_len, cudaStream_t stream); + +template +void cross_attention_dispatch(T* query_buf, const T* Q_bias, + T* key_cache, const T* K_bias, T* value_cache, const T* V_bias, const int* length, + T* context_buf, const bool* finished, + int batch_size, int head_num, int size_per_head, int step, int seq_len, cudaStream_t stream); + +template +void fusedQKV_masked_attention_kernelLauncher( + const T* qkv_buf, + const T* qkv_bias, + T* k_cache, + T* v_cache, + T* output, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int max_seq_len, + cudaStream_t stream +); + +template +void transpose_4d_kernelLauncher(T* dst, T* src, + const int local_batch_size, + const int seq_len, + const int size_per_head, + const int local_hidden_units, + const int local_head_num, + const int batch_size, + const int ite, + cudaStream_t stream +); + +template +void transpose_4d_batch_major_kernelLauncher(T* k_dst, T* v_dst, + const T* k_src, const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + cudaStream_t stream); + +} diff --git a/fastertransformer/cuda/topk_kernels.cu b/fastertransformer/cuda/topk_kernels.cu index a52eb39db..d4daca530 100644 --- a/fastertransformer/cuda/topk_kernels.cu +++ b/fastertransformer/cuda/topk_kernels.cu @@ -20,6 +20,24 @@ namespace fastertransformer { +__global__ void ker_curand_setup(curandState_t* state, const int size) + { + // curand_init(clock(), blockIdx.x * blockDim.x + threadIdx.x, 0, &state[blockIdx.x * blockDim.x + threadIdx.x]); + // fix the seed to prevent the seed of different gpu are differnet in Tensor Parallel + if(threadIdx.x + blockIdx.x * blockDim.x < size) + curand_init(0, blockIdx.x * blockDim.x + threadIdx.x, 0, &state[blockIdx.x * blockDim.x + threadIdx.x]); + } + +void ker_curand_setupLauncher(curandState_t* state, + DecodingSamplingArguments args, + cudaStream_t stream) + { + dim3 block(256); + dim3 grid((int)(ceil(args.batch_size_ * 1.0 / 256))); + ker_curand_setup<<>>(state, args.batch_size_); + } + + template __launch_bounds__(THREADBLOCK_SIZE) __global__ @@ -82,7 +100,7 @@ void beam_topK_kernelLauncher(const T* log_probs, cudaStream_t stream) { const int batch_size = args.batch_size_; - const int vocab_size = args.vocab_size_; + const int vocab_size = args.vocab_size_padded_; const int candidate_num = args.candidate_num_; const int block_size = 256; switch(candidate_num) @@ -192,8 +210,10 @@ __global__ void topk_stage_1_opt3( T* tmp_log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, + const bool* finished, const int k, - const int vocab_size + const int vocab_size, + const int end_id ) { typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; @@ -210,6 +230,26 @@ __global__ void topk_stage_1_opt3( const bool IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16)? HALF_FLT_MAX : FLT_MAX; + if(finished != nullptr && finished[row_id] == true) + { + if(tid < k) + { + const int index = tmp_topk_buf_index + tid; + if(block_lane == 0 && tid == 0) + { + topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; + topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; + } + else + { + topk_tmp_id_buf[index] = -1; + topk_tmp_val_buf[index] = -MAX_T_VAL; + + } + } + return; + } + for(int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { int index = elem_id + tmp_log_buf_index; @@ -281,6 +321,90 @@ __global__ void topk_stage_2_opt3( if(tid < k) ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; } +template +__global__ void topk_stage_2_opt3_sampling(const int* __restrict topk_tmp_id_buf, + T* topk_tmp_val_buf, + T* topk_tmp2_val_buf, + int* ids, + int* sequence_length, + bool* finished_buf, + const int k, + curandState_t* curandstate, + const int end_id, + const int vocab_size) +{ + const int size = k * BLOCKS_PER_BEAM_; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16)? HALF_FLT_MAX : FLT_MAX; + + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + __shared__ float rand_num; + __shared__ float s_sum; + __shared__ float s_max; + T *s_val = topk_tmp_val_buf + batch_id * size; + int *s_id = (int*)(array); + s_max = (float)0.0f; + s_sum = (float)0.0f; + TopK_2 partial; + + for(int index = tid; index < size; index += BLOCK_SIZE_) + { + topk_tmp2_val_buf[batch_id * size + index] = topk_tmp_val_buf[batch_id * size + index]; + } + __syncthreads(); + T *s_val2 = topk_tmp2_val_buf + batch_id * size; + + for(int ite = 0; ite < k; ite++) + { + partial.init(); + #pragma unroll + for(int i = tid; i < size; i+= BLOCK_SIZE_) + { + partial.insert((float)s_val[i], i); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if(ite == 0) + s_max = total.u; + + if(tid == 0) + { + s_id[ite] = total.p; + s_val[total.p] = -MAX_T_VAL; + total.u = __expf(total.u - s_max); + s_val2[total.p] = (T)total.u; + s_sum += total.u; + } + __syncthreads(); + } + if(tid == 0) + { + rand_num = (float)curand_uniform(curandstate + blockIdx.x) * s_sum; + for(int i = 0; i < k; i++) + { + rand_num = rand_num - (float)s_val2[s_id[i]]; + if(rand_num <= 0.0f) + { + ids[batch_id] = topk_tmp_id_buf[batch_id * size + s_id[i]] % vocab_size; + break; + } + } + if(finished_buf != nullptr) + { + finished_buf[batch_id] = ids[batch_id] == end_id ? 1 : 0; + if(sequence_length != nullptr) + { + sequence_length[batch_id] = finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1; + } + } + } +} + template __global__ void topk_stage_1_opt2_general( const T* __restrict log_probs, @@ -393,7 +517,8 @@ __global__ void topk_stage_2_opt2_general( temp_log_probs, \ topk_tmp_id_buf, \ topk_tmp_val_buf, \ - beam_width, vocab_size); \ + finished, \ + beam_width, vocab_size, end_id); \ topk_stage_2_opt3<<>>( \ topk_tmp_id_buf, \ topk_tmp_val_buf, \ @@ -406,13 +531,15 @@ void topK_kernelLauncher(void* workspace, size_t& workspace_size, T* log_probs, int* ids, + const bool* finished, DecodingBeamsearchArguments args, cudaStream_t stream) { const int batch_size = args.batch_size_; const int beam_width = args.beam_width_; - const int vocab_size = args.vocab_size_; + const int vocab_size = args.vocab_size_padded_; const T diversity_rate = args.beam_search_diversity_rate_; + const int end_id = args.end_id_; const int max_block_per_beam = 8; int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float @@ -487,30 +614,31 @@ template void topK_kernelLauncher(void* workspace, size_t& workspace_size, float* log_probs, int* ids, + const bool* finished, DecodingBeamsearchArguments args, cudaStream_t stream); // Sampling kernels template -__global__ void sampling(int* topk_tmp_id_buf, - T* topk_tmp_val_buf, - int* ids, - int* sequence_length, - bool* finished_buf, - const int candidate_num, - int random_num, - const int end_id, - const int vocab_size) +__global__ void sampling(int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + int* sequence_length, + bool* finished_buf, + const int candidate_num, + int random_num, + const int end_id, + const int vocab_size) { int tid = threadIdx.x; int bid = blockIdx.x; - __shared__ T sum; - __shared__ T rand_num; + __shared__ float sum; + __shared__ float rand_num; if(tid < candidate_num) { - T max_val = topk_tmp_val_buf[bid * candidate_num]; - topk_tmp_val_buf[bid * candidate_num + tid] = __expf(topk_tmp_val_buf[bid * candidate_num + tid] - max_val); + float max_val = topk_tmp_val_buf[bid * candidate_num]; + topk_tmp_val_buf[bid * candidate_num + tid] = (T)__expf((float)topk_tmp_val_buf[bid * candidate_num + tid] - max_val); } if(tid == 0) @@ -518,27 +646,29 @@ __global__ void sampling(int* topk_tmp_id_buf, sum = 0.0f; for(int i = 0; i < candidate_num; i++) { - sum = sum + topk_tmp_val_buf[bid * candidate_num + i]; + sum = sum + (float)topk_tmp_val_buf[bid * candidate_num + i]; } curandState_t local_state; curand_init((T)random_num, bid, 0, &local_state); - rand_num = (T)curand_uniform(&local_state) * sum; + rand_num = (float)curand_uniform(&local_state) * sum; ids[bid] = topk_tmp_id_buf[bid * candidate_num + candidate_num - 1] % vocab_size; for(int i = 0; i < candidate_num; i++) { - rand_num = rand_num - topk_tmp_val_buf[bid * candidate_num + i]; - if(rand_num <= (T)0.0f){ + rand_num = rand_num - (float)topk_tmp_val_buf[bid * candidate_num + i]; + if(rand_num <= 0.0f){ ids[bid] = topk_tmp_id_buf[bid * candidate_num + i] % vocab_size; break; } } - - if(sequence_length != nullptr && finished_buf != nullptr) + if(finished_buf != nullptr) { - sequence_length[bid] = finished_buf[bid] ? sequence_length[bid] : sequence_length[bid] + 1; finished_buf[bid] = ids[bid] == end_id ? 1 : 0; + if(sequence_length != nullptr) + { + sequence_length[bid] = finished_buf[bid] ? sequence_length[bid] : sequence_length[bid] + 1; + } } } } @@ -558,10 +688,15 @@ void topK_sampling_kernel_kernelLauncher(void* workspace, bool* finished_buf, int random_num, DecodingSamplingArguments args, - cudaStream_t stream) + cudaStream_t stream, + const int batch_size) { - const int batch_size = args.batch_size_; - const int vocab_size = args.vocab_size_; + // This function would be called two or more times. + // First time is used to get the workspace size, so we need to put + // max batch size we want to use. + // For other times, we need to put the inference batch size to + // set the grid size we use. + const int vocab_size = args.vocab_size_padded_; const int candidate_num = args.candidate_num_; const int end_id = args.end_id_; const int block_size = 256; @@ -573,7 +708,7 @@ void topK_sampling_kernel_kernelLauncher(void* workspace, if(workspace == nullptr) { - workspace_size = sizeof(int) * topk_tmp_ids_buf_size + sizeof(int) * topk_tmp_val_buf_size; + workspace_size = sizeof(int) * topk_tmp_ids_buf_size + sizeof(T) * topk_tmp_val_buf_size; } else { @@ -585,6 +720,8 @@ void topK_sampling_kernel_kernelLauncher(void* workspace, CASE_K(1); CASE_K(2); CASE_K(4); + CASE_K(16); + CASE_K(64); default: printf("[ERROR] Topk kernel does not support candidate_num = %d \n", candidate_num); exit(0); @@ -598,6 +735,91 @@ void topK_sampling_kernel_kernelLauncher(void* workspace, #undef CASE_K +#define CASE_K(K,BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ + case K: \ + topk_stage_1_opt3<<>>( \ + log_probs, \ + temp_log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished_buf, \ + candidate_num, vocab_size, end_id); \ + topk_stage_2_opt3_sampling<<>>( \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + topk_tmp2_val_buf, \ + ids, \ + sequence_length, \ + finished_buf, \ + candidate_num, \ + curandstate, \ + end_id, \ + vocab_size); \ + break; \ + + +template +void topK_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + T* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments args, + cudaStream_t stream, + const int batch_size) +{ + // Here, we put batch size as an argument because the batch size of initialization + // and inference may be different due to pipelint parallelism. + const int candidate_num = args.candidate_num_; + const int vocab_size = args.vocab_size_padded_; + const int end_id = args.end_id_; + + const int max_block_per_beam = 8; + int temp_log_probs_buf_size = batch_size * vocab_size; // type float + int topk_tmp_ids_buf_size = batch_size * candidate_num * max_block_per_beam; // type int + int topk_tmp_val_buf_size = batch_size * candidate_num * max_block_per_beam; // type float + + // prevent memory misalinged address + temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; + topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; + topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; + + if(workspace == nullptr) + { + workspace_size = sizeof(T) * temp_log_probs_buf_size + + sizeof(int) * topk_tmp_ids_buf_size + + 2 * sizeof(T) * topk_tmp_val_buf_size; + return; + } + else + { + T* temp_log_probs = (T*)workspace; + int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); + T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); + T* topk_tmp2_val_buf = (T*)(topk_tmp_val_buf + topk_tmp_val_buf_size); + + switch(candidate_num) + { + CASE_K(1,128,128,8); + CASE_K(4,128,128,8); + CASE_K(8,128,128,8); + CASE_K(16,128,128,8); + CASE_K(32,256,128,8); + CASE_K(64,256,256,8); + default: + printf("[ERROR] Topk kernel does not support candidate_num = %d \n", candidate_num); + exit(0); + break; + } + return; + } +} + +#undef CASE_K + + template void topK_sampling_kernel_kernelLauncher(void* workspace, size_t& workspace_size, float* log_probs, @@ -606,7 +828,8 @@ template void topK_sampling_kernel_kernelLauncher(void* workspace, bool* finished_buf, int random_num, DecodingSamplingArguments args, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); template void topK_sampling_kernel_kernelLauncher(void* workspace, size_t& workspace_size, @@ -616,7 +839,31 @@ template void topK_sampling_kernel_kernelLauncher(void* workspace, bool* finished_buf, int random_num, DecodingSamplingArguments args, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); + +template void topK_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + float* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments args, + cudaStream_t stream, + const int batch_size); + +template void topK_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + half* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments args, + cudaStream_t stream, + const int batch_size); + __global__ void init_topp_id_val(int* topp_id_val_buf, int* topp_offset_buf, @@ -670,21 +917,61 @@ __global__ void top_p_sampling(T* sorted_log_probs, curandState_t local_state; curand_init((T)random_num, tid, 0, &local_state); T rand_num = (T)curand_uniform(&local_state) * (T)prob_threshold; - ids[tid] = sorted_id_vals[vocab_size - 1]; + ids[tid] = sorted_id_vals[tid * vocab_size]; for(int i = tid * vocab_size; i < tid * vocab_size + vocab_size; i++) { rand_num = rand_num - sorted_log_probs[i]; - if(rand_num <= (T)0.0) + if(rand_num <= (T)0.0f) { ids[tid] = sorted_id_vals[i]; break; } } - if(sequence_length != nullptr && finished_buf != nullptr) + if(finished_buf != nullptr) { - sequence_length[tid] = finished_buf[tid] ? sequence_length[tid] : sequence_length[tid] + 1; finished_buf[tid] = ids[tid] == end_id ? 1 : 0; + if(sequence_length != nullptr) + { + sequence_length[tid] = finished_buf[tid] ? sequence_length[tid] : sequence_length[tid] + 1; + } + } +} + +template +__global__ void top_p_sampling_v2(T* sorted_log_probs, + int* sorted_id_vals, + int* ids, + int* sequence_length, + bool* finished_buf, + const int vocab_size, + curandState_t* curandstate, + const float prob_threshold, + const int end_id, + const int batch_size) +{ + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if(tid < batch_size) + { + T rand_num = (T)curand_uniform(curandstate + tid) * (T)prob_threshold; + ids[tid] = sorted_id_vals[vocab_size - 1]; + for(int i = tid * vocab_size; i < tid * vocab_size + vocab_size; i++) + { + rand_num = rand_num - sorted_log_probs[i]; + if(rand_num <= (T)0.0) + { + ids[tid] = sorted_id_vals[i]; + break; + } + }; + if(finished_buf != nullptr) + { + finished_buf[tid] = ids[tid] == end_id ? 1 : 0; + if(sequence_length != nullptr) + { + sequence_length[tid] = finished_buf[tid] ? sequence_length[tid] : sequence_length[tid] + 1; + } + } } } @@ -731,17 +1018,17 @@ void topP_sampling_kernel_kernelLauncher(void* workspace, int* output_ids, int* sequence_length, const int n, - cudaStream_t stream) + cudaStream_t stream, + const int batch_size) { - const int batch_size = args.batch_size_; - const int vocab_size = args.vocab_size_; + const int vocab_size = args.vocab_size_padded_; int sorted_log_prob_buf_size = batch_size * vocab_size; // type T int sorted_id_vals_buf_size = batch_size * vocab_size; // type int sorted_log_prob_buf_size = (int)(ceil(sorted_log_prob_buf_size / 4.)) * 4; sorted_id_vals_buf_size = (int)(ceil(sorted_id_vals_buf_size / 4.)) * 4; void *cub_temp_storage = workspace; - T* sorted_log_probs = (T*)(cub_temp_storage + args.cub_temp_storage_size_); + T* sorted_log_probs = (T*)((char*)cub_temp_storage + args.cub_temp_storage_size_); int* sorted_id_vals = (int*)(sorted_log_probs + sorted_log_prob_buf_size); if(workspace == nullptr) @@ -800,7 +1087,8 @@ template void topP_sampling_kernel_kernelLauncher(void* workspace, int* output_ids, int* sequence_length, const int n, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); template void topP_sampling_kernel_kernelLauncher(void* workspace, size_t& workspace_size, @@ -813,9 +1101,184 @@ template void topP_sampling_kernel_kernelLauncher(void* workspace, int* output_ids, int* sequence_length, const int n, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); + +template +__launch_bounds__(THREADBLOCK_SIZE) +__global__ +void beam_topK_kernel_for_topP(const T* log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const int vocab_size, + int* offset_buf, + int* begin_offset_buf, + float p_threshold) +{ + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int thread_id = threadIdx.x; + int block_id = blockIdx.x; + TopK partial; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16)? HALF_FLT_MAX : FLT_MAX; + + #pragma unroll + for(int i = 0; i < MAX_K; ++i) + { + partial.p[i] = -1; + partial.u[i] = -MAX_T_VAL; + } + #pragma unroll + for(int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) + { + int index = elem_id + block_id * vocab_size; + partial.insert(log_probs[index], index); + } + + TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + + if(thread_id == 0) + { + begin_offset_buf[block_id] = offset_buf[block_id]; + T sum_prob = (T)(0.0f); + + #pragma unroll + for(int i = 0; i < MAX_K; i++) + { + sum_prob += total.u[i]; + } + + if ((float)sum_prob >= p_threshold) + { + begin_offset_buf[block_id] += vocab_size; + int index = block_id * vocab_size; + + #pragma unroll + for(int i = 0; i < MAX_K; ++i) + { + topk_tmp_id_buf[index + i] = total.p[i]%vocab_size; + topk_tmp_val_buf[index + i] = total.u[i]; + } + } + } +} + +template +void topP_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + const T* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments& args, + int* output_ids, + int* sequence_length, + const int n, + cudaStream_t stream, + const int batch_size) +{ + // Here, we put batch size as an argument because the batch size of initialization + // and inference may be different due to pipelint parallelism. + const int vocab_size = args.vocab_size_padded_; + const int block_size = 256; + + int sorted_log_prob_buf_size = batch_size * vocab_size; // type T + int sorted_id_vals_buf_size = batch_size * vocab_size; // type int + sorted_log_prob_buf_size = (int)(ceil(sorted_log_prob_buf_size / 4.)) * 4; + sorted_id_vals_buf_size = (int)(ceil(sorted_id_vals_buf_size / 4.)) * 4; + + void *cub_temp_storage = workspace; + T* sorted_log_probs = (T*)((char*)cub_temp_storage + args.cub_temp_storage_size_); + int* sorted_id_vals = (int*)(sorted_log_probs + sorted_log_prob_buf_size); + + + if(workspace == nullptr) + { + cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, + args.cub_temp_storage_size_, + log_probs, + (T*)nullptr, + id_vals, + (int*)nullptr, + vocab_size * batch_size, + batch_size, + begin_offset_buf, + offset_buf + 1, + 0, // begin_bit + sizeof(T)*8, // end_bit = sizeof(KeyT) * 8 + stream); // cudaStream_t + args.cub_temp_storage_size_ = (int)(ceil(args.cub_temp_storage_size_ / 4.)) * 4; + workspace_size = sizeof(T) * sorted_log_prob_buf_size + sizeof(int) * sorted_id_vals_buf_size + args.cub_temp_storage_size_; + } + else + { + beam_topK_kernel_for_topP<<>>(log_probs, \ + sorted_id_vals, sorted_log_probs, vocab_size, offset_buf,begin_offset_buf, args.probability_threshold_); + + cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage, + args.cub_temp_storage_size_, + log_probs, + sorted_log_probs, + id_vals, + sorted_id_vals, + n * batch_size, + batch_size, + begin_offset_buf, offset_buf+1, + 0, // begin_bit + sizeof(T)*8, // end_bit = sizeof(KeyT) * 8 + stream); // cudaStream_t + + dim3 block(256); + dim3 grid((int)(ceil(batch_size * 1.0 / 256))); + top_p_sampling_v2<<>>(sorted_log_probs, + sorted_id_vals, + output_ids, + sequence_length, + finished_buf, + n, + curandstate, + args.probability_threshold_, + args.end_id_, + batch_size); + } +} + +template void topP_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + const float* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments& args, + int* output_ids, + int* sequence_length, + const int n, + cudaStream_t stream, + const int batch_size); + +template void topP_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + const half* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments& args, + int* output_ids, + int* sequence_length, + const int n, + cudaStream_t stream, + const int batch_size); template __launch_bounds__(THREADBLOCK_SIZE) @@ -898,7 +1361,8 @@ void topK_topP_sampling_kernel_kernelLauncher(void* workspace, const T* logits, const int random_num, DecodingSamplingArguments& args, - cudaStream_t stream) + cudaStream_t stream, + const int batch_size) { if(workspace == nullptr) { @@ -906,7 +1370,6 @@ void topK_topP_sampling_kernel_kernelLauncher(void* workspace, } else { - const int batch_size = args.batch_size_; const int vocab_size = args.vocab_size_padded_; const int block_size = 256; const T prob_threshold = args.probability_threshold_; @@ -915,6 +1378,8 @@ void topK_topP_sampling_kernel_kernelLauncher(void* workspace, CASE_K(1); CASE_K(2); CASE_K(4); + CASE_K(16); + CASE_K(64); default: printf("[ERROR] Topk kernel does not support candidate_num = %d \n", args.candidate_num_); exit(0); @@ -925,13 +1390,183 @@ void topK_topP_sampling_kernel_kernelLauncher(void* workspace, #undef CASE_K +template +__global__ void topk_topp_sampling_kernel_v2(const int* __restrict topk_tmp_id_buf, + T* topk_tmp_val_buf, + T* topk_tmp2_val_buf, + int* ids, + int* sequence_length, + bool* finished_buf, + const int k, + const T prob_threshold, + curandState_t* curandstate, + const int end_id, + const int vocab_size) +{ + const int size = k * BLOCKS_PER_BEAM_; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16)? HALF_FLT_MAX : FLT_MAX; + + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + __shared__ float rand_num; + __shared__ float s_max; + __shared__ float s_sum; + T *s_val = topk_tmp_val_buf + batch_id * size; + int *s_id = (int*)(array); + s_max = 0.0f; + s_sum = 0.0f; + TopK_2 partial; + + for(int index = tid; index < size; index += BLOCK_SIZE_) + { + topk_tmp2_val_buf[batch_id * size + index] = topk_tmp_val_buf[batch_id * size + index]; + } + __syncthreads(); + T *s_val2 = topk_tmp2_val_buf + batch_id * size; + + for(int ite = 0; ite < k; ite++) + { + partial.init(); + #pragma unroll + for(int i = tid; i < size; i+= BLOCK_SIZE_) + { + partial.insert((float)s_val[i], i); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if(ite == 0) + s_max = total.u; + + if(tid == 0) + { + s_id[ite] = total.p; + s_val[total.p] = -MAX_T_VAL; + total.u = __expf(total.u - s_max); + s_val2[total.p] = (T)total.u; + s_sum += total.u; + } + __syncthreads(); + } + if(tid == 0) + { + rand_num = (float)curand_uniform(curandstate + blockIdx.x) * (float)prob_threshold * s_sum; + for(int i = 0; i < k; i++) + { + rand_num = rand_num - (float)s_val2[s_id[i]]; + if(rand_num <= 0.0f) + { + ids[batch_id] = topk_tmp_id_buf[batch_id * size + s_id[i]] % vocab_size; + break; + } + } + if(finished_buf != nullptr) + { + finished_buf[batch_id] = ids[batch_id] == end_id ? 1 : 0; + if(sequence_length != nullptr) + { + sequence_length[batch_id] = finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1; + } + } + } +} + +#define CASE_K(K,BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ + case K: \ + topk_stage_1_opt3<<>>( \ + logits, \ + temp_logits, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished_buf, \ + candidate_num, vocab_size, end_id); \ + topk_topp_sampling_kernel_v2<<>>( \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + topk_tmp2_val_buf, \ + output_ids, \ + nullptr, \ + finished_buf, \ + candidate_num, \ + prob_threshold, \ + curandstate, \ + end_id, \ + vocab_size); \ + break; \ + +template +void topK_topP_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + int* output_ids, + const T* logits, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments& args, + cudaStream_t stream, + const int batch_size) +{ + // Here, we put batch size as an argument because the batch size of initialization + // and inference may be different due to pipelint parallelism. + const int candidate_num = args.candidate_num_; + const int vocab_size = args.vocab_size_padded_; + const int end_id = args.end_id_; + const T prob_threshold = args.probability_threshold_; + + const int max_block_per_beam = 8; + int temp_logits_buf_size = batch_size * vocab_size; // type float + int topk_tmp_ids_buf_size = batch_size * candidate_num * max_block_per_beam; // type int + int topk_tmp_val_buf_size = batch_size * candidate_num * max_block_per_beam; // type float + + // prevent memory misalinged address + temp_logits_buf_size = (int)(ceil(temp_logits_buf_size / 4.)) * 4; + topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; + topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; + + if(workspace == nullptr) + { + workspace_size = sizeof(T) * temp_logits_buf_size + + sizeof(int) * topk_tmp_ids_buf_size + + 2 * sizeof(T) * topk_tmp_val_buf_size; + return; + } + else + { + T* temp_logits = (T*)workspace; + int* topk_tmp_id_buf = (int*)(temp_logits + temp_logits_buf_size); + T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); + T* topk_tmp2_val_buf = (T*)(topk_tmp_val_buf + topk_tmp_val_buf_size); + + switch(candidate_num) + { + CASE_K(1,128,128,8); + CASE_K(4,128,128,8); + CASE_K(8,128,128,8); + CASE_K(16,128,128,8); + CASE_K(32,256,128,8); + CASE_K(64,256,256,8); + default: + printf("[ERROR] Topk kernel does not support candidate_num = %d \n", candidate_num); + exit(0); + break; + } + return; + } +} + +#undef CASE_K + template void topK_topP_sampling_kernel_kernelLauncher(void* workspace, size_t& workspace_size, int* output_ids, const float* logits, const int random_num, DecodingSamplingArguments& args, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); template void topK_topP_sampling_kernel_kernelLauncher(void* workspace, @@ -940,5 +1575,27 @@ template void topK_topP_sampling_kernel_kernelLauncher(void* workspace, const half* logits, const int random_num, DecodingSamplingArguments& args, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); + +template void topK_topP_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + int* output_ids, + const float* logits, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments& args, + cudaStream_t stream, + const int batch_size); + +template void topK_topP_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + int* output_ids, + const half* logits, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments& args, + cudaStream_t stream, + const int batch_size); + } // end of namespace fastertransformer diff --git a/fastertransformer/cuda/topk_kernels.cuh b/fastertransformer/cuda/topk_kernels.cuh index a4c9e1517..209789c67 100644 --- a/fastertransformer/cuda/topk_kernels.cuh +++ b/fastertransformer/cuda/topk_kernels.cuh @@ -19,7 +19,7 @@ #include #include #include -#include "fastertransformer/arguments.h" +#include "fastertransformer/utils/arguments.h" #include "fastertransformer/cuda/cuda_kernels.h" #include #include @@ -126,14 +126,15 @@ void topK_kernelLauncher(void* workspace, size_t& workspace_size, T* log_probs, int* ids, + const bool* finished, DecodingBeamsearchArguments args, cudaStream_t stream); template void topK_softMax(const T* log_probs, - const float* bias, + const T* bias, const bool* finished, - T* cum_log_probs, + float* cum_log_probs, int* ids, void * tmp_storage, DecodingBeamsearchArguments args, @@ -142,6 +143,22 @@ void topK_softMax(const T* log_probs, /* *************************** end of BeamSearch kernel *********************************** */ /* ********************************** Sampling kernel *********************************** */ +void ker_curand_setupLauncher(curandState_t* state, + DecodingSamplingArguments args, + cudaStream_t stream); + + +template +void topK_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + T* log_probs, + int* ids, + int* sequence_length, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments args, + cudaStream_t stream, + const int batch_size); template void topK_sampling_kernel_kernelLauncher(void* workspace, @@ -152,7 +169,8 @@ void topK_sampling_kernel_kernelLauncher(void* workspace, bool* finished_buf, int random_num, DecodingSamplingArguments args, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); template void topP_sampling_kernel_kernelLauncher(void* workspace, @@ -166,7 +184,24 @@ void topP_sampling_kernel_kernelLauncher(void* workspace, int* output_ids, int* sequence_length, const int n, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); + +template +void topP_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + const T* log_probs, + const int* id_vals, + int* offset_buf, + int* begin_offset_buf, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments& args, + int* output_ids, + int* sequence_length, + const int n, + cudaStream_t stream, + const int batch_size); template void beam_topK_kernelLauncher(const T* log_probs, @@ -182,7 +217,19 @@ void topK_topP_sampling_kernel_kernelLauncher(void* workspace, const T* logits, const int random_num, DecodingSamplingArguments& args, - cudaStream_t stream); + cudaStream_t stream, + const int batch_size); + +template +void topK_topP_sampling_kernel_kernelLauncher_v2(void* workspace, + size_t& workspace_size, + int* output_ids, + const T* logits, + bool* finished_buf, + curandState_t* curandstate, + DecodingSamplingArguments& args, + cudaStream_t stream, + const int batch_size); /* *************************** end of Sampling kernel *********************************** */ diff --git a/fastertransformer/cuda/transformer_kernels.cu b/fastertransformer/cuda/transformer_kernels.cu new file mode 100644 index 000000000..e9c654adc --- /dev/null +++ b/fastertransformer/cuda/transformer_kernels.cu @@ -0,0 +1,604 @@ +/* +* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "fastertransformer/cuda/transformer_kernels.cuh" + +namespace fastertransformer +{ + + +template +__inline__ __device__ +T gelu(T x) +{ + float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template <> +__inline__ __device__ +half2 gelu(half2 val) +{ + half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); + +} + +template +__inline__ __device__ +T warpReduceSum(T val) +{ + for(int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +template +__inline__ __device__ +T blockReduceSum(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if(lane == 0) + shared[wid] = val; + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)0.0f; + val = warpReduceSum(val); + return val; +} + +template +__global__ +void add_bias_gelu(T* out, const T* __restrict bias, int m, int n) +{ + for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) + { + T reg_bias = __ldg(&bias[id % n]); + T val = out[id] + reg_bias; + out[id] = (T)(gelu(val)); + } +} + +template <> + __global__ +void add_bias_gelu(half* out, const half* __restrict bias, int m, int n) +{ + half2* out_ptr = (half2*) out; + const half2* bias_ptr = (half2*) bias; + + for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) + { + half2 reg_bias = __ldg(&bias_ptr[id % n]); + half2 val = out_ptr[id] + reg_bias; + out_ptr[id] = gelu(val); + } +} + +template +__global__ +void add_bias_relu(T* out, const T* __restrict bias, int m, int n) +{ + for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) + { + T reg_bias = __ldg(&bias[id % n]); + T val = out[id] + reg_bias; + out[id] = (T)(val > 0.0f ? val : 0.0f); + } +} + +template <> + __global__ +void add_bias_relu(half* out, const half* __restrict bias, int m, int n) +{ + half2* out_ptr = (half2*) out; + const half2* bias_ptr = (half2*) bias; + + for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) + { + half2 reg_bias = __ldg(&bias_ptr[id % n]); + half2 val = out_ptr[id] + reg_bias; + val.x = val.x > (half)0.0f ? val.x : (half)0.0f; + val.y = val.y > (half)0.0f ? val.y : (half)0.0f; + out_ptr[id] = val; + } +} + +template +__global__ +void add_bias_input_layernorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) +{ + int tid = threadIdx.x; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + + float local_out = 0.0f; + local_out += (float)(out[blockIdx.x * n + tid] + input[blockIdx.x * n + tid] + __ldg(&bias[tid])); + + mean = blockReduceSum(local_out); + if(threadIdx.x == 0) + s_mean = mean / n; + __syncthreads(); + + variance = blockReduceSum((local_out - s_mean) * (local_out - s_mean)); + if(threadIdx.x == 0) + s_variance = variance / n + 1e-6f; + __syncthreads(); + + out[blockIdx.x * n + tid] = + (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[tid])) + (float)(__ldg(&beta[tid]))); +} + +template <> +__global__ +void add_bias_input_layernorm(half* out, const half* input, const half* bias, + const half* gamma, const half* beta, int m, int n) +{ + + int tid = threadIdx.x; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float2 local_out_fp2; + + half2* out_ptr = (half2*)out; + const half2* input_ptr = (const half2*)input; + const half2* bias_ptr = (const half2*)bias; + const half2* gamma_ptr = (const half2*)gamma; + const half2* beta_ptr = (const half2*)beta; + + float local_out = 0.0f; + int id = blockIdx.x * n / 2 + tid; + local_out_fp2 = __half22float2(__hadd2(__hadd2(out_ptr[id], input_ptr[id]), __ldg(&bias_ptr[tid]))); + local_out += local_out_fp2.x; + local_out += local_out_fp2.y; + + mean = blockReduceSum(local_out); + if(threadIdx.x == 0) + s_mean = mean / n; + __syncthreads(); + + variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); + variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); + variance = blockReduceSum(variance); + if(threadIdx.x == 0) + s_variance = rsqrtf(variance / n + 1e-6f); + __syncthreads(); + + float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); + float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); + local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + out_ptr[id] = __float22half2_rn(local_out_fp2); +} + + +template +__global__ +void add_bias_input_layernorm_v2(T* out, const T* __restrict input, const T* __restrict bias, + const T* __restrict gamma, const T* __restrict beta, int n) +{ + const int ite = 4; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float local_out[ite]; + + float sum = 0.0f; + #pragma unroll + for(int i = 0; i < ite; i++) + { + int col_id = i * blockDim.x + tid; + int id = bid * n + col_id; + local_out[i] = (float)(out[id] + __ldg(&input[id]) + __ldg(&bias[col_id])); + sum += local_out[i]; + } + + mean = blockReduceSum(sum); + if(tid == 0) + s_mean = mean / n; + __syncthreads(); + + float var = 0.0f; + #pragma unroll + for(int i = 0; i < ite; i++) + { + float diff = local_out[i] - s_mean; + var += diff * diff; + } + + variance = blockReduceSum(var); + if(tid == 0) + s_variance = rsqrtf(variance / n + 1e-6f); + __syncthreads(); + + #pragma unroll + for(int i = 0; i < ite; i++) + { + int col_id = i * blockDim.x + tid; + int id = bid * n + col_id; + out[id] = (T)((local_out[i] - s_mean) * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); + } +} + +template <> +__global__ +void add_bias_input_layernorm_v2(half* out, const half* __restrict input, const half* __restrict bias, + const half* __restrict gamma, const half* __restrict beta, int n) +{ + const int ite = 4; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + half2 local_out_half2[ite]; + + half2* out_ptr = (half2*)out; + const half2* input_ptr = (const half2*)input; + const half2* bias_ptr = (const half2*)bias; + const half2* gamma_ptr = (const half2*)gamma; + const half2* beta_ptr = (const half2*)beta; + + // float sum = 0.0f; + half2 sum = __float2half2_rn(0.0f); + #pragma unroll + for(int i = 0; i < ite; i++) + { + int col_id = i * blockDim.x + tid; + int id = bid * n / 2 + col_id; + local_out_half2[i] = out_ptr[id] + __ldg(&input_ptr[id]) + __ldg(&bias_ptr[col_id]); + sum += local_out_half2[i]; + } + + mean = blockReduceSum((float)(sum.x + sum.y)); + if(threadIdx.x == 0) + s_mean = mean / n; + __syncthreads(); + + float var = 0.0f; + half2 s_mean_2 = __float2half2_rn(s_mean); + #pragma unroll + for(int i = 0; i < ite; i++) + { + local_out_half2[i] = local_out_half2[i] - s_mean_2; + float v1 = (float)local_out_half2[i].x; + float v2 = (float)local_out_half2[i].y; + var += v1 * v1 + v2 * v2; + } + + variance = blockReduceSum(var); + if(threadIdx.x == 0) + s_variance = rsqrtf(variance / n + 1e-6f); + __syncthreads(); + + half2 s_var_2 = __float2half2_rn(s_variance); + #pragma unroll + for(int i = 0; i < ite; i++) + { + int col_id = i * blockDim.x + tid; + int id = bid * n / 2 + col_id; + out_ptr[id] = local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) + __ldg(&beta_ptr[col_id]); + } +} + +template +void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, ActivationType activation_type, cudaStream_t stream) +{ + const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 + dim3 block, grid; + if(n / 4 / data_type_factor <= 1024) + { + block.x = n / 4 / data_type_factor; + grid.x = m; + } + else + { + block.x = 1024; + grid.x = ceil(m * n / 1024.); + } + + + if(activation_type == ActivationType::RELU) + add_bias_relu<<>>(out, bias, m, n / data_type_factor); + else if(activation_type == ActivationType::GELU) + add_bias_gelu<<>>(out, bias, m, n / data_type_factor); +} + +template +void add_bias_input_layernorm_kernelLauncher(T* out, const T* input, const T* bias, + const T* gamma, const T* beta, int m, int n, cudaStream_t stream) +{ + dim3 grid(m); + dim3 block(n); + assert(n <= 1024); + if(n == 768 || n == 1024) + add_bias_input_layernorm_v2<<>>(out, input, bias, gamma, beta, n); + else + add_bias_input_layernorm<<>>(out, input, bias, gamma, beta, m, n); +} + +template <> +void add_bias_input_layernorm_kernelLauncher(half* out, const half* input, const half* bias, + const half* gamma, const half* beta, int m, int n, cudaStream_t stream) +{ + dim3 grid(m); + dim3 block(n / 2); + assert(n / 2 <= 1024); + + if(m >= 512 && (n == 768 || n == 1024)) + add_bias_input_layernorm_v2<<>>(out, input, bias, gamma, beta, n); + else + add_bias_input_layernorm<<>>(out, input, bias, gamma, beta, m, n); +} + +template +__global__ +void add_bias_input_layernorm_2(const T* __restrict input, + const T* __restrict gamma, + const T* __restrict beta, + const T* __restrict bias, + T* output, T* norm_output, + int m, int n) +{ + int tid = threadIdx.x; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + + float local_sum = 0.0f; + for(int i = tid; i < n; i+= blockDim.x) + { + float local_out = (float)(__ldg(&input[blockIdx.x * n + i])); + local_out += (float)(output[blockIdx.x * n + i]); + local_out += (float)(__ldg(&bias[i])); + output[blockIdx.x * n + i] = (T)local_out; + local_sum += local_out; + } + + mean = blockReduceSum(local_sum); + + if(threadIdx.x == 0) + s_mean = mean / n; + __syncthreads(); + + float local_var_sum = 0.0f; + for(int i = tid; i < n; i+= blockDim.x) + { + float diff = (float)(__ldg(&output[blockIdx.x * n + i])) - s_mean; + local_var_sum += diff * diff; + } + variance = blockReduceSum(local_var_sum); + + if(threadIdx.x == 0) + s_variance = rsqrtf(variance / n + 1e-6); + __syncthreads(); + + for(int i = tid; i < n; i+= blockDim.x) + { + norm_output[blockIdx.x * n + i] = + (T)((( (float)output[blockIdx.x * n + i] - s_mean) * s_variance) * (float)(__ldg(&gamma[i])) + (float)(__ldg(&beta[i]))); + } +} + +template +void add_bias_input_layernorm_2_kernelLauncher( + const T* input, + const T* gamma, + const T* beta, + const T* bias, + T* output, + T* norm_output, + int m, int n, + cudaStream_t stream) +{ + dim3 grid(m); + dim3 block(min(n, 1024)); + + /* For general cases, n is equal to hidden_units, e.g., 512/1024. + Since we have warp shuffle inside the code, block.x % 32 should be 0. + */ + + if(n % 32 != 0) + block.x = 1024; + + block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x + + /* should pay attention to the rsqrt precision*/ + add_bias_input_layernorm_2<<>>(input, gamma, beta, bias, output, norm_output, m, n); // For gpt-3 +} + +template +__global__ +void add_bias_input(T* output, const T* input, const T* bias, const int m, const int n) +{ + // This kernel can run with any block size and grid size + // Since the hidden dimension of GPT-3 would be larger than 1024 + const int bid = blockIdx.x; + const int blocks_per_row = n / blockDim.x; + const int col_index = (bid % blocks_per_row) * blockDim.x + threadIdx.x; + T bias_val = __ldg(&bias[col_index]); + for(int index = bid * blockDim.x + threadIdx.x; index < m * n; index += blockDim.x * gridDim.x) + { + output[index] = output[index] + input[index] + bias_val; + } +} + +template +void add_bias_input_kernelLauncher(T* output, const T* bias, const T* input, const int m, const int n, cudaStream_t stream) +{ + dim3 grid(min(m, 65536)); + dim3 block(min(n, 1024)); + + add_bias_input<<>>(output, input, bias, m, n); +} + +template +__global__ +void layer_norm_kernel_generalize(const T* __restrict input, + const T* __restrict gamma, + const T* __restrict beta, + T* output, + int m, int n) +{ + const int tid = threadIdx.x; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + + float local_sum = 0.0f; + for(int i = tid; i < n; i+= blockDim.x) + { + local_sum += (float)(__ldg(&input[blockIdx.x * n + i])); + } + + mean = blockReduceSum(local_sum); + + if(threadIdx.x == 0) + s_mean = mean / n; + __syncthreads(); + + float local_var_sum = 0.0f; + for(int i = tid; i < n; i+= blockDim.x) + { + float diff = (float)(__ldg(&input[blockIdx.x * n + i])) - s_mean; + local_var_sum += diff * diff; + } + variance = blockReduceSum(local_var_sum); + + if(threadIdx.x == 0) + s_variance = rsqrtf(variance / n + 1e-6); + + __syncthreads(); + + for(int i = tid; i < n; i+= blockDim.x) + { + output[blockIdx.x * n + i] = + (T)((( (float)input[blockIdx.x * n + i] - s_mean) * s_variance) * (float)(__ldg(&gamma[i])) + (float)(__ldg(&beta[i]))); + } +} + +template +void layer_norm( + const T* input, + const T* gamma, + const T* beta, + T* output, + int m, int n, + cudaStream_t stream) +{ + dim3 grid(m); + dim3 block(min(n, 1024)); + + /* For general cases, n is equal to hidden_units, e.g., 512/1024. + Since we have warp shuffle inside the code, block.x % 32 should be 0. + */ + if(n % 32 != 0) + block.x = 1024; + + block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x + + /* should pay attention to the rsqrt precision*/ + layer_norm_kernel_generalize<<>>(input, gamma, beta, output, m, n); // For gpt-3 +} + +template void add_bias_act_kernelLauncher( + float* out, const float* bias, int m, int n, ActivationType activation_type, cudaStream_t stream); + +template void add_bias_input_layernorm_kernelLauncher( + float* out, const float* input, const float* bias, const float* gamma, const float* beta, + int m, int n, cudaStream_t stream); + +template void add_bias_act_kernelLauncher( + half* out, const half* bias, int m, int n, ActivationType activation_type, cudaStream_t stream); + +template void add_bias_input_layernorm_kernelLauncher( + half* out, const half* input, const half* bias, const half* gamma, const half* beta, + int m, int n, cudaStream_t stream); + +template void add_bias_input_layernorm_2_kernelLauncher( + const float* input, + const float* gamma, + const float* beta, + const float* bias, + float* output, + float* norm_output, + int m, int n, cudaStream_t stream); + +template void add_bias_input_layernorm_2_kernelLauncher( + const half* input, + const half* gamma, + const half* beta, + const half* bias, + half* output, + half* norm_output, + int m, int n, cudaStream_t stream); + +template void add_bias_input_kernelLauncher( + float* output, + const float* bias, + const float* input, + const int m, + const int n, + cudaStream_t stream); + +template void add_bias_input_kernelLauncher( + half* output, + const half* bias, + const half* input, + const int m, + const int n, + cudaStream_t stream); + +template void layer_norm( + const float* input, + const float* gamma, + const float* beta, + float* output, + int m, int n, + cudaStream_t stream); + +template void layer_norm( + const half* input, + const half* gamma, + const half* beta, + half* output, + int m, int n, + cudaStream_t stream); + +} // namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/cuda/transformer_kernels.cuh b/fastertransformer/cuda/transformer_kernels.cuh new file mode 100644 index 000000000..5136d83c8 --- /dev/null +++ b/fastertransformer/cuda/transformer_kernels.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include "fastertransformer/utils/arguments.h" +#include +#include "fastertransformer/cuda/cuda_kernels.h" + +namespace fastertransformer +{ + +template +void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, ActivationType activation_type, cudaStream_t stream); + +template +void add_bias_input_layernorm_kernelLauncher(T *out, const T *input_tensor, + const T *bias, const T *gamma, + const T *beta, int m, int n, + cudaStream_t stream); + +template +void add_bias_input_layernorm_2_kernelLauncher(const T *from_tensor, const T *gamma, + const T *beta, const T *bias, + T *output, T *norm_output_buf_, + const int m, const int n, cudaStream_t stream); + +template +void add_bias_input_kernelLauncher(T *output, const T *bias, const T *input, const int m, const int n, cudaStream_t stream); + +template +void layer_norm(const T *from_tensor, const T *gamma, + const T *beta, T *norm_from_tensor_buf_, const int m, const int n, cudaStream_t stream); + +} // namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/decoding_beamsearch.h b/fastertransformer/decoding_beamsearch.h index f726345d6..92a07398d 100644 --- a/fastertransformer/decoding_beamsearch.h +++ b/fastertransformer/decoding_beamsearch.h @@ -19,11 +19,12 @@ #pragma once -#include "fastertransformer/common.h" -#include "fastertransformer/allocator.h" +#include "fastertransformer/utils/common.h" +#include "fastertransformer/utils/functions.h" +#include "fastertransformer/utils/allocator.h" #include "fastertransformer/open_decoder.h" #include "fastertransformer/cuda/cuda_kernels.h" -#include "fastertransformer/arguments.h" +#include "fastertransformer/utils/arguments.h" #include namespace fastertransformer @@ -42,7 +43,7 @@ class DecodingBeamsearch const cudaDataType_t AType_ = Traits_::AType; const cudaDataType_t BType_ = Traits_::BType; const cudaDataType_t CType_ = Traits_::CType; - int cublasAlgo_[1] = {20}; + std::map cublasAlgoMap_; OpenDecoder *decoder_; DataType_ **K_cache_; @@ -65,6 +66,11 @@ class DecodingBeamsearch void *topK_kernel_workspace = nullptr; size_t topk_workspace_size_ = 0; + void *cublas_workspace_ = nullptr; + + DataType_ *padded_embedding_kernel; + DataType_ *padded_embedding_bias; + DataType_ *tmp_logits_buf_; public: DecodingBeamsearch(const IAllocator &allocator, const int batch_size, @@ -74,8 +80,9 @@ class DecodingBeamsearch const int memory_hidden_units, const int memory_max_seq_len, const int start_id, const int end_id, const float beam_search_diversity_rate = -0.0f, - const bool is_fuse_topk_softMax = false) : allocator_(allocator), - is_fuse_topk_softMax_(is_fuse_topk_softMax) + const bool is_fuse_topk_softMax = true, + const bool is_fuse_qkv = false) : allocator_(allocator), + is_fuse_topk_softMax_(is_fuse_topk_softMax) { #ifndef NDEBUG PRINT_FUNC_NAME_(); @@ -87,10 +94,15 @@ class DecodingBeamsearch args_.size_per_head_ = size_per_head; args_.hidden_units_ = head_num * size_per_head; args_.decoder_layers_ = decoder_layers; - args_.vocab_size_ = vocab_size; args_.start_id_ = start_id; args_.end_id_ = end_id; args_.beam_search_diversity_rate_ = beam_search_diversity_rate; + if(args_.beam_width_ > 16) is_fuse_topk_softMax_ = false; + args_.vocab_size_ = vocab_size; + if(std::is_same::value) + args_.vocab_size_padded_ = vocab_size; + else if(std::is_same::value) + args_.vocab_size_padded_ = (int)(ceil(vocab_size / 8.)) * 8; K_cache_ = new DataType_ *[2]; V_cache_ = new DataType_ *[2]; @@ -98,52 +110,77 @@ class DecodingBeamsearch K_mem_cache_ = new DataType_ *[args_.decoder_layers_]; V_mem_cache_ = new DataType_ *[args_.decoder_layers_]; - decoder_ = new OpenDecoder(batch_size * beam_width, memory_max_seq_len, - head_num, size_per_head, memory_hidden_units); + decoder_ = new OpenDecoder(head_num, size_per_head, memory_hidden_units, is_fuse_qkv); + decoder_->set_max_batch_size(batch_size * beam_width); - int from_tensor_size = args_.batch_size_ * args_.beam_width_ * args_.hidden_units_; // type T - int decoder_workspace_size = decoder_->getWorkspaceSize(); // type T - int decoder_normed_result_buffer_size = args_.batch_size_ * args_.beam_width_ * args_.hidden_units_; // type T - int cache_size = args_.batch_size_ * args_.beam_width_ * args_.seq_len_ * args_.hidden_units_; // type T - int mem_cache_size = args_.batch_size_ * args_.beam_width_ * memory_max_seq_len * args_.hidden_units_; // type T + size_t from_tensor_size = args_.batch_size_ * args_.beam_width_ * args_.hidden_units_; // type T + size_t decoder_workspace_size = decoder_->getWorkspaceSize(); // type T + size_t decoder_normed_result_buffer_size = args_.batch_size_ * args_.beam_width_ * args_.hidden_units_; // type T + size_t cache_size = args_.batch_size_ * args_.beam_width_ * args_.seq_len_ * args_.hidden_units_; // type T + size_t mem_cache_size = args_.batch_size_ * args_.beam_width_ * memory_max_seq_len * args_.hidden_units_; // type T - int logits_buf_size = args_.batch_size_ * args_.beam_width_ * args_.vocab_size_; // type float - int cum_log_buf_size = args_.batch_size_ * args_.beam_width_; // type float - int word_ids_buf_size = args_.batch_size_ * args_.beam_width_; //type int - int finished_buf_size = args_.batch_size_ * args_.beam_width_; //type bool - int finished_count_size = (int)(ceil(1 / 32.)) * 32; // type int + size_t logits_buf_size = args_.batch_size_ * args_.beam_width_ * args_.vocab_size_padded_; // type float + size_t cum_log_buf_size = args_.batch_size_ * args_.beam_width_; // type float + size_t word_ids_buf_size = args_.batch_size_ * args_.beam_width_; //type int + size_t finished_buf_size = args_.batch_size_ * args_.beam_width_; //type bool + size_t finished_count_size = (size_t)(ceil(1 / 32.)) * 32; // type int - int storage_size_per_beam = 2 * args_.beam_width_ + SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS * (2 * MAX_K + 2); + size_t storage_size_per_beam = 2 * args_.beam_width_ + SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS * (2 * MAX_K + 2); args_.temp_storage_size_ = args_.batch_size_ * args_.beam_width_ * storage_size_per_beam; // type float + args_.temp_storage_size_ = (size_t)( + ceil(args_.batch_size_ * args_.beam_width_ * args_.beam_width_ / 4.) * 4 * 2 + + ceil(args_.batch_size_ * args_.beam_width_ * SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS * (2 * MAX_K + 2) / 4.) * 4 + ); + size_t padded_embedding_kernel_size = args_.hidden_units_ * args_.vocab_size_padded_; + size_t padded_embedding_bias_size = args_.vocab_size_padded_; + if(std::is_same::value || (std::is_same::value && args_.vocab_size_padded_ == args_.vocab_size_)) + { + padded_embedding_kernel_size = 0; + padded_embedding_bias_size = 0; + } // prevent memory misalinged address - logits_buf_size = (int)(ceil(logits_buf_size / 4.)) * 4; - cum_log_buf_size = (int)(ceil(cum_log_buf_size / 4.)) * 4; - word_ids_buf_size = (int)(ceil(word_ids_buf_size / 4.)) * 4; - finished_buf_size = (int)(ceil(finished_buf_size / 32.)) * 32; - args_.temp_storage_size_ = (int)(ceil(args_.temp_storage_size_ / 4.)) * 4; + logits_buf_size = (size_t)(ceil(logits_buf_size / 4.)) * 4; + cum_log_buf_size = (size_t)(ceil(cum_log_buf_size / 4.)) * 4; + word_ids_buf_size = (size_t)(ceil(word_ids_buf_size / 4.)) * 4; + finished_buf_size = (size_t)(ceil(finished_buf_size / 32.)) * 32; + const size_t tmp_logits_buf_size = logits_buf_size; // get workspace size of topk kernel topK_kernelLauncher(topK_kernel_workspace, topk_workspace_size_, logits_buf_, word_ids_buf_, + finished_buf_, args_, 0); - int datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + + size_t datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + (cache_size * 4 + mem_cache_size * 2) * args_.decoder_layers_ + decoder_normed_result_buffer_size; buf_ = reinterpret_cast(allocator_.malloc( + ((sizeof(DataType_) == sizeof(half)) ? CUBLAS_WORKSPACE_SIZE : 0) + sizeof(DataType_) * datatype_buf_size + sizeof(float) * (logits_buf_size + cum_log_buf_size) + + sizeof(DataType_) * tmp_logits_buf_size + + sizeof(DataType_) * padded_embedding_kernel_size + + sizeof(float) * padded_embedding_bias_size + sizeof(int) * word_ids_buf_size + sizeof(bool) * finished_buf_size + topk_workspace_size_ + sizeof(float) * args_.temp_storage_size_ + // should be always float sizeof(int) * finished_count_size)); - from_tensor_[0] = (DataType_ *)buf_; + if (sizeof(DataType_) == sizeof(half)) + { + cublas_workspace_ = buf_; + from_tensor_[0] = (DataType_ *)((char*)cublas_workspace_ + CUBLAS_WORKSPACE_SIZE); + } + else + { + cublas_workspace_ = nullptr; + from_tensor_[0] = (DataType_ *)(buf_); + } from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size); for (int i = 0; i < args_.decoder_layers_; ++i) @@ -152,11 +189,22 @@ class DecodingBeamsearch V_mem_cache_[i] = from_tensor_[1] + from_tensor_size + i * mem_cache_size * 2 + mem_cache_size; } - /* We use two-way buffer since we have to update KV buf at the end of each step. */ - K_cache_[0] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 0 * cache_size * args_.decoder_layers_; - K_cache_[1] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 1 * cache_size * args_.decoder_layers_; - V_cache_[0] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 2 * cache_size * args_.decoder_layers_; - V_cache_[1] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 3 * cache_size * args_.decoder_layers_; + if(args_.beam_width_ > 1) + { + /* We use two-way buffer since we have to update KV buf at the end of each step. */ + K_cache_[0] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 0 * cache_size * args_.decoder_layers_; + K_cache_[1] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 1 * cache_size * args_.decoder_layers_; + V_cache_[0] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 2 * cache_size * args_.decoder_layers_; + V_cache_[1] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 3 * cache_size * args_.decoder_layers_; + } + else + { + // if beam width is 1, we only need one buffer + K_cache_[0] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 0 * cache_size * args_.decoder_layers_; + K_cache_[1] = K_cache_[0]; + V_cache_[0] = V_mem_cache_[decoder_layers - 1] + mem_cache_size + 2 * cache_size * args_.decoder_layers_; + V_cache_[1] = V_cache_[0]; + } decoder_buf_ = V_cache_[1] + cache_size * args_.decoder_layers_; decoder_normed_result_buf_ = (decoder_buf_ + decoder_workspace_size); @@ -167,49 +215,45 @@ class DecodingBeamsearch temp_storage_ = (float *)(finished_buf_ + finished_buf_size); finished_count_buf_ = (int *)(temp_storage_ + args_.temp_storage_size_); topK_kernel_workspace = (void*)(finished_count_buf_ + finished_count_size); + padded_embedding_kernel = (DataType_*)((char*)topK_kernel_workspace + topk_workspace_size_); + padded_embedding_bias = (DataType_*)(padded_embedding_kernel + padded_embedding_kernel_size); + tmp_logits_buf_ = (DataType_*)(padded_embedding_bias + padded_embedding_bias_size); h_finished_buf_ = new bool[finished_buf_size]; - FILE *fd = fopen("decoding_gemm_config.in", "r"); - int err = 0; - if (fd == NULL) - printf("[WARNING] decoding_gemm_config.in is not found\n"); - else - { - err = fscanf(fd, "%d", &cublasAlgo_[0]); - fclose(fd); - } - if (err != 1) + int isConfigExist = access("decoding_gemm_config.in", 0); + if (isConfigExist == -1) { - printf("[WARNING] decoding loading GEMM algorithms error, using default GEMM algorithms!\n"); - if (Traits_::OpType == OperationType::FP32) - { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT; - } - else - { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - } + printf("[WARNING] decoding_gemm_config.in is not found\n"); } else { + readAlgoFromConfig(cublasAlgoMap_, 1); // check that the gemm_config setting is runnable - if (Traits_::OpType == OperationType::FP32) + for (auto iter = cublasAlgoMap_.begin() ; iter != cublasAlgoMap_.end() ; iter++) { - if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO23 || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT) + int algoId = iter->second.algoId; + int stages = iter->second.stages; + //only check for cublas + if (stages != -1) + continue; + if (Traits_::OpType == OperationType::FP32) { - // the algorithm is not for FP32 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", (int)cublasAlgo_[0]); - exit(-1); + if (algoId > CUBLAS_GEMM_ALGO23 || algoId < CUBLAS_GEMM_DEFAULT) + { + // the algorithm is not for FP32 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", algoId); + exit(-1); + } } - } - else - { - if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO15_TENSOR_OP || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + else { - // the algorithm is not for FP16 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", (int)cublasAlgo_[0]); - exit(-1); + if (algoId > CUBLAS_GEMM_ALGO15_TENSOR_OP || algoId < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + { + // the algorithm is not for FP16 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", algoId); + exit(-1); + } } } } @@ -224,7 +268,9 @@ class DecodingBeamsearch #endif const int m = args_.batch_size_ * args_.beam_width_; const int k = args_.hidden_units_; - const int n = args_.vocab_size_; + const int n = args_.vocab_size_padded_; + const DataType_* embedding_kernel_ptr = nullptr; + const DataType_* embedding_bias_ptr = nullptr; /* sequence_length initialize to 0 @@ -235,6 +281,7 @@ class DecodingBeamsearch init_kernelLauncher(finished_buf_, decoding_params.sequence_length, word_ids_buf_, cum_log_buf_, args_.start_id_, args_.batch_size_, args_.beam_width_, decoding_params.stream); + #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); @@ -248,9 +295,36 @@ class DecodingBeamsearch // start_id_, batch_size_, beam_width_, decoding_params.stream); #endif + if(std::is_same::value || (std::is_same::value && args_.vocab_size_padded_ == args_.vocab_size_)) + { + embedding_kernel_ptr = (const DataType_ *)decoding_params.embedding_kernel; + embedding_bias_ptr = (const DataType_ *)decoding_params.embedding_bias; + } + else if(std::is_same::value) + { + + kernel_padding_kernelLauncher(padded_embedding_kernel, decoding_params.embedding_kernel, args_.hidden_units_, + args_.vocab_size_, args_.vocab_size_padded_, decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + bias_padding_kernelLauncher(padded_embedding_bias, decoding_params.embedding_bias, + args_.vocab_size_, args_.vocab_size_padded_, decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + embedding_kernel_ptr = padded_embedding_kernel; + embedding_bias_ptr = padded_embedding_bias; + } + int cache_size = m * args_.seq_len_ * args_.hidden_units_; // type T - for (int step = 1; step <= args_.seq_len_; ++step) + for (uint step = 1; step <= args_.seq_len_; ++step) { //we use two-way buffer int kv_cache_id = step & 0x1; @@ -281,7 +355,7 @@ class DecodingBeamsearch The decoder_buf_ is reused. */ - decoder_->initialize(param[layer], decoder_buf_); + decoder_->initialize(param[layer], decoder_buf_, cublas_workspace_); #ifndef NDEBUG cudaDeviceSynchronize(); @@ -291,34 +365,32 @@ class DecodingBeamsearch K_cache_[kv_cache_id] + layer * cache_size, V_cache_[kv_cache_id] + layer * cache_size, K_mem_cache_[layer], V_mem_cache_[layer], - decoding_params.memory_sequence_length, from_tensor_[out_id], step, - true); + decoding_params.memory_sequence_length, from_tensor_[out_id], step, args_.seq_len_, + true, finished_buf_); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif } - decoder_->decoder_norm1(from_tensor_[out_id], decoding_params.layernorm.gamma, - decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k); - - float alpha = (float)1.0f; - float beta = (float)0.0f; - - check_cuda_error(cublasGemmEx(decoding_params.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - decoding_params.embedding_kernel, AType_, n, - decoder_normed_result_buf_, BType_, k, - &beta, - logits_buf_, CUDA_R_32F, n, -#ifdef CUDA11_MODE - CUBLAS_COMPUTE_32F_PEDANTIC, -#else - CUDA_R_32F, -#endif - static_cast(cublasAlgo_[0]))); + layer_norm(from_tensor_[out_id], decoding_params.layernorm.gamma, + decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k, decoding_params.stream); + + DataType_ alpha = (DataType_)1.0f; + DataType_ beta = (DataType_)0.0f; + + cublasMM_cublasLtMM_wrapper_decoder(decoding_params.cublaslt_handle, + decoding_params.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + embedding_kernel_ptr, AType_, n, + decoder_normed_result_buf_, BType_, k, + &beta, + tmp_logits_buf_, CType_, n, + decoding_params.stream, cublasAlgoMap_, + cublas_workspace_); + #ifndef NDEBUG cudaDeviceSynchronize(); @@ -328,8 +400,8 @@ class DecodingBeamsearch // Beamsearch if (is_fuse_topk_softMax_ == true) { - topK_softMax(logits_buf_, - decoding_params.embedding_bias, + topK_softMax(tmp_logits_buf_, + embedding_bias_ptr, finished_buf_, cum_log_buf_, word_ids_buf_, @@ -356,7 +428,7 @@ class DecodingBeamsearch } else { - update_logits(logits_buf_, decoding_params.embedding_bias, args_.end_id_, finished_buf_, m, n, decoding_params.stream); + update_logits(logits_buf_, tmp_logits_buf_, embedding_bias_ptr, args_.end_id_, finished_buf_, m, n, decoding_params.stream); #ifndef NDEBUG cudaDeviceSynchronize(); @@ -367,12 +439,12 @@ class DecodingBeamsearch update_logits_kernel_check will compare the results of GPU and CPU. Note that update_logits_kernel_check contains update_logits and uses do not need to call it again. */ - // update_logits_kernel_check(logits_buf_, decoding_params.embedding_bias, args_.end_id_, finished_buf_, m, n, decoding_params.stream); + // update_logits_kernel_check(logits_buf_, tmp_logits_buf_, decoding_params.embedding_bias, args_.end_id_, finished_buf_, m, n, decoding_params.stream); #endif /* adding cum_log_buf_ to logits_buf_ */ broadcast_kernelLauncher(logits_buf_, cum_log_buf_, args_.batch_size_, - args_.beam_width_, args_.vocab_size_, decoding_params.stream); + args_.beam_width_, args_.vocab_size_padded_, decoding_params.stream); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); @@ -382,13 +454,14 @@ class DecodingBeamsearch broadcast_kernel_check will compare the results of GPU and CPU. Note that broadcast_kernel_check contains broadcast_kernelLauncher and uses do not need to call it again. */ - // broadcast_kernel_check(logits_buf_, cum_log_buf_, batch_size_, beam_width_, vocab_size_, decoding_params.stream); + // broadcast_kernel_check(logits_buf_, cum_log_buf_, batch_size_, beam_width_, args_.vocab_size_padded_, decoding_params.stream); #endif topK_kernelLauncher(topK_kernel_workspace, topk_workspace_size_, logits_buf_, word_ids_buf_, + finished_buf_, args_, decoding_params.stream); #ifndef NDEBUG @@ -401,7 +474,7 @@ class DecodingBeamsearch decoding_params.sequence_length, word_ids_buf_, decoding_params.output_ids + (step - 1) * m, - args_.batch_size_, args_.beam_width_, args_.vocab_size_, + args_.batch_size_, args_.beam_width_, args_.vocab_size_padded_, decoding_params.stream, args_.end_id_, finished_count_buf_); } @@ -410,21 +483,27 @@ class DecodingBeamsearch check_cuda_error(cudaGetLastError()); #endif - update_KV_cache_kernelLauncher(K_cache_, V_cache_, - decoding_params.parent_ids + (step - 1) * m, - args_.batch_size_, args_.beam_width_, args_.hidden_units_, step, - cache_size, args_.decoder_layers_, decoding_params.stream); + if(args_.beam_width_ > 1) + { + // chose which self cache to use + int decoder_max_seq_len = (decoder_->getCacheFormat() != 0)? args_.seq_len_ : -1; + update_KV_cache_kernelLauncher(K_cache_, V_cache_, + decoding_params.parent_ids + (step - 1) * m, + finished_buf_, + args_.batch_size_, args_.beam_width_, args_.head_num_, args_.size_per_head_, step, decoder_max_seq_len, + cache_size, args_.decoder_layers_, decoding_params.stream); #ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); /* User can check the update_KV_cache by update_KV_cache_kernel_check. update_KV_cache_kernel_check will compare the results of GPU and CPU. Note that update_KV_cache_kernel_check contains update_KV_cache and uses do not need to call it again. */ - // update_KV_cache_kernel_check(K_cache_, V_cache_, decoding_params.parent_ids + (step - 1) * batch_size_ * beam_width_, batch_size_, beam_width_, hidden_units_, step, cache_size, decoder_layers_, decoding_params.stream); + // update_KV_cache_kernel_check(K_cache_, V_cache_, decoding_params.parent_ids + (step - 1) * batch_size_ * beam_width_, batch_size_, beam_width_, head_num_, size_per_head_, step, cache_size, decoder_layers_, decoding_params.stream); #endif + } // TODO Find a better method to check the is_finished cudaMemcpy(h_finished_buf_, finished_buf_, sizeof(bool) * m, cudaMemcpyDeviceToHost); diff --git a/fastertransformer/decoding_sampling.h b/fastertransformer/decoding_sampling.h index b13ff8c27..171226b2d 100644 --- a/fastertransformer/decoding_sampling.h +++ b/fastertransformer/decoding_sampling.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,11 +19,12 @@ #pragma once -#include "fastertransformer/common.h" -#include "fastertransformer/allocator.h" +#include "fastertransformer/utils/common.h" +#include "fastertransformer/utils/functions.h" +#include "fastertransformer/utils/allocator.h" #include "fastertransformer/open_decoder.h" #include "fastertransformer/cuda/cuda_kernels.h" -#include "fastertransformer/arguments.h" +#include "fastertransformer/utils/arguments.h" #include namespace fastertransformer @@ -42,7 +43,7 @@ class DecodingSampling const cudaDataType_t AType_ = Traits_::AType; const cudaDataType_t BType_ = Traits_::BType; const cudaDataType_t CType_ = Traits_::CType; - int cublasAlgo_[1] = {20}; + std::map cublasAlgoMap_; OpenDecoder *decoder_; DataType_ **K_cache_; @@ -55,7 +56,7 @@ class DecodingSampling DataType_ *logits_buf_; int *word_ids_buf_; bool *finished_buf_; - + void *buf_; int *finished_count_buf_; bool *h_finished_buf_; @@ -64,8 +65,14 @@ class DecodingSampling size_t topk_workspace_size_ = 0; void *topp_workspace_ = nullptr; size_t topp_workspace_size_ = 0; + void *cublas_workspace_ = nullptr; + curandState_t *curandstate_buf_; int *topp_id_vals_buf_; int *topp_offset_buf_; + int *begin_topp_offset_buf_; + + DataType_ *padded_embedding_kernel; + DataType_ *padded_embedding_bias; public: DecodingSampling(const IAllocator &allocator, const int batch_size, @@ -75,7 +82,8 @@ class DecodingSampling const int memory_hidden_units, const int memory_max_seq_len, const int start_id, const int end_id, const int candidate_num = 0, - const float probability_threshold = 0.0) : allocator_(allocator) + const float probability_threshold = 0.0, + const int is_fuse_qkv = false) : allocator_(allocator) { args_.batch_size_ = batch_size; args_.seq_len_ = seq_len; @@ -84,6 +92,11 @@ class DecodingSampling args_.hidden_units_ = head_num * size_per_head; args_.decoder_layers_ = decoder_layers; args_.vocab_size_ = vocab_size; + if(std::is_same::value) + args_.vocab_size_padded_ = vocab_size; + else if(std::is_same::value) + args_.vocab_size_padded_ = (int)(ceil(vocab_size / 8.)) * 8; + args_.candidate_num_ = candidate_num; args_.probability_threshold_ = probability_threshold; args_.start_id_ = start_id; @@ -108,64 +121,90 @@ class DecodingSampling K_mem_cache_ = new DataType_ *[args_.decoder_layers_]; V_mem_cache_ = new DataType_ *[args_.decoder_layers_]; - decoder_ = new OpenDecoder(batch_size, memory_max_seq_len, - head_num, size_per_head, memory_hidden_units); - - int from_tensor_size = args_.batch_size_ * args_.hidden_units_; // type T - int decoder_workspace_size = decoder_->getWorkspaceSize(); // type T - int decoder_normed_result_buffer_size = args_.batch_size_ * args_.hidden_units_; // type T - int cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_; // type T - int mem_cache_size = args_.batch_size_ * memory_max_seq_len * args_.hidden_units_; // type T - int logits_buf_size = args_.batch_size_ * args_.vocab_size_; // type T - - int word_ids_buf_size = args_.batch_size_; //type int - int finished_buf_size = args_.batch_size_; //type bool - int finished_count_size = (int)(ceil(1 / 32.)) * 32; // type int - - int topp_id_vals_buf_size = args_.batch_size_ * args_.vocab_size_; // type int - int topp_offset_buf_size = args_.batch_size_ + 1; // type int + decoder_ = new OpenDecoder(head_num, size_per_head, memory_hidden_units, is_fuse_qkv); + decoder_->set_max_batch_size(batch_size); + + size_t from_tensor_size = args_.batch_size_ * args_.hidden_units_; // type T + size_t decoder_workspace_size = decoder_->getWorkspaceSize(); // type T + size_t decoder_normed_result_buffer_size = args_.batch_size_ * args_.hidden_units_; // type T + size_t cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_; // type T + size_t mem_cache_size = args_.batch_size_ * memory_max_seq_len * args_.hidden_units_; // type T + size_t logits_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type T + + size_t word_ids_buf_size = args_.batch_size_; //type int + size_t finished_buf_size = args_.batch_size_; //type bool + size_t finished_count_size = (size_t)(ceil(1 / 32.)) * 32; // type int + + size_t topp_id_vals_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type int + size_t topp_offset_buf_size = args_.batch_size_ + 1; // type int + size_t begin_topp_offset_buf_size = topp_offset_buf_size; + size_t curandState_size = args_.batch_size_; + size_t padded_embedding_kernel_size = args_.hidden_units_ * args_.vocab_size_padded_; + size_t padded_embedding_bias_size = args_.vocab_size_padded_; + if(std::is_same::value || (std::is_same::value && args_.vocab_size_ == args_.vocab_size_padded_)) + { + padded_embedding_kernel_size = 0; + padded_embedding_bias_size = 0; + } // prevent memory misalinged address - logits_buf_size = (int)(ceil(logits_buf_size / 4.)) * 4; - word_ids_buf_size = (int)(ceil(word_ids_buf_size / 4.)) * 4; - finished_buf_size = (int)(ceil(finished_buf_size / 32.)) * 32; - - topp_id_vals_buf_size = (int)(ceil(topp_id_vals_buf_size / 4.)) * 4; - topp_offset_buf_size = (int)(ceil(topp_offset_buf_size / 4.)) * 4; - topP_sampling_kernel_kernelLauncher(topp_workspace_, + logits_buf_size = (size_t)(ceil(logits_buf_size / 4.)) * 4; + word_ids_buf_size = (size_t)(ceil(word_ids_buf_size / 4.)) * 4; + finished_buf_size = (size_t)(ceil(finished_buf_size / 32.)) * 32; + + topp_id_vals_buf_size = (size_t)(ceil(topp_id_vals_buf_size / 4.)) * 4; + topp_offset_buf_size = (size_t)(ceil(topp_offset_buf_size / 4.)) * 4; + begin_topp_offset_buf_size = topp_offset_buf_size; + + topP_sampling_kernel_kernelLauncher_v2(topp_workspace_, topp_workspace_size_, logits_buf_, topp_id_vals_buf_, topp_offset_buf_, + begin_topp_offset_buf_, finished_buf_, - 0, + curandstate_buf_, args_, nullptr, nullptr, - args_.vocab_size_, - 0); - topK_sampling_kernel_kernelLauncher(topk_workspace_, - topk_workspace_size_, - logits_buf_, - nullptr, - nullptr, - finished_buf_, + args_.vocab_size_padded_, 0, - args_, - 0); - - int datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + + args_.batch_size_); + + topK_sampling_kernel_kernelLauncher_v2(topk_workspace_, + topk_workspace_size_, + logits_buf_, + nullptr, + nullptr, + finished_buf_, + curandstate_buf_, + args_, + 0, + args_.batch_size_); + + size_t datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + (cache_size * 4 + mem_cache_size * 2) * args_.decoder_layers_ + decoder_normed_result_buffer_size; buf_ = reinterpret_cast(allocator_.malloc( - sizeof(DataType_) * (datatype_buf_size + logits_buf_size) + + ( (sizeof(DataType_) == sizeof(half)) ? CUBLAS_WORKSPACE_SIZE : 0 ) + + sizeof(DataType_) * (datatype_buf_size + logits_buf_size) + + sizeof(DataType_) * (padded_embedding_kernel_size + padded_embedding_bias_size) + sizeof(int) * word_ids_buf_size + sizeof(bool) * finished_buf_size + sizeof(int) * finished_count_size + - sizeof(int) * (topp_id_vals_buf_size + topp_offset_buf_size) + - topp_workspace_size_ + topk_workspace_size_)); - - from_tensor_[0] = (DataType_ *)buf_; + sizeof(int) * (topp_id_vals_buf_size + 2 * topp_offset_buf_size) + + topp_workspace_size_ + topk_workspace_size_ + curandState_size * sizeof(curandState_t))); + + if (sizeof(DataType_) == sizeof(half)) + { + cublas_workspace_ = buf_; + from_tensor_[0] = (DataType_ *)((char*)cublas_workspace_ + CUBLAS_WORKSPACE_SIZE); + } + else + { + cublas_workspace_ = nullptr; + from_tensor_[0] = (DataType_ *)buf_; + } from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size); for (int i = 0; i < args_.decoder_layers_; ++i) @@ -185,52 +224,49 @@ class DecodingSampling finished_buf_ = (bool *)(word_ids_buf_ + word_ids_buf_size); finished_count_buf_ = (int *)(finished_buf_ + finished_buf_size); topp_id_vals_buf_ = (int *)(finished_count_buf_ + finished_count_size); - topp_offset_buf_ = (int *)(topp_id_vals_buf_ + topp_id_vals_buf_size); + begin_topp_offset_buf_ = (int *)(topp_id_vals_buf_ + topp_id_vals_buf_size); + topp_offset_buf_ = (int *)(begin_topp_offset_buf_ + begin_topp_offset_buf_size); topp_workspace_ = (void*)(topp_offset_buf_ + topp_offset_buf_size); - topk_workspace_ = (void*)(topp_workspace_ + topp_workspace_size_); + topk_workspace_ = (void*)((char*)topp_workspace_ + topp_workspace_size_); + padded_embedding_kernel = (DataType_*)((char*)topk_workspace_ + topk_workspace_size_); + padded_embedding_bias = (DataType_*)(padded_embedding_kernel + padded_embedding_kernel_size); + curandstate_buf_ = (curandState_t*)(padded_embedding_bias + padded_embedding_bias_size); h_finished_buf_ = new bool[finished_buf_size]; - FILE *fd = fopen("decoding_gemm_config.in", "r"); - int err = 0; - if (fd == NULL) - printf("[WARNING] decoding_gemm_config.in is not found\n"); - else + int isConfigExist = access("decoding_gemm_config.in", 0); + if (isConfigExist == -1) { - err = fscanf(fd, "%d", &cublasAlgo_[0]); - fclose(fd); - } - if (err != 1) - { - printf("[WARNING] decoding loading GEMM algorithms error, using default GEMM algorithms!\n"); - if (Traits_::OpType == OperationType::FP32) - { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT; - } - else - { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - } + printf("[WARNING] decoding_gemm_config.in is not found\n"); } else { + readAlgoFromConfig(cublasAlgoMap_, 1); // check that the gemm_config setting is runnable - if (Traits_::OpType == OperationType::FP32) + for (auto iter = cublasAlgoMap_.begin() ; iter != cublasAlgoMap_.end() ; iter++) { - if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO23 || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT) + int algoId = iter->second.algoId; + int stages = iter->second.stages; + //only check for cublas + if (stages != -1) + continue; + if (Traits_::OpType == OperationType::FP32) { - // the algorithm is not for FP32 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", (int)cublasAlgo_[0]); - exit(-1); + if (algoId > CUBLAS_GEMM_ALGO23 || algoId < CUBLAS_GEMM_DEFAULT) + { + // the algorithm is not for FP32 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", algoId); + exit(-1); + } } - } - else - { - if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO15_TENSOR_OP || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + else { - // the algorithm is not for FP16 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", (int)cublasAlgo_[0]); - exit(-1); + if (algoId > CUBLAS_GEMM_ALGO15_TENSOR_OP || algoId < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + { + // the algorithm is not for FP16 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", algoId); + exit(-1); + } } } } @@ -245,7 +281,9 @@ class DecodingSampling #endif const int m = args_.batch_size_; const int k = args_.hidden_units_; - const int n = args_.vocab_size_; + const int n = args_.vocab_size_padded_; + const DataType_* embedding_kernel_ptr = nullptr; + const DataType_* embedding_bias_ptr = nullptr; /* sequence_length initialize to 0 @@ -260,24 +298,54 @@ class DecodingSampling } else if (args_.probability_threshold_ != 0.0) { - topp_initialization_kernelLauncher(finished_buf_, - decoding_params.sequence_length, - word_ids_buf_, - topp_id_vals_buf_, - topp_offset_buf_, - args_.vocab_size_, - args_, - decoding_params.stream); + topp_initialization_kernelLauncher_v2(finished_buf_, + decoding_params.sequence_length, + word_ids_buf_, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + args_.vocab_size_padded_, + args_, + decoding_params.stream); } + ker_curand_setupLauncher(curandstate_buf_, + args_, + decoding_params.stream); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif + if(std::is_same::value || (std::is_same::value && args_.vocab_size_ == args_.vocab_size_padded_)) + { + embedding_kernel_ptr = (const DataType_ *)decoding_params.embedding_kernel; + embedding_bias_ptr = (const DataType_ *)decoding_params.embedding_bias; + } + else if(std::is_same::value) + { + kernel_padding_kernelLauncher(padded_embedding_kernel, decoding_params.embedding_kernel, args_.hidden_units_, + args_.vocab_size_, args_.vocab_size_padded_, decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + bias_padding_kernelLauncher(padded_embedding_bias, decoding_params.embedding_bias, + args_.vocab_size_, args_.vocab_size_padded_, decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + embedding_kernel_ptr = padded_embedding_kernel; + embedding_bias_ptr = padded_embedding_bias; + } + int cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_; // type T - for (int step = 1; step <= args_.seq_len_; ++step) + for (uint step = 1; step <= args_.seq_len_; ++step) { embedding_lookup_sine_position_encoding_kernel_launcher(from_tensor_[0], decoding_params.embedding_table, @@ -305,7 +373,7 @@ class DecodingSampling The decoder_buf_ is reused. */ - decoder_->initialize(param[layer], decoder_buf_); + decoder_->initialize(param[layer], decoder_buf_, cublas_workspace_); #ifndef NDEBUG cudaDeviceSynchronize(); @@ -315,30 +383,32 @@ class DecodingSampling K_cache_[0] + layer * cache_size, V_cache_[0] + layer * cache_size, K_mem_cache_[layer], V_mem_cache_[layer], - decoding_params.memory_sequence_length, from_tensor_[out_id], step, - true); + decoding_params.memory_sequence_length, from_tensor_[out_id], step, args_.seq_len_, + true, finished_buf_); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif } - decoder_->decoder_norm1(from_tensor_[out_id], decoding_params.layernorm.gamma, - decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k); - + layer_norm(from_tensor_[out_id], decoding_params.layernorm.gamma, + decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k, decoding_params.stream); + DataType_ alpha = (DataType_)1.0f; DataType_ beta = (DataType_)0.0f; - check_cuda_error(cublasGemmEx(decoding_params.cublas_handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n, m, k, - &alpha, - decoding_params.embedding_kernel, AType_, n, - decoder_normed_result_buf_, BType_, k, - &beta, - logits_buf_, CType_, n, - computeType_, - static_cast(cublasAlgo_[0]))); + cublasMM_cublasLtMM_wrapper_decoder(decoding_params.cublaslt_handle, + decoding_params.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + embedding_kernel_ptr, AType_, n, + decoder_normed_result_buf_, BType_, k, + &beta, + logits_buf_, CType_, n, + decoding_params.stream, cublasAlgoMap_, + cublas_workspace_); + #ifndef NDEBUG cudaDeviceSynchronize(); @@ -349,42 +419,55 @@ class DecodingSampling { // top k sampling update_logits_without_softmax(logits_buf_, - decoding_params.embedding_bias_T, + embedding_bias_ptr, args_.end_id_, finished_buf_, m, n, decoding_params.stream); - topK_sampling_kernel_kernelLauncher(topk_workspace_, - topk_workspace_size_, - logits_buf_, - decoding_params.output_ids + (step - 1) * args_.batch_size_, - decoding_params.sequence_length, - finished_buf_, - step, // used as random number - args_, - decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + topK_sampling_kernel_kernelLauncher_v2(topk_workspace_, + topk_workspace_size_, + logits_buf_, + decoding_params.output_ids + (step - 1) * args_.batch_size_, + decoding_params.sequence_length, + finished_buf_, + curandstate_buf_, // used as random number + args_, + decoding_params.stream, + args_.batch_size_); } else if (args_.probability_threshold_ != 0.0) { // top p sampling softmax_kernelLauncher(logits_buf_, - decoding_params.embedding_bias_T, + embedding_bias_ptr, args_.end_id_, finished_buf_, - m, n, decoding_params.stream); + m, n, n, decoding_params.stream); - topP_sampling_kernel_kernelLauncher(topp_workspace_, - topp_workspace_size_, - logits_buf_, - topp_id_vals_buf_, - topp_offset_buf_, - finished_buf_, - step, - args_, - decoding_params.output_ids + (step - 1) * args_.batch_size_, - decoding_params.sequence_length, - n, - decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + topP_sampling_kernel_kernelLauncher_v2(topp_workspace_, + topp_workspace_size_, + logits_buf_, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + finished_buf_, + curandstate_buf_, + args_, + decoding_params.output_ids + (step - 1) * args_.batch_size_, + decoding_params.sequence_length, + n, + decoding_params.stream, + args_.batch_size_); } word_ids_buf_ = decoding_params.output_ids + (step - 1) * args_.batch_size_; @@ -396,8 +479,8 @@ class DecodingSampling // TODO Find a better method to check the is_finished cudaMemcpy(h_finished_buf_, finished_buf_, sizeof(bool) * args_.batch_size_, cudaMemcpyDeviceToHost); - int sum = 0; - for (int i = 0; i < args_.batch_size_; i++) + uint sum = 0; + for (uint i = 0; i < args_.batch_size_; i++) { sum += (int)h_finished_buf_[i]; } diff --git a/fastertransformer/faster_transformer.h b/fastertransformer/faster_transformer.h deleted file mode 100644 index c969e3dc8..000000000 --- a/fastertransformer/faster_transformer.h +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * c++ interface of Faster Transformer - **/ - -#pragma once - -#include "fastertransformer/bert_encoder_transformer.h" -#include -namespace fastertransformer{ - - - -}//namespace fastertransformer diff --git a/fastertransformer/gemm_test/encoder_gemm_func.cc b/fastertransformer/gemm_test/encoder_gemm_func.cc index 7f074d9cd..6414b4459 100644 --- a/fastertransformer/gemm_test/encoder_gemm_func.cc +++ b/fastertransformer/gemm_test/encoder_gemm_func.cc @@ -15,26 +15,503 @@ */ #include "encoder_gemm_func.h" -#include "fastertransformer/common.h" +#include "fastertransformer/utils/common.h" #include namespace fastertransformer{ -double diffTime(timeval start, timeval end) +// Utility function to print customMatmulPerf_t structure +int printPerfStructure(int batch_size, int seq_len, int head_num, int size_per_head, int m, int n, int k, const customMatmulPerf_t &perf, FILE* fout, int is_fp16, int hasPrint) { + int algoId, tile, swizzle, customOption, numSplitsK, reductionScheme, stages; + + const cublasLtMatmulAlgo_t *matmulAlgo = &perf.algo; + cublasLtMatmulAlgoConfigGetAttribute( matmulAlgo, CUBLASLT_ALGO_CONFIG_ID, &algoId, sizeof(algoId), NULL); + cublasLtMatmulAlgoConfigGetAttribute( matmulAlgo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tile, sizeof(tile), NULL); + cublasLtMatmulAlgoConfigGetAttribute( matmulAlgo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &numSplitsK, sizeof(numSplitsK), NULL); + cublasLtMatmulAlgoConfigGetAttribute( matmulAlgo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &reductionScheme, sizeof(reductionScheme), NULL); + cublasLtMatmulAlgoConfigGetAttribute( matmulAlgo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &swizzle, sizeof(swizzle), NULL); + cublasLtMatmulAlgoConfigGetAttribute( matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), NULL); +#ifdef CUDA11_MODE + cublasLtMatmulAlgoConfigGetAttribute( matmulAlgo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stages, sizeof(stages), NULL); +#else + stages=0; +#endif + + printf("algo={ Id=%d, tileIdx=%d (%s) splitK=%d reduc=%d swizzle=%d custom=%d stages=%d} status %d " + "time %fms workspace=%d mathMode=%d waves=%f\n", + algoId, tile, matmulTileName[tile], + numSplitsK, reductionScheme, + swizzle, customOption, stages, + perf.status, + perf.time, + (int)perf.workspaceSize, + (int)perf.mathMode, + perf.wavesCount); + if (hasPrint == 0){ + fprintf(fout, "%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d %f\n", batch_size, seq_len, head_num, size_per_head, is_fp16 ? HALF_DATATYPE:FLOAT_DATATYPE, + 1, m, n, k, algoId, customOption, tile, numSplitsK, swizzle, reductionScheme, (int)perf.workspaceSize, stages, perf.time); + return 1; + } + else{ + return hasPrint; + } + +} + +static inline bool +time_compare(const customMatmulPerf_t &perf_a, const customMatmulPerf_t &perf_b) { + return ((perf_a.status == CUBLAS_STATUS_SUCCESS) && (perf_a.time < perf_b.time)); +} + + +static cublasStatus_t +customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required a GPU) + cublasLtMatmulDesc_t operationDesc, + const void *alpha, /* host or device pointer */ + const void *A, + cublasLtMatrixLayout_t Adesc, + const void *B, + cublasLtMatrixLayout_t Bdesc, + const void *beta, /* host or device pointer */ + const void *C, + cublasLtMatrixLayout_t Cdesc, + void *D, + cublasLtMatrixLayout_t Ddesc, + const cublasLtMatmulAlgo_t &algo, + int kernelRepeats, + void *workSpace, + size_t workSpaceSizeInBytes, + customMatmulPerf_t &perfResults, + cudaStream_t stream, + cudaEvent_t &startEvent, + cudaEvent_t &stopEvent) { - return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001; + cublasLtMatmulHeuristicResult_t heurResult; + /* Looping over the Algo */ + int repeats = kernelRepeats; + cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( ltHandle, + operationDesc, + Adesc, + Bdesc, + Cdesc, + Ddesc, + &algo, + &heurResult); + + if (algoStatus == CUBLAS_STATUS_SUCCESS) { + if (heurResult.workspaceSize <= workSpaceSizeInBytes) { + cudaError_t err, err1, err2, err3; + err = cudaEventRecord(startEvent, stream); + for (int loop = 0; loop < repeats; loop++) { + cublasStatus_t oneRunStatus = cublasLtMatmul( ltHandle, + operationDesc, + alpha, + A, Adesc, + B, Bdesc, + beta, + C, Cdesc, + D, Ddesc, + &algo, + workSpace, + workSpaceSizeInBytes, + stream); + if (oneRunStatus != CUBLAS_STATUS_SUCCESS) { + algoStatus = oneRunStatus; + break; + } + } + err1 = cudaEventRecord(stopEvent, stream); + err2 = cudaEventSynchronize(stopEvent); + float time; + err3 = cudaEventElapsedTime(&time, startEvent, stopEvent); + if ((err != cudaSuccess) || (err1 != cudaSuccess) || (err2 != cudaSuccess) || (err3 != cudaSuccess)) { + algoStatus = CUBLAS_STATUS_INTERNAL_ERROR; + } + // For the moment only add successful findings + if (algoStatus == CUBLAS_STATUS_SUCCESS) { + perfResults.algo = algo; + perfResults.time = time/repeats; + perfResults.workspaceSize = heurResult.workspaceSize; + perfResults.wavesCount = heurResult.wavesCount; + } + } + else { + //printf("not enough workspace! %ld\n", heurResult.workspaceSize); + algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; //Not enough workspace + } + } + + return algoStatus; } +template +int LtHgemmCustomFind(cublasLtHandle_t ltHandle, + int batch_size, + int seq_len, + int head_num, + int size_per_head, + int m, + int n, + int k, + const T *alpha, /* host pointer */ + const T *A, + const T *B, + const T *beta, /* host pointer */ + T *C, + void *workSpace, + size_t workSpaceSize, + FILE* fout, + customMatmulPerf_t perfResults[], + int AlgoCombinations) +{ + + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + cudaEvent_t startEvent; + cudaEvent_t stopEvent; + int is_fp16 = (sizeof(T) == sizeof(half) ? 1 : 0); + + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + + cudaStream_t stream = 0; + // SplitK value that we are going to try when SplitK is supported for a given algo + const int splitKSequenceA[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; + // Let try a fixed number of combinations + int AlgoCount = 0; + int AlgoCountRestrict = 0; // workspace == 0 + int maxNumTraversal = 50; // max number of traversal + cublasLtMatmulAlgo_t algos[AlgoCombinations]; // 0 <= workspace <= 32MB + cublasLtMatmulAlgo_t algosRestrict[AlgoCombinations]; // workspace == 0 + int kernelRepeats = 100; //number of time the CUDA kernels will be run back to back + int nbAlgoIds = 0; // Number of algorithms actually returned by cublasLtMatmulAlgoGetIds function. + #define ALGO_IDS 100 // Number of algorithms requested. + int algoIdA[ALGO_IDS]; // Array containing the algorithm IDs returned by cublasLtMatmulAlgoGetIds function. + cudaDataType_t Atype, Btype, Ctype, scaleType; +#ifdef CUDA11_MODE + cublasComputeType_t computeType; +#else + cudaDataType_t computeType; +#endif + + if(sizeof(T) == sizeof(float)){ + scaleType = CUDA_R_32F, Atype = CUDA_R_32F, Btype = CUDA_R_32F, Ctype = CUDA_R_32F; +#ifdef CUDA11_MODE + computeType = CUBLAS_COMPUTE_32F; +#else + computeType = CUDA_R_32F; +#endif + }else{ + scaleType = CUDA_R_16F, Atype = CUDA_R_16F, Btype = CUDA_R_16F, Ctype = CUDA_R_16F; +#ifdef CUDA11_MODE + computeType = CUBLAS_COMPUTE_16F; +#else + computeType = CUDA_R_16F; +#endif + } + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t for details about defaults; here we just need to + // set the transforms for A and B +#ifdef CUDA11_MODE + status = cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); // creates a matrix multiply descriptor +#else + status = cublasLtMatmulDescCreate(&operationDesc, computeType); +#endif + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create matrix descriptors. We are good with the details here so no need to set any extra attributes + status = cublasLtMatrixLayoutCreate( + &Adesc, Atype, m, k, m); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutCreate( + &Bdesc, Btype, k, n, k); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + status = cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, m); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create CUDA event to time the execution time of each algo + if (cudaEventCreate(&startEvent, cudaEventBlockingSync) != cudaSuccess) { + goto CLEANUP; + } + if (cudaEventCreate(&stopEvent, cudaEventBlockingSync) != cudaSuccess) { + goto CLEANUP; + } + + // Request the 100 first AlgoId available + status = cublasLtMatmulAlgoGetIds( ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, ALGO_IDS, algoIdA, &nbAlgoIds); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Loop over the Algo IDs + for (int idx = 0; (idx < nbAlgoIds) && (AlgoCount < AlgoCombinations); idx++) { + cublasLtMatmulAlgo_t algo; + size_t sizeWritten = 0; + /* Initialize algo structure with given Algp ID */ + status = cublasLtMatmulAlgoInit(ltHandle, computeType, scaleType, Atype, Btype, Ctype, Ctype, algoIdA[idx], &algo); + if (status != CUBLAS_STATUS_SUCCESS) { + continue; + } + // Query the tiles enums supported by that algo + cublasLtMatmulAlgoCapGetAttribute( &algo, CUBLASLT_ALGO_CAP_TILE_IDS, NULL, 0, &sizeWritten); + int nbTiles = int(sizeWritten/sizeof(int)); + int *tileA = new int[ nbTiles == 0 ? 1:nbTiles]; + if(nbTiles == 0){ + tileA[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; + nbTiles = 1; + } +#ifdef CUDA11_MODE + cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, NULL, 0, &sizeWritten); + int nbStages = int(sizeWritten/sizeof(int)); + std::vector stagesA(nbStages == 0 ? 1 : nbStages); + if (nbStages == 0) { + stagesA[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; + nbStages = 1; + } else { + cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_STAGES_IDS, stagesA.data(), sizeof(int)*nbStages, &sizeWritten); + } +#endif + int splitkSupport, redMask, swizzlingMax, customOptionMax; + // Retrieve Algo Capabilities attributes to be able to setup loop over the different combinations + cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_TILE_IDS, tileA, sizeof(int)*nbTiles, &sizeWritten); + cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, &splitkSupport, sizeof(splitkSupport), &sizeWritten); + cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, &redMask, sizeof(redMask), &sizeWritten); + cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, &swizzlingMax, sizeof(swizzlingMax), &sizeWritten); + cublasLtMatmulAlgoCapGetAttribute(&algo, CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, &customOptionMax, sizeof(customOptionMax), &sizeWritten); + + /* Loop over the different tiles */ + for (int tileIdx = 0; tileIdx < nbTiles; tileIdx++) { +#ifdef CUDA11_MODE + /* Loop over different stages count */ + for (int stagesIdx = 0; stagesIdx < nbStages; stagesIdx++) { + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &stagesA[stagesIdx], sizeof(stagesA[stagesIdx])); +#endif + /* Loop over the different custom option if any */ + for (int customOption = 0; customOption <= customOptionMax; customOption++) { + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption)); + /* Loop over the CTAs swizzling support */ + for (int k = 0; k <= swizzlingMax; k++) { + int splitK_trial = 0; + if (splitkSupport) { + splitK_trial += sizeof(splitKSequenceA) / sizeof(splitKSequenceA[0]); + } + // Loop over the splitK value over a fixed sequence splitKSequenceA in addtion to the case where splitK is not enabled + for (int l = 0; (l < (1 + splitK_trial)) && (AlgoCount < AlgoCombinations); l++) { + /* Setup attribute of the algo to run */ + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &tileA[tileIdx], sizeof(tileA[tileIdx])); + int splitK_val = 0; + int redScheme = CUBLASLT_REDUCTION_SCHEME_NONE; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitK_val, sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(int)); + + if (l > 0) { // Split-K case + splitK_val = splitKSequenceA[l - 1]; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &splitKSequenceA[l - 1], sizeof(splitKSequenceA[l - 1])); + /* Going over all the reduction scheme */ + for (redScheme = 1 ; redScheme < (int)CUBLASLT_REDUCTION_SCHEME_MASK && (AlgoCount < AlgoCombinations); redScheme = redScheme << 1) { + if (redScheme & redMask) { + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &redScheme, sizeof(redScheme)); + + cublasLtMatmulHeuristicResult_t heurResult; + cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( ltHandle, + operationDesc, + Adesc, + Bdesc, + Cdesc, + Cdesc, + &algo, + &heurResult); + if (heurResult.workspaceSize > workSpaceSize) { + // printf("not enough workspace! %ld\n", heurResult.workspaceSize); + algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; //Not enough workspace + }else if(heurResult.workspaceSize == 0){ + if(algoStatus == CUBLAS_STATUS_SUCCESS){ + algosRestrict[AlgoCountRestrict++] = algo; + } + } + if(algoStatus == CUBLAS_STATUS_SUCCESS){ + algos[AlgoCount++] = algo; + } + } // end if + } // end for + } else { // Non-splitK case + /* if user preference is ok with workspace */ + if (AlgoCount < AlgoCombinations) { + cublasLtMatmulHeuristicResult_t heurResult; + cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( ltHandle, + operationDesc, + Adesc, + Bdesc, + Cdesc, + Cdesc, + &algo, + &heurResult); + if (heurResult.workspaceSize > workSpaceSize) { + // printf("not enough workspace! %ld\n", heurResult.workspaceSize); + algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; //Not enough workspace + }else if(heurResult.workspaceSize == 0){ + if(algoStatus == CUBLAS_STATUS_SUCCESS){ + algosRestrict[AlgoCountRestrict++] = algo; + } + } + if(algoStatus == CUBLAS_STATUS_SUCCESS){ + algos[AlgoCount++] = algo; + } + } + } + } // end l + } // end k + } //end customOption +#ifdef CUDA11_MODE + } // end stagesIdx +#endif + } // end tileIdx + delete [] tileA; + } // end idx + + printf("AlgoCount: %d\n", AlgoCount); + if(AlgoCount < maxNumTraversal){ + // 0 <= workspacesize <= 32MB + for(int i=0;i void generate_encoder_gemm_config(int batch_size, int seq_len, int head_num, int size_per_head, - void *buffer, + void *buffer_in, bool isAppend) { + void *cublas_workspace; + void *buffer; + int workSpaceSize; + if (std::is_same::value) + { + //cublas_workspace_ should be the start pointer of cudaMalloc() + //to ensure 16B alignemnet + cublas_workspace = buffer_in; + buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); + workSpaceSize = CUBLAS_WORKSPACE_SIZE; + } + else + { + cublas_workspace = nullptr; + buffer = buffer_in; + workSpaceSize = 0; + } struct cudaDeviceProp prop; check_cuda_error(cudaGetDeviceProperties(&prop, 0)); @@ -42,6 +519,7 @@ void generate_encoder_gemm_config(int batch_size, //check config FILE *fd; + int line_count = 0; if (!isAppend) { fd = fopen(GEMM_CONFIG, "w+"); @@ -55,15 +533,18 @@ void generate_encoder_gemm_config(int batch_size, { config.push_back(std::string(line)); } - if (config.size() >= MAX_CONFIG_NUM*GEMM_NUM) + line_count = config.size(); + if (config.size() >= (MAX_CONFIG_NUM*GEMM_NUM + 1)) // 6 cublas/cublasLt, first row is not included { - int startIdx = config.size() - (MAX_CONFIG_NUM - 1)*GEMM_NUM; + int startIdx = config.size() - ((MAX_CONFIG_NUM - 1)*GEMM_NUM); fclose(fd); fd = fopen(GEMM_CONFIG, "w+"); - for (int i = startIdx ; i < config.size() ; i++) + fprintf(fd, "%s", config[0].c_str()); + for (uint i = startIdx ; i < config.size() ; i++) { fprintf(fd, "%s", config[i].c_str()); } + line_count = config.size() - (GEMM_NUM + 3); } } @@ -73,7 +554,7 @@ void generate_encoder_gemm_config(int batch_size, int K[gemm_num]; int batchCount[gemm_num] = {1,1,1,1,1,1}; char mess[gemm_num][256]; - + //gemm1 M[0] = batch_size * seq_len; K[0] = head_num * size_per_head; @@ -112,6 +593,8 @@ void generate_encoder_gemm_config(int batch_size, cublasHandle_t cublas_handle; check_cuda_error(cublasCreate(&cublas_handle)); + cublasLtHandle_t ltHandle; + check_cuda_error(cublasLtCreate(<Handle)); cudaDataType_t AType; cudaDataType_t BType; @@ -142,6 +625,11 @@ void generate_encoder_gemm_config(int batch_size, T beta = (T)0.0f; printf("***Encoder Gemm Testing Begin***\n"); + printf("***Cublas Gemm Testing Begin***\n"); + if (line_count == 0){ + fprintf(fd, "batch_size, seq_len, head_num, size_per_head dataType ### batchCount, n, m, k, algoId, "\ + "customOption, tile, numSplitsK, swizzle, reductionScheme, workspaceSize, stages, exec_time\n"); + } for(int i = 0; i < gemm_num; ++i) { // if(i != 0 && i != 5) continue; @@ -156,14 +644,14 @@ void generate_encoder_gemm_config(int batch_size, // array of pointer for batchedGemm T* harray[9]; harray[0] = (T*)buffer; - harray[1] = (T*)(buffer + sizeof(T) * m * k); - harray[2] = (T*)(buffer + 2 * sizeof(T) * m * k); - harray[3] = (T*)(buffer + 3 * sizeof(T) * m * k); - harray[4] = (T*)(buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); - harray[5] = (T*)(buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); - harray[6] = (T*)(buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); - harray[7] = (T*)(buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); - harray[8] = (T*)(buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + 2 * sizeof(T) * m * n); + harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); + harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); + harray[3] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); + harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); + harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); + harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); + harray[7] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); + harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + 2 * sizeof(T) * m * n); T** darray = 0; check_cuda_error(cudaMalloc((void**)&darray, sizeof(T*) * 9)); @@ -254,10 +742,32 @@ void generate_encoder_gemm_config(int batch_size, int is_fp16 = 0; if (sizeof(T) == sizeof(half)) is_fp16 = 1; - fprintf(fd, "%d %d %d %d ### %d %d %d %d %d %d %f\n", batch_size, seq_len, head_num, size_per_head, batchCount[i], m, n, k, is_fp16, fast_algo, exec_time); + + //for fp16, we compare cublasLt + if(i < 3 && is_fp16 == 1){ + printf("***cublasLt Gemm Testing Beign***\n"); + // Let try a fixed number of combinations + int ALGO_COMBINATIONS = 5000; + customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; + + LtHgemmCustomFind(ltHandle, batch_size, seq_len, head_num, size_per_head, n, m, k, &alpha, d_B, d_A, + &beta, d_C, cublas_workspace, workSpaceSize, fd, perfResults, ALGO_COMBINATIONS); + if(perfResults[0].time < exec_time){ + printPerfStructure(batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, is_fp16, 0); + }else{ + fprintf(fd, "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 %f\n", batch_size, seq_len, head_num, size_per_head, is_fp16 ? HALF_DATATYPE:FLOAT_DATATYPE, + batchCount[i], n, m, k, fast_algo, exec_time); + } + printf("***cublasLt Gemm Testing End***\n"); + } + else + { + fprintf(fd, "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 %f\n", batch_size, seq_len, head_num, size_per_head, is_fp16 ? HALF_DATATYPE:FLOAT_DATATYPE, + batchCount[i], n, m, k, fast_algo, exec_time); + } cudaFree(darray); } - + printf("***cublas Gemm Testing End***\n\n"); fclose(fd); printf("***Encoder Gemm Testing End***\n"); return; @@ -275,7 +785,6 @@ size_t calGemmTestBufSizeInByte(int batch_size, int seq_len, int head_num, int s int m = batch_size*seq_len; int n = head_num*size_per_head; int k = n; - int batchCount; size_t size1 = 3*(m*k*sizeof(int8_t) + k*n*sizeof(int8_t) + m*n*sizeof(int)); size_t size2 = batch_size*head_num*(seq_len*size_per_head*sizeof(int8_t) + size_per_head*seq_len*sizeof(int8_t) + seq_len*seq_len*sizeof(int)); @@ -296,8 +805,10 @@ size_t calGemmTestBufSizeInByte(int batch_size, int seq_len, int head_num, int s size_t size3 = (m*k + k*4*n + m*4*n)*wordSize; buf_size_in_byte = size1 > size2 ? size1 : size2; buf_size_in_byte = buf_size_in_byte > size3 ? buf_size_in_byte : size3; + buf_size_in_byte += ((is_fp16 == 1) ? CUBLAS_WORKSPACE_SIZE : 0); } return buf_size_in_byte; } } + diff --git a/fastertransformer/gemm_test/encoder_igemm_func.cc b/fastertransformer/gemm_test/encoder_igemm_func.cc index b08beb3bc..96984a9e6 100644 --- a/fastertransformer/gemm_test/encoder_igemm_func.cc +++ b/fastertransformer/gemm_test/encoder_igemm_func.cc @@ -15,7 +15,7 @@ */ #include "encoder_igemm_func.h" -#include "fastertransformer/common.h" +#include "fastertransformer/utils/common.h" #include namespace fastertransformer{ @@ -90,10 +90,10 @@ int printPerfStructure(int m, int n, int k, const customMatmulPerf_t &perf, FILE (int)perf.workspaceSize, (int)perf.mathMode, perf.wavesCount); - + //chose the fastest algo that does not need workspace if ((int)perf.workspaceSize == 0 && hasPrint == 0){ - fprintf(fout, "%d %d %d %d ### 1 %d %d %d %d %d %d %d %d %d %d %d\n", batch_size_, seq_len_, head_num_, size_per_head_, m, n, k, algoId, customOption, tile, numSplitsK, swizzle, reductionScheme, (int)perf.workspaceSize, stages); + fprintf(fout, "%d %d %d %d %d ### 1 %d %d %d %d %d %d %d %d %d %d %d %f\n", batch_size_, seq_len_, head_num_, size_per_head_, INT8_DATATYPE, m, n, k, algoId, customOption, tile, numSplitsK, swizzle, reductionScheme, (int)perf.workspaceSize, stages, perf.time); return 1; } else{ @@ -128,9 +128,9 @@ int printBatchPerfStructure(int batchCount, int m, int n, int k, const customMat (int)perf.mathMode, perf.wavesCount); - //chose the fastest algo that does not need workspace + //chose the fastest algo that does not need workspace if ((int)perf.workspaceSize == 0 && hasPrint == 0){ - fprintf(fout, "%d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d\n",batch_size_, seq_len_, head_num_, size_per_head_, batchCount, m, n, k, algoId, customOption, tile, numSplitsK, swizzle, reductionScheme, (int)perf.workspaceSize, stages); + fprintf(fout, "%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d %f\n",batch_size_, seq_len_, head_num_, size_per_head_, INT8_DATATYPE, batchCount, m, n, k, algoId, customOption, tile, numSplitsK, swizzle, reductionScheme, (int)perf.workspaceSize, stages, perf.time); return 1; } else{ @@ -209,7 +209,7 @@ customMatmulRun(cublasLtHandle_t ltHandle, // to get the capabilities (required } } else { - printf("not enough workspace! %ld\n", heurResult.workspaceSize); + //printf("not enough workspace! %ld\n", heurResult.workspaceSize); algoStatus = CUBLAS_STATUS_NOT_SUPPORTED; //Not enough workspace } } @@ -732,6 +732,7 @@ int batch_igemm_config(int batchCount, int m, int n, int k, FILE* fout, void* bu fout); //free memory cublasLtDestroy(ltHandle); + return 0; } int igemm_config(int m, int n, int k, FILE* fout, void* buffer){ @@ -760,6 +761,7 @@ int igemm_config(int m, int n, int k, FILE* fout, void* buffer){ fout); cublasLtDestroy(ltHandle); + return 0; } int generate_encoder_igemm_config(int batch_size, int seq_len, int head_num, int size_per_head, void *buffer, bool isAppend) @@ -780,6 +782,7 @@ int generate_encoder_igemm_config(int batch_size, int seq_len, int head_num, int if (!isAppend) { fout = fopen(IGEMM_CONFIG, "w+"); + fprintf(fout, "batch_size seq_len head_num size_per_head dataType ### batchCount m n k algoId customOption tile splitK_val swizzle reductionScheme workspaceSize stages exec_time\n"); } else { @@ -817,14 +820,28 @@ int generate_encoder_igemm_config(int batch_size, int seq_len, int head_num, int m = batch_size*seq_len; k = head_num*size_per_head; n = k; - batch_igemm_config(batchCount,m,n,k,fout,buffer); + if (n%32 != 0 || k%32 != 0) + { + printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); + } + else + { + batch_igemm_config(batchCount,m,n,k,fout,buffer); + } printf("\n-----------------------------\n"); m = seq_len; n = seq_len; k = size_per_head; batchCount = batch_size*head_num; - batch_igemm_config(batchCount,m,n,k,fout,buffer); + if (n%32 != 0 || k%32 != 0) + { + printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); + } + else + { + batch_igemm_config(batchCount,m,n,k,fout,buffer); + } printf("\n-----------------------------\n"); @@ -832,30 +849,58 @@ int generate_encoder_igemm_config(int batch_size, int seq_len, int head_num, int n = size_per_head; k = seq_len; batchCount = batch_size*head_num; - batch_igemm_config(batchCount,m,n,k,fout,buffer); + if (n%32 != 0 || k%32 != 0) + { + printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); + } + else + { + batch_igemm_config(batchCount,m,n,k,fout,buffer); + } printf("\n-----------------------------\n"); m = batch_size*seq_len; n = head_num*size_per_head; k = head_num*size_per_head; - - igemm_config(m,n,k,fout,buffer); + if (n%32 != 0 || k%32 != 0) + { + printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); + } + else + { + igemm_config(m,n,k,fout,buffer); + } printf("\n-----------------------------\n"); n = 4*n; - igemm_config(m,n,k,fout,buffer); + if (n%32 != 0 || k%32 != 0) + { + printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); + } + else + { + igemm_config(m,n,k,fout,buffer); + } printf("\n-----------------------------\n"); n = k; k = 4*n; - igemm_config(m,n,k,fout,buffer); + if (n%32 != 0 || k%32 != 0) + { + printf("[WARNING] For INT8 gemm test, n, k should be multiples of 32 (n = %d, k = %d)\n", n, k); + } + else + { + igemm_config(m,n,k,fout,buffer); + } fclose(fout); printf("\n-----------------------------\n"); printf("***Encoder IGemm Testing End***\n"); + return 0; } } diff --git a/fastertransformer/gemm_test/encoder_igemm_func.h b/fastertransformer/gemm_test/encoder_igemm_func.h index f284156eb..4d4ad1af4 100644 --- a/fastertransformer/gemm_test/encoder_igemm_func.h +++ b/fastertransformer/gemm_test/encoder_igemm_func.h @@ -31,6 +31,7 @@ typedef struct { int algoId, customOption, tile, splitK_val, swizzle, reductionScheme, workspaceSize; //only used in cublasLt >= 11.0 int stages; + float exec_time; } cublasLtMatmulAlgo_info; /* Structure to store information about different run trials */ typedef struct { @@ -74,7 +75,6 @@ const char * const matmulTileName[] = { "512x64" , }; -double diffTime(timeval start, timeval end); int generate_encoder_igemm_config(int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true); diff --git a/fastertransformer/gpt.h b/fastertransformer/gpt.h new file mode 100644 index 000000000..1b89790d1 --- /dev/null +++ b/fastertransformer/gpt.h @@ -0,0 +1,835 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Decoder transformer + **/ + +#pragma once + +#include "fastertransformer/utils/common.h" +#include "fastertransformer/utils/functions.h" +#include "fastertransformer/utils/allocator.h" +#include "fastertransformer/utils/arguments.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/open_decoder.h" +#include +#include +#include "fastertransformer/utils/nvtx_utils.h" + +namespace fastertransformer +{ + +template +class DecodingGpt +{ +private: + typedef DecoderTransformerTraits Traits_; + typedef typename Traits_::DataType DataType_; + const IAllocator &allocator_; + struct GptArguments args_; + TensorParallelParam t_parallel_param_; + LayerParallelParam l_parallel_param_; + + const cudaDataType_t computeType_ = Traits_::computeType; + const cudaDataType_t AType_ = Traits_::AType; + const cudaDataType_t BType_ = Traits_::BType; + const cudaDataType_t CType_ = Traits_::CType; + std::map cublasAlgoMap_; + + DataType_ *embedding_kernel_padded_; + + OpenDecoder *decoder_; + DataType_ **K_cache_; + DataType_ **V_cache_; + DataType_ *from_tensor_[2]; + DataType_ *decoder_buf_; + DataType_ *decoder_normed_result_buf_; + DataType_ *logits_buf_; + void *buf_; + + void *topk_workspace_ = nullptr; + size_t topk_workspace_size_ = 0; + void *topp_workspace_ = nullptr; + size_t topp_workspace_size_ = 0; + void *topk_topp_workspace_ = nullptr; + size_t topk_topp_workspace_size_ = 0; + void *cublas_workspace_ = nullptr; + int *topp_id_vals_buf_; + int *topp_offset_buf_; + curandState_t *curandstate_buf_; + int *begin_topp_offset_buf_; + + size_t nccl_buf_size_; + DataType_ *nccl_logits_buf_; + + bool *finished_buf_; + bool *h_finished_buf_; + +public: + DecodingGpt(const IAllocator &allocator, const int batch_size, + const int seq_len, + const int head_num, const int size_per_head, + const int vocab_size, const int decoder_layers, + const int start_id, const int end_id, + const int candidate_num = 1, + const float probability_threshold = 0.0, + const float temperature = 1.0, + const int tensor_para_size = 1, + const int layer_para_size = 1, + const bool is_fuse_QKV = true) : allocator_(allocator) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + assert(temperature != 0.0); + assert(candidate_num > 0 || probability_threshold > 0.0); + assert(decoder_layers % layer_para_size == 0); + + args_.batch_size_ = batch_size; + args_.seq_len_ = seq_len; + args_.head_num_ = head_num; + args_.size_per_head_ = size_per_head; + args_.hidden_units_ = head_num * size_per_head; + args_.decoder_layers_ = decoder_layers; + args_.vocab_size_ = vocab_size; + args_.start_id_ = start_id; + args_.end_id_ = end_id; + args_.candidate_num_ = candidate_num; + args_.probability_threshold_ = probability_threshold; + args_.temperature_ = temperature; + + K_cache_ = new DataType_ *[1]; + V_cache_ = new DataType_ *[1]; + + decoder_ = new OpenDecoder(args_.head_num_, size_per_head, 0 /* memory_hidden_units */, is_fuse_QKV); + decoder_->set_max_batch_size(args_.batch_size_); + + args_.vocab_size_padded_ = div_up(args_.vocab_size_, 64) * 64; + + size_t from_tensor_size = args_.batch_size_ * args_.hidden_units_; // type T + size_t decoder_workspace_size = (size_t)decoder_->getWorkspaceSize(); // type T + size_t decoder_normed_result_buffer_size = args_.batch_size_ * args_.hidden_units_; // type T + // cache costs lots of memory, so we only store part of them when we use multi-gpu for inference + size_t cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_ / tensor_para_size; // type T + size_t logits_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type T + + size_t topp_id_vals_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type int + size_t topp_offset_buf_size = args_.batch_size_ + 1; + size_t begin_topp_offset_buf_size = topp_offset_buf_size; + size_t curandState_size = args_.batch_size_; + size_t finished_buf_size = args_.batch_size_; + + const int MEM_C = 128; + size_t embedding_kernel_transposed_padded_size = args_.hidden_units_ * args_.vocab_size_padded_; + embedding_kernel_transposed_padded_size = div_up(embedding_kernel_transposed_padded_size, MEM_C) * MEM_C; + + // prevent memory misalinged address + logits_buf_size = (size_t)(ceil(logits_buf_size / 4.)) * 4; + + topp_id_vals_buf_size = (size_t)(ceil(topp_id_vals_buf_size / 4.)) * 4; + topp_offset_buf_size = (size_t)(ceil(topp_offset_buf_size / 4.)) * 4; + begin_topp_offset_buf_size = topp_offset_buf_size; + curandState_size = (size_t)(ceil(curandState_size / 32.)) * 32; + finished_buf_size = (size_t)(ceil(finished_buf_size / 32.)) * 32; + + topP_sampling_kernel_kernelLauncher_v2(topp_workspace_, + topp_workspace_size_, + logits_buf_, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + nullptr, + curandstate_buf_, + args_, + nullptr, + nullptr, + args_.vocab_size_padded_, + 0, + args_.batch_size_); + + topK_sampling_kernel_kernelLauncher_v2(topk_workspace_, + topk_workspace_size_, + logits_buf_, + nullptr, + nullptr, + nullptr, + curandstate_buf_, + args_, + 0, + args_.batch_size_); + + topK_topP_sampling_kernel_kernelLauncher_v2(topk_topp_workspace_, + topk_topp_workspace_size_, + nullptr, + logits_buf_, + nullptr, + curandstate_buf_, + args_, + 0, + args_.batch_size_); + + size_t datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + + cache_size * 2 * (args_.decoder_layers_ / layer_para_size) + decoder_normed_result_buffer_size; + + nccl_buf_size_ = args_.batch_size_ * args_.vocab_size_padded_; + nccl_buf_size_ = (size_t)(ceil(nccl_buf_size_ / 4.)) * 4; + + buf_ = reinterpret_cast(allocator_.malloc( + ((sizeof(DataType_) == sizeof(half)) ? CUBLAS_WORKSPACE_SIZE : 0) + + sizeof(DataType_) * embedding_kernel_transposed_padded_size + + sizeof(DataType_) * (datatype_buf_size + logits_buf_size) + + sizeof(int) * (topp_id_vals_buf_size + topp_offset_buf_size + begin_topp_offset_buf_size) + + topp_workspace_size_ + topk_workspace_size_ + topk_topp_workspace_size_ + sizeof(DataType_) * nccl_buf_size_ + + finished_buf_size + curandState_size * sizeof(curandState_t))); + + if (sizeof(DataType_) == sizeof(half)) + { + cublas_workspace_ = buf_; + embedding_kernel_padded_ = (DataType_ *)((char*)cublas_workspace_ + CUBLAS_WORKSPACE_SIZE); + } + else + { + cublas_workspace_ = nullptr; + embedding_kernel_padded_ = (DataType_ *)buf_; + } + from_tensor_[0] = (DataType_ *)(embedding_kernel_padded_ + embedding_kernel_transposed_padded_size); + from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size); + + K_cache_[0] = from_tensor_[1] + from_tensor_size + 0 * cache_size * args_.decoder_layers_ / layer_para_size; + V_cache_[0] = from_tensor_[1] + from_tensor_size + 1 * cache_size * args_.decoder_layers_ / layer_para_size; + + decoder_buf_ = V_cache_[0] + cache_size * args_.decoder_layers_ / layer_para_size; + decoder_normed_result_buf_ = (decoder_buf_ + decoder_workspace_size); + logits_buf_ = decoder_normed_result_buf_ + decoder_normed_result_buffer_size; + topp_id_vals_buf_ = (int *)((DataType_*)logits_buf_ + logits_buf_size); + begin_topp_offset_buf_ = (int *)(topp_id_vals_buf_ + topp_id_vals_buf_size); + topp_offset_buf_ = (int *)((int*)begin_topp_offset_buf_ + begin_topp_offset_buf_size); + topp_workspace_ = (void *)((int*)topp_offset_buf_ + topp_offset_buf_size); + topk_workspace_ = (void *)((char*)topp_workspace_ + topp_workspace_size_); + topk_topp_workspace_ = (void *)((char*)topk_workspace_ + topk_workspace_size_); + nccl_logits_buf_ = (DataType_ *)((char*)topk_topp_workspace_ + topk_topp_workspace_size_); + curandstate_buf_ = (curandState_t*)(nccl_logits_buf_ + nccl_buf_size_); + finished_buf_ = (bool*)(curandstate_buf_ + curandState_size); + h_finished_buf_ = new bool[args_.batch_size_]; + + cudaMemset(embedding_kernel_padded_, 0, embedding_kernel_transposed_padded_size * sizeof(DataType_)); + + int isConfigExist = access("decoding_gemm_config.in", 0); + if (isConfigExist == -1) + printf("[WARNING] decoding_gemm_config.in is not found\n"); + else + { + readAlgoFromConfig(cublasAlgoMap_, 1); + // check that the gemm_config setting is runnable + for (auto iter = cublasAlgoMap_.begin() ; iter != cublasAlgoMap_.end() ; iter++) + { + int algoId = iter->second.algoId; + int stages = iter->second.stages; + //only check for cublas + if (stages != -1) + continue; + if (Traits_::OpType == OperationType::FP32) + { + if (algoId > CUBLAS_GEMM_ALGO23 || algoId < CUBLAS_GEMM_DEFAULT) + { + // the algorithm is not for FP32 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", algoId); + exit(-1); + } + } + else + { + if (algoId > CUBLAS_GEMM_ALGO15_TENSOR_OP || algoId < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + { + // the algorithm is not for FP16 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", algoId); + exit(-1); + } + } + } + } + } + + void set_tensor_parallel_param(const TensorParallelParam param) + { + t_parallel_param_ = param; + decoder_->set_tensor_parallel_param(param); + } + + void set_layer_parallel_param(const LayerParallelParam param) + { + l_parallel_param_ = param; + decoder_->set_layer_parallel_param(param); + } + + void forward_context(const DecoderInitParam *decoder_param, + const DecodingInitParam decoding_params) + { + cudaMemsetAsync(decoding_params.output_ids, 0, sizeof(int) * args_.batch_size_ * args_.seq_len_, decoding_params.stream); +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + const int input_len = decoding_params.request_input_len; + const int max_input_len = decoding_params.max_input_len; + const int request_batch_size = decoding_params.request_batch_size; + + // d_start_ids: [batch * seqlen] + if(input_len == 1) + { + cudaMemcpyAsync(decoding_params.output_ids, decoding_params.d_start_ids, + sizeof(int) * request_batch_size, cudaMemcpyDeviceToDevice, decoding_params.stream); + return; + } + const int local_batch_size = ceil(request_batch_size * 1.0 / l_parallel_param_.world_size); + const int m = local_batch_size * input_len; + const int h_1 = args_.hidden_units_; + + DataType_* from_tensor[2]; + DataType_* decoder_output; + DataType_* decoder_workspace; + void *buf = reinterpret_cast(allocator_.malloc( + decoder_->getContextWorkspaceSize(input_len, local_batch_size) + + (m * h_1 + 2 * request_batch_size * input_len * h_1) * sizeof(DataType_) + )); + + from_tensor[0] = (DataType_*) buf; + from_tensor[1] = from_tensor[0] + request_batch_size * input_len * h_1; + decoder_output = from_tensor[1] + request_batch_size * input_len * h_1; + decoder_workspace = decoder_output + m * h_1; + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + if(l_parallel_param_.rank == 0) + { + PUSH_RANGE("Before Transformer/Embedding") + start_id_embedding_position_lookups_kernel_launcher(from_tensor[0], + decoding_params.output_ids, + decoding_params.embedding_table, + decoding_params.position_encoding_table, + decoding_params.d_start_ids, + 1, + input_len, + max_input_len, + request_batch_size, + args_.hidden_units_, + decoding_params.stream); + POP_RANGE +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + + int ite_num = (int)(ceil(request_batch_size * 1.0 / local_batch_size)); + for(int ite = 0; ite < ite_num; ite++) + { + int in_id, out_id; + for (int layer = 0; layer < args_.decoder_layers_; ++layer) + { + if(l_parallel_param_.is_valid(layer)) + { + in_id = layer & 0x1; + out_id = 1 - in_id; + + if(layer == l_parallel_param_.layers_per_group * l_parallel_param_.rank && layer != 0 && l_parallel_param_.world_size > 1) + { + const int size = m * t_parallel_param_.local_hidden_units_; + nccl_recv(from_tensor[in_id] + ite * m * h_1 + size * t_parallel_param_.rank, size, l_parallel_param_.rank - 1, + l_parallel_param_.nccl_comm, decoding_params.stream); + all2all_gather(from_tensor[in_id] + ite * m * h_1, from_tensor[in_id] + ite * m * h_1, size, + t_parallel_param_, decoding_params.stream); + } + + decoder_->initialize(decoder_param[layer], decoder_buf_, cublas_workspace_, false); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + int dummy_decoder_max_seq_len = args_.seq_len_; + // int dummy_decoder_max_seq_len = -1; + size_t cache_offset; + if(dummy_decoder_max_seq_len == -1) + { + cache_offset = (layer - l_parallel_param_.layers_per_group * l_parallel_param_.rank) * + args_.batch_size_ * args_.seq_len_ * t_parallel_param_.local_hidden_units_; + } + else + { + cache_offset = (layer - l_parallel_param_.layers_per_group * l_parallel_param_.rank) * + args_.batch_size_ * args_.seq_len_ * t_parallel_param_.local_hidden_units_ + + ite * local_batch_size * args_.seq_len_ * t_parallel_param_.local_hidden_units_; + } + decoder_->forward_context(decoder_workspace, + from_tensor[out_id] + ite * m * h_1, + K_cache_[0] + cache_offset, + V_cache_[0] + cache_offset, + from_tensor[in_id] + ite * m * h_1, + decoding_params.d_attn_mask + ite * local_batch_size * input_len * input_len, + local_batch_size, + input_len, + ite, + dummy_decoder_max_seq_len, + layer == args_.decoder_layers_ - 1); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + if(layer == l_parallel_param_.layers_per_group * (l_parallel_param_.rank + 1) - 1 && layer != args_.decoder_layers_ - 1 && l_parallel_param_.world_size > 1) + { + const int size = m * t_parallel_param_.local_hidden_units_; + nccl_send(from_tensor[out_id] + ite * m * h_1 + size * t_parallel_param_.rank, size, l_parallel_param_.rank + 1, + l_parallel_param_.nccl_comm, decoding_params.stream); + } + } + } // end of for loop of layer + } // end of for loop of ite + allocator_.free(buf); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + + void forward(const DecoderInitParam *decoder_param, + DecodingInitParam decoding_params) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + const int input_len = decoding_params.request_input_len; + const int max_input_len = decoding_params.max_input_len; + const int request_batch_size = decoding_params.request_batch_size; + const int max_len = (decoding_params.request_output_len > 0 && input_len + decoding_params.request_output_len <= args_.seq_len_) ? + input_len + decoding_params.request_output_len : + args_.seq_len_; + + assert(request_batch_size <= args_.batch_size_); + assert(request_batch_size % l_parallel_param_.local_batch_size == 0); + const int m = request_batch_size; + const int k = args_.hidden_units_; + const DataType_* embedding_kernel_ptr = nullptr; + + if (args_.probability_threshold_ != 0.0) + { + topp_initialization_kernelLauncher_v2(finished_buf_, + nullptr, + nullptr, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + args_.candidate_num_ > 0 ? args_.candidate_num_ : args_.vocab_size_padded_, + args_, + decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + ker_curand_setupLauncher(curandstate_buf_, + args_, + decoding_params.stream); + + if(std::is_same::value || (std::is_same::value && args_.vocab_size_padded_ == args_.vocab_size_)) + { + embedding_kernel_ptr = (const DataType_ *)decoding_params.embedding_kernel; + } + else + { + cudaMemcpyAsync(embedding_kernel_padded_, decoding_params.embedding_kernel, + sizeof(DataType_) * args_.vocab_size_ * args_.hidden_units_, cudaMemcpyDeviceToDevice, decoding_params.stream); + embedding_kernel_ptr = (const DataType_ *)embedding_kernel_padded_; + } +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + const int local_batch = l_parallel_param_.local_batch_size; + for (size_t step = input_len; step < max_len; ++step) + { + const int ite_num = request_batch_size / local_batch; + for(size_t ite = 0; ite < ite_num; ite++) + { + if(l_parallel_param_.rank == 0 && l_parallel_param_.world_size > 1) + { + if(step != (size_t)input_len) + { + PUSH_RANGE("token/recv") + nccl_recv(decoding_params.output_ids + (step - 1) * m + ite * local_batch, local_batch, + l_parallel_param_.world_size - 1, l_parallel_param_.nccl_comm, decoding_params.stream); + POP_RANGE + } + } + + if(l_parallel_param_.rank < l_parallel_param_.world_size - 1 && l_parallel_param_.world_size > 1) + { + if(step != (size_t)input_len) + { + nccl_broadcast(finished_buf_ + ite * local_batch, local_batch, l_parallel_param_.world_size - 1, l_parallel_param_, decoding_params.stream); + } + } + if(ite == 0) + { + cudaMemcpyAsync(h_finished_buf_, finished_buf_, sizeof(bool) * request_batch_size, cudaMemcpyDeviceToHost, decoding_params.stream); + uint sum = 0; + for (uint i = 0; i < request_batch_size; i++) + { + sum += (int)h_finished_buf_[i]; + } + if (sum == request_batch_size) + break; + } + + int *word_ids_buf_ = decoding_params.output_ids + (step - 1) * m + local_batch * ite; + if(l_parallel_param_.rank == 0) + { + PUSH_RANGE("Before Transformer/Embedding") + embedding_position_lookups_kernel_launcher(from_tensor_[0], + decoding_params.embedding_table, + decoding_params.position_encoding_table, + word_ids_buf_, + local_batch, + args_.hidden_units_, + step, + decoding_params.stream); + POP_RANGE +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + + //we use two-way buffer + int from_id, out_id; + for (int layer = 0; layer < args_.decoder_layers_; ++layer) + { + if(l_parallel_param_.is_valid(layer)) + { + /* + For the first layer (layer-0), from_id is 0. We also stored the embedding lookup + result in from_tensor_[0] + */ + from_id = layer & 0x1; + out_id = 1 - from_id; + + if(layer == l_parallel_param_.layers_per_group * l_parallel_param_.rank && layer != 0 && l_parallel_param_.world_size > 1) + { + const int size = local_batch * t_parallel_param_.local_hidden_units_; + nccl_recv(from_tensor_[from_id] + size * t_parallel_param_.rank, size, l_parallel_param_.rank - 1, + l_parallel_param_.nccl_comm, decoding_params.stream); + all2all_gather(from_tensor_[from_id], from_tensor_[from_id], size, + t_parallel_param_, decoding_params.stream); + } + + /* + We use one decoder_ object to process multiple decoder layers. + + At the beginning of each decoder layer, we initialize the decoder object + with corresponding weights and decoder_buf_. + + The decoder_buf_ is reused. + */ + decoder_->initialize(decoder_param[layer], decoder_buf_, cublas_workspace_, false); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + int dummy_decoder_max_seq_len = args_.seq_len_; + // int dummy_decoder_max_seq_len = -1; + size_t cache_offset; + if(dummy_decoder_max_seq_len == -1) + { + cache_offset = (layer - l_parallel_param_.layers_per_group * l_parallel_param_.rank) * + args_.batch_size_ * args_.seq_len_ * t_parallel_param_.local_hidden_units_ + + ite * local_batch * t_parallel_param_.local_hidden_units_; + } + else + { + cache_offset = (layer - l_parallel_param_.layers_per_group * l_parallel_param_.rank) * + args_.batch_size_ * args_.seq_len_ * t_parallel_param_.local_hidden_units_ + + ite * local_batch * args_.seq_len_ * t_parallel_param_.local_hidden_units_; + } + decoder_->forward(from_tensor_[from_id], + nullptr, // memory_tensor should be nullptr + K_cache_[0] + cache_offset, + V_cache_[0] + cache_offset, + nullptr, nullptr, // key_mem_cache_ and value_mem_cache_ should be nullptr + nullptr, // memory_sequence_length should be nullptr + from_tensor_[out_id], step, dummy_decoder_max_seq_len, + false, + finished_buf_ + ite * local_batch); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + if(layer == l_parallel_param_.layers_per_group * (l_parallel_param_.rank + 1) - 1 && layer != args_.decoder_layers_ - 1 && l_parallel_param_.world_size > 1) + { + const size_t size = local_batch * t_parallel_param_.local_hidden_units_; + nccl_send(from_tensor_[out_id] + size * t_parallel_param_.rank, size, l_parallel_param_.rank + 1, + l_parallel_param_.nccl_comm, decoding_params.stream); + } + } + } + + if(l_parallel_param_.rank == l_parallel_param_.world_size - 1) + { + + layer_norm(from_tensor_[out_id], + decoding_params.layernorm.gamma, + decoding_params.layernorm.beta, + decoder_normed_result_buf_, + local_batch, + k, + decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + DataType_ alpha = DataType_(1.0f); + DataType_ beta = DataType_(0.0f); + assert(args_.vocab_size_padded_ % t_parallel_param_.world_size == 0); + int n = args_.vocab_size_padded_ / t_parallel_param_.world_size; + + if(t_parallel_param_.world_size == 1) + { + PUSH_RANGE("After Transformer/GEMM") + cublasMM_cublasLtMM_wrapper_decoder(decoding_params.cublaslt_handle, + decoding_params.cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + n, local_batch, k, + &alpha, + embedding_kernel_ptr, AType_, k, + decoder_normed_result_buf_, BType_, k, + &beta, + logits_buf_, CType_, n, + decoding_params.stream, cublasAlgoMap_, + cublas_workspace_); + POP_RANGE + } + else + { + PUSH_RANGE("After Transformer/GEMM") + cublasMM_cublasLtMM_wrapper_decoder(decoding_params.cublaslt_handle, + decoding_params.cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + n, local_batch, k, + &alpha, + embedding_kernel_ptr + t_parallel_param_.rank * n * k, + AType_, k, + decoder_normed_result_buf_, BType_, k, + &beta, + nccl_logits_buf_ + t_parallel_param_.rank * local_batch * n, + CType_, n, + decoding_params.stream, cublasAlgoMap_, + cublas_workspace_); + POP_RANGE + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + if(t_parallel_param_.world_size == 1) + { + apply_temperature_penalty_kernelLauncher(logits_buf_, + (DataType_) args_.temperature_, + local_batch, + args_.vocab_size_, + n, + decoding_params.stream); + } + else + { + if(t_parallel_param_.rank == t_parallel_param_.world_size - 1) + { + apply_temperature_penalty_kernelLauncher(nccl_logits_buf_ + t_parallel_param_.rank * local_batch * n, + (DataType_) args_.temperature_, + local_batch, + args_.vocab_size_ - n * t_parallel_param_.rank, + n, + decoding_params.stream); + } + else + { + apply_temperature_penalty_kernelLauncher(nccl_logits_buf_ + t_parallel_param_.rank * local_batch * n, + (DataType_) args_.temperature_, + local_batch, + n, + n, + decoding_params.stream); + } + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + // reduce and concat the reuslt + if(t_parallel_param_.world_size > 1) + { + PUSH_RANGE("After Transformer/all2all_gather") + all2all_gather(nccl_logits_buf_, nccl_logits_buf_, local_batch * n, + t_parallel_param_, decoding_params.stream); + POP_RANGE + + transpose_axis_01_kernelLauncher(logits_buf_, nccl_logits_buf_, + t_parallel_param_.world_size, local_batch, n, decoding_params.stream); + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + n = args_.vocab_size_padded_; + // Sampling + if(args_.candidate_num_ > 0 && args_.probability_threshold_ == 0.0) + { + PUSH_RANGE("After Transformer/Sampling") + // top k sampling + topK_sampling_kernel_kernelLauncher_v2(topk_workspace_, + topk_workspace_size_, + logits_buf_, + decoding_params.output_ids + step * m + ite * local_batch, + nullptr, + finished_buf_ + ite * local_batch, + curandstate_buf_, // used as random number + args_, + decoding_params.stream, + local_batch); + POP_RANGE + } + else if(args_.candidate_num_ == 0 && args_.probability_threshold_ > 0.0f) + { + PUSH_RANGE("After Transformer/Sampling") + // top p sampling + softmax_kernelLauncher(logits_buf_, + (DataType_*) nullptr, + args_.end_id_, + finished_buf_ + ite * local_batch, + local_batch, + args_.vocab_size_padded_, + args_.vocab_size_, + decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + topP_sampling_kernel_kernelLauncher_v2(topp_workspace_, + topp_workspace_size_, + logits_buf_, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + finished_buf_ + ite * local_batch, + curandstate_buf_, + args_, + decoding_params.output_ids + step * m + ite * local_batch, + nullptr, + n, + decoding_params.stream, + local_batch); + + POP_RANGE + } + else if(args_.candidate_num_ > 0 && args_.probability_threshold_ > 0.0f) + { + PUSH_RANGE("After Transformer/Sampling") + topK_topP_sampling_kernel_kernelLauncher_v2(topk_topp_workspace_, + topk_topp_workspace_size_, + decoding_params.output_ids + step * m + ite * local_batch, + logits_buf_, + finished_buf_ + ite * local_batch, + curandstate_buf_, + args_, + decoding_params.stream, + local_batch); + POP_RANGE + } +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + if(step < (size_t)max_input_len) + { + // Replace the sampled id by start ids + set_start_ids_kernelLauncher(decoding_params.output_ids, decoding_params.d_start_ids, max_input_len, + step, ite, request_batch_size, local_batch, args_.end_id_, decoding_params.stream); + } + + if(l_parallel_param_.rank == l_parallel_param_.world_size - 1 && l_parallel_param_.world_size > 1) + { + PUSH_RANGE("token/send") + nccl_send(decoding_params.output_ids + step * m + ite * local_batch, local_batch, 0, l_parallel_param_.nccl_comm, decoding_params.stream); + POP_RANGE + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + if(l_parallel_param_.rank == l_parallel_param_.world_size - 1 && l_parallel_param_.world_size > 1 && step < max_len - 1) + { + nccl_broadcast(finished_buf_ + ite * local_batch, local_batch, l_parallel_param_.world_size - 1, l_parallel_param_, decoding_params.stream); + } +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } // end for ite for loop + } // end for decoding step for loop + if(l_parallel_param_.rank == 0 && l_parallel_param_.world_size > 1) + { + for(size_t ite = 0; ite < request_batch_size / local_batch; ite++) + { + nccl_recv(decoding_params.output_ids + (max_len - 1) * m + ite * local_batch, + local_batch, l_parallel_param_.world_size - 1, + l_parallel_param_.nccl_comm, decoding_params.stream); + } + } + } // end of forward + + virtual ~DecodingGpt() + { + delete[] K_cache_; + delete[] V_cache_; + delete decoder_; + allocator_.free(buf_); + delete [] h_finished_buf_; + } + + inline int get_num_layer() {return args_.decoder_layers_;} + + inline void set_local_batch_size(int local_batch) + { + l_parallel_param_.local_batch_size = local_batch; + decoder_->set_local_batch_size(local_batch); + } +}; + +} //namespace fastertransformer \ No newline at end of file diff --git a/fastertransformer/gpt2.h b/fastertransformer/gpt2.h deleted file mode 100644 index a7a6adeee..000000000 --- a/fastertransformer/gpt2.h +++ /dev/null @@ -1,512 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * Decoder transformer - **/ - -#pragma once - -#include "fastertransformer/common.h" -#include "fastertransformer/allocator.h" -#include "fastertransformer/open_decoder.h" -#include "fastertransformer/cuda/cuda_kernels.h" -#include "fastertransformer/arguments.h" -#include -#include - -#define EMBEDDING_TRANSPOSE_OPT 0 // TODO This feature has bug. - -namespace fastertransformer -{ - -template -class DecodingGpt2 -{ -private: - typedef DecoderTransformerTraits Traits_; - typedef typename Traits_::DataType DataType_; - const IAllocator &allocator_; - struct Gpt2Arguments args_; - - const cudaDataType_t computeType_ = Traits_::computeType; - const cudaDataType_t AType_ = Traits_::AType; - const cudaDataType_t BType_ = Traits_::BType; - const cudaDataType_t CType_ = Traits_::CType; - int cublasAlgo_[1] = {20}; - - DataType_ *embedding_kernel_transposed_padded_; - - OpenDecoder *decoder_; - DataType_ **K_cache_; - DataType_ **V_cache_; - DataType_ *from_tensor_[2]; - DataType_ *decoder_buf_; - DataType_ *decoder_normed_result_buf_; - DataType_ *logits_buf_; - void *buf_; - - void *topk_workspace_ = nullptr; - size_t topk_workspace_size_ = 0; - void *topp_workspace_ = nullptr; - size_t topp_workspace_size_ = 0; - void *topk_topp_workspace_ = nullptr; - size_t topk_topp_workspace_size_ = 0; - int *topp_id_vals_buf_; - int *topp_offset_buf_; - -public: - DecodingGpt2(const IAllocator &allocator, const int batch_size, - const int seq_len, - const int head_num, const int size_per_head, - const int vocab_size, const int decoder_layers, - const int start_id, const int end_id, - const int *start_ids = nullptr, const int start_len = -1, - const int candidate_num = 1, - const float probability_threshold = 0.0, - const float temperature = 1.0) : allocator_(allocator) - { -#ifndef NDEBUG - PRINT_FUNC_NAME_(); -#endif - assert(temperature != 0.0); - assert(candidate_num > 0 || probability_threshold > 0.0); - - args_.batch_size_ = batch_size; - args_.seq_len_ = seq_len; - args_.head_num_ = head_num; - args_.size_per_head_ = size_per_head; - args_.hidden_units_ = head_num * size_per_head; - args_.decoder_layers_ = decoder_layers; - args_.vocab_size_ = vocab_size; - args_.start_id_ = start_id; - args_.end_id_ = end_id; - args_.candidate_num_ = candidate_num; - args_.probability_threshold_ = probability_threshold; - args_.temperature_ = temperature; - - // Convert the start_ids to 2D and transpose the - // start_ids from [batch_size, start_len] to [start_len, batch_size] - if (start_ids != nullptr && start_len > 0) - { - args_.start_len_ = start_len; - args_.start_ids_ = new int*[start_len]; - for(int i = 0; i < start_len; i++) - { - args_.start_ids_[i] = new int[batch_size]; - for(int j = 0; j < batch_size; j++) - { - args_.start_ids_[i][j] = start_ids[j * start_len + i]; - } - } - } - else - { - // fill the start_ids by start_id - args_.start_len_ = 1; - args_.start_ids_ = new int*[start_len]; - args_.start_ids_[0] = new int[batch_size]; - for(int j = 0; j < batch_size; j++) - { - args_.start_ids_[0][j] = args_.start_id_; - } - } - - K_cache_ = new DataType_ *[1]; - V_cache_ = new DataType_ *[1]; - - decoder_ = new OpenDecoder(batch_size * 1, 0 /* memory_max_seq_len */, - head_num, size_per_head, 0 /* memory_hidden_units */ ); - -#if EMBEDDING_TRANSPOSE_OPT == 1 - args_.vocab_size_padded_ = div_up(args_.vocab_size_, 8) * 8; -#else - args_.vocab_size_padded_ = args_.vocab_size_; -#endif - - int from_tensor_size = args_.batch_size_ * args_.hidden_units_; // type T - int decoder_workspace_size = decoder_->getWorkspaceSize(); // type T - int decoder_normed_result_buffer_size = args_.batch_size_ * args_.hidden_units_; // type T - int cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_; // type T - int logits_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type T - - int topp_id_vals_buf_size = args_.batch_size_ * args_.vocab_size_; // type int - int topp_offset_buf_size = args_.batch_size_ + 1; - - const int MEM_C = 128; - /*from_tensor_size = div_up(from_tensor_size, MEM_C) * MEM_C; - decoder_workspace_size = div_up(decoder_workspace_size, MEM_C) * MEM_C; - decoder_normed_result_buffer_size = div_up(decoder_normed_result_buffer_size, MEM_C) * MEM_C; - cache_size = div_up(cache_size, MEM_C) * MEM_C; - - logits_buf_size = div_up(logits_buf_size, MEM_C) * MEM_C; - cum_log_buf_size = div_up(cum_log_buf_size, MEM_C) * MEM_C; - finished_buf_size = div_up(finished_buf_size, MEM_C) * MEM_C; - - topk_ids_buf_size = div_up(topk_ids_buf_size, MEM_C) * MEM_C; - topk_val_buf_size = div_up(topk_val_buf_size, MEM_C) * MEM_C; - args_.temp_storage_size_ = div_up(args_.temp_storage_size_, MEM_C) * MEM_C; */ - - int embedding_kernel_transposed_padded_size = args_.hidden_units_ * args_.vocab_size_padded_; - embedding_kernel_transposed_padded_size = div_up(embedding_kernel_transposed_padded_size, MEM_C) * MEM_C; - - // prevent memory misalinged address - logits_buf_size = (int)(ceil(logits_buf_size / 4.)) * 4; - - topp_id_vals_buf_size = (int)(ceil(topp_id_vals_buf_size / 4.)) * 4; - topp_offset_buf_size = (int)(ceil(topp_offset_buf_size / 4.)) * 4; - - topP_sampling_kernel_kernelLauncher(topp_workspace_, - topp_workspace_size_, - logits_buf_, - topp_id_vals_buf_, - topp_offset_buf_, - nullptr, - 0, - args_, - nullptr, - nullptr, - args_.vocab_size_, - 0); - topK_sampling_kernel_kernelLauncher(topk_workspace_, - topk_workspace_size_, - logits_buf_, - nullptr, - nullptr, - nullptr, - 0, - args_, - 0); - topK_topP_sampling_kernel_kernelLauncher(topk_topp_workspace_, - topk_topp_workspace_size_, - nullptr, - logits_buf_, - 0, - args_, - 0); - - int datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + - cache_size * 2 * args_.decoder_layers_ + decoder_normed_result_buffer_size; - - buf_ = reinterpret_cast(allocator_.malloc( -#if EMBEDDING_TRANSPOSE_OPT == 1 - sizeof(DataType_) * embedding_kernel_transposed_padded_size + -#endif - sizeof(DataType_) * (datatype_buf_size + logits_buf_size) + - sizeof(int) * (topp_id_vals_buf_size + topp_offset_buf_size) + - topp_workspace_size_ + topk_workspace_size_ + topk_topp_workspace_size_)); - -#if EMBEDDING_TRANSPOSE_OPT == 1 - embedding_kernel_transposed_padded_ = (DataType_ *)buf_; - from_tensor_[0] = (DataType_ *)(embedding_kernel_transposed_padded_ + embedding_kernel_transposed_padded_size); -#else - from_tensor_[0] = (DataType_ *)buf_; -#endif - from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size); - - /* We use two-way buffer since we have to update KV buf at the end of each step. */ - K_cache_[0] = from_tensor_[1] + from_tensor_size + 0 * cache_size * args_.decoder_layers_; - V_cache_[0] = from_tensor_[1] + from_tensor_size + 1 * cache_size * args_.decoder_layers_; - - decoder_buf_ = V_cache_[0] + cache_size * args_.decoder_layers_; - decoder_normed_result_buf_ = (decoder_buf_ + decoder_workspace_size); - logits_buf_ = decoder_normed_result_buf_ + decoder_normed_result_buffer_size; - topp_id_vals_buf_ = (int *)(logits_buf_ + logits_buf_size); - topp_offset_buf_ = (int *)(topp_id_vals_buf_ + topp_id_vals_buf_size); - topp_workspace_ = (void *)(topp_offset_buf_ + topp_offset_buf_size); - topk_workspace_ = (void *)(topp_workspace_ + topp_workspace_size_); - topk_topp_workspace_ = (void *)(topk_workspace_ + topk_workspace_size_); - -#if EMBEDDING_TRANSPOSE_OPT == 1 - cudaMemset(embedding_kernel_transposed_padded_, 0, embedding_kernel_transposed_padded_size * sizeof(DataType_)); -#endif - - cudaDeviceSynchronize(); - - FILE *fd = fopen("decoding_gemm_config.in", "r"); - int err = 0; - if (fd == NULL) - printf("[WARNING] decoding_gemm_config.in is not found\n"); - else - { - err = fscanf(fd, "%d", &cublasAlgo_[0]); - fclose(fd); - } - if (err != 1) - { - printf("[WARNING] decoding loading GEMM algorithms error, using default GEMM algorithms!\n"); - if (Traits_::OpType == OperationType::FP32) - { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT; - } - else - { - cublasAlgo_[0] = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - } - } - else - { - // check that the gemm_config setting is runnable - if (Traits_::OpType == OperationType::FP32) - { - if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO23 || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT) - { - // the algorithm is not for FP32 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", (int)cublasAlgo_[0]); - exit(-1); - } - } - else - { - if (cublasAlgo_[0] > CUBLAS_GEMM_ALGO15_TENSOR_OP || cublasAlgo_[0] < CUBLAS_GEMM_DEFAULT_TENSOR_OP) - { - // the algorithm is not for FP16 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", (int)cublasAlgo_[0]); - exit(-1); - } - } - } - } - - void forward(const DecoderInitParam *param, - DecodingInitParam decoding_params) - { - -#ifndef NDEBUG - PRINT_FUNC_NAME_(); -#endif - const int m = args_.batch_size_; - const int k = args_.hidden_units_; - const int n = args_.vocab_size_; - - /* - sequence_length initialize to 0 - finished: false - word_ids: start_id_ - cum_log_probs (for eacm beam, the first element is 0). e.g., [0 -inf -inf -inf][0 -inf -inf -inf] - */ - - /* Initialize the first output_ids */ - - check_cuda_error(cudaMemcpyAsync(decoding_params.output_ids, args_.start_ids_[0], m*sizeof(int), cudaMemcpyHostToDevice, decoding_params.stream)); - if (args_.probability_threshold_ != 0.0) - { - topp_initialization_kernelLauncher(nullptr, - nullptr, - nullptr, - topp_id_vals_buf_, - topp_offset_buf_, - args_.candidate_num_ > 0 ? args_.candidate_num_ : args_.vocab_size_, - args_, - decoding_params.stream); - } - -#if EMBEDDING_TRANSPOSE_OPT == 1 - transpose(embedding_kernel_transposed_padded_, decoding_params.embedding_kernel, 1, - args_.vocab_size_, args_.hidden_units_, 0, decoding_params.stream); -#endif -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - - int cache_size = m * args_.seq_len_ * args_.hidden_units_; // type T - - bool do_beamsearch = false; - for (int step = 1; step < args_.seq_len_; ++step) - { - int *word_ids_buf_ = decoding_params.output_ids + (step - 1) * m; - do_beamsearch = step >= args_.start_len_; - //we use two-way buffer - embedding_position_lookups_kernel_launcher(from_tensor_[0], - decoding_params.embedding_table, - decoding_params.position_encoding_table, - word_ids_buf_, - m, - args_.hidden_units_, - step, - decoding_params.stream); -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - int from_id, out_id; - for (int layer = 0; layer < args_.decoder_layers_; ++layer) - { - /* - For the first layer (layer-0), from_id is 0. We also stored the embedding lookup - result in from_tensor_[0] - */ - from_id = layer & 0x1; - out_id = 1 - from_id; - - /* - We use one decoder_ object to process multiple decoder layers. - - At the beginning of each decoder layer, we initialize the decoder object - with corresponding weights and decoder_buf_. - - The decoder_buf_ is reused. - */ - decoder_->initialize(param[layer], decoder_buf_); - -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - decoder_->forward(from_tensor_[from_id], - nullptr, // memory_tensor should be nullptr - K_cache_[0] + layer * cache_size, - V_cache_[0] + layer * cache_size, - nullptr, nullptr, // key_mem_cache_ and value_mem_cache_ should be nullptr - nullptr, // memory_sequence_length should be nullptr - from_tensor_[out_id], step, - false); - -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - } - decoder_->decoder_norm1(from_tensor_[out_id], decoding_params.layernorm.gamma, - decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k); - -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - - DataType_ alpha = DataType_(1.0f); - DataType_ beta = DataType_(0.0f); - - cublasGemmAlgo_t cublasAlgo = static_cast(cublasAlgo_[0]); - check_cuda_error(cublasGemmEx(decoding_params.cublas_handle, -#if EMBEDDING_TRANSPOSE_OPT == 1 - CUBLAS_OP_N, CUBLAS_OP_N, - args_.vocab_size_padded_, m, k, - &alpha, - embedding_kernel_transposed_padded_, - AType_, args_.vocab_size_padded_, //n -#else - CUBLAS_OP_T, CUBLAS_OP_N, - n, m, k, - &alpha, - decoding_params.embedding_kernel, - AType_, k, -#endif - decoder_normed_result_buf_, BType_, k, - &beta, - logits_buf_, CType_, -#if EMBEDDING_TRANSPOSE_OPT == 1 - args_.vocab_size_padded_, -#else - n, -#endif - computeType_, - cublasAlgo)); - -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - - apply_temperature_penalty_kernelLauncher(logits_buf_, - (DataType_) args_.temperature_, - m, - n, - decoding_params.stream); - int random_num = rand(); - if (do_beamsearch) - { - // Sampling - if(args_.candidate_num_ > 0 && args_.probability_threshold_ == 0.0) - { - // top k sampling - topK_sampling_kernel_kernelLauncher(topk_workspace_, - topk_workspace_size_, - logits_buf_, - decoding_params.output_ids + step * m, - nullptr, - nullptr, - random_num, - args_, - decoding_params.stream); - } - else if(args_.candidate_num_ == 0 && args_.probability_threshold_ > 0.0f) - { - // top p sampling - softmax_kernelLauncher(logits_buf_, - (DataType_*) nullptr, - args_.end_id_, - nullptr, - m, - n, - decoding_params.stream); -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - topP_sampling_kernel_kernelLauncher(topp_workspace_, - topp_workspace_size_, - logits_buf_, - topp_id_vals_buf_, - topp_offset_buf_, - nullptr, - random_num, - args_, - decoding_params.output_ids + step * m, - nullptr, - n, - decoding_params.stream); - } - else if(args_.candidate_num_ > 0 && args_.probability_threshold_ > 0.0f) - { - topK_topP_sampling_kernel_kernelLauncher(topk_topp_workspace_, - topk_topp_workspace_size_, - decoding_params.output_ids + step * m, - logits_buf_, - random_num, - args_, - decoding_params.stream); - } -#ifndef NDEBUG - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -#endif - } - else - { - // else of do_beamsearch (set pre-determined word ids) - check_cuda_error(cudaMemcpyAsync(decoding_params.output_ids + step*m, args_.start_ids_[step], - m*sizeof(int), cudaMemcpyHostToDevice, decoding_params.stream)); - } - } // end for decoding step for llop - } // end of forward - - virtual ~DecodingGpt2() - { - delete[] K_cache_; - delete[] V_cache_; - delete decoder_; - allocator_.free(buf_); - for(int i = 0; i < args_.start_len_; i++) - { - delete [] args_.start_ids_[i]; - } - delete [] args_.start_ids_; - } -}; - -} //namespace fastertransformer diff --git a/fastertransformer/open_decoder.h b/fastertransformer/open_decoder.h index 8b0ce610c..dfdce555c 100644 --- a/fastertransformer/open_decoder.h +++ b/fastertransformer/open_decoder.h @@ -18,18 +18,31 @@ **/ #pragma once -#include "fastertransformer/allocator.h" -#include "fastertransformer/common.h" -#include "fastertransformer/common_structure.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/cuda/attention_kernels.cuh" +#include "fastertransformer/cuda/transformer_kernels.cuh" +#include "fastertransformer/cuda/open_decoder.cuh" +#include "fastertransformer/utils/nvtx_utils.h" +#include "fastertransformer/utils/allocator.h" +#include "fastertransformer/utils/common.h" +#include "fastertransformer/utils/functions.h" +#include "fastertransformer/utils/common_structure.h" #include #include #include +#include "fastertransformer/utils/nvtx_utils.h" +#include "fastertransformer/utils/nccl_utils.h" + +// use new attention implementation with [B, H, Dh/x, L, x] cache format for the keys + // and [B, H, L, Dh] for values + + #define USE_CACHE_BATCH_MAJOR_ATTENTION 1 namespace fastertransformer { template -class DecoderInitParam +class DecoderInitParam : public AbstractParam { public: /* weights for masked_multi_head_attention */ @@ -42,7 +55,11 @@ class DecoderInitParam LayerNormWeight ffn_layernorm; FFNWeight ffn; cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; cudaStream_t stream; + + int request_batch_size = -1; + int request_max_mem_seq_len = -1; }; template @@ -65,15 +82,16 @@ class OpenDecoder typedef DecoderTransformerTraits Traits_; typedef typename Traits_::DataType DataType_; DecoderInitParam param_; + TensorParallelParam t_parallel_param_; + LayerParallelParam l_parallel_param_; const cudaDataType_t computeType_ = Traits_::computeType; const cudaDataType_t AType_ = Traits_::AType; const cudaDataType_t BType_ = Traits_::BType; const cudaDataType_t CType_ = Traits_::CType; - int cublasAlgo_[5]; + std::map cublasAlgoMap_; - int batch_size_; - int max_seq_len_; + int max_batch_size_ = -1; int head_num_; int size_per_head_; int hidden_units_; @@ -88,115 +106,142 @@ class OpenDecoder DataType_ **qkv_kernel_; DataType_ **qkv_input_; DataType_ **qkv_buf_; + void* cublas_workspace_ = nullptr; - bool is_fuse_QKV; - + bool is_fuse_QKV_in_batched_gemm_; + const bool is_fuse_QKV_in_normal_gemm_; public: - OpenDecoder(int batch_size, int seq_len, - int head_num, int size_per_head, - int memory_hidden_units) : batch_size_(batch_size), - max_seq_len_(seq_len), head_num_(head_num), - size_per_head_(size_per_head), - memory_hidden_units_(memory_hidden_units) + + void judgeFusedQKV() + { + is_fuse_QKV_in_batched_gemm_ = false; + int m, n, k, dataType; + if (std::is_same::value) + dataType = HALF_DATATYPE; + else + dataType = FLOAT_DATATYPE; + + m = l_parallel_param_.local_batch_size; + n = t_parallel_param_.local_hidden_units_; + k = hidden_units_; + char mark[256], mark2[256]; + sprintf(mark, "1_%d_%d_%d_%d", n, m, k, dataType); + sprintf(mark2, "3_%d_%d_%d_%d", n, m, k, dataType); + if ( + cublasAlgoMap_.find(mark) != cublasAlgoMap_.end() && + cublasAlgoMap_.find(mark2) != cublasAlgoMap_.end() && + 3*cublasAlgoMap_[mark].exec_time > cublasAlgoMap_[mark2].exec_time + ) + { + is_fuse_QKV_in_batched_gemm_ = true; + } + } + + + OpenDecoder(int head_num, int size_per_head, + int memory_hidden_units, + bool is_fuse_QKV_in_normal_gemm = false) : + head_num_(head_num), + size_per_head_(size_per_head), + memory_hidden_units_(memory_hidden_units), + is_fuse_QKV_in_normal_gemm_(is_fuse_QKV_in_normal_gemm) { #ifndef NDEBUG PRINT_FUNC_NAME_(); #endif hidden_units_ = head_num_ * size_per_head_; + t_parallel_param_.local_head_num_ = head_num_; + t_parallel_param_.local_hidden_units_ = hidden_units_; - FILE *fd = fopen("decoding_gemm_config.in", "r"); - int err = 0; - if (fd == NULL) + int isConfigExist = access("decoding_gemm_config.in", 0); + if (isConfigExist == -1) { printf("[WARNING] decoding_gemm_config.in is not found\n"); } else { - // First number is a setting for gemm in Decoding, which computes the embedding output. - // so we need to skip the number - float split_time, fused_time; - err = fscanf(fd, "%*d %*f %d %f %d %*f %d %*f %d %*f %d %f", &cublasAlgo_[0], &split_time, &cublasAlgo_[1], - &cublasAlgo_[2], &cublasAlgo_[3], &cublasAlgo_[4], &fused_time); - is_fuse_QKV = fused_time < split_time * 3 ? true : false; - fclose(fd); - } - if (err != 7) - { - // printf("[WARNING] decoder loading GEMM algorithms error, using default GEMM algorithms!\n"); - int default_algo; - if (Traits_::OpType == OperationType::FP32) - { - default_algo = CUBLAS_GEMM_DEFAULT; - } - else - { - default_algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - } - for (int i = 0; i < 5; i++) - cublasAlgo_[i] = default_algo; - is_fuse_QKV = false; - } - else - { + readAlgoFromConfig(cublasAlgoMap_); // check that the gemm_config setting is runnable - if (Traits_::OpType == OperationType::FP32) + for (auto iter = cublasAlgoMap_.begin() ; iter != cublasAlgoMap_.end() ; iter++) { - for (int i = 0; i < 5; i++) + int algoId = iter->second.algoId; + int stages = iter->second.stages; + //only check for cublas + if (stages != -1) + continue; + if (Traits_::OpType == OperationType::FP32) { - if (cublasAlgo_[i] > CUBLAS_GEMM_ALGO23 || cublasAlgo_[i] < CUBLAS_GEMM_DEFAULT) + if (algoId > CUBLAS_GEMM_ALGO23 || algoId < CUBLAS_GEMM_DEFAULT) { // the algorithm is not for FP32 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", (int)cublasAlgo_[i]); + printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", algoId); exit(-1); } } - } - else - { - for (int i = 0; i < 5; i++) + else { - if (cublasAlgo_[i] > CUBLAS_GEMM_ALGO15_TENSOR_OP || cublasAlgo_[i] < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + if (algoId > CUBLAS_GEMM_ALGO15_TENSOR_OP || algoId < CUBLAS_GEMM_DEFAULT_TENSOR_OP) { // the algorithm is not for FP16 - printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", (int)cublasAlgo_[i]); + printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", algoId); exit(-1); } } } } + judgeFusedQKV(); + } + + inline void set_max_batch_size(int batch_size) + { + max_batch_size_ = batch_size; } int getWorkspaceSize() { - int buf_size = batch_size_ * hidden_units_; - return 13 * buf_size + sizeof(DataType_ *) * 9; + assert(max_batch_size_ != -1); + return 13 * max_batch_size_ * hidden_units_ + sizeof(DataType_ *) * 9; + } + + void set_tensor_parallel_param(const TensorParallelParam param) + { + t_parallel_param_ = param; } - void initialize(DecoderInitParam param, DataType_ *buf) + void set_layer_parallel_param(const LayerParallelParam param) + { + l_parallel_param_ = param; + } + + void initialize(DecoderInitParam param, DataType_ *buf, void *cublas_workapsce, bool set_local_batch = true) { #ifndef NDEBUG - // PRINT_FUNC_NAME_(); + PRINT_FUNC_NAME_(); #endif param_ = param; - const int buf_size = batch_size_ * hidden_units_; + if(l_parallel_param_.local_batch_size == -1 || set_local_batch == true) l_parallel_param_.local_batch_size = param_.request_batch_size; + const int buf_size = max_batch_size_ * hidden_units_; + //cublas_workspace_ should be the start pointer of cudaMalloc() + //to ensure 16B alignemnet + cublas_workspace_ = cublas_workapsce; norm_from_tensor_buf_ = buf; query_buf_ = buf + buf_size; //store the query values (from_tensor * Q) in both masked and multi-head attention - key_buf_ = buf + 2 * buf_size; - value_buf_ = buf + 3 * buf_size; - context_buf_ = buf + 4 * buf_size; //store the context result (softmax(qk)v) in both masked and multi-head attention + key_buf_ = query_buf_ + buf_size; + value_buf_ = key_buf_ + buf_size; + context_buf_ = value_buf_ + buf_size; //store the context result (softmax(qk)v) in both masked and multi-head attention - masked_output_buf_ = buf + 5 * buf_size; //masked_attention_output - norm_masked_output_buf_ = buf + 6 * buf_size; //norm(masked_attention_output) + masked_output_buf_ = context_buf_ + buf_size; //masked_attention_output + norm_masked_output_buf_ = masked_output_buf_ + buf_size; //norm(masked_attention_output) - cross_output_buf_ = buf + 7 * buf_size; //mutli-head attention_output - norm_cross_output_buf_ = buf + 8 * buf_size; //norm(multi-head attention_output) - ffn_inner_buf_ = buf + 9 * buf_size; //4 buf size to store inner product + cross_output_buf_ = norm_masked_output_buf_ + buf_size; //mutli-head attention_output + norm_cross_output_buf_ = cross_output_buf_ + buf_size; //norm(multi-head attention_output) + ffn_inner_buf_ = norm_cross_output_buf_ + buf_size; //4 buf size to store inner product qkv_kernel_ = (DataType_ **)(ffn_inner_buf_ + 4 * buf_size); qkv_input_ = qkv_kernel_ + 3; qkv_buf_ = qkv_input_ + 3; - if (is_fuse_QKV == true) + if (is_fuse_QKV_in_normal_gemm_ == false && is_fuse_QKV_in_batched_gemm_ == true) { const DataType_ *hA[]{param_.self_attention.query_weight.kernel, param_.self_attention.key_weight.kernel, @@ -211,30 +256,32 @@ class OpenDecoder DataType_ *key_cache_, DataType_ *value_cache_, DataType_ *key_mem_cache_, DataType_ *value_mem_cache_, const int *memory_sequence_length, DataType_ *decoder_output, const int step, - const bool is_cross_attention) + const int decoder_max_seq_len, const bool is_cross_attention, const bool* finished = nullptr) { #ifndef NDEBUG - // PRINT_FUNC_NAME_(); + PRINT_FUNC_NAME_(); #endif - const int m = batch_size_; - const int n = hidden_units_; - + const int m = l_parallel_param_.local_batch_size; try { /* masked multi-head attention */ /* layernorm(from_tensor) -> norm_from_tensor_buf_ */ - decoder_norm1(from_tensor, - param_.self_layernorm.gamma, - param_.self_layernorm.beta, - norm_from_tensor_buf_, - m, - n); + + layer_norm(from_tensor, + param_.self_layernorm.gamma, + param_.self_layernorm.beta, + norm_from_tensor_buf_, + m, + hidden_units_, + param_.stream); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif - masked_multi_head_attention(norm_from_tensor_buf_, key_cache_, value_cache_, masked_output_buf_, step); + PUSH_RANGE("Transformer/slf_attn") + masked_multi_head_attention(norm_from_tensor_buf_, key_cache_, value_cache_, masked_output_buf_, finished, step, decoder_max_seq_len); + POP_RANGE #ifndef NDEBUG cudaDeviceSynchronize(); @@ -248,12 +295,12 @@ class OpenDecoder masked_output_buf_ + from_tensor -> masked_output_buf_ norm(masked_output_buf_) -> norm_masked_output_buf_ */ - decoder_norm2(from_tensor, - param_.cross_layernorm.gamma, - param_.cross_layernorm.beta, - param_.self_attention.attention_output_weight.bias, - masked_output_buf_, - norm_masked_output_buf_, m, n); + add_bias_input_layernorm_2_kernelLauncher(from_tensor, + param_.cross_layernorm.gamma, + param_.cross_layernorm.beta, + param_.self_attention.attention_output_weight.bias, + masked_output_buf_, + norm_masked_output_buf_, m, hidden_units_, param_.stream); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); @@ -262,7 +309,7 @@ class OpenDecoder /* cross attention with memory */ cross_multi_head_attention(norm_masked_output_buf_, memory_tensor, key_mem_cache_, value_mem_cache_, cross_output_buf_, - memory_sequence_length, max_seq_len_, step); + memory_sequence_length, finished, param_.request_max_mem_seq_len, step); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); @@ -271,42 +318,45 @@ class OpenDecoder cross_output_buf_ + bias + masked_output_buf_ -> cross_output_buf_ norm(cross_otuput_buf) -> normed_last_context (input for ffn) */ - decoder_norm2(masked_output_buf_, - param_.ffn_layernorm.gamma, - param_.ffn_layernorm.beta, - param_.cross_attention.attention_output_weight.bias, - cross_output_buf_, - norm_cross_output_buf_, m, n); + add_bias_input_layernorm_2_kernelLauncher(masked_output_buf_, + param_.ffn_layernorm.gamma, + param_.ffn_layernorm.beta, + param_.cross_attention.attention_output_weight.bias, + cross_output_buf_, + norm_cross_output_buf_, m, hidden_units_, param_.stream); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif - ffn(norm_cross_output_buf_, ffn_inner_buf_, decoder_output, m, 4 * n, n, ActivationType::RELU); + ffn(norm_cross_output_buf_, ffn_inner_buf_, decoder_output, m, 4 * t_parallel_param_.local_hidden_units_, hidden_units_, ActivationType::RELU); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif - add_bias_input(decoder_output, cross_output_buf_, m, n); + add_bias_input_kernelLauncher(decoder_output, param_.ffn.output_weight.bias, cross_output_buf_, m, hidden_units_, param_.stream); } else { - decoder_norm2(from_tensor, - param_.ffn_layernorm.gamma, - param_.ffn_layernorm.beta, - param_.self_attention.attention_output_weight.bias, - masked_output_buf_, - norm_masked_output_buf_, m, n); + add_bias_input_layernorm_2_kernelLauncher(from_tensor, + param_.ffn_layernorm.gamma, + param_.ffn_layernorm.beta, + param_.self_attention.attention_output_weight.bias, + masked_output_buf_, + norm_masked_output_buf_, m, hidden_units_, param_.stream); #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif // For GPT-2 decoder - ffn(norm_masked_output_buf_, ffn_inner_buf_, decoder_output, m, 4 * n, n, ActivationType::GELU); + PUSH_RANGE("Transformer/MLP") + ffn(norm_masked_output_buf_, ffn_inner_buf_, decoder_output, m, 4 * t_parallel_param_.local_hidden_units_, hidden_units_, ActivationType::GELU); + POP_RANGE + #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif - add_bias_input(decoder_output, masked_output_buf_, m, n); + add_bias_input_kernelLauncher(decoder_output, param_.ffn.output_weight.bias, masked_output_buf_, m, hidden_units_, param_.stream); } #ifndef NDEBUG cudaDeviceSynchronize(); @@ -319,26 +369,575 @@ class OpenDecoder throw error; } } + + size_t getContextWorkspaceSize(const int seq_len, const int local_batch_size) + { + const size_t m = local_batch_size * seq_len; + const size_t qk_buf_size = (size_t)(ceil(local_batch_size * t_parallel_param_.local_head_num_ * seq_len * seq_len / 4.)) * 4; + const size_t attn_work_space_size = 3 * m * hidden_units_ /* Q, K, V */ + + 3 * m * t_parallel_param_.local_hidden_units_ /* q_buf, k_buf, v_buf */ + + qk_buf_size + + 2 * m * t_parallel_param_.local_hidden_units_ /* trans_attn, attn */; + return (m * hidden_units_ * 3 + + attn_work_space_size + + m * t_parallel_param_.local_hidden_units_ * 4 /* ffn buffer */ ) * sizeof(DataType_); + } + + // use to compute the context of gpt model + void forward_context(DataType_* workspace, + DataType_ *decoder_output, + DataType_ *key_cache_, + DataType_ *value_cache_, + const DataType_ *from_tensor, + const DataType_ *d_attn_mask, + const int local_batch_size, + const int seq_len, + const int ite, + const int max_seq_len, + const bool is_final) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + try + { + const int m = local_batch_size * seq_len; + const int qk_buf_size = (int)(ceil(local_batch_size * t_parallel_param_.local_head_num_ * seq_len * seq_len / 4.)) * 4; + const int attn_work_space_size = 3 * m * hidden_units_ /* Q, K, V */ + + 3 * m * t_parallel_param_.local_hidden_units_ /* q_buf, k_buf, v_buf */ + + qk_buf_size + + 2 * m * t_parallel_param_.local_hidden_units_ /* trans_attn, attn */; + + // set workspace + DataType_* norm_from_tensor_buf = (DataType_*)workspace; + DataType_* attention_workspace = norm_from_tensor_buf + m * hidden_units_; + DataType_* masked_output_buf = attention_workspace + attn_work_space_size; + DataType_* norm_masked_output_buf = masked_output_buf + m * hidden_units_; + DataType_* ffn_inner_buf = norm_masked_output_buf + m * hidden_units_; + + layer_norm(from_tensor, + param_.self_layernorm.gamma, + param_.self_layernorm.beta, + norm_from_tensor_buf, + m, + hidden_units_, + param_.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + PUSH_RANGE("Transformer/slf_attn") + unfused_masked_multi_head_attention(attention_workspace, + norm_from_tensor_buf, + key_cache_, + value_cache_, + masked_output_buf, + d_attn_mask, + local_batch_size, + seq_len, + ite, + max_seq_len, + is_final); + if(is_final) return; + POP_RANGE +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + add_bias_input_layernorm_2_kernelLauncher(from_tensor, + param_.ffn_layernorm.gamma, + param_.ffn_layernorm.beta, + param_.self_attention.attention_output_weight.bias, + masked_output_buf, + norm_masked_output_buf, m, hidden_units_, param_.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + // For GPT decoder + PUSH_RANGE("Transformer/MLP"); + ffn(norm_masked_output_buf, ffn_inner_buf, decoder_output, m, 4 * t_parallel_param_.local_hidden_units_, hidden_units_, ActivationType::GELU); + POP_RANGE +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + add_bias_input_kernelLauncher(decoder_output, param_.ffn.output_weight.bias, masked_output_buf, m, hidden_units_, param_.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + catch (std::runtime_error &error) + { + throw error; + } + } + void masked_multi_head_attention(const DataType_ *from_tensor, DataType_ *key_cache_, - DataType_ *value_cache_, DataType_ *decoder_output, const int step); + DataType_ *value_cache_, DataType_ *decoder_output, + const bool* finished, const int step, const int max_seq_len) + { + int m = l_parallel_param_.local_batch_size; + int n = t_parallel_param_.local_hidden_units_; + int k = hidden_units_; + + // chose which attention to use + int decoder_max_seq_len = (getCacheFormat() != 0)? max_seq_len : -1; + + DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; + if(is_fuse_QKV_in_normal_gemm_ == true) + { + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + 3*n, m, k, + &alpha, + param_.self_attention.query_weight.kernel , AType_, 3*n, + from_tensor, BType_, k, + &beta, + query_buf_, CType_, 3*n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + fusedQKV_masked_attention_dispatch( + query_buf_, + param_.self_attention.query_weight.bias, + key_cache_, + value_cache_, + context_buf_, finished, param_.request_batch_size, l_parallel_param_.local_batch_size, + t_parallel_param_.local_head_num_, size_per_head_, step, decoder_max_seq_len, param_.stream); + } + else + { + if(is_fuse_QKV_in_batched_gemm_ == true) + { + cublasGemmAlgo_t cublasAlgo = static_cast(getAlgoIdFromMap(cublasAlgoMap_, 3, n, m, k, std::is_same::value ? FLOAT_DATATYPE : HALF_DATATYPE)); + check_cuda_error(cublasGemmBatchedEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + (const void* const*) qkv_kernel_, AType_, n, + (const void* const*) qkv_input_, BType_, k, + &beta, + (void* const*)qkv_buf_, CType_, n, + 3, + computeType_, + cublasAlgo)); + } + else + { + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.query_weight.kernel , AType_, n, + from_tensor, BType_, k, + &beta, + query_buf_, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.key_weight.kernel, AType_, n, + from_tensor, BType_, k, + &beta, + key_buf_, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.value_weight.kernel, AType_, n, + from_tensor, BType_, k, + &beta, + value_buf_, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + } + masked_attention_dispatch( + key_buf_, value_buf_, + query_buf_, param_.self_attention.query_weight.bias, + key_cache_, param_.self_attention.key_weight.bias, + value_cache_, param_.self_attention.value_weight.bias, + context_buf_, finished, param_.request_batch_size, l_parallel_param_.local_batch_size, + t_parallel_param_.local_head_num_, size_per_head_, step, decoder_max_seq_len, param_.stream); + } + + k = t_parallel_param_.local_hidden_units_; + n = hidden_units_; + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.attention_output_weight.kernel, AType_, n, + context_buf_, BType_, k, + &beta, + decoder_output, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + PUSH_RANGE("Transformer/slf_attn/all2all_reduce") + all2all_reduce_sum(decoder_output, decoder_output, m*n, + t_parallel_param_, param_.stream); + POP_RANGE + } + + /* attention with source sentence */ void cross_multi_head_attention(const DataType_ *from_tensor, const DataType_ *memory_tensor, DataType_ *key_mem_cache_, DataType_ *value_mem_cache_, - DataType_ *decoder_output, const int *memory_sequence_length, - const int max_seq_len, const int step); + DataType_ *decoder_output, const int *memory_sequence_length, const bool* finished, + const int max_seq_len, const int step) + { + int m = param_.request_batch_size; + int n = t_parallel_param_.local_hidden_units_; + int k = hidden_units_; + + DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; + + //reuse the query_buf + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.cross_attention.query_weight.kernel, AType_, n, + from_tensor, BType_, k, + &beta, + query_buf_, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + if(step == 1) + { + m *= max_seq_len; + k = memory_hidden_units_; + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.cross_attention.key_weight.kernel, AType_, n, + memory_tensor, BType_, k, + &beta, + key_mem_cache_, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.cross_attention.value_weight.kernel, AType_, n, + memory_tensor, BType_, k, + &beta, + value_mem_cache_, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + k = t_parallel_param_.local_hidden_units_; + } + + cross_attention_dispatch( + query_buf_, param_.cross_attention.query_weight.bias, + key_mem_cache_, param_.cross_attention.key_weight.bias, + value_mem_cache_, param_.cross_attention.value_weight.bias, + memory_sequence_length, context_buf_, finished, param_.request_batch_size, + head_num_, size_per_head_, step, max_seq_len, param_.stream); + + m = param_.request_batch_size; + n = hidden_units_; + k = t_parallel_param_.local_hidden_units_; + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.cross_attention.attention_output_weight.kernel, AType_, n, + context_buf_, BType_, k, + &beta, + decoder_output, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + } + void ffn(const DataType_ *input, DataType_ *ffn_inner, DataType_ *output, - const int m, const int inner_size, const int n, ActivationType activation_type); + const int m, const int inner_size, const int n, ActivationType activation_type) + { + int m1 = m, k1 = n, n1 = inner_size; + DataType_ alpha = (DataType_)1.0f; + DataType_ beta = (DataType_)0.0f; + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n1, m1, k1, + &alpha, + param_.ffn.intermediate_weight.kernel, AType_, n1, + input, BType_, k1, + &beta, + ffn_inner, CType_, n1, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + add_bias_act_kernelLauncher(ffn_inner, param_.ffn.intermediate_weight.bias, m1, inner_size, activation_type, param_.stream); + + int m2 = m, n2 = n, k2 = inner_size; + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n2, m2, k2, + &alpha, + param_.ffn.output_weight.kernel, AType_, n2, + ffn_inner, BType_, k2, + &beta, + output, CType_, n2, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + PUSH_RANGE("Transformer/MLP/all2all_reduce") + all2all_reduce_sum(output, output, m*n, + t_parallel_param_, param_.stream); + POP_RANGE + } + + void unfused_masked_multi_head_attention(DataType_ *workspace, + const DataType_* from_tensor, + DataType_* key_cache_, + DataType_* value_cache_, + DataType_* decoder_output, + const DataType_* attr_mask, + const int local_batch_size, + const int seq_len, + const int ite, + const int max_seq_len, + const bool is_final) + { + const DataType_ scalar = 1 / sqrtf(size_per_head_ * 1.0f); + const int m = local_batch_size * seq_len; + + const int qk_buf_size = (int)(ceil(local_batch_size * t_parallel_param_.local_head_num_ * seq_len * seq_len / 4.)) * 4; + + DataType_* Q = workspace; + DataType_* K = Q + m * hidden_units_; + DataType_* V = K + m * hidden_units_; + DataType_* q_buf = V + m * hidden_units_; + DataType_* k_buf = q_buf + m * t_parallel_param_.local_hidden_units_; + DataType_* v_buf = k_buf + m * t_parallel_param_.local_hidden_units_; + DataType_* qk_buf = v_buf + m * t_parallel_param_.local_hidden_units_; + DataType_* attn_trans_out = qk_buf + qk_buf_size; + DataType_* attn_out = attn_trans_out + m * t_parallel_param_.local_hidden_units_; + + DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f; + + if(is_fuse_QKV_in_normal_gemm_ == true) + { + const int n = t_parallel_param_.local_hidden_units_; + const int k = hidden_units_; + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + 3*n, m, k, + &alpha, + param_.self_attention.query_weight.kernel , AType_, 3*n, + from_tensor, BType_, k, + &beta, + Q, CType_, 3*n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + add_fusedQKV_bias_transpose_kernelLauncher( + q_buf, k_buf, v_buf, + Q, param_.self_attention.query_weight.bias, + local_batch_size, seq_len, + t_parallel_param_.local_head_num_, + size_per_head_, param_.stream); + } + else + { + const int n = t_parallel_param_.local_hidden_units_; + const int k = hidden_units_; + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.query_weight.kernel , AType_, n, + from_tensor, BType_, k, + &beta, + Q, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.key_weight.kernel , AType_, n, + from_tensor, BType_, k, + &beta, + K, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.value_weight.kernel , AType_, n, + from_tensor, BType_, k, + &beta, + V, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + add_QKV_bias_transpose_kernelLauncher(q_buf, k_buf, v_buf, + Q, param_.self_attention.query_weight.bias, + K, param_.self_attention.key_weight.bias, + V, param_.self_attention.value_weight.bias, + local_batch_size, seq_len, + t_parallel_param_.local_head_num_, + size_per_head_, param_.stream); + } + + // !!! need to implement cget_cache_config + if(max_seq_len == -1 || USE_CACHE_BATCH_MAJOR_ATTENTION == 0 ) + { + transpose_4d_kernelLauncher(key_cache_, k_buf, + local_batch_size, + seq_len, + size_per_head_, + t_parallel_param_.local_hidden_units_, + t_parallel_param_.local_head_num_, + param_.request_batch_size, + ite, + param_.stream); + + transpose_4d_kernelLauncher(value_cache_, v_buf, + local_batch_size, + seq_len, + size_per_head_, + t_parallel_param_.local_hidden_units_, + t_parallel_param_.local_head_num_, + param_.request_batch_size, + ite, + param_.stream); + } + else if (USE_CACHE_BATCH_MAJOR_ATTENTION == 1) + { + // Use batch major + // put k/v_buf from shape [B, H, L, Dh] + // to cache [B, H, Dh/x, L, x] and [B, H, L, Dh/x, x] + transpose_4d_batch_major_kernelLauncher(key_cache_, value_cache_, + k_buf, v_buf, + local_batch_size, + seq_len, + max_seq_len, + size_per_head_, + t_parallel_param_.local_head_num_, + param_.stream); + } + else + { + printf("[ERROR] Can not decide on the cache config \n"); + exit(-1); + } - void decoder_norm1(const DataType_ *from_tensor, const DataType_ *gamma, - const DataType_ *beta, DataType_ *norm_from_tensor_buf_, const int m, const int n); + if(is_final) return; + + cublasGemmAlgo_t cublasAlgo = static_cast(getAlgoIdFromMap(cublasAlgoMap_, local_batch_size * t_parallel_param_.local_head_num_, seq_len, seq_len, size_per_head_, std::is_same::value ? FLOAT_DATATYPE : HALF_DATATYPE)); + + check_cuda_error(cublasGemmStridedBatchedEx(param_.cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + seq_len, seq_len, size_per_head_, + &alpha, + k_buf, AType_, size_per_head_, seq_len * size_per_head_, + q_buf, BType_, size_per_head_, seq_len * size_per_head_, + &beta, + qk_buf, CType_, seq_len, seq_len * seq_len, + local_batch_size * t_parallel_param_.local_head_num_, + computeType_, + cublasAlgo)); + + attn_softmax_kernelLauncher(qk_buf, + attr_mask, + local_batch_size, + seq_len, + t_parallel_param_.local_head_num_, + scalar, + param_.stream); + + cublasAlgo = static_cast(getAlgoIdFromMap(cublasAlgoMap_, local_batch_size * t_parallel_param_.local_head_num_, size_per_head_, seq_len, seq_len, std::is_same::value ? FLOAT_DATATYPE : HALF_DATATYPE)); + + check_cuda_error(cublasGemmStridedBatchedEx(param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + size_per_head_, seq_len, seq_len, + &alpha, + v_buf, AType_, size_per_head_, seq_len * size_per_head_, + qk_buf, BType_, seq_len, seq_len * seq_len, + &beta, + attn_trans_out, CType_, size_per_head_, seq_len * size_per_head_, + local_batch_size * t_parallel_param_.local_head_num_, + computeType_, + cublasAlgo)); + + transpose_kernelLauncher(attn_out, + attn_trans_out, + local_batch_size, + seq_len, + t_parallel_param_.local_head_num_, + size_per_head_, + param_.stream); - void decoder_norm2(const DataType_ *from_tensor, const DataType_ *gamma, - const DataType_ *beta, const DataType_ *bias, - DataType_ *output, DataType_ *norm_output_buf_, - const int m, const int n); + { + const int k = t_parallel_param_.local_hidden_units_; + const int n = hidden_units_; + + cublasMM_cublasLtMM_wrapper_decoder(param_.cublaslt_handle, + param_.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + param_.self_attention.attention_output_weight.kernel, AType_, n, + attn_out, BType_, k, + &beta, + decoder_output, CType_, n, + param_.stream, cublasAlgoMap_, + cublas_workspace_); + + PUSH_RANGE("Transformer/slf_attn/all2all_reduce") + all2all_reduce_sum(decoder_output, decoder_output, m*n, + t_parallel_param_, param_.stream); + POP_RANGE + } + } - void add_bias_input(DataType_ *output, const DataType_ *input, const int m, const int n); + int getCacheFormat() + { + int x = (Traits_::OpType == OperationType::FP32)? 4 : 8; + return (USE_CACHE_BATCH_MAJOR_ATTENTION == 1 && size_per_head_ % x == 0)? x : 0; + } ~OpenDecoder() { @@ -355,5 +954,11 @@ class OpenDecoder norm_cross_output_buf_ = nullptr; ffn_inner_buf_ = nullptr; } + + inline void set_local_batch_size(int local_batch) + { + l_parallel_param_.local_batch_size = local_batch; + } }; + } //namespace fastertransformer diff --git a/fastertransformer/standard_encoder.h b/fastertransformer/standard_encoder.h new file mode 100644 index 000000000..98aa1430e --- /dev/null +++ b/fastertransformer/standard_encoder.h @@ -0,0 +1,787 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Standard Encoder transformer + **/ + +#pragma once + +#include +#include "fastertransformer/utils/allocator.h" +#include "fastertransformer/utils/common_structure.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/cuda/cuda_int8_kernels.h" +#include "fastertransformer/cuda/open_attention.h" +#include "fastertransformer/gemm_test/encoder_gemm_func.h" +#include "fastertransformer/gemm_test/encoder_igemm_func.h" +#include "fastertransformer/utils/functions.h" + +namespace fastertransformer +{ + +template +class EncoderInitParam +{ +public: + const T *from_tensor = nullptr; + const T *to_tensor = nullptr; + + LayerNormWeight input_layernorm; + AttentionWeight self_attention; + const T *attr_mask = nullptr; + LayerNormWeight self_layernorm; + + FFNWeight ffn; + + T *transformer_out; + cublasHandle_t cublas_handle = nullptr; + cublasLtHandle_t cublaslt_handle = nullptr; + cudaStream_t stream = 0; + + const int* sequence_id_offset = nullptr; + int valid_word_num = -1; + int layer_idx = 0; + int layer_num = 12; + + //Part 1: + // First 80 are for activation amaxs. For each activation amax, there are 4 values: amax, amax/127.0f, amax/127.0f/127.0f, 127.0f/amax -- input_amax 0-3 , Q_aftergemm_amax 4-7, Qbias_amax 8-11, K_aftergemm_amax 12-15, Kbias_amax 16-19, V_aftergemm_amax 20-23, Vbias_amax 24-27, bmm1_amax 28-31, Softmax_amax 32-35, bmm2_amax 36-39, Proj_aftergemm_scale 40-43, ProjBiasNorm_amax 44-47, FC1_aftergemm_amax 48-51, F1Bias_amax 52-55, FC2_aftergemm_amax 56-59, F2BiasNorm_amax 60-63, reserve 64-79 + //Part 2: + // Kernel amaxs, for each kernel amax list, there are output_channel values : query_weight_amax_list, key_weight_amax_list, value_weight_amax_list, proj_weight_amax_list, FC1_weight_amax_list, FC2_weight_amax_list + //Part 3: + // Int8 gemm deQFactor list (8 values): Q_deQ_scale, K_deQ_scale, V_deQ_scale, bmm1_deQ_scale, bmm2_deQ_scale, FC0_deQ_scale, FC1_deQ_scale, FC2_deQ_scale + //Part 4: + // Amax used in trt fused mha kernel (3 values) : QKVbias_amax, Softmax_amax, bmm2_amax + const float *amaxList = nullptr; + const int* trt_seqlen_offset = nullptr; + int trt_seqlen_size = -1; +}; + +template class MultiHeadAttention_> +class OpenEncoderTraits; + +template