From 8666bc1278ead62a130809653701cef47ec3356a Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Fri, 10 Nov 2023 04:44:22 -0800 Subject: [PATCH 1/2] Update TensorRT-LLM --- README.md | 1 + benchmarks/python/mem_monitor.py | 3 +- cpp/CMakeLists.txt | 47 +- cpp/cmake/modules/parse_make_options.cmake | 28 + .../tensorrt_llm/batch_manager/GptManager.h | 1 + .../batch_manager/kvCacheConfig.h | 7 +- .../batch_manager/kvCacheManager.h | 5 +- .../tensorrt_llm/runtime/decodingInput.h | 4 +- .../tensorrt_llm/runtime/generationInput.h | 2 +- .../tensorrt_llm/runtime/gptDecoderBatch.h | 5 +- cpp/include/tensorrt_llm/runtime/gptSession.h | 7 +- .../runtime/iStatefulGptDecoder.h | 4 +- cpp/include/tensorrt_llm/runtime/iTensor.h | 44 +- .../tensorrt_llm/runtime/promptTuningParams.h | 2 +- cpp/tensorrt_llm/CMakeLists.txt | 8 - .../libtensorrt_llm_batch_manager_static.a | 3 - ...sorrt_llm_batch_manager_static.pre_cxx11.a | 3 - .../aarch64-linux-gnu/version.txt | 3 - .../libtensorrt_llm_batch_manager_static.a | 4 +- ...sorrt_llm_batch_manager_static.pre_cxx11.a | 4 +- .../x86_64-linux-gnu/version.txt | 4 +- .../fmhaRunner.cpp | 20 +- .../fmhaRunner.h | 9 +- .../fused_multihead_attention_common.h | 3 +- .../fused_multihead_attention_v2.h | 1 + .../kernels/decoderMaskedMultiheadAttention.h | 11 +- .../decoderMaskedMultiheadAttentionLaunch.h | 184 +++--- .../decoderMaskedMultiheadAttentionTemplate.h | 166 +++-- cpp/tensorrt_llm/kernels/gptKernels.cu | 18 +- cpp/tensorrt_llm/kernels/gptKernels.h | 3 + .../kernels/preQuantScaleKernel.cu | 34 +- .../kernels/preQuantScaleKernel.h | 4 + .../kernels/unfusedAttentionKernels.cu | 22 +- .../kernels/unfusedAttentionKernels.h | 13 +- .../kernels/unfusedAttentionKernels_2.cu | 137 ++-- .../kernels/weightOnlyBatchedGemv/common.h | 82 ++- .../kernels/weightOnlyBatchedGemv/kernel.h | 202 ++++-- .../weightOnlyBatchedGemv/kernelLauncher.cu | 84 +-- .../weightOnlyBatchedGemv/kernelLauncher.h | 3 +- .../weightOnlyBatchedGemvBs1Int4b.cu | 56 +- .../weightOnlyBatchedGemvBs1Int8b.cu | 56 +- .../weightOnlyBatchedGemvBs2Int4b.cu | 56 +- .../weightOnlyBatchedGemvBs2Int8b.cu | 56 +- .../weightOnlyBatchedGemvBs3Int4b.cu | 56 +- .../weightOnlyBatchedGemvBs3Int8b.cu | 56 +- .../weightOnlyBatchedGemvBs4Int4b.cu | 56 +- .../weightOnlyBatchedGemvBs4Int8b.cu | 56 +- .../layers/baseBeamSearchLayer.cu | 24 +- cpp/tensorrt_llm/layers/baseBeamSearchLayer.h | 6 +- .../layers/dynamicDecodeLayer.cpp | 3 +- cpp/tensorrt_llm/layers/dynamicDecodeLayer.h | 5 +- .../bertAttentionPlugin.cpp | 2 +- .../gptAttentionCommon/gptAttentionCommon.cpp | 95 +-- .../gptAttentionCommon/gptAttentionCommon.h | 16 +- .../gptAttentionPlugin/gptAttentionPlugin.cpp | 28 +- .../gptAttentionPlugin/gptAttentionPlugin.h | 30 +- .../weightOnlyGroupwiseQuantMatmulPlugin.cpp | 72 +- .../weightOnlyQuantMatmulPlugin.cpp | 122 ++-- cpp/tensorrt_llm/pybind/bindings.cpp | 6 +- .../pybind/runtime/generationInput.cpp | 4 +- cpp/tensorrt_llm/runtime/gptDecoder.cpp | 4 +- cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp | 29 +- cpp/tensorrt_llm/runtime/gptSession.cpp | 37 +- cpp/tensorrt_llm/runtime/iTensor.cpp | 34 +- .../runtime/promptTuningParams.cpp | 2 +- cpp/tensorrt_llm/runtime/runtimeBuffers.cpp | 40 +- cpp/tensorrt_llm/runtime/runtimeBuffers.h | 9 +- .../runtime/statefulGptDecoder.cpp | 17 +- cpp/tensorrt_llm/runtime/statefulGptDecoder.h | 7 +- cpp/tensorrt_llm/runtime/tensorView.h | 10 +- cpp/tensorrt_llm/runtime/tllmBuffers.h | 39 +- cpp/tensorrt_llm/runtime/torchView.h | 20 +- cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp | 18 +- cpp/tensorrt_llm/thop/dynamicDecodeOp.h | 8 +- cpp/tests/CMakeLists.txt | 7 +- .../weightOnly/weightOnlyKernelTest.cpp | 429 ++++++++++++ cpp/tests/runtime/gptDecoderBatchTest.cpp | 8 +- cpp/tests/runtime/gptDecoderTest.cpp | 2 +- cpp/tests/runtime/gptSessionTest.cpp | 1 + cpp/tests/runtime/iTensorTest.cpp | 96 ++- cpp/tests/runtime/samplingTest.cpp | 2 +- cpp/tests/runtime/tllmBuffersTest.cpp | 104 +-- docs/source/gpt_attention.md | 28 + docs/source/gpt_runtime.md | 58 +- examples/baichuan/build.py | 30 +- examples/baichuan/run.py | 10 +- examples/baichuan/summarize.py | 16 +- examples/baichuan/weight.py | 7 + examples/bloom/build.py | 20 +- examples/bloom/run.py | 9 +- examples/bloom/summarize.py | 9 +- examples/chatglm/build.py | 14 + examples/chatglm/summarize.py | 9 +- examples/enc_dec/build.py | 10 +- examples/falcon/build.py | 31 + examples/falcon/run.py | 9 +- examples/falcon/summarize.py | 9 +- examples/gpt/build.py | 34 +- examples/gpt/run.py | 10 +- examples/gpt/summarize.py | 13 +- examples/gpt/weight.py | 29 +- examples/gptj/README.md | 3 +- examples/gptj/build.py | 32 + examples/gptj/run.py | 10 +- examples/gptj/summarize.py | 9 +- examples/gptneox/build.py | 21 +- examples/gptneox/run.py | 10 +- examples/gptneox/summarize.py | 9 +- examples/internlm/build.py | 10 + examples/llama/README.md | 27 + examples/llama/build.py | 35 +- examples/llama/requirements.txt | 2 +- examples/llama/run.py | 14 +- examples/llama/summarize.py | 9 +- examples/llama/summarize_long.py | 304 +++++++++ examples/llama/weight.py | 617 ++++++++---------- examples/mpt/build.py | 10 + examples/opt/build.py | 20 +- examples/opt/summarize.py | 9 +- scripts/build_wheel.py | 1 + tensorrt_llm/_utils.py | 6 + tensorrt_llm/functional.py | 20 +- tensorrt_llm/layers/attention.py | 53 +- tensorrt_llm/models/baichuan/model.py | 20 +- tensorrt_llm/models/bloom/model.py | 10 +- tensorrt_llm/models/chatglm/model.py | 24 +- tensorrt_llm/models/falcon/model.py | 11 +- tensorrt_llm/models/generation_mixin.py | 15 + tensorrt_llm/models/gpt/model.py | 11 +- tensorrt_llm/models/gptj/model.py | 11 +- tensorrt_llm/models/gptneox/model.py | 10 +- tensorrt_llm/models/llama/model.py | 11 +- tensorrt_llm/models/opt/model.py | 10 +- tensorrt_llm/models/quantized/ammo.py | 19 +- tensorrt_llm/parameter.py | 10 +- tensorrt_llm/plugin/plugin.py | 6 + tensorrt_llm/profiler.py | 57 +- tensorrt_llm/quantization/layers.py | 17 +- tensorrt_llm/runtime/generation.py | 101 ++- tensorrt_llm/runtime/kv_cache_manager.py | 10 +- tensorrt_llm/runtime/session.py | 15 +- tests/attention/test_gpt_attention.py | 44 +- tests/attention/test_gpt_attention_IFB.py | 37 +- tests/model/test_bloom.py | 10 + tests/model/test_falcon.py | 7 + tests/model/test_gpt.py | 18 +- tests/model/test_gptj.py | 11 + tests/model/test_gptneox.py | 6 + tests/model/test_llama.py | 6 + tests/model/test_mistral.py | 577 ++++++++++++++++ tests/test_graph_rewriter.py | 9 +- tests/test_kv_cache_manager.py | 2 + tests/test_layer.py | 12 + 153 files changed, 3853 insertions(+), 1809 deletions(-) create mode 100644 cpp/cmake/modules/parse_make_options.cmake delete mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a delete mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a delete mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt create mode 100644 cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp create mode 100644 examples/llama/summarize_long.py create mode 100644 tests/model/test_mistral.py diff --git a/README.md b/README.md index 6319bbfa4..b422987ea 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ TensorRT-LLM [![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/) [![cuda](https://img.shields.io/badge/cuda-12.2-green)](https://developer.nvidia.com/cuda-downloads) [![trt](https://img.shields.io/badge/TRT-9.1-green)](https://developer.nvidia.com/tensorrt) +[![version](https://img.shields.io/badge/release-0.5.0-green)](./setup.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) [Architecture](./docs/source/architecture.md)   |   [Results](./docs/source/performance.md)   |   [Examples](./examples/)   |   [Documentation](./docs/source/) diff --git a/benchmarks/python/mem_monitor.py b/benchmarks/python/mem_monitor.py index eb9353d5a..132e23c5b 100644 --- a/benchmarks/python/mem_monitor.py +++ b/benchmarks/python/mem_monitor.py @@ -18,7 +18,8 @@ def get_memory_info(handle): - mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle, + version=pynvml.nvmlMemory_v2) total = round(mem_info.total / 1024 / 1024 / 1024, 2) used = round(mem_info.used / 1024 / 1024 / 1024, 2) free = round(mem_info.used / 1024 / 1024 / 1024, 2) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ad30374a1..98ad371c7 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -22,6 +22,7 @@ include(CheckLanguage) include(cmake/modules/set_ifndef.cmake) include(cmake/modules/find_library_create_target.cmake) include(cmake/modules/resolve_dirs.cmake) +include(cmake/modules/parse_make_options.cmake) project(tensorrt_llm LANGUAGES CXX) @@ -246,6 +247,22 @@ endif() set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDAToolkit_INCLUDE_DIR}) message(STATUS "COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}") +if(NOT WIN32 AND NOT DEFINED USE_CXX11_ABI) + find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + execute_process( + COMMAND ${Python3_EXECUTABLE} "-c" + "import torch; print(torch.compiled_with_cxx11_abi(),end='');" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE USE_CXX11_ABI) + # Convert the bool variable to integer. + if(USE_CXX11_ABI) + set(USE_CXX11_ABI 1) + else() + set(USE_CXX11_ABI 0) + endif() + message(STATUS "USE_CXX11_ABI is set by python Torch to ${USE_CXX11_ABI}") +endif() + if(BUILD_PYT) # Build TORCH_CUDA_ARCH_LIST set(TORCH_CUDA_ARCH_LIST "") @@ -304,27 +321,39 @@ print(os.path.dirname(torch.__file__),end='');" message(STATUS "TORCH_CXX_FLAGS: ${TORCH_CXX_FLAGS}") add_compile_options(${TORCH_CXX_FLAGS}) add_compile_definitions(TORCH_CUDA=1) + + if(DEFINED USE_CXX11_ABI) + parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS") + if(DEFINED TORCH_CXX_FLAGS__GLIBCXX_USE_CXX11_ABI + AND NOT ${TORCH_CXX_FLAGS__GLIBCXX_USE_CXX11_ABI} EQUAL ${USE_CXX11_ABI}) + message( + WARNING + "The libtorch compilation options _GLIBCXX_USE_CXX11_ABI=${TORCH_CXX_FLAGS__GLIBCXX_USE_CXX11_ABI} " + "found by CMake conflict with the project setting USE_CXX11_ABI=${USE_CXX11_ABI}, and the project " + "setting will be discarded.") + endif() + endif() + +elseif(NOT WIN32) + if(NOT USE_CXX11_ABI) + add_compile_options("-D_GLIBCXX_USE_CXX11_ABI=0") + endif() + message(STATUS "Build without PyTorch, USE_CXX11_ABI=${USE_CXX11_ABI}") endif() file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" VERSION_STRINGS REGEX "#define NV_TENSORRT_.*") foreach(TYPE MAJOR MINOR PATCH BUILD) - string(REGEX MATCH "NV_TENSORRT_${TYPE} [0-9]" TRT_TYPE_STRING - ${VERSION_STRINGS}) - string(REGEX MATCH "[0-9]" TRT_${TYPE} ${TRT_TYPE_STRING}) -endforeach(TYPE) - -foreach(TYPE MAJOR MINOR PATCH) - string(REGEX MATCH "NV_TENSORRT_SONAME_${TYPE} [0-9]" TRT_TYPE_STRING + string(REGEX MATCH "NV_TENSORRT_${TYPE} [0-9]+" TRT_TYPE_STRING ${VERSION_STRINGS}) - string(REGEX MATCH "[0-9]" TRT_SO_${TYPE} ${TRT_TYPE_STRING}) + string(REGEX MATCH "[0-9]+" TRT_${TYPE} ${TRT_TYPE_STRING}) endforeach(TYPE) set(TRT_VERSION "${TRT_MAJOR}.${TRT_MINOR}.${TRT_PATCH}" CACHE STRING "TensorRT project version") set(TRT_SOVERSION - "${TRT_SO_MAJOR}" + "${TRT_MAJOR}" CACHE STRING "TensorRT library so version") message( STATUS diff --git a/cpp/cmake/modules/parse_make_options.cmake b/cpp/cmake/modules/parse_make_options.cmake new file mode 100644 index 000000000..7c0240afe --- /dev/null +++ b/cpp/cmake/modules/parse_make_options.cmake @@ -0,0 +1,28 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# + +function(parse_make_options options result) + foreach(option ${options}) + string(REGEX REPLACE "(-D|-)" "" option ${option}) + string(REPLACE "=" ";" option ${option}) + list(GET option 0 option_name) + list(GET option 1 option_value) + set(${result}_${option_name} + ${option_value} + PARENT_SCOPE) + endforeach() +endfunction() diff --git a/cpp/include/tensorrt_llm/batch_manager/GptManager.h b/cpp/include/tensorrt_llm/batch_manager/GptManager.h index 89d799445..945095fe6 100644 --- a/cpp/include/tensorrt_llm/batch_manager/GptManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/GptManager.h @@ -86,6 +86,7 @@ class GptManager std::shared_ptr mTrtGptModel; SizeType mMaxInputLen; SizeType mMaxOutputLen; + SizeType mMaxKvCacheLen; SizeType mMaxNumSequences; std::optional mTerminateReqId; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h index 2d159d7b4..ba8d41270 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h @@ -29,14 +29,17 @@ class KvCacheConfig public: using SizeType = tensorrt_llm::runtime::SizeType; - explicit KvCacheConfig( - std::optional maxTokens = std::nullopt, std::optional freeGpuMemoryFraction = std::nullopt) + explicit KvCacheConfig(std::optional maxTokens = std::nullopt, + std::optional maxKvCacheLength = std::nullopt, + std::optional freeGpuMemoryFraction = std::nullopt) : maxTokens{maxTokens} + , maxKvCacheLength{maxKvCacheLength} , freeGpuMemoryFraction{freeGpuMemoryFraction} { } std::optional maxTokens; + std::optional maxKvCacheLength; std::optional freeGpuMemoryFraction; static constexpr auto kDefaultGpuMemFraction = 0.85f; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 690d337ff..6bccd129b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -217,7 +217,7 @@ class KVCacheManager KVCacheManager(SizeType numLayers, SizeType numHeads, SizeType numKvHeads, SizeType hiddenSize, SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxBatchSize, SizeType maxBeamWidth, - SizeType maxBlocksPerSeq, nvinfer1::DataType dtype, CudaStreamPtr stream); + SizeType maxBlocksPerSeq, SizeType maxKvCacheLength, nvinfer1::DataType dtype, CudaStreamPtr stream); void startScheduling(); @@ -330,6 +330,9 @@ class KVCacheManager SizeType mMaxBeamWidth; // Maximum number of blocks per sequence SizeType mMaxBlocksPerSeq; + // Maximum kv cache length per sequence + // Enable cyclic kv cache when it exceeds + SizeType mMaxKvCacheLength; // Pools std::vector mPools; // Block manager diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h index 405f1705d..bcee6c037 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingInput.h +++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h @@ -29,9 +29,10 @@ class DecodingInput public: using TensorPtr = std::shared_ptr; - DecodingInput(SizeType maxLength, SizeType batchSize, TensorPtr logits, TensorPtr endIds) + DecodingInput(SizeType maxLength, SizeType maxKvCacheLength, SizeType batchSize, TensorPtr logits, TensorPtr endIds) : step{maxLength} , maxLength{maxLength} + , maxKvCacheLength{maxKvCacheLength} , batchSize{batchSize} , logits{std::move(logits)} , endIds{std::move(endIds)} @@ -43,6 +44,7 @@ class DecodingInput // mandatory parameters SizeType step; SizeType maxLength; + SizeType maxKvCacheLength; SizeType batchSize; TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu TensorPtr endIds; // [batchSize * beamWidth], on gpu diff --git a/cpp/include/tensorrt_llm/runtime/generationInput.h b/cpp/include/tensorrt_llm/runtime/generationInput.h index 840bc247a..8ca5bda05 100644 --- a/cpp/include/tensorrt_llm/runtime/generationInput.h +++ b/cpp/include/tensorrt_llm/runtime/generationInput.h @@ -54,7 +54,7 @@ class GenericGenerationInput bool packed; // indicates if ids are packed or padded to maxInputLength // optional parameters - TensorPtr embeddingBiasOpt; // [vocabSizePadded], on gpu + TensorPtr embeddingBias; // [vocabSizePadded], on gpu TensorPtr badWordsList; // [2, badWordsLength] or [batchSize, 2, badWordsLength], on gpu TensorPtr stopWordsList; // [batchSize, 2, stopWordsLength], on gpu std::optional maxNewTokens; // max number of tokens to generate diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatch.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatch.h index b28304877..7ea774880 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatch.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatch.h @@ -44,8 +44,8 @@ class GptDecoderBatch : public IGptDecoderBatch GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream); //! Setup the decoder before calling `forward()` - void setup( - SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype) override; + void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength, + nvinfer1::DataType dtype) override; //! @brief Initialize the decoder at `batchIdx` with a new `request`. void newRequest( @@ -166,6 +166,7 @@ class GptDecoderBatch : public IGptDecoderBatch std::vector mMaxNewTokens; std::vector mBeamWidths; SizeType mMaxSequenceLength{}; + SizeType mMaxKvCacheLength{}; SizeType mActualBatchSize{}; }; } // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/gptSession.h b/cpp/include/tensorrt_llm/runtime/gptSession.h index fc490e3d9..f6ef79145 100644 --- a/cpp/include/tensorrt_llm/runtime/gptSession.h +++ b/cpp/include/tensorrt_llm/runtime/gptSession.h @@ -140,10 +140,10 @@ class GptSession void createContexts(SizeType numBatchesCtx, SizeType numBatchesGen, bool useCudaGraphs); void createBuffers(SizeType numMicroBatches); - void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, + void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches); - void createKvCacheManager( - SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, KvCacheConfig const& config); + void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, + SizeType maxSequenceLength, KvCacheConfig const& config); void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength); void executeContextStep(std::vector const& microBatches, @@ -258,6 +258,7 @@ class GptSession std::vector> mIpcMemoryHandles; SizeType mDecoderMaxSequenceLength{}; + SizeType mDecoderMaxKvCacheLength{}; LoggerPtr mLogger; std::shared_ptr mRuntime; diff --git a/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h b/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h index 0f046b726..5f6697173 100644 --- a/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h @@ -73,8 +73,8 @@ class IStatefulGptDecoder using TensorPtr = std::shared_ptr; //! Setup the decoder before calling `forward()`, also calls reshapeBuffers - virtual void setup( - SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype) + virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, + SizeType maxSequenceLength, nvinfer1::DataType dtype) = 0; //! @brief Initialize the decoder with new batch of inputs. diff --git a/cpp/include/tensorrt_llm/runtime/iTensor.h b/cpp/include/tensorrt_llm/runtime/iTensor.h index a5847b4ef..fcd415549 100644 --- a/cpp/include/tensorrt_llm/runtime/iTensor.h +++ b/cpp/include/tensorrt_llm/runtime/iTensor.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,9 @@ class ITensor : virtual public IBuffer using UniqueConstPtr = std::unique_ptr; using SharedConstPtr = std::shared_ptr; using Shape = nvinfer1::Dims; + using DimType = std::remove_reference_t; + + ~ITensor() override = default; //! //! \brief Returns the tensor dimensions. @@ -59,7 +63,13 @@ class ITensor : virtual public IBuffer //! virtual void reshape(Shape const& dims) = 0; - ~ITensor() override = default; + void resize(std::size_t newSize) override + { + if (newSize == getSize()) + return; + + reshape(makeShape({castSize(newSize)})); + } //! //! \brief Not allowed to copy. @@ -101,18 +111,7 @@ class ITensor : virtual public IBuffer //! \param dim The dimension that should be removed ("squeezed"). //! \return A new shape without the unit dimension. //! - static Shape squeeze(Shape const& shape, SizeType dim) - { - TLLM_CHECK_WITH_INFO(shape.nbDims > 0, "Cannot squeeze 1-dimensional tensor"); - TLLM_CHECK_WITH_INFO( - dim < shape.nbDims, common::fmtstr("Invalid index %d, tensor has %d dimensions", dim, shape.nbDims)); - TLLM_CHECK_WITH_INFO(shape.d[dim] == 1, "Can only squeeze dimension of size 1"); - - Shape newDims{shape.nbDims - 1}; - std::copy(shape.d, shape.d + dim, newDims.d); - std::copy(shape.d + dim + 1, shape.d + shape.nbDims, newDims.d + dim); - return newDims; - } + static Shape squeeze(Shape const& shape, SizeType dim); //! //! \brief Add a *unit* dimension to `shape` at the specified position. @@ -121,17 +120,7 @@ class ITensor : virtual public IBuffer //! \param dim The dimension where unit dimension should be added. //! \return A new shape with the added unit dimension. //! - static Shape unsqueeze(Shape const& shape, SizeType dim) - { - TLLM_CHECK_WITH_INFO(dim <= shape.nbDims && dim >= 0, - common::fmtstr("Invalid dim %d, tensor has %d dimensions", dim, shape.nbDims)); - - Shape newDims{shape.nbDims + 1}; - std::copy(shape.d, shape.d + dim, newDims.d); - newDims.d[dim] = 1; - std::copy(shape.d + dim, shape.d + shape.nbDims, newDims.d + dim + 1); - return newDims; - } + static Shape unsqueeze(Shape const& shape, SizeType dim); //! //! \brief Removes the given *unit* dimensions from this tensor. @@ -251,6 +240,13 @@ class ITensor : virtual public IBuffer protected: ITensor() = default; + + static DimType castSize(size_t newSize) + { + TLLM_CHECK_WITH_INFO( + newSize <= std::numeric_limits::max(), "New size is too large. Use reshape() instead."); + return static_cast(newSize); + } }; //! \brief Utility function to print a shape. diff --git a/cpp/include/tensorrt_llm/runtime/promptTuningParams.h b/cpp/include/tensorrt_llm/runtime/promptTuningParams.h index 3690165f5..99c233b84 100644 --- a/cpp/include/tensorrt_llm/runtime/promptTuningParams.h +++ b/cpp/include/tensorrt_llm/runtime/promptTuningParams.h @@ -71,7 +71,7 @@ class PromptTuningParams : public GenericPromptTuningParams // Function assumes that the first numContextRequests requests in the batch are context requests void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests, const std::vector& reqBeamWidths, const std::vector& reqPromptLengths, - BufferManager& manager, bool packedInput); + BufferManager const& manager, bool packedInput); }; } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index 2b3796033..0720649d2 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -84,14 +84,6 @@ if(BUILD_BATCH_MANAGER) else() add_library(${BATCH_MANAGER_TARGET} STATIC IMPORTED) if(NOT WIN32) # Linux - execute_process( - COMMAND ${Python3_EXECUTABLE} "-c" - "import torch; print(torch.compiled_with_cxx11_abi(),end='');" - RESULT_VARIABLE _PYTHON_SUCCESS - OUTPUT_VARIABLE USE_CXX11_ABI) - - message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}") - if(USE_CXX11_ABI) set(BATCH_MANAGER_LIB_LOC "${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/${BATCH_MANAGER_TARGET_ARCH}/libtensorrt_llm_batch_manager_static.a" diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a deleted file mode 100644 index 6131fa3c3..000000000 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f591dd181613b14f7ded3ba3e167d14073564254bc46db8c4bd9636d6d896b16 -size 1611436 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a deleted file mode 100644 index 138428aad..000000000 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:21d17a9fa736d033ad77270a0fbcdd09c27dfab3f871d92a5ffa0cb744fa48fd -size 1623126 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt deleted file mode 100644 index 8b007588a..000000000 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ /dev/null @@ -1,3 +0,0 @@ -e1dc326c0c45864b9e7963b4d92d322f libtensorrt_llm_batch_manager_static.a -d2e9d76efe6b4173270aa6b494dfe59c libtensorrt_llm_batch_manager_static.pre_cxx11.a -07363ea7a6fdd6eeedc1670dedeeaedff7f9a848 commit diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index f30db7d14..e5ed0eca9 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3fe444bf079ce35262b932302806b372ccb677182969e3bba45698343e5e350f -size 1523444 +oid sha256:abdce9bc64cecddb39ed14809eefc8bcf7164524a6dd20ec7c8167229f3c22a3 +size 1557782 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 130b4932c..6ca65959a 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:99641389fdf26f6324b7465df0b61b74946787a6a147d145de23b444261e6e5f -size 1524188 +oid sha256:a9109b506e993a041ea238f992bec2a5064dffd9c0a7af10cca0d4d96c5047a9 +size 1557482 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt index 7bf295098..bc433dbfe 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -b10b0e00d0132b04969d779af45d73d0 libtensorrt_llm_batch_manager_static.a -3ad06255afdaa8450c133d1d1bc486c4 libtensorrt_llm_batch_manager_static.pre_cxx11.a +25d1ebdd5977208c25023329c621e970 libtensorrt_llm_batch_manager_static.a +5cb1a7a13db34fcaee6b89fcdc1212ce libtensorrt_llm_batch_manager_static.pre_cxx11.a diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp index b9607d2de..44d130cf3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp @@ -99,8 +99,8 @@ class FusedMHARunnerV2::mhaImpl ~mhaImpl() {} - void setup(const int b, const int s, const int total_seqlen, const bool has_alibi, const bool scale_alibi, - const int tp_size, const int tp_rank) + void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi, + const bool scale_alibi, const int tp_size, const int tp_rank) { const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling)); // Note that we apply scales and bias in the order of @@ -170,11 +170,19 @@ class FusedMHARunnerV2::mhaImpl launch_params.use_tma = true; } + // alibi. if (has_alibi) { params.has_alibi = true; params.alibi_params = AlibiParams(mNumHeads, s, tp_size, tp_rank, scale_after_alibi); } + + // Sliding_window_causal mask. + if (s > sliding_window_size && launch_params.attention_mask_type == ContextAttentionMaskType::CAUSAL) + { + params.sliding_window_size = sliding_window_size; + launch_params.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL; + } } // NOTE: assume that heads_interleaved = false (b, s, 3, h, d), and sequences are padded/non-padded @@ -271,8 +279,6 @@ class FusedMHARunnerV2::mhaImpl { // BF16 FMHA only accumulates on FP32 launch_params.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc; - // sliding_window_causal is disabled temporally. - // TODO (perkzz): It will be enabled when the sliding window attention is fully supported. launch_params.attention_mask_type = causal_mask ? ContextAttentionMaskType::CAUSAL : ContextAttentionMaskType::PADDING; params.h_kv = num_kv_heads; @@ -360,10 +366,10 @@ FusedMHARunnerV2::FusedMHARunnerV2( FusedMHARunnerV2::~FusedMHARunnerV2() = default; -void FusedMHARunnerV2::setup(const int b, const int s, const int total_seqlen, const bool has_alibi, - const bool scale_alibi, const int tp_size, const int tp_rank) +void FusedMHARunnerV2::setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, + const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank) { - pimpl->setup(b, s, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank); + pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank); } bool FusedMHARunnerV2::fmha_supported() diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h index 0e9c721b4..008d1cc64 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h @@ -47,8 +47,8 @@ class MHARunner virtual ~MHARunner() = default; - virtual void setup(const int b, const int s, const int total_seqlen, const bool has_alibi = false, - const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0) + virtual void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, + const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0) = 0; static bool fmha_supported(const int headSize, const int sm); @@ -80,8 +80,9 @@ class FusedMHARunnerV2 : public MHARunner ~FusedMHARunnerV2(); // for pimpl - void setup(const int b, const int s, const int total_seqlen, const bool has_alibi = false, - const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0) override; + void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, + const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, + const int tp_rank = 0) override; bool fmha_supported() override; diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h index 27dd70d36..7fdc2b011 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h @@ -49,8 +49,7 @@ enum class ContextAttentionMaskType { PADDING, CAUSAL, - // The past attention length is limited. - LIMITED_LENGTH_CAUSAL + SLIDING_WINDOW_CAUSAL }; constexpr int32_t kSM_70 = 70; diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.h index 226279959..91d52cf3c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.h @@ -282,6 +282,7 @@ class FusedMultiHeadAttentionXMMAKernelV2 const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; const CUfunction func = findIter->second.mDeviceFunction; + void* kernelParams[] = {¶ms, nullptr}; if (!forceUnroll) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h index 5a82561fe..ebba4f39b 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h @@ -102,9 +102,12 @@ struct Multihead_attention_params_base int batch_size = 0; // The beam width int beam_width = 0; - // The sequence length. - // TODO: change name max_seq_len - int memory_max_len = 0; + // By default, max_kv_cache_length == cyclic_kv_cache_length + // unless each layer has different cyclic kv cache length. + // Max cache capacity (used to allocate KV cache) + int max_kv_cache_length = 0; + // Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens) + int cyclic_kv_cache_length = 0; // The number of heads (H). int num_heads = 0; // Controls MHA/MQA/GQA @@ -148,7 +151,7 @@ struct Multihead_attention_params_base bool fp8_kv_cache = false; // Multi-block setups - bool multi_block_mode = false; + mutable bool multi_block_mode = false; // Number of streaming processors on the device. // Tune block size to maximum occupancy. diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h index 775ce9db9..abb0f5d83 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h @@ -44,8 +44,8 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params::Type; // The amount of shared memory needed to store the Q*K^T values in float. const int max_timesteps = DO_CROSS_ATTENTION - ? params.memory_max_len - : min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.memory_max_len); + ? params.cyclic_kv_cache_length + : min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.cyclic_kv_cache_length); const auto qk_elts = static_cast(divUp(max_timesteps + 1, 4)); // explicit cast because of the sign const auto qk_sz = qk_elts * 16; @@ -90,29 +90,31 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params -inline size_t multi_block_grid_setup(const Multihead_attention_params& params, - int threads_per_block, int tlength, bool do_multi_block) +inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params& params, + int blocks_per_sm, int block_size, int tlength, bool do_multi_block) { if (!do_multi_block) { - return 1; + params.multi_block_mode = false; + return; } - auto constexpr threads_per_value = mmha::threads_per_value(mmha::dh_max(Dh)); + params.seq_len_tile + = mmha::divUp(params.multi_processor_count * blocks_per_sm, params.batch_size * params.num_heads); - // Make sure: seq_len_tile * threads_per_value <= threads_per_block (for multi_block_mode) - params.seq_len_tile = std::floor(threads_per_block / threads_per_value); + const int threads_per_value = mmha::threads_per_value(mmha::dh_max(Dh)); + // Make sure that each block at least processes one loop of kv (unroll size is default at 8). + const int seq_len_per_kv_loop = mmha::divUp(block_size, threads_per_value) * 8; + const int max_seq_len_tile = std::min(mmha::divUp(tlength + 1, seq_len_per_kv_loop), params.max_seq_len_tile); - assert(params.seq_len_tile <= params.max_seq_len_tile); + params.seq_len_tile = std::min(params.seq_len_tile, max_seq_len_tile); - params.timesteps_per_block = mmha::divUp(tlength, params.seq_len_tile); + // We should consider the new timestep. + params.timesteps_per_block = mmha::divUp(tlength + 1, params.seq_len_tile); -#ifndef ENABLE_MULTI_BLOCK_OPTION - do_multi_block = false; -#endif + params.multi_block_mode = (params.seq_len_tile > 1); - // Return the sequence length tile if using multi block modes. - return params.seq_len_tile; + grid.z = params.seq_len_tile; } #define MMHA_LAUNCH_CHECK(DYNAMIC_THDS_PER_BLOCK) \ @@ -121,58 +123,58 @@ inline size_t multi_block_grid_setup(const Multihead_attention_params= 46 * 1024) \ { \ - cudaError_t res = cudaFuncSetAttribute(mmha::masked_multihead_attention_kernel, \ + cudaError_t res = cudaFuncSetAttribute( \ + mmha::masked_multihead_attention_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \ TLLM_CHECK_WITH_INFO( \ res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \ } \ TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&available_blocks, \ - mmha::masked_multihead_attention_kernel, \ + mmha::masked_multihead_attention_kernel, \ DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz)); -#define MMHA_KERNEL(DYNAMIC_THDS_PER_BLOCK) \ +#define MMHA_KERNEL(DYNAMIC_THDS_PER_BLOCK, ENABLE_MULTI_BLOCK) \ std::size_t const dynamic_smem_sz{ \ - mmha::smem_size_in_bytes(params, DYNAMIC_THDS_PER_BLOCK)}; \ + mmha::smem_size_in_bytes(params, DYNAMIC_THDS_PER_BLOCK)}; \ /* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */ \ if (dynamic_smem_sz >= 46 * 1024) \ { \ cudaError_t res = cudaFuncSetAttribute( \ mmha::masked_multihead_attention_kernel, \ + KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \ TLLM_CHECK_WITH_INFO( \ res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \ } \ mmha::masked_multihead_attention_kernel \ + KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK> \ <<>>(params, kv_cache_buffer); // if resources are not enough to launch 512 threads per block, we will fallback to 256. -#define MMHA_LAUNCH_512_BLOCKSIZE() \ - int available_blocks = -1; \ +#define MMHA_512_BLOCKSIZE_CHECK() \ MMHA_LAUNCH_CHECK(512); \ if (available_blocks <= 0) \ { \ - MMHA_KERNEL(256); \ + MMHA_LAUNCH_CHECK(256); \ + dynamic_block_size = 256; \ } \ else \ { \ - MMHA_KERNEL(512); \ + dynamic_block_size = 512; \ } // if resources are not enough to launch 1024 threads per block, we will fallback to 512. -#define MMHA_LAUNCH_1024_BLOCKSIZE() \ - int available_blocks = -1; \ +#define MMHA_1024_BLOCKSIZE_CHECK() \ MMHA_LAUNCH_CHECK(1024); \ - if (available_blocks <= 0) \ + if (available_blocks > 0) \ { \ - MMHA_LAUNCH_512_BLOCKSIZE(); \ + dynamic_block_size = 1024; \ } \ else \ { \ - MMHA_KERNEL(1024); \ + MMHA_512_BLOCKSIZE_CHECK(); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -182,55 +184,83 @@ template ( - params, THDS_PER_BLOCK, tlength, DO_MULTI_BLOCK)}; - dim3 grid{static_cast(params.num_heads), static_cast(params.batch_size), - static_cast(seq_len_tile)}; + dim3 grid{static_cast(params.num_heads), static_cast(params.batch_size), 1}; - if (DO_MULTI_BLOCK) + const int kernel_total_blocks = params.batch_size * params.num_heads; + // Don't tune the block size if batchxhead is large enough. + // The max number of warps we can launch per SM is 32 limited by registers. + if (kernel_total_blocks >= params.multi_processor_count * 4) { - MMHA_KERNEL(THDS_PER_BLOCK); + MMHA_KERNEL(THDS_PER_BLOCK, false); + return; } - else + + // Tune block size based on batchxhead to increase occupancy. + int num_blocks_per_sm = -1; + // Set 0 dynamic shared memory size as we need the number of available blocks limited by registers. + // Dynamic shared memory is fixed for different block size. + TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, + mmha::masked_multihead_attention_kernel, + THDS_PER_BLOCK, 0)); + + int block_size_factor + = min(mmha::divUp(params.multi_processor_count * num_blocks_per_sm, kernel_total_blocks), num_blocks_per_sm); + // Max block size is 1024. + int dynamic_block_size = min(THDS_PER_BLOCK * block_size_factor, 1024); + + // Check if resources are enough for launch. + int available_blocks = -1; + if (dynamic_block_size < 512) { - const int kernel_total_blocks = params.batch_size * params.num_heads; - // Don't tune the block size if batchxhead is large enough. - // The max number of warps we can launch per SM is 32 limited by registers. - if (kernel_total_blocks >= params.multi_processor_count * 4) + MMHA_LAUNCH_CHECK(256); + dynamic_block_size = 256; + } + else if (dynamic_block_size < 1024) + { + MMHA_512_BLOCKSIZE_CHECK(); + } + else if (dynamic_block_size == 1024) + { + MMHA_1024_BLOCKSIZE_CHECK(); + } + + // If blocks with larger block size already fill all SMs, then disable the multi blocks mode. + mmha::multi_block_grid_setup(grid, params, dynamic_block_size, available_blocks, tlength, DO_MULTI_BLOCK); + + // Launch kernels based on the valid block size. + switch (dynamic_block_size) + { + case 256: + if (params.multi_block_mode) { - MMHA_KERNEL(THDS_PER_BLOCK); - return; + MMHA_KERNEL(256, true); } - - // Tune block size based on batchxhead to increase occupancy. - int num_blocks_per_sm = -1; - // Set 0 dynamic shared memory size as we need the number of available blocks limited by registers. - // Dynamic shared memory is fixed for different block size. - TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, - mmha::masked_multihead_attention_kernel, - THDS_PER_BLOCK, 0)); - - int block_size_factor = min( - mmha::divUp(params.multi_processor_count * num_blocks_per_sm, kernel_total_blocks), num_blocks_per_sm); - // Max block size is 1024. - const int dynamic_block_size = min(THDS_PER_BLOCK * block_size_factor, 1024); - - // Make sure number of threads per block is power of 2. - if (dynamic_block_size <= 256) + else + { + MMHA_KERNEL(256, false); + } + break; + case 512: + if (params.multi_block_mode) + { + MMHA_KERNEL(512, true); + } + else { - MMHA_KERNEL(256); + MMHA_KERNEL(512, false); } - else if (dynamic_block_size <= 512) + break; + case 1024: + if (params.multi_block_mode) { - // Check if the kernel with new block size can be launched in terms of resources. - MMHA_LAUNCH_512_BLOCKSIZE(); + MMHA_KERNEL(1024, true); } - else if (dynamic_block_size <= 1024) + else { - // Check if the kernel with new block size can be launched in terms of resources. - MMHA_LAUNCH_1024_BLOCKSIZE(); + MMHA_KERNEL(1024, false); } + break; } } @@ -263,23 +293,15 @@ void mmha_launch_kernel_dispatch( const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream) { int const tlength = params.timestep; - if (tlength < 1024) + if (params.multi_block_mode) { - mmha_launch_kernel_dispatch_8bits_kv_cache( + mmha_launch_kernel_dispatch_8bits_kv_cache( params, kv_cache_buffer, stream, tlength); } else { - if (params.multi_block_mode) - { - mmha_launch_kernel_dispatch_8bits_kv_cache( - params, kv_cache_buffer, stream, tlength); - } - else - { - mmha_launch_kernel_dispatch_8bits_kv_cache( - params, kv_cache_buffer, stream, tlength); - } + mmha_launch_kernel_dispatch_8bits_kv_cache( + params, kv_cache_buffer, stream, tlength); } } diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index ab7db6ff2..df47baa75 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -1244,10 +1244,10 @@ template < // The number of threads per value. unsigned THREADS_PER_VALUE = threads_per_value(dh_max(Dh)), // The unroll factor for loading from K cache. - unsigned K_LOOP_UNROLL = 8, - // The unroll factor for loading from V cache. // Set it default to 4 for higher occupancy (by reducing registers usage). - unsigned V_LOOP_UNROLL = 4> + unsigned K_LOOP_UNROLL = 4, + // The unroll factor for loading from V cache. + unsigned V_LOOP_UNROLL = 8> __global__ void masked_multihead_attention_kernel( Multihead_attention_params params, KVCacheBuffer kvCacheBuffer) { @@ -1271,10 +1271,11 @@ __global__ void masked_multihead_attention_kernel( static_assert(Dh_MAX >= WARP_SIZE); static_assert(Dh_MAX >= Dh); - // The maximum sequence length in the kv_cache, i.e., an upper bound on L. + // The maximum sequence length in the cyclic kv_cache, i.e., an upper bound on L. // Note that the maximum sequence length supported by the model might be greater than this. - const auto max_seq_len = static_cast(params.memory_max_len); - assert(max_seq_len > 0); + // Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers. + // By default, you can assume that they are the same. + const auto cyclic_kv_cache_len = static_cast(params.cyclic_kv_cache_length); // The current timestep (including paddings). // It is only used to calculate the smem stride. const auto timestep = static_cast(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep); @@ -1298,8 +1299,7 @@ __global__ void masked_multihead_attention_kernel( #ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS if (sizeof(Tk) != 4) { - // TODO - change to tlength - const auto max_timesteps = DO_CROSS_ATTENTION ? max_seq_len : min(timestep, max_seq_len); + const auto max_timesteps = DO_CROSS_ATTENTION ? cyclic_kv_cache_len : min(timestep, cyclic_kv_cache_len); logits_smem_ += divUp(max_timesteps + 1, 4u) * 16; } Tk* logits_smem = reinterpret_cast(logits_smem_); @@ -1345,6 +1345,7 @@ __global__ void masked_multihead_attention_kernel( // Use alignment for safely casting the shared buffers as Qk_vec_k and K_vec_k. // Shared memory to store Q inputs. __shared__ __align__(mmha::const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk q_smem[Dh_MAX]; + __shared__ __align__(mmha::const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk k_smem[Dh_MAX]; // Make sure the hidden dimension per head is a multiple of the number of threads per value. static_assert(Dh_MAX % THREADS_PER_VALUE == 0); // trivially satisfied since THREADS_PER_VALUE == Dh_MAX / p @@ -1420,8 +1421,17 @@ __global__ void masked_multihead_attention_kernel( const int tlength = DO_CROSS_ATTENTION ? params.memory_length_per_sample[batch_beam_idx] - 1 : (params.length_per_sample ? (params.length_per_sample[batch_beam_idx] - 1) : static_cast(timestep)); + // We will use cyclic kv cache when it exceeds the limit. + // The length position for storing new key and value. + const int cyclic_tlength = tlength % cyclic_kv_cache_len; + // The actual kv cache length. + // tlength is the past length actually. + const int kv_loop_length = min(tlength, cyclic_kv_cache_len); // The context length for beam searching optimization (all points to beam 0). - const int input_length = params.input_lengths[batch_beam_idx]; + // TODO: with cyclic kv cache, we set it 0 for now (will optimize in the future) + // as context kv cache might be overwritten by the new kv cache + const int beam0_context_length + = HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx]; // The offset in the Q and K buffer also accounts for the batch. const auto qk_vec_idx = tidx * QK_VEC_SIZE; @@ -1474,8 +1484,8 @@ __global__ void masked_multihead_attention_kernel( if constexpr (DO_CROSS_ATTENTION) { const auto k_idx = QK_VEC_SIZE * tidx; - const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tlength, hi, Dh, k_idx); - Tcache* k_cache = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_beam_idx, tlength)); + const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi, Dh, k_idx); + Tcache* k_cache = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength)); k = vec_conversion(*reinterpret_cast(&k_cache[inBlockIdx])); } @@ -1572,7 +1582,7 @@ __global__ void masked_multihead_attention_kernel( const bool do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; T* q_smem_ = reinterpret_cast(smem_); - T* k_smem = q_smem_ + params.rotary_embedding_dim; + T* k_smem_ = q_smem_ + params.rotary_embedding_dim; const int half_rotary_dim = params.rotary_embedding_dim / 2; const int half_idx = qk_vec_idx / half_rotary_dim; @@ -1586,7 +1596,7 @@ __global__ void masked_multihead_attention_kernel( *reinterpret_cast(q_smem_ + half_idx * smem_pitch + intra_half_idx) = q; if (HANDLE_KV) { - *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + *reinterpret_cast(k_smem_ + half_idx * smem_pitch + intra_half_idx) = k; } } @@ -1599,12 +1609,12 @@ __global__ void masked_multihead_attention_kernel( mmha::vec_from_smem_transpose(q, q_smem_, transpose_idx, smem_pitch); if (HANDLE_KV) { - mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + mmha::vec_from_smem_transpose(k, k_smem_, transpose_idx, smem_pitch); mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, tlength); - mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + mmha::write_smem_transpose(k, k_smem_, transpose_idx, smem_pitch); } else { @@ -1621,7 +1631,7 @@ __global__ void masked_multihead_attention_kernel( q = *reinterpret_cast(q_smem_ + half_idx * smem_pitch + intra_half_idx); if (HANDLE_KV) { - k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + k = *reinterpret_cast(k_smem_ + half_idx * smem_pitch + intra_half_idx); } } @@ -1631,7 +1641,7 @@ __global__ void masked_multihead_attention_kernel( } // For the same reason as HANDLE_KV, no compute needed in Cross-Attention's 1st step - + // Store Q K vectors to shared memory, and calculate QK. if (qk_vec_idx < Dh_MAX) { @@ -1658,31 +1668,10 @@ __global__ void masked_multihead_attention_kernel( reinterpret_cast(&q_smem[qk_vec_idx])[0] = is_valid_qk_vec ? q : zero_q; } - // Write the K values to the global memory cache. - // - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. - - // For MQA/GQA mode, write only with the first Q head of each group per KV head. - if (HANDLE_KV && hi == (hi_kv * qhead_per_kv) && (IS_Dh_MAX || is_valid_qk_vec)) - { - // Trigger the stores to global memory. - const auto k_idx = QK_VEC_SIZE * tidx; - const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tlength, hi_kv, Dh, k_idx); - // The base pointer for the value in the cache buffer. - Tcache* k_cache = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_beam_idx, tlength)); - - if constexpr (ENABLE_8BITS_CACHE) - { - store_8bits_kv_cache_vec(reinterpret_cast(k_cache), k, inBlockIdx, kv_scale_orig_quant); - } - else - { - *reinterpret_cast(&k_cache[inBlockIdx]) = vec_conversion(k); - } - } + // Store the K values to shared memory. + // We store K values from shared memory to global memory + // when the target position of K cache in global memory has been accessed (in the case of cyclic kv cache) + reinterpret_cast(&k_smem[qk_vec_idx])[0] = k; // Compute \sum_i Q[i] * K^T[i] for the current timestep. qk = dot(q, k); @@ -1736,7 +1725,9 @@ __global__ void masked_multihead_attention_kernel( } else { - qk_smem[tlength] = qk; + // We need to store the qk result to the end of the qk_smem for cyclic kv cache (+ 1 for smem memory + // allocation) because the previous cache will still write to the new_cache_pos of qk_smem. + qk_smem[kv_loop_length] = qk; } } @@ -1778,17 +1769,22 @@ __global__ void masked_multihead_attention_kernel( // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible. - const int context_length = HAS_BEAMS ? input_length : tlength; + const int context_length = HAS_BEAMS ? beam0_context_length : kv_loop_length; const auto context_ti_end = MULTI_BLOCK_FLAG ? divUp(timesteps_per_block, UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP : divUp(static_cast(context_length), UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP; // The generation ti_end. - const auto generation_ti_end = MULTI_BLOCK_FLAG ? divUp(timesteps_per_block, K_PER_WARP) * K_PER_WARP - : divUp(static_cast(tlength), K_PER_WARP) * K_PER_WARP; + const auto generation_ti_end = MULTI_BLOCK_FLAG + ? divUp(timesteps_per_block, K_PER_WARP) * K_PER_WARP + : divUp(static_cast(kv_loop_length), K_PER_WARP) * K_PER_WARP; // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. - const auto bi_seq_len_offset = static_cast(batch_beam_idx) * max_seq_len; + // Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers. + // By default, you can assume that they are the same. + const auto bi_seq_len_offset = static_cast(batch_beam_idx) * params.max_kv_cache_length; + // Beam indices are based on the max_kv_cache_length while each layer may have different cyclic_kv_cache_length + // So we need to rebuild the beam_indices if max_kv_cache_length is not equal to cyclic_kv_cache_length. const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; const auto c_tile_times_timesteps_per_block = c_tile * timesteps_per_block; // 0 if !MULTI_BLOCK_FLAG @@ -1940,11 +1936,11 @@ __global__ void masked_multihead_attention_kernel( } // Handle generation key cache with beam searching. - // Note that it may be overlapped with the context key loop, but it won't impact the correctness. - if (HAS_BEAMS && (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > input_length)) + // Note that it may be overlapped with the context key loop, but it won't impact the corretness. + if (HAS_BEAMS && (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length)) { // The input length; - const int input_length_ = MULTI_BLOCK_FLAG ? input_length % timesteps_per_block : input_length; + const int input_length_ = MULTI_BLOCK_FLAG ? beam0_context_length % timesteps_per_block : beam0_context_length; // The beginning of the generation. const int generation_start_ti = k_idx.x + input_length_ / K_PER_WARP * K_PER_WARP; @@ -1960,7 +1956,7 @@ __global__ void masked_multihead_attention_kernel( for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i) { const int jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE); - const int valid_time_now = min(time_now, tlength - 1); + const int valid_time_now = min(time_now, kv_loop_length - 1); int beam_offset = beam_indices[valid_time_now]; const int seqIdx = batch_idx * beam_width + beam_offset; // Base pointer to k cache block for beam's batch, before offsetting with indirection buffer @@ -1971,7 +1967,7 @@ __global__ void masked_multihead_attention_kernel( } // Is it active? - const bool is_active = time_now >= input_length && time_now < tlength; + const bool is_active = time_now >= context_length && time_now < kv_loop_length; if (implicit_rel_attn_bias) { @@ -2092,6 +2088,34 @@ __global__ void masked_multihead_attention_kernel( // Make sure the products are in shared memory. __syncthreads(); + // After the syncthreads, the target k position (cyclic kv cache) should also have been used by the k loop. + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // For MQA/GQA mode, write only with the first Q head of each group per KV head. + if (HANDLE_KV && hi == (hi_kv * qhead_per_kv) && qk_vec_idx < Dh) + { + // Trigger the stores to global memory. + Qk_vec_k k_vec = *reinterpret_cast(&k_smem[qk_vec_idx]); + const auto k_idx = QK_VEC_SIZE * tidx; + const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, k_idx); + // The base pointer for the value in the cache buffer. + Tcache* k_cache = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength)); + + if constexpr (ENABLE_8BITS_CACHE) + { + store_8bits_kv_cache_vec(reinterpret_cast(k_cache), k_vec, inBlockIdx, kv_scale_orig_quant); + } + else + { + *reinterpret_cast(&k_cache[inBlockIdx]) = vec_conversion(k_vec); + } + } + // The warps finalize the reduction. qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; #pragma unroll @@ -2107,7 +2131,7 @@ __global__ void masked_multihead_attention_kernel( float sum = 0.f; // Each thread will handle one float (either qk_smem/logit). - const int logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength; + const int logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; for (int ti = tidx; ti <= logit_loop_end; ti += THREADS_PER_BLOCK) { @@ -2123,13 +2147,13 @@ __global__ void masked_multihead_attention_kernel( else { // Not supported yet: multi-block mode with FP8_MHA - if (time_now < tlength && ti != timesteps_per_block) + if (time_now < kv_loop_length && ti != timesteps_per_block) { float logit = __expf(qk_smem[ti] - qk_max); sum += logit; qk_smem[ti] = logit; } - else if (time_now == tlength) + else if (time_now == kv_loop_length) { float logit = __expf(qk_current_smem[0] - qk_max); sum += logit; @@ -2149,7 +2173,7 @@ __global__ void masked_multihead_attention_kernel( #endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V float inv_sum = __fdividef(logit_scale, sum + 1.e-6f); - const int normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength; + const int normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK) { const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; @@ -2161,11 +2185,11 @@ __global__ void masked_multihead_attention_kernel( else { // no scaling factor inv_sum applied here, will apply the scaling factor after all blocks finished - if (time_now < tlength && ti != timesteps_per_block) + if (time_now < kv_loop_length && ti != timesteps_per_block) { convert_from_float(&logits_smem[ti], qk_smem[ti]); } - else if (time_now == tlength) + else if (time_now == kv_loop_length) { convert_from_float(&logits_current_smem[0], qk_current_smem[0]); } @@ -2198,7 +2222,7 @@ __global__ void masked_multihead_attention_kernel( V_vec_k v_bias; zero(v_bias); // if( vo == params.timestep % V_PER_ITER ) { - if (is_valid_vi && HANDLE_KV && vo == tlength % V_PER_ITER) + if (is_valid_vi && HANDLE_KV && vo == kv_loop_length % V_PER_ITER) { // Trigger the loads from the V bias buffer. if (params.v_bias != nullptr) @@ -2236,9 +2260,9 @@ __global__ void masked_multihead_attention_kernel( // Handle both context and generation value cache without beam searching. // Explicit batching of LDGs (by V_LOOP_UNROLL) as it doesn't depend on indirection tables. // Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible. - const int context_length = HAS_BEAMS ? input_length : tlength; + const int context_length = HAS_BEAMS ? beam0_context_length : kv_loop_length; int context_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : context_length; - int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength; + int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; for (int ti = vo; ti < context_v_loop_end; ti += UNROLLED_V_PER_ITER) { V_vec_m v_vec_cache[V_LOOP_UNROLL]; @@ -2247,7 +2271,7 @@ __global__ void masked_multihead_attention_kernel( { // Fetch offset based on cache_indir when beam sampling int time_idx = ti + v_loop * V_PER_ITER + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0); - time_idx = min(time_idx, tlength - 1); + time_idx = min(time_idx, kv_loop_length - 1); int rowIdx = batch_idx * beam_width; const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi); @@ -2278,16 +2302,17 @@ __global__ void masked_multihead_attention_kernel( // Handle generation value cache with beam searching. if (HAS_BEAMS) { - const auto generation_start_ti = MULTI_BLOCK_FLAG ? vo : (vo + (input_length / V_PER_ITER) * V_PER_ITER); + const auto generation_start_ti + = MULTI_BLOCK_FLAG ? vo : (vo + (beam0_context_length / V_PER_ITER) * V_PER_ITER); // Only the last few blocks need to handle the generation value cache. - if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > input_length) + if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length) { for (int ti = generation_start_ti; ti < generation_v_loop_end; ti += V_PER_ITER) { // Fetch offset based on cache_indir when beam sampling int time_idx = ti + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0); int local_time_idx = ti; - if (time_idx < input_length || (MULTI_BLOCK_FLAG && time_idx >= tlength)) + if (time_idx < beam0_context_length || (MULTI_BLOCK_FLAG && time_idx >= kv_loop_length)) { continue; } @@ -2307,13 +2332,16 @@ __global__ void masked_multihead_attention_kernel( } } + // Make sure we can overwrite the v cache if using cyclic kv cache. + __syncthreads(); + // Get the c_tile_id that handles the current timestep. const int ctile_idx = tlength / timesteps_per_block; // One group of threads computes the product(s) for the current timestep. - if (vo == tlength % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx))) + if (vo == kv_loop_length % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx))) { - const int tokenIdx = tlength; + const int tokenIdx = cyclic_tlength; const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tokenIdx, hi_kv, Dh, vi); // The base pointer for the value in the cache buffer. Tcache* v_cache_base = reinterpret_cast(kvCacheBuffer.getBlockPtr(v_cache_base_row_ptr, tokenIdx)); @@ -2380,7 +2408,7 @@ __global__ void masked_multihead_attention_kernel( // out = fma(logits_smem[params.timestep], cast_to_float(v), out); if (!MULTI_BLOCK_FLAG) { - out = fma(logits_smem[tlength], cast_to_float(v), out); + out = fma(logits_smem[kv_loop_length], cast_to_float(v), out); } else { @@ -2390,7 +2418,7 @@ __global__ void masked_multihead_attention_kernel( // out = fma(logits_smem[params.timestep], v, out); if (!MULTI_BLOCK_FLAG) { - out = fma(logits_smem[tlength], v, out); + out = fma(logits_smem[kv_loop_length], v, out); } else { // MULTI_BLOCK_FLAG // Not supported yet: multi-block mode with FP8_MHA diff --git a/cpp/tensorrt_llm/kernels/gptKernels.cu b/cpp/tensorrt_llm/kernels/gptKernels.cu index 01af8b0af..e341a03de 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.cu +++ b/cpp/tensorrt_llm/kernels/gptKernels.cu @@ -133,8 +133,8 @@ __global__ void computePaddingOffsets(int* paddingOffsets, const int* seqOffsets // This kernel computes the attention mask. We must compute this on-the-fly in the future. template -__global__ void computeAttentionMask( - AttentionMaskDataType* attentionMask, const int* seqOffsets, int maxSeqLength, AttentionMaskType attentionMaskType) +__global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const int* seqOffsets, int maxSeqLength, + int maxKvCacheLength, AttentionMaskType attentionMaskType) { // The index of the sequence in the batch. int batchIdx = blockIdx.y; @@ -173,12 +173,22 @@ __global__ void computeAttentionMask( break; case AttentionMaskType::CAUSAL: isValid = rowIdx < seqLength && colIdx < seqLength && colIdx <= rowIdx; + // Sliding_window_causal when there are not enough kv cache. + isValid = isValid && colIdx >= max(0, rowIdx - maxKvCacheLength); // seq_length==4, max_seq_len==5 // 1 0 0 0 0 // 1 1 0 0 0 // 1 1 1 0 0 // 1 1 1 1 0 // 0 0 0 0 0 + + // seq_length==6, max_seq_len==6, max_kv_cache_length = 2 + // 1 0 0 0 0 0 + // 1 1 0 0 0 0 + // 1 1 1 0 0 0 + // 0 1 1 1 0 0 + // 0 0 1 1 1 0 + // 0 0 0 1 1 1 break; case AttentionMaskType::BIDIRECTIONAL: // clang-format off @@ -222,8 +232,8 @@ void invokeBuildDecoderInfo(const BuildDecoderInfoParams& params, cudaStream_ blocksPerSeq *= 2; } dim3 grid(blocksPerSeq, params.batchSize); - computeAttentionMask<<>>( - params.attentionMask, params.seqOffsets, params.maxSeqLength, params.attentionMaskType); + computeAttentionMask<<>>(params.attentionMask, params.seqOffsets, + params.maxSeqLength, params.maxKvCacheLength, params.attentionMaskType); } } diff --git a/cpp/tensorrt_llm/kernels/gptKernels.h b/cpp/tensorrt_llm/kernels/gptKernels.h index d2bdb370a..b30edfa80 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.h +++ b/cpp/tensorrt_llm/kernels/gptKernels.h @@ -73,6 +73,9 @@ struct BuildDecoderInfoParams int batchSize; // The maximum length of a sequence; it includes input and output. int maxSeqLength; + // The kv cache capacity. + // We will apply the limited_length_causal mask when there are not enough kv cache. + int maxKvCacheLength; // The number of tokens in total. It's \sum_{ii=0}^{batchSize} seqLengths[ii]. int numTokens; // The type of attention. diff --git a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu index 08205ab57..89dd731c2 100644 --- a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu +++ b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu @@ -4,6 +4,24 @@ namespace tensorrt_llm { namespace kernels { +namespace +{ +template +struct Vec2Type; + +template <> +struct Vec2Type +{ + using type = half2; +}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) +template <> +struct Vec2Type<__nv_bfloat16> +{ + using type = __nv_bfloat162; +}; +#endif +}; // namespace template __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols) @@ -21,13 +39,18 @@ __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T* for (int i = 0; i < kProcessRows; ++i) { *reinterpret_cast(act_vec) = reinterpret_cast(act + i * cols)[col_offset]; - if constexpr (std::is_same_v && kElems % 2 == 0) + if constexpr ((std::is_same_v +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + || std::is_same_v +#endif + ) &&(kElems % 2 == 0)) { + using Vec2 = typename Vec2Type::type; #pragma unroll for (int j = 0; j < kElems; j += 2) { - *reinterpret_cast(act_vec + j) - = __hmul2(*reinterpret_cast(act_vec + j), *reinterpret_cast(scale + j)); + *reinterpret_cast(act_vec + j) + = __hmul2(*reinterpret_cast(act_vec + j), *reinterpret_cast(scale + j)); } } else @@ -35,7 +58,7 @@ __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T* #pragma unroll for (int j = 0; j < kElems; ++j) { - act_vec[j] *= scale[j]; + act_vec[j] = static_cast(static_cast(act_vec[j]) * static_cast(scale[j])); } } reinterpret_cast(smoothed_act + i * cols)[col_offset] = *reinterpret_cast(act_vec); @@ -85,6 +108,9 @@ void apply_per_channel_scale_kernel_launcher( T * smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream) INSTANTIATE_PREQUANT_SCALE(half); +#if defined(ENABLE_BF16) +INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16); +#endif } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h index 75db9a8c2..25bda2c41 100644 --- a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h +++ b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h @@ -20,6 +20,10 @@ #include #include +#if defined(ENABLE_BF16) +#include +#endif + #include #include diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu index cf577c971..2b86b605b 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu @@ -1629,7 +1629,7 @@ INSTANTIATE_TRANSPOSE_4D(half); template __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCacheBuffer kvCacheBuffer, - const int headNum, const int sizePerHead, const int seqLen, const float* kvScaleOrigQuant, + const int headNum, const int sizePerHead, const int seqLen, const int maxKvCacheLen, const float* kvScaleOrigQuant, const int* sequence_lengths) { // We allow only fp32/fp16/bf16 as input types @@ -1655,14 +1655,20 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac } // Get linear token index - const int tokenIdx = idx / sizePerHeadDivX; + int tokenIdx = idx / sizePerHeadDivX; + // Apply cyclic kv cache if tokenIdx >= max_kv_cache_length. + // which means we will drop the tokens in the beginning if seqLen > max_kv_cache_length. + const int tokenIdxLowerBound = max(sequence_lengths[batchIdx] - maxKvCacheLen, 0); // Get channel index const int channelIdx = idx % sizePerHeadDivX; - if (tokenIdx >= sequence_lengths[batchIdx]) + if (tokenIdx >= sequence_lengths[batchIdx] || tokenIdx < tokenIdxLowerBound) { return; } + // Apply cyclic kv cache if tokenIdx >= max_kv_cache_length. + tokenIdx = tokenIdx % maxKvCacheLen; + // Get pointer to the dst block given sequence, head and token ids auto valDst = handle_k ? reinterpret_cast(kvCacheBuffer.getKBlockPtr(batchIdx, tokenIdx)) : reinterpret_cast(kvCacheBuffer.getVBlockPtr(batchIdx, tokenIdx)); @@ -1697,7 +1703,7 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac template void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, const int localBatchSize, - const int seqLen, const int maxSeqLen, const int sizePerHead, const int localHeadNum, + const int seqLen, const int maxKvCacheLen, const int sizePerHead, const int localHeadNum, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream) { // Block handles both K and V tile. @@ -1710,25 +1716,25 @@ void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kv if (cache_type == KvCacheDataType::INT8) { transpose4dBatchMajorKVCache<<>>( - kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, kvScaleOrigQuant, sequence_lengths); + kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths); } #ifdef ENABLE_FP8 else if (cache_type == KvCacheDataType::FP8) { transpose4dBatchMajorKVCache<<>>( - kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, kvScaleOrigQuant, sequence_lengths); + kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths); } #endif // ENABLE_FP8 else { transpose4dBatchMajorKVCache<<>>( - kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, kvScaleOrigQuant, sequence_lengths); + kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths); } } #define INSTANTIATE_TRANSPOSE_4D_BATCH_MAJOR_KV_CACHE_TYPE(T, KVCacheBuffer) \ template void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, \ - const int localBatchSize, const int seqLen, const int maxSeqLen, const int sizePerHead, \ + const int localBatchSize, const int seqLen, const int maxKvCacheLen, const int sizePerHead, \ const int localHeadNum, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, \ const int* sequence_lengths, cudaStream_t stream) diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h index d392d4b5e..8664229b0 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h @@ -105,16 +105,17 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const template void invokeTranspose4dBatchMajor(const T* k_src, const T* v_src, KVCacheBuffer& kvTable, const int local_batch_size, - const int seq_len, const int max_seq_len, const int size_per_head, const int local_head_num, + const int seq_len, const int max_kv_cache_len, const int size_per_head, const int local_head_num, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream); template void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens, - const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, - const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale, - const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream); + const int* padding_offset, const int batch_size, const int seq_len, const int cyclic_kv_cache_len, + const int token_num, const int head_num, const int kv_head_num, const int size_per_head, + const int rotary_embedding_dim, const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, + const float rotary_embedding_scale, const int rotary_embedding_max_positions, + const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, + const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream); template void invokeAddRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, const int batch_size, diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels_2.cu b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels_2.cu index a6c3650fc..84d978247 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels_2.cu +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels_2.cu @@ -65,9 +65,9 @@ struct Vec_t<__nv_bfloat16> template __global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer, const T* __restrict qkv_bias, const int* seq_lens, const int* padding_offset, const float* kvScaleOrigQuant, const int batch_size, - const int seq_len, const int head_num, const int kv_head_num, const int size_per_head, - const int rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type, - float rotary_embedding_scale, const int rotary_embedding_max_positions, + const int seq_len, const int cyclic_kv_cache_len, const int head_num, const int kv_head_num, + const int size_per_head, const int rotary_embedding_dim, float rotary_embedding_base, + RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type) { // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and @@ -222,9 +222,11 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer, } const int channelIdx{tidx}; - auto kDst = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_idx, token_idx_in_seq)); - auto vDst = reinterpret_cast(kvCacheBuffer.getVBlockPtr(batch_idx, token_idx_in_seq)); - int inBlockIdx = kvCacheBuffer.getKVLocalIdx(token_idx_in_seq, kv_head_idx, sizePerHeadDivX, channelIdx); + const bool valid_kv_cache_pos = token_idx_in_seq >= (actual_seq_len - cyclic_kv_cache_len); + const int token_idx_in_kv_cache = token_idx_in_seq % cyclic_kv_cache_len; + auto kDst = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_idx, token_idx_in_kv_cache)); + auto vDst = reinterpret_cast(kvCacheBuffer.getVBlockPtr(batch_idx, token_idx_in_kv_cache)); + int inBlockIdx = kvCacheBuffer.getKVLocalIdx(token_idx_in_kv_cache, kv_head_idx, sizePerHeadDivX, channelIdx); if (!is_masked) { *reinterpret_cast(&QKV[src_q_idx]) = q; @@ -233,48 +235,24 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer, *reinterpret_cast(&QKV[src_k_idx]) = k; *reinterpret_cast(&QKV[src_v_idx]) = v; - if (ENABLE_8BITS_CACHE) + if (valid_kv_cache_pos) { - inBlockIdx = inBlockIdx * vec_size; - // Cast float scale to dst data type. - using T_scale = typename mmha::kv_cache_scale_type_t::Type; - T_scale scaleOrigQuant; - mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]); - // Store 8bits kv cache. - mmha::store_8bits_kv_cache_vec(kDst, k, inBlockIdx, scaleOrigQuant); - mmha::store_8bits_kv_cache_vec(vDst, v, inBlockIdx, scaleOrigQuant); - } - else - { - reinterpret_cast(kDst)[inBlockIdx] = k; - reinterpret_cast(vDst)[inBlockIdx] = v; - } - } - } - else if (is_seq_masked && !is_head_size_masked) - { - // Set padding to zero in case of potential nan generated. - *reinterpret_cast(&QKV[src_q_idx]) = zero; - if ((head_num == kv_head_num) || (head_idx == (kv_head_idx * qheads_per_kv_head))) - { - *reinterpret_cast(&QKV[src_k_idx]) = zero; - *reinterpret_cast(&QKV[src_v_idx]) = zero; - - if (ENABLE_8BITS_CACHE) - { - inBlockIdx = inBlockIdx * vec_size; - // Cast float scale to dst data type. - using T_scale = typename mmha::kv_cache_scale_type_t::Type; - T_scale scaleOrigQuant; - mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]); - // Store 8bits kv cache. - mmha::store_8bits_kv_cache_vec(kDst, zero, inBlockIdx, scaleOrigQuant); - mmha::store_8bits_kv_cache_vec(vDst, zero, inBlockIdx, scaleOrigQuant); - } - else - { - reinterpret_cast(kDst)[inBlockIdx] = zero; - reinterpret_cast(vDst)[inBlockIdx] = zero; + if (ENABLE_8BITS_CACHE) + { + inBlockIdx = inBlockIdx * vec_size; + // Cast float scale to dst data type. + using T_scale = typename mmha::kv_cache_scale_type_t::Type; + T_scale scaleOrigQuant; + mmha::convert_from_float(&scaleOrigQuant, kvScaleOrigQuant[0]); + // Store 8bits kv cache. + mmha::store_8bits_kv_cache_vec(kDst, k, inBlockIdx, scaleOrigQuant); + mmha::store_8bits_kv_cache_vec(vDst, v, inBlockIdx, scaleOrigQuant); + } + else + { + reinterpret_cast(kDst)[inBlockIdx] = k; + reinterpret_cast(vDst)[inBlockIdx] = v; + } } } } @@ -282,11 +260,12 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer kvCacheBuffer, template void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens, - const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, - const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale, - const float* kvScaleOrigQuant, const int int8_mode, cudaStream_t stream) + const int* padding_offset, const int batch_size, const int seq_len, const int cyclic_kv_cache_len, + const int token_num, const int head_num, const int kv_head_num, const int size_per_head, + const int rotary_embedding_dim, const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, + const float rotary_embedding_scale, const int rotary_embedding_max_positions, + const PositionEmbeddingType position_embedding_type, const float* scale, const float* kvScaleOrigQuant, + const int int8_mode, cudaStream_t stream) { TLLM_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with RoPE"); // TODO // To implement rotary embeddings, each thread processes two QKV elems: @@ -298,26 +277,27 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, KVCacheBuffer& kvTable, co if (qkv_bias != nullptr) { applyBiasRopeUpdateKVCache<<>>(QKV, kvTable, - qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, head_num, kv_head_num, - size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, - rotary_embedding_max_positions, position_embedding_type); + qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, cyclic_kv_cache_len, head_num, + kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, + rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type); } else { applyBiasRopeUpdateKVCache<<>>(QKV, kvTable, - qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, head_num, kv_head_num, - size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, - rotary_embedding_max_positions, position_embedding_type); + qkv_bias, seq_lens, padding_offset, kvScaleOrigQuant, batch_size, seq_len, cyclic_kv_cache_len, head_num, + kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, + rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type); } } template void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens, - const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, - const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale, - const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream) + const int* padding_offset, const int batch_size, const int seq_len, const int cyclic_kv_cache_len, + const int token_num, const int head_num, const int kv_head_num, const int size_per_head, + const int rotary_embedding_dim, const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, + const float rotary_embedding_scale, const int rotary_embedding_max_positions, + const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, + const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream) { // Block handles both K and V tile. constexpr int x = (sizeof(T) == 4) ? 4 : 8; @@ -326,36 +306,37 @@ void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* q if (cache_type == KvCacheDataType::INT8) { invokeApplyBiasRopeUpdateKVCacheDispatch(QKV, kvTable, qkv_bias, seq_lens, - padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim, - rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, - position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream); + padding_offset, batch_size, seq_len, cyclic_kv_cache_len, token_num, head_num, kv_head_num, size_per_head, + rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, + rotary_embedding_max_positions, position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream); } #ifdef ENABLE_FP8 else if (cache_type == KvCacheDataType::FP8) { invokeApplyBiasRopeUpdateKVCacheDispatch(QKV, kvTable, qkv_bias, seq_lens, - padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim, - rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, - position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream); + padding_offset, batch_size, seq_len, cyclic_kv_cache_len, token_num, head_num, kv_head_num, size_per_head, + rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, + rotary_embedding_max_positions, position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream); } #endif // ENABLE_FP8 else { invokeApplyBiasRopeUpdateKVCacheDispatch(QKV, kvTable, qkv_bias, seq_lens, padding_offset, - batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim, - rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, - position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream); + batch_size, seq_len, cyclic_kv_cache_len, token_num, head_num, kv_head_num, size_per_head, + rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, + rotary_embedding_max_positions, position_embedding_type, scale, kvScaleOrigQuant, int8_mode, stream); } } #define INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(T, KVCacheBuffer) \ template void invokeApplyBiasRopeUpdateKVCache(T* QKV, KVCacheBuffer& kvTable, const T* qkv_bias, \ - const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int token_num, \ - const int head_num, const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, \ - const float rotary_embedding_base, const RotaryScalingType rotary_scale_type, \ - const float rotary_embedding_scale, const int rotary_embedding_max_positions, \ - const PositionEmbeddingType position_embedding_type, const float* scale, const int int8_mode, \ - const KvCacheDataType cache_type, const float* kvScaleOrigQuant, cudaStream_t stream) + const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, \ + const int cyclic_kv_cache_len, const int token_num, const int head_num, const int kv_head_num, \ + const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, \ + const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, \ + const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, \ + const float* scale, const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, \ + cudaStream_t stream) INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float, KVBlockArray); INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(float, KVLinearBuffer); diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h index 415f2d7b3..62576338a 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h @@ -19,6 +19,9 @@ #include #include #include +#if defined(ENABLE_BF16) +#include +#endif #include #include #include @@ -27,34 +30,6 @@ namespace tensorrt_llm { namespace kernels { -struct WeightOnlyParams -{ - const uint8_t* qweight; - const half* scales; - const half* zeros; - const half* in; - const half* bias; - half* out; - const int m; - const int n; - const int k; - const int group_size; - - WeightOnlyParams(const uint8_t* _qweight, const half* _scales, const half* _zeros, const half* _in, - const half* _bias, half* _out, const int _m, const int _n, const int _k, const int _group_size) - : qweight(_qweight) - , scales(_scales) - , zeros(_zeros) - , in(_in) - , bias(_bias) - , out(_out) - , m(_m) - , n(_n) - , k(_k) - , group_size(_group_size) - { - } -}; enum class WeightOnlyQuantType { Int4b, @@ -70,12 +45,61 @@ struct WeightOnlyPerChannel; template struct WeightOnlyGroupWise; -enum class WeightOnlyActivationType +enum class WeightOnlyActivationFunctionType { Gelu, Relu, Identity, InvalidType }; + +enum class WeightOnlyActivationType +{ + FP16, + BF16 +}; + +struct WeightOnlyParams +{ + // ActType is fp16 or bf16 + using ActType = void; + using WeiType = uint8_t; + + const uint8_t* qweight; + const ActType* scales; + const ActType* zeros; + const ActType* in; + const ActType* bias; + ActType* out; + const int m; + const int n; + const int k; + const int group_size; + WeightOnlyQuantType quant_type; + WeightOnlyType weight_only_type; + WeightOnlyActivationFunctionType act_func_type; + WeightOnlyActivationType act_type; + + WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in, + const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k, const int _group_size, + const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type, + const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type) + : qweight(_qweight) + , scales(_scales) + , zeros(_zeros) + , in(_in) + , bias(_bias) + , out(_out) + , m(_m) + , n(_n) + , k(_k) + , group_size(_group_size) + , quant_type(_quant_type) + , weight_only_type(_weight_only_type) + , act_func_type(_act_func_type) + , act_type(_act_type) + { + } +}; } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 9a58c352b..21d087956 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -22,11 +22,51 @@ namespace tensorrt_llm { namespace kernels { -template -struct WeightLayoutDetails; +template +struct ActTypeDetails; template <> -struct WeightLayoutDetails +struct ActTypeDetails +{ + using CutlassType = cutlass::half_t; + using Vec2 = half2; + + __device__ __forceinline__ static Vec2 to_vec2(half v) + { + return __half2half2(v); + } +}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) +template <> +struct ActTypeDetails<__nv_bfloat16> +{ + using CutlassType = cutlass::bfloat16_t; + using Vec2 = __nv_bfloat162; + + __device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v) + { + return __bfloat162bfloat162(v); + } +}; +#endif + +template +struct ConverterSelector +{ + static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b); + + using WeiType = std::conditional_t; + static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4; + using Converter + = cutlass::FastInterleavedAndBiasedNumericArrayConverter::CutlassType, WeiType, + kConvertCount>; +}; + +template +struct WeightOnlyDetails; + +template +struct WeightOnlyDetails { // Every four rows of the original weights are interleaved into a row with stride of 64, so if each thread // processes 32 elements(for int4, we can use ldg.128 to load weights), then every group of two adjacent threads @@ -49,16 +89,6 @@ struct WeightLayoutDetails static constexpr int kShuffleContinous = 4; static constexpr int kShuffleStrided = 4; - // The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4s_inplace - // Input int8 data layout - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) - // - // Converted fp16 data layout - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) - static constexpr int kConvertCount = 8; - using Converter - = cutlass::FastInterleavedAndBiasedNumericArrayConverter; - // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the // corresponding address in shared memory template @@ -85,8 +115,8 @@ struct WeightLayoutDetails } }; -template <> -struct WeightLayoutDetails +template +struct WeightOnlyDetails { // Every two rows of the original weights are interleaved into a row with stride of 64, so if each thread // processes 16 elements(for int8, we can use ldg.128 to load weights), then every group of four adjacent threads @@ -109,15 +139,6 @@ struct WeightLayoutDetails static constexpr int kShuffleContinous = 2; static constexpr int kShuffleStrided = 4; - // The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int8s_inplace - // Input int8 data layout - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - // - // Converted fp16 data layout - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) - static constexpr int kConvertCount = 4; - using Converter = cutlass::FastInterleavedAndBiasedNumericArrayConverter; - // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the // corresponding address in shared memory template @@ -145,10 +166,10 @@ struct WeightLayoutDetails } }; -template +template struct WeightOnlyKernelDetails { - using Layout = WeightLayoutDetails; + using Layout = WeightOnlyDetails; static constexpr int kElemBits = Layout::kElemBits; static constexpr int kInterleave = Layout::kInterleave; @@ -159,8 +180,20 @@ struct WeightOnlyKernelDetails static constexpr int kShuffleContinous = Layout::kShuffleContinous; static constexpr int kShuffleStrided = Layout::kShuffleStrided; - using Converter = typename Layout::Converter; - static constexpr int kConvertCount = Layout::kConvertCount; + // The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4/8s_inplace + // Input int8 data layout + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + // + // Converted fp16/bf16 data layout + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) + + // Input int8 data layout + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) + // + // Converted fp16/bf16 data layout + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) + static constexpr int kConvertCount = ConverterSelector::kConvertCount; + using Converter = typename ConverterSelector::Converter; // Use ldg128 load data from global memory static constexpr int kAccessSize = 128; @@ -175,8 +208,8 @@ struct WeightOnlyKernelDetails static constexpr int kConvertIters = kElemsPerThread / kConvertCount; // Each thread loads 16(int8b)/32(int4b) quantized weight elements each time through ldg128 - // So more times of ldg128 are needed to load the same number of fp16 activation elements. - static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(half) * 8); + // So more times of ldg128 are needed to load the same number of fp16/bf16 activation elements. + static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8); static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess; }; @@ -197,11 +230,11 @@ struct WeightOnlyProperties> static constexpr int kGroupSize = GS; }; -template +template struct WeightOnlyScaleLoader { - using ElemType = half; - using Details = WeightOnlyKernelDetails; + using ElemType = ActType; + using Details = WeightOnlyKernelDetails; static constexpr bool kIsFineGrained = WeightOnlyProperties::kIsFineGrained; static constexpr int kGroupSize = WeightOnlyProperties::kGroupSize; @@ -258,19 +291,20 @@ struct WeightOnlyScaleLoader } }; -template class ActOp, bool Zero, bool Bias, - int NPerBlock, int Batch, int BlockSize> -__global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* scales, const half* zeros, const half* in, - const half* bias, half* out, const int n, const int k) +template class ActOp, + bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize> +__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros, + const ActType* in, const ActType* bias, ActType* out, const int n, const int k) { static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0)); - using Details = WeightOnlyKernelDetails; + using ActType2 = typename ActTypeDetails::Vec2; + using Details = WeightOnlyKernelDetails; using Converter = typename Details::Converter; using AccType = typename Details::AccessType; using CvtSrcType = typename Converter::source_type; using CvtResType = typename Converter::result_type; - using ScaleLoader = WeightOnlyScaleLoader; + using ScaleLoader = WeightOnlyScaleLoader; extern __shared__ uint8_t shmem[]; constexpr int Interleave = Details::kInterleave; constexpr int WarpSize = 32; @@ -286,20 +320,20 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca float(*sm)[Num * Interleave] = reinterpret_cast(shmem); - // In order to take advantage of hfma2, we use fp16 for accumulation within threads and fp32 for accumulation + // In order to take advantage of hfma2, we use fp16/bf16 for accumulation within threads and fp32 for accumulation // between threads. - half accumulator[Num]; + ActType accumulator[Num]; for (int i = 0; i < Num; ++i) { - accumulator[i] = __float2half_rn(0.f); + accumulator[i] = static_cast(0.f); } // Iteration in k dimensions for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave; local_k += BlockSize * Details::kElemsPerThread) { - half weights_f16[Details::kElemsPerThread * NPerBlock]; - half scale[NPerBlock], zero[NPerBlock]; + ActType weights_f16[Details::kElemsPerThread * NPerBlock]; + ActType scale[NPerBlock], zero[NPerBlock]; #pragma unroll for (int idx = 0; idx < NPerBlock; ++idx) { @@ -308,7 +342,7 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca load(weights_quantized, qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte); scale_loader.load(scale[idx], zero[idx], idx); - half weights_vec[Details::kElemsPerThread]; + ActType weights_vec[Details::kElemsPerThread]; #pragma unroll for (int i = 0; i < Details::kConvertIters; ++i) { @@ -325,9 +359,10 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca { // Dequantize the weights and arrange the shuffled elements back to the correct order in the // register array - half2 v = *reinterpret_cast(weights_vec + i * Details::kShuffleBasicTile + ActType2 v = *reinterpret_cast(weights_vec + i * Details::kShuffleBasicTile + j * Details::kShuffleContinous * Details::kShuffleBasicTile); - v = __hfma2(v, __half2half2(scale[idx]), __half2half2(zero[idx])); + v = __hfma2( + v, ActTypeDetails::to_vec2(scale[idx]), ActTypeDetails::to_vec2(zero[idx])); weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile + j * Details::kShuffleBasicTile + 0) * NPerBlock @@ -344,7 +379,7 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca #pragma unroll for (int b = 0; b < Batch; ++b) { - half in_v[Details::kElemsPerThread]; + ActType in_v[Details::kElemsPerThread]; #pragma unroll for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) { @@ -355,11 +390,12 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca // Perform vector inner product and accumulate if constexpr (NPerBlock == 1) { - half2 v = __float2half2_rn(0.f); + ActType2 v = ActTypeDetails::to_vec2(static_cast(0.f)); #pragma unroll for (int y = 0; y < Details::kElemsPerThread; y += 2) { - v = __hfma2(*reinterpret_cast(weights_f16 + y), *reinterpret_cast(in_v + y), v); + v = __hfma2( + *reinterpret_cast(weights_f16 + y), *reinterpret_cast(in_v + y), v); } accumulator[b] += __hadd(v.x, v.y); } @@ -371,9 +407,10 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca #pragma unroll for (int y = 0; y < Details::kElemsPerThread; ++y) { - *reinterpret_cast(accumulator + b * NPerBlock + x * 2) - = __hfma2(*reinterpret_cast(weights_f16 + y * NPerBlock + x * 2), - __half2half2(in_v[y]), *reinterpret_cast(accumulator + b * NPerBlock + x * 2)); + *reinterpret_cast(accumulator + b * NPerBlock + x * 2) + = __hfma2(*reinterpret_cast(weights_f16 + y * NPerBlock + x * 2), + ActTypeDetails::to_vec2(in_v[y]), + *reinterpret_cast(accumulator + b * NPerBlock + x * 2)); } } } @@ -384,7 +421,7 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca #pragma unroll for (int i = 0; i < Num; ++i) { - reses[i] = __half2float(accumulator[i]); + reses[i] = static_cast(accumulator[i]); } // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the @@ -403,27 +440,64 @@ __global__ void weight_only_batched_gemv(const uint8_t* qweight, const half* sca float bias_v = 0.f; if constexpr (Bias) { - bias_v = __half2float(bias[n_start_id + nid]); + bias_v = static_cast(bias[n_start_id + nid]); } int b = i / NPerBlock / Interleave; - out[b * n + n_start_id + nid] = __float2half_rn(ActOp::apply(v + bias_v)); + out[b * n + n_start_id + nid] = static_cast(ActOp::apply(v + bias_v)); } } +template class ActOp, + bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize> +__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros, + const ActType* in, const ActType* bias, ActType* out, const int n, const int k) +{ + if constexpr (std::is_same_v) + { + weight_only_batched_gemv( + qweight, scales, zeros, in, bias, out, n, k); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + else if (std::is_same_v) + { + weight_only_batched_gemv( + qweight, scales, zeros, in, bias, out, n, k); + } +#endif +} + template class ActOp, bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize> struct WeightOnlyBatchedGemvKernelLauncher { - static constexpr int kInterleave = WeightLayoutDetails::kInterleave; - static void run(const WeightOnlyParams& params, cudaStream_t stream) { - dim3 grid(params.n / NPerBlock / kInterleave); - dim3 block(BlockSize); - int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; - weight_only_batched_gemv - <<>>( - params.qweight, params.scales, params.zeros, params.in, params.bias, params.out, params.n, params.k); + if (params.act_type == WeightOnlyActivationType::FP16) + { + constexpr int kInterleave = WeightOnlyDetails::kInterleave; + dim3 grid(params.n / NPerBlock / kInterleave); + dim3 block(BlockSize); + int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; + weight_only_batched_gemv_wrapper<<>>(params.qweight, reinterpret_cast(params.scales), + reinterpret_cast(params.zeros), reinterpret_cast(params.in), + reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, params.k); + } +#if defined(ENABLE_BF16) + else if (params.act_type == WeightOnlyActivationType::BF16) + { + constexpr int kInterleave = WeightOnlyDetails::kInterleave; + dim3 grid(params.n / NPerBlock / kInterleave); + dim3 block(BlockSize); + int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; + weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch, + BlockSize><<>>(params.qweight, + reinterpret_cast(params.scales), + reinterpret_cast(params.zeros), reinterpret_cast(params.in), + reinterpret_cast(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), + params.n, params.k); + } +#endif } }; } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu index f04b2d354..06f07473c 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu @@ -55,21 +55,24 @@ void select_zero_bias(const WeightOnlyParams& params, cudaStream_t stream) } template -void select_activation(WeightOnlyActivationType atype, const WeightOnlyParams& params, cudaStream_t stream) +void select_activation(const WeightOnlyParams& params, cudaStream_t stream) { - switch (atype) + switch (params.act_func_type) { - case WeightOnlyActivationType::Gelu: + // Currently, activation function is not called in the plugin +#if 0 + case WeightOnlyActivationFunctionType::Gelu: { select_zero_bias(params, stream); break; } - case WeightOnlyActivationType::Relu: + case WeightOnlyActivationFunctionType::Relu: { select_zero_bias(params, stream); break; } - case WeightOnlyActivationType::Identity: +#endif + case WeightOnlyActivationFunctionType::Identity: { select_zero_bias(params, stream); break; @@ -83,18 +86,15 @@ void select_activation(WeightOnlyActivationType atype, const WeightOnlyParams& p } template -void select_quant_type( - WeightOnlyQuantType qtype, WeightOnlyActivationType atype, const WeightOnlyParams& params, cudaStream_t stream) +void select_quant_type(const WeightOnlyParams& params, cudaStream_t stream) { - if (qtype == WeightOnlyQuantType::Int4b) + if (params.quant_type == WeightOnlyQuantType::Int4b) { - select_activation( - atype, params, stream); + select_activation(params, stream); } - else if (qtype == WeightOnlyQuantType::Int8b) + else if (params.quant_type == WeightOnlyQuantType::Int8b) { - select_activation( - atype, params, stream); + select_activation(params, stream); } else { @@ -103,16 +103,15 @@ void select_quant_type( } template -void select_groupwise_weight_only(WeightOnlyQuantType qtype, WeightOnlyType wtype, WeightOnlyActivationType atype, - const WeightOnlyParams& params, cudaStream_t stream) +void select_groupwise_weight_only(const WeightOnlyParams& params, cudaStream_t stream) { - if (wtype == WeightOnlyType::GroupWise && params.group_size == 64) + if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 64) { - select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(qtype, atype, params, stream); + select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); } - else if (wtype == WeightOnlyType::GroupWise && params.group_size == 128) + else if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 128) { - select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(qtype, atype, params, stream); + select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); } else { @@ -120,33 +119,40 @@ void select_groupwise_weight_only(WeightOnlyQuantType qtype, WeightOnlyType wtyp } } -void weight_only_batched_gemv_launcher(WeightOnlyQuantType qtype, WeightOnlyType wtype, WeightOnlyActivationType atype, - const WeightOnlyParams& params, cudaStream_t stream) +void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream) { - if (wtype == WeightOnlyType::PerChannel) + assert(params.act_func_type == WeightOnlyActivationFunctionType::Identity); + assert(params.weight_only_type == WeightOnlyType::GroupWise + || (params.weight_only_type == WeightOnlyType::PerChannel && params.bias == nullptr + && params.zeros == nullptr)); + if (params.weight_only_type == WeightOnlyType::PerChannel) { - if (qtype == WeightOnlyQuantType::Int4b) + if (params.quant_type == WeightOnlyQuantType::Int4b) { switch (params.m) { case 1: { - select_activation(atype, params, stream); + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 2: { - select_activation(atype, params, stream); + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 3: { - select_activation(atype, params, stream); + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 4: { - select_activation(atype, params, stream); + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } default: @@ -156,28 +162,32 @@ void weight_only_batched_gemv_launcher(WeightOnlyQuantType qtype, WeightOnlyType } } } - else if (qtype == WeightOnlyQuantType::Int8b) + else if (params.quant_type == WeightOnlyQuantType::Int8b) { switch (params.m) { case 1: { - select_activation(atype, params, stream); + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 2: { - select_activation(atype, params, stream); + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 3: { - select_activation(atype, params, stream); + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } case 4: { - select_activation(atype, params, stream); + WeightOnlyBatchedGemvKernelLauncher::run(params, stream); break; } default: @@ -188,28 +198,28 @@ void weight_only_batched_gemv_launcher(WeightOnlyQuantType qtype, WeightOnlyType } } } - else if (wtype == WeightOnlyType::GroupWise) + else if (params.weight_only_type == WeightOnlyType::GroupWise) { switch (params.m) { case 1: { - select_groupwise_weight_only<2, 1, 256>(qtype, wtype, atype, params, stream); + select_groupwise_weight_only<2, 1, 256>(params, stream); break; } case 2: { - select_groupwise_weight_only<2, 2, 256>(qtype, wtype, atype, params, stream); + select_groupwise_weight_only<2, 2, 256>(params, stream); break; } case 3: { - select_groupwise_weight_only<2, 3, 128>(qtype, wtype, atype, params, stream); + select_groupwise_weight_only<2, 3, 128>(params, stream); break; } case 4: { - select_groupwise_weight_only<2, 4, 128>(qtype, wtype, atype, params, stream); + select_groupwise_weight_only<2, 4, 128>(params, stream); break; } default: diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h index b4b032105..dad8b2e50 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h @@ -20,7 +20,6 @@ namespace tensorrt_llm { namespace kernels { -void weight_only_batched_gemv_launcher(WeightOnlyQuantType qtype, WeightOnlyType wtype, WeightOnlyActivationType atype, - const WeightOnlyParams& params, cudaStream_t stream); +void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream); } } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu index cb9ea68fd..923d6b976 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu @@ -21,46 +21,9 @@ namespace tensorrt_llm namespace kernels { -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, false, 2, 1, 256>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 1, 256>; template struct WeightOnlyBatchedGemvKernelLauncher, @@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, false, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, false, 2, 1, 256>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 1, 256>; template struct WeightOnlyBatchedGemvKernelLauncher, diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu index 59270fdd7..55c05c612 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu @@ -21,46 +21,9 @@ namespace tensorrt_llm namespace kernels { -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, false, 2, 1, 256>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 1, 256>; template struct WeightOnlyBatchedGemvKernelLauncher, @@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, false, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, false, 2, 1, 256>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 1, 256>; template struct WeightOnlyBatchedGemvKernelLauncher, diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu index 1302e8dce..fb15f85db 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu @@ -20,46 +20,9 @@ namespace tensorrt_llm namespace kernels { -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, false, 2, 2, 256>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 2, 256>; template struct WeightOnlyBatchedGemvKernelLauncher, @@ -68,22 +31,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, false, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, false, 2, 2, 256>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 2, 256>; template struct WeightOnlyBatchedGemvKernelLauncher, diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu index 72a515fe2..d064e4b38 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu @@ -20,46 +20,9 @@ namespace tensorrt_llm namespace kernels { -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, false, 2, 2, 256>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 2, 256>; template struct WeightOnlyBatchedGemvKernelLauncher, @@ -68,22 +31,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, false, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, false, 2, 2, 256>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 2, 256>; template struct WeightOnlyBatchedGemvKernelLauncher, diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu index 4224bdac2..c9a7100c5 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu @@ -21,46 +21,9 @@ namespace tensorrt_llm namespace kernels { -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, false, 2, 3, 128>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 3, 128>; template struct WeightOnlyBatchedGemvKernelLauncher, @@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, false, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, false, 2, 3, 128>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 3, 128>; template struct WeightOnlyBatchedGemvKernelLauncher, diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu index 032aea0cb..09d67ab5c 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu @@ -21,46 +21,9 @@ namespace tensorrt_llm namespace kernels { -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, false, 2, 3, 128>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 3, 128>; template struct WeightOnlyBatchedGemvKernelLauncher, @@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, false, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, false, 2, 3, 128>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 3, 128>; template struct WeightOnlyBatchedGemvKernelLauncher, diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu index b3049c70f..001bfd392 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu @@ -20,46 +20,9 @@ namespace tensorrt_llm namespace kernels { -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, false, 2, 4, 128>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 4, 128>; template struct WeightOnlyBatchedGemvKernelLauncher, @@ -68,22 +31,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, false, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, false, 2, 4, 128>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 4, 128>; template struct WeightOnlyBatchedGemvKernelLauncher, diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu index 66cebb38b..6fb6fe9cc 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu @@ -21,46 +21,9 @@ namespace tensorrt_llm namespace kernels { -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, GeluActivation, - false, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, ReluActivation, - false, false, 2, 4, 128>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 4, 128>; template struct WeightOnlyBatchedGemvKernelLauncher, @@ -69,22 +32,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher; template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, false, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - GeluActivation, false, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - ReluActivation, false, false, 2, 4, 128>; + template struct WeightOnlyBatchedGemvKernelLauncher, IdentityActivation, true, true, 2, 4, 128>; template struct WeightOnlyBatchedGemvKernelLauncher, diff --git a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu index 13df17934..d8c0ba0bf 100644 --- a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu +++ b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu @@ -29,39 +29,40 @@ namespace layers __global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids, const bool* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim, int local_batch_size, - int beam_width, int max_seq_len) + int beam_width, int max_kv_cache_length, int max_seq_len) { int time_step = threadIdx.x + blockIdx.x * blockDim.x; int bb_id = threadIdx.y + blockIdx.y * blockDim.y; const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1 - const int input_length{input_lengths == nullptr ? 0 : input_lengths[bb_id]}; const int batch_id = bb_id / beam_width; const int beam_id = bb_id % beam_width; - if (bb_id >= beam_width * local_batch_size || time_step < input_length || finished[bb_id]) + if (bb_id >= beam_width * local_batch_size || time_step < (max_seq_len - max_kv_cache_length) || finished[bb_id]) { return; } - int time_step_circ = time_step % max_seq_len; - // FIXME: we will remove all paddings later (@boyang) - // Skip input paddings when updating the indir cache table. + int time_step_circ = time_step % max_kv_cache_length; + // for the parent_ids, we will still keep it for all past tokens (i.e. max_seq_len) const int src_beam = parent_ids[batch_id][beam_id * max_seq_len + current_step]; - const uint32_t tgt_offset = batch_id * beam_width * max_seq_len + beam_id * max_seq_len + time_step_circ; - const uint32_t src_offset = batch_id * beam_width * max_seq_len + src_beam * max_seq_len + time_step_circ; + // for the indir tables, we have the cyclic kv cache. + const uint tgt_offset + = batch_id * beam_width * max_kv_cache_length + beam_id * max_kv_cache_length + time_step_circ; + const uint src_offset + = batch_id * beam_width * max_kv_cache_length + src_beam * max_kv_cache_length + time_step_circ; tgt_indir_cache[tgt_offset] = (time_step == current_step) ? beam_id : src_indir_cache[src_offset]; } void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids, const bool* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim, int local_batch_size, - int beam_width, int max_seq_len, cudaStream_t stream) + int beam_width, int max_seq_len, int max_kv_cache_length, cudaStream_t stream) { const dim3 block(32); // Update indirections steps [input_length[bb_id], sequence_lengths[bb_id]], included const dim3 grid((max_seq_len + block.x - 1) / block.x, local_batch_size * beam_width); update_indir_cache_kernel<<>>(tgt_indir_cache, src_indir_cache, parent_ids, finished, - sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_seq_len); + sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_kv_cache_length, max_seq_len); } template @@ -201,7 +202,8 @@ void BaseBeamSearchLayer::forward(BeamSearchOutputParams& outputs, ForwardPar update_indir_cache_kernelLauncher(outputs.tgt_cache_indirection.template getPtr(), params.src_cache_indirection.template getPtr(), outputs.parent_ids_ptr.template getPtr(), outputs.finished->template getPtr(), - sequence_length, input_lengths, batch_size, local_batch_size, beam_width, max_seq_len, stream_); + sequence_length, input_lengths, batch_size, local_batch_size, beam_width, max_seq_len, + params.max_kv_cache_length, stream_); sync_check_cuda_error(); } sync_check_cuda_error(); diff --git a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h index 9b6e58b56..3f5592bfa 100644 --- a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h +++ b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h @@ -54,15 +54,17 @@ class BaseBeamSearchLayer : public BaseLayer class ForwardParams : public SoftmaxParams { public: - ForwardParams( - int step, int ite, tc::Tensor logits, tc::Tensor endIds, tc::Tensor src_cache_indirection, int max_seq_len) + ForwardParams(int step, int ite, tc::Tensor logits, tc::Tensor endIds, tc::Tensor src_cache_indirection, + int max_kv_cache_length, int max_seq_len) : SoftmaxParams(step, ite, std::move(logits), std::move(endIds)) , src_cache_indirection{std::move(src_cache_indirection)} + , max_kv_cache_length{max_kv_cache_length} , max_seq_len{max_seq_len} { } // mandatory parameters + int max_kv_cache_length; int max_seq_len; tc::Tensor src_cache_indirection; // [local_batch_size, beam_width, max_seq_len] diff --git a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp index cf20b42df..fbc0de7cd 100644 --- a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp +++ b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp @@ -295,7 +295,8 @@ void DynamicDecodeLayer::forward(OutputParams& outputs, ForwardParams const& auto const end_id_offset = end_ids.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size); typename BaseBeamSearchLayer::ForwardParams dynamic_decode_input_tensors{step, ite, logits_offset, - end_id_offset, *params.src_cache_indirection, static_cast(max_seq_len)}; + end_id_offset, *params.src_cache_indirection, static_cast(params.max_kv_cache_length), + static_cast(max_seq_len)}; dynamic_decode_input_tensors.embedding_bias = params.embedding_bias; diff --git a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h index 882a39a95..ae7ac8cd2 100644 --- a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h +++ b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h @@ -79,10 +79,12 @@ class DynamicDecodeLayer : public BaseLayer class ForwardParams { public: - ForwardParams(int step, int ite, int maxInputLength, int localBatchSize, tc::Tensor logits, tc::Tensor endIds) + ForwardParams(int step, int ite, int maxInputLength, int maxKvCacheLength, int localBatchSize, + tc::Tensor logits, tc::Tensor endIds) : step{step} , ite{ite} , max_input_length{maxInputLength} + , max_kv_cache_length{maxKvCacheLength} , local_batch_size{localBatchSize} , logits{std::move(logits)} , end_ids{std::move(endIds)} @@ -93,6 +95,7 @@ class DynamicDecodeLayer : public BaseLayer int step; int ite; int max_input_length; + int max_kv_cache_length; int local_batch_size; tc::Tensor logits; // [batch_size, beam_width, vocab_size_padded], on gpu tc::Tensor end_ids; // [batch_size], on gpu diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp index 71e3f1187..09ede9f42 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp @@ -252,7 +252,7 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc if (mEnableContextFMHA && !mRelativeAttention) { // b, max_seqlen, actual_total_seqlen - mFMHARunner->setup(request_batch_size, request_seq_len, request_batch_size * request_seq_len); + mFMHARunner->setup(request_batch_size, request_seq_len, request_seq_len, request_batch_size * request_seq_len); mFMHARunner->run(const_cast(attention_input), cu_seqlens, context_buf_, stream); } else diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp index 8c221e9af..26df67deb 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp @@ -83,7 +83,8 @@ struct FusedQKVMaskedAttentionDispatchParams float rotary_embedding_scale; int rotary_embedding_max_positions; PositionEmbeddingType position_embedding_type; - int max_seq_len; + int max_kv_cache_length; + int cyclic_kv_cache_length; const int* input_lengths; int step; float q_scaling; @@ -157,7 +158,8 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params= 80) || (mType != DataType::kBF16), - "Unsupported data type, pre SM 80 GPUs do not support bfloat16"); + TLLM_CHECK_WITH_INFO( + (mSM >= 80) || (mType != DataType::kBF16), "Unsupported data type, pre SM 80 GPUs do not support bfloat16"); } const int GPTAttentionPluginCommon::getHeadSize(bool checkInit) const @@ -318,8 +323,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t leng mKVCacheQuantMode = tc::QuantMode(kvCacheQuantMode); TLLM_CHECK(d == a + length); - TLLM_CHECK_WITH_INFO((tc::getSMVersion() >= 80) || (mType != DataType::kBF16), - "Unsupported data type, pre SM 80 GPUs do not support bfloat16"); + TLLM_CHECK_WITH_INFO( + (mSM >= 80) || (mType != DataType::kBF16), "Unsupported data type, pre SM 80 GPUs do not support bfloat16"); } size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext( @@ -380,12 +385,12 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration(DataType type, in size_t generation_workspace_size = 0; const int batch_beam = total_num_seq; - int32_t const maxSeqLenTile = getMaxSeqLenTile(size); + int32_t const maxSeqLenTile = getMaxNumSeqLenTile(); - const size_t partial_out_size = mMultiBlockMode ? size * batch_beam * mNumHeads * mHeadSize * maxSeqLenTile : 0; - const size_t partial_sum_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0; - const size_t partial_max_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0; - const size_t block_counter_size = mMultiBlockMode ? sizeof(int) * batch_beam * mNumHeads : 0; + const size_t partial_out_size = size * batch_beam * mNumHeads * mHeadSize * maxSeqLenTile; + const size_t partial_sum_size = sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile; + const size_t partial_max_size = sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile; + const size_t block_counter_size = sizeof(int) * batch_beam * mNumHeads; const int NUM_BUFFERS = 4; size_t workspaces[NUM_BUFFERS]; @@ -397,18 +402,13 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration(DataType type, in return generation_workspace_size; } -int GPTAttentionPluginCommon::getMaxSeqLenTile(int elemSize) const +int GPTAttentionPluginCommon::getMaxNumSeqLenTile(int batch_beam_size) const { if (mMultiBlockMode) { - const int threads_per_value = pow2roundup(getHeadSize()) * elemSize / 16; - - // max_seq_len_tile to make sure: seq_len_tile * threads_per_value <= threads_per_block (for - // multi_block_mode) - const int max_seq_len_tile - = 256 / threads_per_value; // for allocate partial output results memory. Regardless to THDS_PER_BLOCK - // (which may be smaller than 256 like being 128) - return max_seq_len_tile; + // And we allocate the buffer based on the maximum number of blocks per sequence (batch_beam_size = 1). + // Assume we can only have 1 block (large block size like 1024) in SM, and we only want one wave of blocks. + return tc::divUp(mMultiProcessorCount, batch_beam_size * mNumHeads); } return 0; } @@ -439,7 +439,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams::Type; kv_cache_buffer = KVCacheBuffer(params.batch_size, 1, - isCrossAttention() ? params.cross_qkv_length : params.max_seq_length, num_kv_heads * head_size * elem_size); + isCrossAttention() ? params.cross_qkv_length : params.max_kv_cache_length, + num_kv_heads * head_size * elem_size); kv_cache_buffer.data = reinterpret_cast(params.key_value_cache); } @@ -524,6 +525,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams(params.attention_input), kv_cache_buffer, const_cast(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr, - params.batch_size, params.input_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(), - mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale, - mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, cache_type, - params.kv_scale_orig_quant, stream); - mFMHARunner->setup(params.batch_size, params.input_seq_length, params.num_tokens, isALiBi(), isAliBiWithScale(), - mTpSize, mTpRank); + params.batch_size, params.input_seq_length, params.cyclic_kv_cache_length, params.num_tokens, mNumHeads, + mNumKVHeads, getHeadSize(), mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, + mRotaryEmbeddingScale, mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, + cache_type, params.kv_scale_orig_quant, stream); + // we will apply limited_length_causal when the max_past_length(cyclic_kv_cache_length) is set. + // the token will pay attention to previous tokens while starting from max(0, rowIdx - cyclic_kv_cache_length); + mFMHARunner->setup(params.batch_size, params.input_seq_length, params.cyclic_kv_cache_length, params.num_tokens, + isALiBi(), isAliBiWithScale(), mTpSize, mTpRank); mFMHARunner->run(const_cast(params.attention_input), cu_seqlens, params.context_buf, stream); } else @@ -611,7 +615,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams 0, relative_attention_bias_stride, max_distance, true /* bidirectional */); + mNumHeads, attention_seq_len_1, + isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, stream, max_distance > 0, + relative_attention_bias_stride, max_distance, true /* bidirectional */); } if (is_qk_buf_float_ == true) @@ -808,6 +813,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams(params.workspace); size_t offset = 0; - int32_t const maxSeqLenTile = getMaxSeqLenTile(sizeof(T)); + // Runtime check to see the actual number of blocks per sequence we need. + int32_t const max_num_seq_len_tiles = getMaxNumSeqLenTile(batch_beam); + const bool enable_multi_block = mMultiBlockMode && max_num_seq_len_tiles > 1; const size_t partial_out_size - = mMultiBlockMode ? sizeof(T) * batch_beam * mNumHeads * mHeadSize * maxSeqLenTile : 0; - const size_t partial_sum_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0; - const size_t partial_max_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0; - const size_t block_counter_size = mMultiBlockMode ? sizeof(int) * batch_beam * mNumHeads : 0; + = enable_multi_block ? sizeof(T) * batch_beam * mNumHeads * mHeadSize * max_num_seq_len_tiles : 0; + const size_t partial_sum_size + = enable_multi_block ? sizeof(float) * batch_beam * mNumHeads * max_num_seq_len_tiles : 0; + const size_t partial_max_size + = enable_multi_block ? sizeof(float) * batch_beam * mNumHeads * max_num_seq_len_tiles : 0; + const size_t block_counter_size = enable_multi_block ? sizeof(int) * batch_beam * mNumHeads : 0; // Workspace pointer shift T* partial_out = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, partial_out_size)); float* partial_sum = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, partial_sum_size)); float* partial_max = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, partial_max_size)); int* block_counter = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, block_counter_size)); - if (mMultiBlockMode) + if (enable_multi_block) { TLLM_CUDA_CHECK(cudaMemsetAsync(block_counter, 0, block_counter_size, stream)); } @@ -894,7 +904,8 @@ int GPTAttentionPluginCommon::enqueueGeneration( else { using BufferDataType = typename KVCacheBufferDataType::Type; - kv_cache_buffer = KVCacheBuffer(batch_beam, 1, params.max_seq_length, num_kv_heads * head_size * elem_size); + kv_cache_buffer + = KVCacheBuffer(batch_beam, 1, params.max_kv_cache_length, num_kv_heads * head_size * elem_size); kv_cache_buffer.data = reinterpret_cast(params.key_value_cache); } sync_check_cuda_error(); @@ -919,7 +930,8 @@ int GPTAttentionPluginCommon::enqueueGeneration( dispatch_params.size_per_head = getHeadSize(); dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim; dispatch_params.position_embedding_type = mPositionEmbeddingType; - dispatch_params.max_seq_len = params.max_seq_length; // difference between max_seq_lengths and max_seq_length? + dispatch_params.max_kv_cache_length = params.max_kv_cache_length; + dispatch_params.cyclic_kv_cache_length = params.cyclic_kv_cache_length; dispatch_params.input_lengths = params.context_lengths; dispatch_params.step = step; dispatch_params.q_scaling = q_scaling; @@ -931,8 +943,8 @@ int GPTAttentionPluginCommon::enqueueGeneration( dispatch_params.qkv_scale_out = qkv_scale_out; dispatch_params.attention_out_scale = attention_out_scale; dispatch_params.quant_option = quant_option; - dispatch_params.multi_block_mode = mMultiBlockMode; - dispatch_params.max_seq_len_tile = getMaxSeqLenTile(sizeof(T)); + dispatch_params.multi_block_mode = enable_multi_block; + dispatch_params.max_seq_len_tile = max_num_seq_len_tiles; dispatch_params.partial_out = partial_out; dispatch_params.partial_sum = partial_sum; dispatch_params.partial_max = partial_max; @@ -962,6 +974,7 @@ int GPTAttentionPluginCommon::enqueueGeneration( Cross_multihead_attention_params mmhca_params; fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream); } + return 0; } diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h index 9c1443bf3..d6d6726e0 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h @@ -73,7 +73,7 @@ class GPTAttentionPluginCommon : public BasePlugin const int getHeadSize(bool checkInit = true) const; protected: - int getMaxSeqLenTile(int elemSize) const; + int getMaxNumSeqLenTile(int batch_beam_size = 1) const; size_t getWorkspaceSizeForContext( nvinfer1::DataType type, int32_t nbReq, int32_t max_input_length, int32_t cross_qkv_length = 0) const noexcept; // total_num_seq is the sum of beam_width for multiple requests @@ -85,7 +85,12 @@ class GPTAttentionPluginCommon : public BasePlugin T const* attention_input; T const* qkv_bias; int32_t input_seq_length; // padded input length - int32_t max_seq_length; // cache capacity + // By default, max_kv_cache_length == cyclic_kv_cache_length + // unless each layer has different cyclic kv cache length. + // Max cache capacity (used to allocate KV cache) + int32_t max_kv_cache_length; + // Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens) + int32_t cyclic_kv_cache_length; int32_t const* context_lengths; float const* kv_scale_orig_quant; float const* kv_scale_quant_orig; @@ -125,7 +130,12 @@ class GPTAttentionPluginCommon : public BasePlugin T* context_buf; void* key_value_cache; void* block_pointers; - int32_t max_seq_length; // cache capacity + // By default, max_kv_cache_length == cyclic_kv_cache_length + // unless each layer has different cyclic kv cache length. + // Max cache capacity (used to allocate KV cache) + int32_t max_kv_cache_length; + // Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens) + int32_t cyclic_kv_cache_length; int32_t num_requests; int32_t max_blocks_per_sequence; int32_t const* cache_indir; diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index 9e6cc02dc..0633762c1 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -84,8 +84,8 @@ nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions( bool GPTAttentionPlugin::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept { - if (pos == getSequenceLengthIdx() || pos == getHostPastKeyValueLengthsIdx() || pos == getContextLengthsIdx() - || pos == getCacheIndirIdx() || pos == getRequestTypesIdx()) + if (pos == getSequenceLengthIdx() || pos == getHostPastKeyValueLengthsIdx() || pos == getHostMaxKvCacheLengthIdx() + || pos == getContextLengthsIdx() || pos == getCacheIndirIdx() || pos == getRequestTypesIdx()) { return inOut[pos].type == nvinfer1::DataType::kINT32; } @@ -131,9 +131,6 @@ void GPTAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept { TLLM_CHECK(mHeadSize > 0); - - // pre-check whether FMHA is supported in order to save memory allocation - mEnableContextFMHA = mEnableContextFMHA && MHARunner::fmha_supported(getHeadSize(), mSM); } size_t GPTAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, @@ -258,7 +255,15 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 // -- max_encoder_context_len: len of encoder input (in cross attn). Also called encoder_input_seq_length const int beamWidth = inputDesc[getCacheIndirIdx()].dims.d[1]; - const int maxSeqLen = isCrossAttention() ? max_encoder_context_len : inputDesc[getCacheIndirIdx()].dims.d[2]; + + // Commonly, cyclic kv cache length, and max kv cache length will be the same + // unless each layer has different max kv cache length. + // the kv_cache capacity. + const int max_kv_cache_length + = isCrossAttention() ? max_encoder_context_len : inputDesc[getCacheIndirIdx()].dims.d[2]; + // The cyclic_kv_cache_length will determine the cyclic kv cache position of new tokens. + // Note that this cyclic_kv_cache_length might be smaller than the actual kv cache capactity (max_kv_cache_length). + const int cyclic_kv_cache_length = reinterpret_cast(inputs[getHostMaxKvCacheLengthIdx()])[0]; const float* kv_scale_orig_quant = nullptr; const float* kv_scale_quant_orig = nullptr; @@ -308,9 +313,10 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 num_encoder_tokens = inputDesc[getCrossQKVIdx()].dims.d[1]; } - EnqueueContextParams enqueue_params{attention_input, qkv_bias, max_context_len, maxSeqLen, - context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache, - block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace}; + EnqueueContextParams enqueue_params{attention_input, qkv_bias, max_context_len, + max_kv_cache_length, cyclic_kv_cache_length, context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, + alibi_slopes, context_buf_, key_value_cache, block_pointers, batch_size, localNbTokens, + max_blocks_per_sequence, workspace}; if (isRelativePosition()) { enqueue_params.relative_attention_bias = static_cast(inputs[getRelativeAttentionBiasIdx()]); @@ -340,8 +346,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 int32_t const past_kv_len = *std::max_element(past_kv_len_list, past_kv_len_list + localNbSeq); EnqueueGenerationParams enqueue_params{attention_input, qkv_bias, sequence_length, past_kv_len, beamWidth, context_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, - context_buf_, key_value_cache, block_pointers, maxSeqLen, num_requests, max_blocks_per_sequence, - cache_indir, workspace}; + context_buf_, key_value_cache, block_pointers, max_kv_cache_length, cyclic_kv_cache_length, num_requests, + max_blocks_per_sequence, cache_indir, workspace}; if (isRelativePosition()) { enqueue_params.relative_attention_bias = static_cast(inputs[getRelativeAttentionBiasIdx()]); diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h index 45db32bfe..18f19885b 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h @@ -43,9 +43,10 @@ namespace tensorrt_llm::plugins // enable_remove_input_padding // 1. sequence_length [batch_size] // 2. host_past_key_value_lengths [batch_size] (int32) -// 3. context_lengths [batch_size] -// 4. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch) -// 5. host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching +// 3. host_max_kv_cache_lengths [1] (int32) +// 4. context_lengths [batch_size] +// 5. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch) +// 6. host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching // mode, // all elements must be identical. // 6. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or @@ -145,46 +146,51 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon return 2; } - IndexType getContextLengthsIdx() const + IndexType getHostMaxKvCacheLengthIdx() const { return 3; } - IndexType getCacheIndirIdx() const + IndexType getContextLengthsIdx() const { return 4; } - IndexType getRequestTypesIdx() const + IndexType getCacheIndirIdx() const { return 5; } + IndexType getRequestTypesIdx() const + { + return 6; + } + IndexType getKVCacheBlockPointersIdx() const { // NOTE We either provide this tensor when mPagedKVCache is true or PastKeyValue otherwise - return 6; + return 7; } IndexType getPastKeyValueIdx() const { // NOTE We either provide this tensor when mPagedKVCache is false or KVCacheBlockPointers otherwise - return 6; + return 7; } IndexType getKVCacheQuantizationScaleIdx() const { - return 7; + return 8; } IndexType getKVCacheDequantizationScaleIdx() const { - return 8; + return 9; } IndexType getAlibiSlopesIdx() const { - return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7); + return (mKVCacheQuantMode.hasKvCacheQuant() ? 10 : 8); } IndexType getRelativeAttentionBiasIdx() const @@ -216,7 +222,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon IndexType getQKVBiasTensorIdx() const { TLLM_CHECK(mQKVBiasEnabled); - return (mKVCacheQuantMode.hasKvCacheQuant() ? 9 : 7) + (isALiBi() ? 1 : 0) + (mRemovePadding ? 1 : 0); + return (mKVCacheQuantMode.hasKvCacheQuant() ? 10 : 8) + (isALiBi() ? 1 : 0) + (mRemovePadding ? 1 : 0); } }; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp index 3c1c15512..fcd331564 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp @@ -150,14 +150,32 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua = std::make_shared>(); } - mCudaKernelEnabled - = tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b); } +#if defined(ENABLE_BF16) + else if (mType == nvinfer1::DataType::kBF16) + { + if (quant_algo & ZERO) + { + // has zeros + m_weightOnlyGroupwiseGemmRunner + = std::make_shared>(); + } + else + { + // no zeros + m_weightOnlyGroupwiseGemmRunner + = std::make_shared>(); + } + } +#endif else { TLLM_THROW("Unsupported data type"); } - + mCudaKernelEnabled + = tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b); mPluginProfiler->setQuantAlgo(mQuantAlgo); mPluginProfiler->setGroupSize(mGroupSize); @@ -295,27 +313,52 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe if (mQuantAlgo & PRE_QUANT_SCALE) { // Apply pre-quant per channel scale on activations - tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher(reinterpret_cast(workspace), - reinterpret_cast(inputs[0]), reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, - k, stream); + if (mType == nvinfer1::DataType::kHALF) + { + tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher(reinterpret_cast(workspace), + reinterpret_cast(inputs[0]), reinterpret_cast(inputs[mPreQuantScaleInputIdx]), + m, k, stream); + } +#if defined(ENABLE_BF16) + else if (mType == nvinfer1::DataType::kBF16) + { + tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<__nv_bfloat16>( + reinterpret_cast<__nv_bfloat16*>(workspace), reinterpret_cast(inputs[0]), + reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); + } +#endif } const half* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast(inputs[mZerosInputIdx]) : nullptr; const half* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast(inputs[mBiasesInputIdx]) : nullptr; const half* act_ptr = reinterpret_cast((mQuantAlgo & PRE_QUANT_SCALE) ? workspace : inputs[0]); - TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyGropwiseQuantMatmul configuration"); + TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF +#if defined(ENABLE_BF16) + || mType == nvinfer1::DataType::kBF16 +#endif + , + "No valid weightOnlyGropwiseQuantMatmul configuration"); + tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type; + int real_n = n * INT8_INT4_RATIO; + if (mType == nvinfer1::DataType::kHALF) + { + weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16; + } + else if (mType == nvinfer1::DataType::kBF16) + { + weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::BF16; + } if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled) { // Use CUDA kernels for small batch size // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel // when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[mWeightInputIdx]), - reinterpret_cast(inputs[mScalesInputIdx]), zeros_ptr, act_ptr, biases_ptr, - reinterpret_cast(outputs[0]), m, n * INT8_INT4_RATIO, k, mGroupSize}; - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b, - tensorrt_llm::kernels::WeightOnlyType::GroupWise, tensorrt_llm::kernels::WeightOnlyActivationType::Identity, - params, stream); + inputs[mScalesInputIdx], zeros_ptr, act_ptr, biases_ptr, outputs[0], m, real_n, k, mGroupSize, + tensorrt_llm::kernels::WeightOnlyQuantType::Int4b, tensorrt_llm::kernels::WeightOnlyType::GroupWise, + tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type}; + tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream); } else { @@ -326,9 +369,8 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only groupwise GEMM tactic"); - m_weightOnlyGroupwiseGemmRunner->gemm(act_ptr, reinterpret_cast(weight_ptr), - reinterpret_cast(inputs[mScalesInputIdx]), zeros_ptr, biases_ptr, - reinterpret_cast(outputs[0]), m, n * INT8_INT4_RATIO, k, mGroupSize, *bestTactic, + m_weightOnlyGroupwiseGemmRunner->gemm(act_ptr, weight_ptr, inputs[mScalesInputIdx], zeros_ptr, biases_ptr, + outputs[0], m, real_n, k, mGroupSize, *bestTactic, reinterpret_cast(workspace) + m * k * sizeof(half), ws_bytes, stream); } diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp index 548f44941..122868347 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp @@ -108,17 +108,46 @@ void WeightOnlyQuantMatmulPlugin::init(nvinfer1::DataType type, WeightTypeId wei { mType = type; mWeightTypeId = weightTypeId; - if (mType == nvinfer1::DataType::kHALF && mWeightTypeId == WeightTypeId::INT8) + if (mWeightTypeId == WeightTypeId::INT8) { - m_weightOnlyGemmRunner = std::make_shared< - CutlassFpAIntBGemmRunner>(); + if (mType == nvinfer1::DataType::kHALF) + { + m_weightOnlyGemmRunner = std::make_shared< + CutlassFpAIntBGemmRunner>(); + } +#if defined(ENABLE_BF16) + else if (mType == nvinfer1::DataType::kBF16) + { + m_weightOnlyGemmRunner = std::make_shared< + CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>>(); + } +#endif + else + { + TLLM_CHECK(false); + } + mCudaKernelEnabled = tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int8b); } - else if (mType == nvinfer1::DataType::kHALF && mWeightTypeId == WeightTypeId::INT4) + else if (mWeightTypeId == WeightTypeId::INT4) { - m_weightOnlyGemmRunner = std::make_shared< - CutlassFpAIntBGemmRunner>(); + if (mType == nvinfer1::DataType::kHALF) + { + m_weightOnlyGemmRunner = std::make_shared< + CutlassFpAIntBGemmRunner>(); + } +#if defined(ENABLE_BF16) + else if (mType == nvinfer1::DataType::kBF16) + { + m_weightOnlyGemmRunner = std::make_shared>(); + } +#endif + else + { + TLLM_CHECK(false); + } mCudaKernelEnabled = tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b); } @@ -259,50 +288,49 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input const int ws_size = m_weightOnlyGemmRunner->getWorkspaceSize(m, n, k); const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only groupwise GEMM tactic"); - TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyQuantMatmul configuration"); - if (mType == nvinfer1::DataType::kHALF && mWeightTypeId == WeightTypeId::INT8) + TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || +#if defined(ENABLE_BF16) + mType == nvinfer1::DataType::kBF16 +#endif + , + "No valid weightOnlyQuantMatmul configuration"); + + tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type; + tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type; + int real_n; + if (mType == nvinfer1::DataType::kHALF) { - if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled) - { - // Use CUDA kernels for small batch size - // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel - // when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. - tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[1]), - reinterpret_cast(inputs[2]), nullptr, reinterpret_cast(inputs[0]), nullptr, - reinterpret_cast(outputs[0]), m, n, k, 0}; - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(tensorrt_llm::kernels::WeightOnlyQuantType::Int8b, - tensorrt_llm::kernels::WeightOnlyType::PerChannel, - tensorrt_llm::kernels::WeightOnlyActivationType::Identity, params, stream); - } - else - { - m_weightOnlyGemmRunner->gemm(reinterpret_cast(inputs[0]), - reinterpret_cast(inputs[1]), reinterpret_cast(inputs[2]), - reinterpret_cast(outputs[0]), m, n, k, *bestTactic, reinterpret_cast(workspace), ws_size, - stream); - } + weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16; } - else if (mType == nvinfer1::DataType::kHALF && mWeightTypeId == WeightTypeId::INT4) + else if (mType == nvinfer1::DataType::kBF16) { - if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled) - { - // Use CUDA kernels for small batch size - // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel - // when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. - tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[1]), - reinterpret_cast(inputs[2]), nullptr, reinterpret_cast(inputs[0]), nullptr, - reinterpret_cast(outputs[0]), m, n * INT8_INT4_RATIO, k, 0}; - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b, - tensorrt_llm::kernels::WeightOnlyType::PerChannel, - tensorrt_llm::kernels::WeightOnlyActivationType::Identity, params, stream); - } - else - { - m_weightOnlyGemmRunner->gemm(reinterpret_cast(inputs[0]), - reinterpret_cast(inputs[1]), reinterpret_cast(inputs[2]), - reinterpret_cast(outputs[0]), m, n * INT8_INT4_RATIO, k, *bestTactic, - reinterpret_cast(workspace), ws_size, stream); - } + weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::BF16; + } + if (mWeightTypeId == WeightTypeId::INT8) + { + weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b; + real_n = n; + } + else if (mWeightTypeId == WeightTypeId::INT4) + { + weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int4b; + real_n = n * INT8_INT4_RATIO; + } + if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled) + { + // Use CUDA kernels for small batch size + // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass + // kernel when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. + tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[1]), inputs[2], nullptr, + inputs[0], nullptr, outputs[0], m, real_n, k, 0, weight_only_quant_type, + tensorrt_llm::kernels::WeightOnlyType::PerChannel, + tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type}; + tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream); + } + else + { + m_weightOnlyGemmRunner->gemm(inputs[0], inputs[1], inputs[2], outputs[0], m, real_n, k, *bestTactic, + reinterpret_cast(workspace), ws_size, stream); } return 0; diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index a1060d897..6485b4732 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -60,7 +60,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_readwrite("ids", &tpr::GenerationInput::ids) .def_readwrite("lengths", &tpr::GenerationInput::lengths) .def_readwrite("packed", &tpr::GenerationInput::packed) - .def_readwrite("embedding_bias", &tpr::GenerationInput::embeddingBiasOpt) + .def_readwrite("embedding_bias", &tpr::GenerationInput::embeddingBias) .def_readwrite("bad_words_list", &tpr::GenerationInput::badWordsList) .def_readwrite("stop_words_list", &tpr::GenerationInput::stopWordsList) .def_readwrite("max_new_tokens", &tpr::GenerationInput::maxNewTokens) @@ -75,9 +75,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits); py::class_(m, "KvCacheConfig") - .def(py::init, std::optional>(), py::arg("max_tokens") = py::none(), + .def(py::init, std::optional, std::optional>(), + py::arg("max_tokens") = py::none(), py::arg("max_kv_cache_length") = py::none(), py::arg("free_gpu_memory_fraction") = py::none()) .def_readwrite("max_tokens", &tb::kv_cache_manager::KvCacheConfig::maxTokens) + .def_readwrite("max_kv_cache_length", &tb::kv_cache_manager::KvCacheConfig::maxKvCacheLength) .def_readwrite("free_gpu_memory_fraction", &tb::kv_cache_manager::KvCacheConfig::freeGpuMemoryFraction); py::class_(m, "GptSessionConfig") diff --git a/cpp/tensorrt_llm/pybind/runtime/generationInput.cpp b/cpp/tensorrt_llm/pybind/runtime/generationInput.cpp index bef4ee167..5eb54f280 100644 --- a/cpp/tensorrt_llm/pybind/runtime/generationInput.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/generationInput.cpp @@ -40,8 +40,8 @@ std::shared_ptr GenerationInput::toTrtLlm() const { auto input = std::make_shared( endId, padId, tr::TorchView::of(ids.value()), tr::TorchView::of(lengths.value()), packed); - if (embeddingBiasOpt) - input->embeddingBiasOpt = tr::TorchView::of(embeddingBiasOpt.value()); + if (embeddingBias) + input->embeddingBias = tr::TorchView::of(embeddingBias.value()); if (badWordsList) input->badWordsList = tr::TorchView::of(badWordsList.value()); if (stopWordsList) diff --git a/cpp/tensorrt_llm/runtime/gptDecoder.cpp b/cpp/tensorrt_llm/runtime/gptDecoder.cpp index cdb237935..5d8209c27 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoder.cpp @@ -90,8 +90,8 @@ typename tl::DynamicDecodeLayer::ForwardParams prepareInputs(DecodingInput co TLLM_CHECK(input.logits->getDataType() == TRTDataType::value); auto constexpr ite = 0; // no pipeline parallelism - typename tl::DynamicDecodeLayer::ForwardParams forwardParams{input.step, ite, input.maxLength, input.batchSize, - tcc::toTllmTensor(*input.logits), tcc::toTllmTensor(*input.endIds)}; + typename tl::DynamicDecodeLayer::ForwardParams forwardParams{input.step, ite, input.maxLength, + input.maxKvCacheLength, input.batchSize, tcc::toTllmTensor(*input.logits), tcc::toTllmTensor(*input.endIds)}; if (input.cacheIndirection) { diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp index ec2dc1b02..158bf05b2 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp @@ -83,7 +83,7 @@ GptDecoderBatch::GptDecoderBatch( auto& dInput = mJointDecodingInput; auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); - dInput = std::make_unique(0, 0, std::move(dummyLogits), std::move(endIds)); + dInput = std::make_unique(0, 0, 0, std::move(dummyLogits), std::move(endIds)); dInput->sequenceLimitLength = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dInput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); @@ -104,8 +104,8 @@ GptDecoderBatch::GptDecoderBatch( TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } -void GptDecoderBatch::setup( - SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype) +void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, + SizeType maxSequenceLength, nvinfer1::DataType dtype) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(maxBatchSize > 0); @@ -114,6 +114,7 @@ void GptDecoderBatch::setup( mActualBatchSize = maxBatchSize; mMaxSequenceLength = maxSequenceLength; + mMaxKvCacheLength = maxKvCacheLength; auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize}); auto const maxBatchSizeXmaxBeamWidth = ITensor::makeShape({maxBatchSize, maxBeamWidth}); @@ -211,7 +212,8 @@ void GptDecoderBatch::newRequest( TensorPtr endIdTensorPtr{ITensor::slice(constPointerCast(dJointInput.endIds), batchIdx, localBatchSize)}; kernels::invokeFill(*endIdTensorPtr, endId, *stream); - dInput = std::make_unique(inputLength, localBatchSize, dJointInput.logits, endIdTensorPtr); + dInput = std::make_unique( + inputLength, mMaxKvCacheLength, localBatchSize, dJointInput.logits, endIdTensorPtr); // Here, we need to add leading 1 dimension since decoderInput expects batchSize as leading dim // and decoder_batch::Request doesn't have batch dimension @@ -458,17 +460,30 @@ void GptDecoderBatch::newBatch(GenerationInput const& inputs, SamplingConfig con } auto request = decoder_batch::Request{inputView, inputs.maxNewTokens, inputs.endId, inputs.padId}; - if (inputs.embeddingBiasOpt) + if (inputs.embeddingBias) { TLLM_THROW("newBatch doesn't support embeddingBias yet."); } if (inputs.badWordsList) { - TLLM_THROW("newBatch doesn't support badWordsList yet."); + auto const& shape = inputs.badWordsList->getShape(); + if (shape.nbDims == 2) + { + request.badWordsList = inputs.badWordsList; + } + else + { + assert(shape.nbDims == 3); + TensorPtr badWordsListView = ITensor::slice(inputs.badWordsList, batchIdx, 1); + badWordsListView->squeeze(0); + request.badWordsList = badWordsListView; + } } if (inputs.stopWordsList) { - TLLM_THROW("newBatch doesn't support stopWordsList yet."); + TensorPtr stopWordsListView = ITensor::slice(inputs.stopWordsList, batchIdx, 1); + stopWordsListView->squeeze(0); + request.stopWordsList = stopWordsListView; } newRequest(batchIdx, request, extractSamplingConfig(samplingConfig, batchIdx)); } diff --git a/cpp/tensorrt_llm/runtime/gptSession.cpp b/cpp/tensorrt_llm/runtime/gptSession.cpp index 68c70d377..e79cbad52 100644 --- a/cpp/tensorrt_llm/runtime/gptSession.cpp +++ b/cpp/tensorrt_llm/runtime/gptSession.cpp @@ -117,8 +117,8 @@ void GptSession::createBuffers(SizeType numMicroBatches) TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } -void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, - nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches) +void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, + SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); auto const vocabSize = mModelConfig.getVocabSize(); @@ -133,14 +133,14 @@ void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType mDecoders.emplace_back(std::make_shared(vocabSize, vocabSizePadded, stream)); else mDecoders.emplace_back(std::make_shared(vocabSize, vocabSizePadded, stream)); - mDecoders.back()->setup(batchSize, beamWidth, maxSequenceLength, logitsType); + mDecoders.back()->setup(batchSize, beamWidth, maxKvCacheLength, maxSequenceLength, logitsType); } TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } -void GptSession::createKvCacheManager( - SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, KvCacheConfig const& config) +void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, + SizeType maxSequenceLength, KvCacheConfig const& config) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism()); @@ -168,8 +168,9 @@ void GptSession::createKvCacheManager( auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock); auto const maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock); - mKvCacheManager = std::make_shared(localNbLayers, nbHeads, nbKvHeads, hiddenSize, - tokensPerBlock, maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, kvDtype, mRuntime->getStreamPtr()); + mKvCacheManager + = std::make_shared(localNbLayers, nbHeads, nbKvHeads, hiddenSize, tokensPerBlock, + maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxKvCacheLength, kvDtype, mRuntime->getStreamPtr()); TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } @@ -232,6 +233,9 @@ void GptSession::setup(Config const& sessionConfig) auto const maxBatchSize = sessionConfig.maxBatchSize; auto const maxBeamWidth = sessionConfig.maxBeamWidth; auto const maxSequenceLength = sessionConfig.maxSequenceLength; + auto const maxKvCacheLength = sessionConfig.kvCacheConfig.maxKvCacheLength.has_value() + ? std::min(sessionConfig.kvCacheConfig.maxKvCacheLength.value(), maxSequenceLength) + : maxSequenceLength; mMicroBatchConfig = MicroBatchConfig(maxBatchSize, mWorldConfig.getPipelineParallelism(), sessionConfig.genMicroBatchSize, sessionConfig.ctxMicroBatchSize); @@ -244,16 +248,18 @@ void GptSession::setup(Config const& sessionConfig) // gptDecoderBatch does not resize buffers, but allows smaller batchSize and beamWidth. // TODO refactor batch manager to remove dependency on maxSequenceLength. mDecoderMaxSequenceLength = maxSequenceLength; + mDecoderMaxKvCacheLength = maxKvCacheLength; if (mModelConfig.usePagedKvCache()) { - createKvCacheManager(maxBatchSize, maxBeamWidth, maxSequenceLength, sessionConfig.kvCacheConfig); + createKvCacheManager( + maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength, sessionConfig.kvCacheConfig); } if (mWorldConfig.isLastPipelineParallelRank()) { auto const logitsType = mRuntime->getEngine().getTensorDataType("logits"); - createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxSequenceLength, logitsType, + createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength, logitsType, sessionConfig.decoderPerRequest, mMicroBatchConfig.numGenBatches); } @@ -272,8 +278,8 @@ void GptSession::setup(Config const& sessionConfig) for (auto& buffers : mBuffers) { // we don't know maxInputLength yet and ignore it for pre-allocation - buffers->generationConfig - = RuntimeBuffers::GenerationConfig{mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxSequenceLength}; + buffers->generationConfig = RuntimeBuffers::GenerationConfig{ + mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxKvCacheLength, maxSequenceLength}; buffers->reshape(mModelConfig, mWorldConfig); } TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); @@ -403,8 +409,9 @@ std::vector splitInputs(GenerationInput const& inputs, SizeType auto const offset = microBatchOffsets[batchId]; auto const batchSize = microBatchOffsets[batchId + 1] - offset; - if (inputs.embeddingBiasOpt) - batch.embeddingBiasOpt = inputs.embeddingBiasOpt; + if (inputs.embeddingBias) + batch.embeddingBias = inputs.embeddingBias; + if (inputs.badWordsList) { auto const& shape = inputs.badWordsList->getShape(); @@ -414,7 +421,7 @@ std::vector splitInputs(GenerationInput const& inputs, SizeType } else { - assert(nbDims == 3); + assert(shape.nbDims == 3); batch.badWordsList = ITensor::slice(inputs.badWordsList, offset, batchSize); } } @@ -524,7 +531,7 @@ void GptSession::generateBatched( auto const& microBatchInputs = microBatches.at(microBatchId); auto& buffers = *mBuffers.at(microBatchId); buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth, - mDecoderMaxSequenceLength, manager); + mDecoderMaxKvCacheLength, mDecoderMaxSequenceLength, manager); buffers.reshape(mModelConfig, mWorldConfig); buffers.reset(manager); } diff --git a/cpp/tensorrt_llm/runtime/iTensor.cpp b/cpp/tensorrt_llm/runtime/iTensor.cpp index 42b57f7de..46b8eea14 100644 --- a/cpp/tensorrt_llm/runtime/iTensor.cpp +++ b/cpp/tensorrt_llm/runtime/iTensor.cpp @@ -22,7 +22,6 @@ #include "tensorrt_llm/runtime/tensorView.h" #include -#include #include using namespace tensorrt_llm::runtime; @@ -44,11 +43,8 @@ nvinfer1::Dims ITensor::makeShape(std::initializer_list const& dims) { TLLM_CHECK_WITH_INFO(dims.size() <= nvinfer1::Dims::MAX_DIMS, "Number of dimensions is too large"); nvinfer1::Dims shape{}; - shape.nbDims = dims.size(); - for (std::size_t i = 0; i < dims.size(); ++i) - { - shape.d[i] = std::data(dims)[i]; - } + shape.nbDims = static_cast(dims.size()); + std::copy(dims.begin(), dims.end(), shape.d); return shape; } @@ -97,6 +93,32 @@ ITensor::UniquePtr ITensor::wrap(void* data, nvinfer1::DataType type, nvinfer1:: return result; } +ITensor::Shape ITensor::squeeze(Shape const& shape, SizeType dim) +{ + TLLM_CHECK_WITH_INFO(0 < shape.nbDims, "Cannot squeeze 1-dimensional tensor"); + TLLM_CHECK_WITH_INFO( + dim < shape.nbDims, tc::fmtstr("Invalid index %d, tensor has %d dimensions", dim, shape.nbDims)); + TLLM_CHECK_WITH_INFO(shape.d[dim] == 1, "Can only squeeze dimension of size 1"); + + Shape newDims{shape.nbDims - 1}; + std::copy(shape.d, shape.d + dim, newDims.d); + std::copy(shape.d + dim + 1, shape.d + shape.nbDims, newDims.d + dim); + return newDims; +} + +ITensor::Shape ITensor::unsqueeze(Shape const& shape, SizeType dim) +{ + TLLM_CHECK_WITH_INFO(shape.nbDims < Shape::MAX_DIMS, "Too many dimensions to unsqueeze"); + TLLM_CHECK_WITH_INFO( + 0 <= dim && dim <= shape.nbDims, common::fmtstr("Invalid dim %d, tensor has %d dimensions", dim, shape.nbDims)); + + Shape newDims{shape.nbDims + 1}; + std::copy(shape.d, shape.d + dim, newDims.d); + newDims.d[dim] = 1; + std::copy(shape.d + dim, shape.d + shape.nbDims, newDims.d + dim + 1); + return newDims; +} + namespace { template diff --git a/cpp/tensorrt_llm/runtime/promptTuningParams.cpp b/cpp/tensorrt_llm/runtime/promptTuningParams.cpp index 60074ad02..e78115fcf 100644 --- a/cpp/tensorrt_llm/runtime/promptTuningParams.cpp +++ b/cpp/tensorrt_llm/runtime/promptTuningParams.cpp @@ -21,7 +21,7 @@ namespace tensorrt_llm::runtime void PromptTuningParams::fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests, const std::vector& reqBeamWidths, - const std::vector& reqPromptLengths, BufferManager& manager, bool packedInput) + const std::vector& reqPromptLengths, BufferManager const& manager, bool packedInput) { auto const& tasksHostShape = tasksHost->getShape(); TLLM_CHECK_WITH_INFO(tasksHostShape.nbDims == 1, "tasksHost expected to have dimension [batchSize]"); diff --git a/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp b/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp index 6d7ab16d7..4aed18c02 100644 --- a/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp @@ -29,7 +29,8 @@ using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITensor const& inputIds, - ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, SizeType const maxSequenceLength) + ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, SizeType const maxKvCacheLength, + SizeType const maxSequenceLength) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); auto const batchSize = static_cast(inputLengthsHost.getSize()); @@ -57,7 +58,7 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe "generated."); TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); - return GenerationConfig{batchSize, beamWidth, maxInputLength, maxSequenceLength, inputLengthSum}; + return GenerationConfig{batchSize, beamWidth, maxInputLength, maxKvCacheLength, maxSequenceLength, inputLengthSum}; } void RuntimeBuffers::clear() @@ -154,6 +155,10 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon if (modelConfig.useGptAttentionPlugin()) { pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32); + for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i) + { + maxKvCacheLengths.emplace_back(manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32)); + } } else { @@ -179,7 +184,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon } void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked, - SizeType beamWidth, SizeType maxSequenceLength, BufferManager& manager) + SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength, BufferManager& manager) { contextLengthsDevice = inputLengths; contextLengthsHost->reshape(inputLengths->getShape()); @@ -187,7 +192,7 @@ void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inp manager.getStream().synchronize(); // wait for context lengths to be copied to host generationConfig = RuntimeBuffers::GenerationConfig::fromInput( - inputIds, *contextLengthsHost, inputPacked, beamWidth, maxSequenceLength); + inputIds, *contextLengthsHost, inputPacked, beamWidth, maxKvCacheLength, maxSequenceLength); } void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig) @@ -197,7 +202,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons auto const batchSize = generationConfig.batchSize; auto const beamWidth = generationConfig.beamWidth; auto const maxInputLength = generationConfig.maxInputLength; - auto const maxSeqLength = generationConfig.maxSeqLength; + auto const maxKvCacheLength = generationConfig.maxKvCacheLength; if (worldConfig.isLastPipelineParallelRank() && !modelConfig.computeContextLogits()) { @@ -207,15 +212,15 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons lastTokenIds->reshape(ITensor::makeShape({batchSize})); - auto kvCacheReserve - = ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxSeqLength, modelConfig.getSizePerHead()}); + auto kvCacheReserve = ITensor::makeShape( + {batchSize, 2, modelConfig.getNbKvHeads(), maxKvCacheLength, modelConfig.getSizePerHead()}); auto kvCacheShape = ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()}); if (modelConfig.usePagedKvCache()) { auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism()); auto const tokensPerBlock = modelConfig.getTokensPerBlock(); - auto const maxBlocksPerSeq = (maxSeqLength + tokensPerBlock - 1) / tokensPerBlock; + auto const maxBlocksPerSeq = (maxKvCacheLength + tokensPerBlock - 1) / tokensPerBlock; // reserve batchSize * beamWidth and resize to batchSize auto cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize * beamWidth, 2, maxBlocksPerSeq}); @@ -233,6 +238,10 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons if (modelConfig.useGptAttentionPlugin()) { pastKeyValueLengths->reshape(ITensor::makeShape({batchSize})); + for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i) + { + maxKvCacheLengths[i]->reshape(ITensor::makeShape({1})); + } requestTypes->reshape(ITensor::makeShape({batchSize})); } else @@ -243,7 +252,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons utils::reshapeBufferVector(presentKeysVals, kvCacheShape); } - auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxSeqLength}); + auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxKvCacheLength}); cacheIndirectionDecoderInput->reshape(cacheIndirShape); cacheIndirectionDecoderOutput->reshape(cacheIndirShape); @@ -327,6 +336,7 @@ std::vector RuntimeBuffers::split( if (modelConfig.useGptAttentionPlugin()) { buffers.pastKeyValueLengths = ITensor::slice(pastKeyValueLengths, offset, batchSize); + buffers.maxKvCacheLengths = maxKvCacheLengths; buffers.requestTypes = ITensor::slice(requestTypes, offset, batchSize); } else @@ -523,6 +533,12 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c TLLM_CHECK(requestTypes->getSize() == static_cast(batchSize)); std::fill_n(RequestTypesPtr, batchSize, 0); + // Set maxKvCacheLengths buffer to the same value currently. + for (auto layer = 0; layer < modelConfig.getNbLayers(); ++layer) + { + bufferCast(*maxKvCacheLengths[layer])[0] = generationConfig.maxKvCacheLength; + } + auto const& inputShape = inputIds->getShape(); auto const contextLengthsHostPtr = bufferCast(*contextLengthsHost); auto const modelVariant = modelConfig.getModelVariant(); @@ -788,6 +804,12 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu inputBuffers.insert_or_assign("host_request_types", requestTypes); inputBuffers.insert_or_assign("sequence_length", sequenceLengths); + for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i) + { + std::string name = "host_max_kv_cache_length_" + std::to_string(i); + inputBuffers.insert_or_assign(name, maxKvCacheLengths[i]); + } + if (modelConfig.usePackedInput()) { inputBuffers.insert_or_assign("host_context_lengths", contextLengthsHost); diff --git a/cpp/tensorrt_llm/runtime/runtimeBuffers.h b/cpp/tensorrt_llm/runtime/runtimeBuffers.h index 72b59ef36..efa669328 100644 --- a/cpp/tensorrt_llm/runtime/runtimeBuffers.h +++ b/cpp/tensorrt_llm/runtime/runtimeBuffers.h @@ -49,10 +49,11 @@ class RuntimeBuffers GenerationConfig() = default; explicit GenerationConfig(SizeType batchSize, SizeType beamWidth, SizeType maxInputLength, - SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0)) + SizeType maxKvCacheLength, SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0)) : batchSize{batchSize} , beamWidth{beamWidth} , maxInputLength{maxInputLength} + , maxKvCacheLength{maxKvCacheLength} , maxSeqLength{maxSeqLength} , inputLengthSum{inputLengthSum} { @@ -61,11 +62,12 @@ class RuntimeBuffers SizeType batchSize{}; SizeType beamWidth{}; SizeType maxInputLength{}; + SizeType maxKvCacheLength{}; SizeType maxSeqLength{}; SizeType inputLengthSum{}; // Initialized only if inputPacked is set to true in fromInput. static GenerationConfig fromInput(ITensor const& inputIds, ITensor const& inputLengths, bool inputPacked, - SizeType beamWidth, SizeType maxSequenceLength); + SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength); }; public: @@ -88,6 +90,7 @@ class RuntimeBuffers std::vector presentKeysVals; std::vector presentKeysValsAlt; // without attention plugin + std::vector maxKvCacheLengths; // with attention plugin, host tensor TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2] TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2] @@ -119,7 +122,7 @@ class RuntimeBuffers void create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig); void initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked, SizeType beamWidth, - SizeType maxSequenceLength, BufferManager& manager); + SizeType maxKvCacheLength, SizeType maxSequenceLength, BufferManager& manager); //! \brief Reshape buffers based on current GenerationConfig void reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig); diff --git a/cpp/tensorrt_llm/runtime/statefulGptDecoder.cpp b/cpp/tensorrt_llm/runtime/statefulGptDecoder.cpp index fb438ab93..60b00963d 100644 --- a/cpp/tensorrt_llm/runtime/statefulGptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/statefulGptDecoder.cpp @@ -41,7 +41,7 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS auto& dInput = mDecodingInput; auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); - dInput = std::make_unique(0, 0, std::move(dummyLogits), std::move(endIds)); + dInput = std::make_unique(0, 0, 0, std::move(dummyLogits), std::move(endIds)); dInput->sequenceLimitLength = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); dInput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); @@ -61,17 +61,18 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } -void StatefulGptDecoder::setup( - SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype) +void StatefulGptDecoder::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, + SizeType maxSequenceLength, nvinfer1::DataType dtype) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); mDecoder = IGptDecoder::create(dtype, mVocabSize, mVocabSizePadded, mStream); - reshapeBuffers(maxBatchSize, maxBeamWidth, maxSequenceLength); + reshapeBuffers(maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength); TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } -void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength) +void StatefulGptDecoder::reshapeBuffers( + SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(batchSize > 0); @@ -79,6 +80,7 @@ void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth, TLLM_CHECK(maxSequenceLength > 0); mMaxSequenceLength = maxSequenceLength; + mMaxKvCacheLength = maxKvCacheLength; auto const batchSizeShape = ITensor::makeShape({batchSize}); auto const batchSizeXbeamWidth = ITensor::makeShape({batchSize, beamWidth}); @@ -129,7 +131,7 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig auto const batchSize = inputLengthsShape.d[0]; auto const beamWidth = samplingConfig.beamWidth; - reshapeBuffers(batchSize, beamWidth, mMaxSequenceLength); + reshapeBuffers(batchSize, beamWidth, mMaxKvCacheLength, mMaxSequenceLength); mDecoder->setup(samplingConfig, batchSize); // sanity checks, should always be true after reshape @@ -159,9 +161,10 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig // inputs auto& dInput = *mDecodingInput; dInput.maxLength = maxInputLength; + dInput.maxKvCacheLength = mMaxKvCacheLength; dInput.batchSize = batchSize; kernels::invokeFill(const_cast(*dInput.endIds), endId, *stream); - dInput.embeddingBias = inputs.embeddingBiasOpt; + dInput.embeddingBias = inputs.embeddingBias; dInput.badWordsList = inputs.badWordsList; dInput.stopWordsList = inputs.stopWordsList; auto inputLengthsView = ITensor::view(dInput.lengths, ITensor::makeShape({batchSize * beamWidth})); diff --git a/cpp/tensorrt_llm/runtime/statefulGptDecoder.h b/cpp/tensorrt_llm/runtime/statefulGptDecoder.h index 0276518ce..5244dceb3 100644 --- a/cpp/tensorrt_llm/runtime/statefulGptDecoder.h +++ b/cpp/tensorrt_llm/runtime/statefulGptDecoder.h @@ -39,8 +39,8 @@ class StatefulGptDecoder : public IStatefulGptDecoder StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream); //! Setup the decoder before calling `forward()` - void setup( - SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, nvinfer1::DataType dtype) override; + void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength, + nvinfer1::DataType dtype) override; //! @brief Initialize the decoder with new batch of inputs. void newBatch(GenerationInput const& input, SamplingConfig const& samplingConfig) override; @@ -72,7 +72,7 @@ class StatefulGptDecoder : public IStatefulGptDecoder } private: - void reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength); + void reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType mMaxKvCacheLength, SizeType maxSequenceLength); private: std::size_t const mVocabSize; @@ -90,5 +90,6 @@ class StatefulGptDecoder : public IStatefulGptDecoder SizeType mNbSteps; SizeType mMaxSequenceLength{}; + SizeType mMaxKvCacheLength{}; }; } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/tensorView.h b/cpp/tensorrt_llm/runtime/tensorView.h index a47333b9e..e16efe992 100644 --- a/cpp/tensorrt_llm/runtime/tensorView.h +++ b/cpp/tensorrt_llm/runtime/tensorView.h @@ -64,15 +64,7 @@ class TensorView : virtual public ITensor, public BufferView void resize(std::size_t newSize) override { - if (newSize != getSize()) - { - using dimType = std::remove_reference_t; - auto constexpr max_size = std::numeric_limits::max(); - TLLM_CHECK_WITH_INFO(newSize <= max_size, "New size is too large. Use reshape() instead."); - Base::resize(newSize); - mDims.nbDims = 1; - mDims.d[0] = static_cast(newSize); - } + ITensor::resize(newSize); } void release() override diff --git a/cpp/tensorrt_llm/runtime/tllmBuffers.h b/cpp/tensorrt_llm/runtime/tllmBuffers.h index f274928ee..795d2f3e5 100644 --- a/cpp/tensorrt_llm/runtime/tllmBuffers.h +++ b/cpp/tensorrt_llm/runtime/tllmBuffers.h @@ -101,7 +101,7 @@ class CudaAllocatorAsync : public BaseAllocator(mCudaStream), "Undefined CUDA stream"); } - CudaStreamPtr getCudaStream() const + [[nodiscard]] CudaStreamPtr getCudaStream() const { return mCudaStream; } @@ -236,13 +236,14 @@ class GenericBuffer : virtual public IBuffer //! //! \brief Construct an empty buffer. //! - explicit GenericBuffer(nvinfer1::DataType type, TAllocator allocator = {}) + explicit GenericBuffer(nvinfer1::DataType type, TAllocator allocator = {}) // NOLINT(*-pro-type-member-init) : GenericBuffer{0, type, std::move(allocator)} {}; //! //! \brief Construct a buffer with the specified allocation size in number of elements. //! - explicit GenericBuffer(std::size_t size, nvinfer1::DataType type, TAllocator allocator = {}) + explicit GenericBuffer( // NOLINT(*-pro-type-member-init) + std::size_t size, nvinfer1::DataType type, TAllocator allocator = {}) : GenericBuffer{size, size, type, std::move(allocator)} {}; GenericBuffer(GenericBuffer&& buf) noexcept @@ -280,21 +281,21 @@ class GenericBuffer : virtual public IBuffer //! void* data() override { - return mBuffer; + return TLLM_LIKELY(mSize > 0) ? mBuffer : nullptr; } //! //! \brief Returns pointer to underlying array. //! - const void* data() const override + [[nodiscard]] void const* data() const override { - return mBuffer; + return TLLM_LIKELY(mSize > 0) ? mBuffer : nullptr; } //! //! \brief Returns the size (in number of elements) of the buffer. //! - std::size_t getSize() const override + [[nodiscard]] std::size_t getSize() const override { return mSize; } @@ -302,7 +303,7 @@ class GenericBuffer : virtual public IBuffer //! //! \brief Returns the capacity of the buffer. //! - std::size_t getCapacity() const override + [[nodiscard]] std::size_t getCapacity() const override { return mCapacity; } @@ -310,7 +311,7 @@ class GenericBuffer : virtual public IBuffer //! //! \brief Returns the type of the buffer. //! - nvinfer1::DataType getDataType() const override + [[nodiscard]] nvinfer1::DataType getDataType() const override { return mType; } @@ -318,7 +319,7 @@ class GenericBuffer : virtual public IBuffer //! //! \brief Returns the memory type of the buffer. //! - MemoryType getMemoryType() const override + [[nodiscard]] MemoryType getMemoryType() const override { return mAllocator.getMemoryType(); } @@ -328,11 +329,7 @@ class GenericBuffer : virtual public IBuffer //! void resize(std::size_t newSize) override { - if (newSize == 0) - { - release(); - } - else if (mCapacity < newSize) + if (mCapacity < newSize) { mAllocator.deallocate(mBuffer, toBytes(mCapacity)); mBuffer = mAllocator.allocate(toBytes(newSize)); @@ -444,7 +441,7 @@ class GenericTensor : virtual public ITensor, public GenericBuffer return *this; } - nvinfer1::Dims const& getShape() const override + [[nodiscard]] nvinfer1::Dims const& getShape() const override { return mDims; } @@ -457,15 +454,7 @@ class GenericTensor : virtual public ITensor, public GenericBuffer void resize(std::size_t newSize) override { - if (newSize != getSize()) - { - using dimType = std::remove_reference_t; - auto constexpr max_size = std::numeric_limits::max(); - TLLM_CHECK_WITH_INFO(newSize <= max_size, "New size is too large. Use reshape() instead."); - Base::resize(newSize); - mDims.nbDims = 1; - mDims.d[0] = static_cast(newSize); - } + ITensor::resize(newSize); } void release() override diff --git a/cpp/tensorrt_llm/runtime/torchView.h b/cpp/tensorrt_llm/runtime/torchView.h index 93d6cdbf4..682dd646e 100644 --- a/cpp/tensorrt_llm/runtime/torchView.h +++ b/cpp/tensorrt_llm/runtime/torchView.h @@ -42,16 +42,12 @@ class TorchView : virtual public ITensor void* data() override { - if (getSize() == 0) - return nullptr; - return mTensor.data_ptr(); + return TLLM_LIKELY(getSize() > 0) ? mTensor.data_ptr() : nullptr; } [[nodiscard]] void const* data() const override { - if (getSize() == 0) - return nullptr; - return mTensor.data_ptr(); + return TLLM_LIKELY(getSize() > 0) ? mTensor.data_ptr() : nullptr; } [[nodiscard]] size_t getSize() const override @@ -76,17 +72,7 @@ class TorchView : virtual public ITensor void resize(std::size_t newSize) override { - TLLM_CHECK(newSize <= getCapacity()); - - if (newSize != getSize()) - { - using dimType = std::remove_reference_t; - auto constexpr max_size = std::numeric_limits::max(); - TLLM_CHECK_WITH_INFO(newSize <= max_size, "New size is too large. Use reshape() instead."); - mTensor.resize_({static_cast(newSize)}); - mDims.nbDims = 1; - mDims.d[0] = static_cast(newSize); - } + ITensor::resize(newSize); } void release() override diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp index be7938849..8598a49ac 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp @@ -138,7 +138,7 @@ void FtDynamicDecode::setup(size_t batch_size, size_t beam_width, th::optiona template void FtDynamicDecode::forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size) - int step, int max_input_length, uint64_t ite, int local_batch_size, th::Tensor end_id, + int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id, th::optional embedding_bias_opt, th::optional input_lengths_opt, th::optional sequence_limit_length_opt, th::optional stop_words_list_opt, th::optional bad_words_list_opt, th::optional no_repeat_ngram_size_opt, @@ -156,8 +156,8 @@ void FtDynamicDecode::forward(th::Tensor& logits, // (batch_size, beam_width, { auto const& logits_converted = convert_tensor(logits); auto const& end_ids_converted = convert_tensor(end_id); - typename tensorrt_llm::layers::DynamicDecodeLayer::ForwardParams forwardParams{ - step, static_cast(ite), max_input_length, local_batch_size, logits_converted, end_ids_converted}; + typename tensorrt_llm::layers::DynamicDecodeLayer::ForwardParams forwardParams{step, static_cast(ite), + max_input_length, max_kv_cache_length, local_batch_size, logits_converted, end_ids_converted}; safeUpdate(src_cache_indirection_opt, forwardParams.src_cache_indirection); safeUpdate(sequence_limit_length_opt, forwardParams.sequence_limit_length); @@ -272,8 +272,9 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional top_p_reset_ids_opt); } -th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length, int64_t ite, - int64_t local_batch_size, th::Tensor end_id, th::optional embedding_bias_opt, +th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length, + int64_t max_kv_cache_length, int64_t ite, int64_t local_batch_size, th::Tensor end_id, + th::optional embedding_bias_opt, th::optional input_lengths_opt, // length of input contexts. th::optional sequence_limit_length_opt, th::optional stop_words_list_opt, th::optional bad_words_list_opt, th::optional no_repeat_ngram_size_opt, @@ -339,9 +340,10 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max dynamic_decode_->forward( // Inputs - logits, static_cast(step), static_cast(max_input_length), static_cast(ite), - static_cast(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt, sequence_limit_length_opt, - stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt, src_cache_indirection_opt, + logits, static_cast(step), static_cast(max_input_length), static_cast(max_kv_cache_length), + static_cast(ite), static_cast(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt, + sequence_limit_length_opt, stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt, + src_cache_indirection_opt, // Outputs output_token_ids, newTokens, should_stop, finished_opt, seuqence_lengths_opt, cum_log_probs_opt, output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt, beam_hyps_output_ids_tgt_opt, diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h index 47db12397..477f52897 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h @@ -39,7 +39,7 @@ class IFtDynamicDecode = 0; virtual void forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size) - int step, int max_input_length, uint64_t ite, int local_batch_size, th::Tensor end_id, + int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id, th::optional embedding_bias_opt, th::optional input_lengths_opt, th::optional sequence_limit_length_opt, th::optional stop_words_list_opt, th::optional bad_words_list_opt, th::optional no_repeat_ngram_size_opt, @@ -77,7 +77,7 @@ class FtDynamicDecode : public IFtDynamicDecode th::optional top_p_reset_ids_opt) override; void forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size) - int step, int max_input_length, uint64_t ite, int local_batch_size, th::Tensor end_id, + int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id, th::optional embedding_bias_opt, th::optional input_lengths_opt, th::optional sequence_limit_length_opt, th::optional stop_words_list_opt, th::optional bad_words_list_opt, th::optional no_repeat_ngram_size_opt, @@ -121,8 +121,8 @@ class DynamicDecodeOp : public th::jit::CustomClassHolder th::optional top_p_reset_ids_opt); th::Tensor forward(th::Tensor logits, // (batch_size, beam_width, vocab_size) - int64_t step, int64_t max_input_length, int64_t ite, int64_t local_batch_size, th::Tensor end_id, - th::optional embedding_bias_opt, + int64_t step, int64_t max_input_length, int64_t max_kv_cache_length, int64_t ite, int64_t local_batch_size, + th::Tensor end_id, th::optional embedding_bias_opt, th::optional input_lengths_opt, // length of input contexts. th::optional sequence_limit_length_opt, th::optional stop_words_list_opt, th::optional bad_words_list_opt, th::optional no_repeat_ngram_size_opt, diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index a3fa915c8..07b64c81b 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -75,7 +75,10 @@ add_gtest(bufferManagerTest runtime/bufferManagerTest.cpp) add_gtest(runtimeKernelTest runtime/runtimeKernelTest.cpp) add_gtest(samplingTest runtime/samplingTest.cpp) add_gtest(iTensorTest runtime/iTensorTest.cpp) -add_gtest(torchTest runtime/torchTest.cpp) +if(${BUILD_PYT}) + add_gtest(torchTest runtime/torchTest.cpp) + target_link_libraries(torchTest PUBLIC ${TORCH_LIBRARIES}) +endif() set(SAMPLING_KERNEL_TEST_SRC kernels/sampling/samplingTest.cpp kernels/sampling/samplingTopKTest.cpp @@ -83,7 +86,7 @@ set(SAMPLING_KERNEL_TEST_SRC kernels/sampling/samplingPenaltyTest.cpp kernels/sampling/samplingUtilsTest.cu) add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}") -target_link_libraries(torchTest PUBLIC ${TORCH_LIBRARIES}) +add_gtest(weightOnlyKernelTest kernels/weightOnly/weightOnlyKernelTest.cpp) if(BUILD_BATCH_MANAGER) add_subdirectory(batch_manager) diff --git a/cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp b/cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp new file mode 100644 index 000000000..28bcb8c04 --- /dev/null +++ b/cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp @@ -0,0 +1,429 @@ +#include +#include +#include + +#include "cutlass/numeric_types.h" +#include "tensorrt_llm/common/quantization.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using tensorrt_llm::kernels::WeightOnlyParams; +using tensorrt_llm::kernels::WeightOnlyType; +using tensorrt_llm::kernels::WeightOnlyQuantType; +using tensorrt_llm::kernels::WeightOnlyActivationType; +using tensorrt_llm::kernels::WeightOnlyActivationFunctionType; +template +struct AType; + +template <> +struct AType +{ + using CudaKernelAType = half; + using CutlassKernelAType = half; +}; +#if defined(ENABLE_BF16) +template <> +struct AType +{ + using CudaKernelAType = __nv_bfloat16; + using CutlassKernelAType = __nv_bfloat16; +}; +#endif +template +struct BType; + +template <> +struct BType +{ + using CudaKernelBType = uint8_t; + using CutlassKernelBType = cutlass::uint4b_t; + static constexpr int elemsPerByte = 2; +}; + +template <> +struct BType +{ + using CudaKernelBType = uint8_t; + using CutlassKernelBType = uint8_t; + static constexpr int elemsPerByte = 1; +}; +struct CutlassKernel; +struct CudaKernel; + +template +float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n, + int k, int group_size, int warmup, int iter) +{ + assert(zeros == nullptr && bias == nullptr && group_size == 0); + cudaStream_t s; + cudaStreamCreate(&s); + cudaEvent_t begin, end; + cudaEventCreate(&begin); + cudaEventCreate(&end); + if constexpr (std::is_same_v) + { + WeightOnlyParams params{reinterpret_cast(weight), scales, zeros, act, bias, out, m, n, k, group_size, + BFlag, WeightOnlyType::PerChannel, WeightOnlyActivationFunctionType::Identity, AFlag}; + for (int i = 0; i < warmup; ++i) + { + tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s); + } + cudaEventRecord(begin, s); + for (int i = 0; i < iter; ++i) + { + tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s); + } + } + else if (std::is_same_v) + { + tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner::CutlassKernelAType, + typename BType::CutlassKernelBType, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY> + gemm; + auto configs = gemm.getConfigs(); + int ws_bytes = gemm.getWorkspaceSize(m, n, k); + char* ws_ptr = nullptr; + if (ws_bytes) + cudaMalloc(&ws_ptr, ws_bytes); + float fast_time = 1e8; + auto best_config = configs[0]; + for (auto& config : configs) + { + for (int i = 0; i < 2; ++i) + { + gemm.gemm(act, weight, scales, out, m, n, k, config, ws_ptr, ws_bytes, s); + } + cudaEventRecord(begin, s); + for (int i = 0; i < 5; ++i) + { + gemm.gemm(act, weight, scales, out, m, n, k, config, ws_ptr, ws_bytes, s); + } + cudaEventRecord(end, s); + cudaEventSynchronize(end); + float time; + cudaEventElapsedTime(&time, begin, end); + fast_time = std::min(fast_time, time); + if (time < fast_time) + { + fast_time = time; + best_config = config; + } + } + + for (int i = 0; i < warmup; ++i) + { + gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s); + } + cudaProfilerStart(); + cudaEventRecord(begin, s); + for (int i = 0; i < iter; ++i) + { + gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s); + } + if (ws_ptr) + cudaFree(ws_ptr); + } + + cudaEventRecord(end, s); + cudaEventSynchronize(end); + float time; + cudaEventElapsedTime(&time, begin, end); + cudaEventDestroy(begin); + cudaEventDestroy(end); + cudaStreamDestroy(s); + return time / iter; +} + +template +float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n, + int k, int group_size, int warmup, int iter) +{ + assert(zeros && bias && (group_size == 64 || group_size == 128)); + cudaStream_t s; + cudaStreamCreate(&s); + cudaEvent_t begin, end; + cudaEventCreate(&begin); + cudaEventCreate(&end); + if constexpr (std::is_same_v) + { + WeightOnlyParams params{reinterpret_cast(weight), scales, zeros, act, bias, out, m, n, k, group_size, + BFlag, WeightOnlyType::GroupWise, WeightOnlyActivationFunctionType::Identity, AFlag}; + for (int i = 0; i < warmup; ++i) + { + tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s); + } + cudaEventRecord(begin, s); + for (int i = 0; i < iter; ++i) + { + tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s); + } + } + else if (std::is_same_v) + { + tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner::CutlassKernelAType, + typename BType::CutlassKernelBType, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS> + gemm; + auto configs = gemm.getConfigs(); + int ws_bytes = gemm.getWorkspaceSize(m, n, k); + char* ws_ptr = nullptr; + if (ws_bytes) + cudaMalloc(&ws_ptr, ws_bytes); + float fast_time = 1e8; + auto best_config = configs[0]; + for (auto& config : configs) + { + for (int i = 0; i < 2; ++i) + { + gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, config, ws_ptr, ws_bytes, s); + } + cudaEventRecord(begin, s); + for (int i = 0; i < 5; ++i) + { + gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, config, ws_ptr, ws_bytes, s); + } + cudaEventRecord(end, s); + cudaEventSynchronize(end); + float time; + cudaEventElapsedTime(&time, begin, end); + fast_time = std::min(fast_time, time); + if (time < fast_time) + { + fast_time = time; + best_config = config; + } + } + + for (int i = 0; i < warmup; ++i) + { + gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s); + } + cudaProfilerStart(); + cudaEventRecord(begin, s); + for (int i = 0; i < iter; ++i) + { + gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s); + } + if (ws_ptr) + cudaFree(ws_ptr); + } + + cudaEventRecord(end, s); + cudaEventSynchronize(end); + float time; + cudaEventElapsedTime(&time, begin, end); + cudaEventDestroy(begin); + cudaEventDestroy(end); + cudaStreamDestroy(s); + return time / iter; +} + +struct CudaBuffer +{ + void* _data; + int _size; + + CudaBuffer(int size_in_bytes) + : _size(size_in_bytes) + { + cudaMalloc(&_data, _size); + } + + template + T* data() + { + return reinterpret_cast(_data); + } + + void copy_to(void* dst) + { + cudaMemcpy(dst, _data, _size, cudaMemcpyDeviceToHost); + } + + void copy_from(void* src) + { + cudaMemcpy(_data, src, _size, cudaMemcpyHostToDevice); + } + + ~CudaBuffer() + { + cudaFree(_data); + } +}; + +template +float compare(void* _pa, void* _pb, int size, float scale) +{ + auto pa = reinterpret_cast(_pa); + auto pb = reinterpret_cast(_pb); + float max_diff = 0.f, tot_diff = 0.f; + float max_val = 0.f; + int diff_cnt = 0; + float threshold = 1e-7; + for (int n = 0; n < size; ++n) + { + float va = static_cast(pa[n]); + float vb = static_cast(pb[n]); + max_val = std::max(max_val, vb); + float diff = std::abs(va - vb); + if (diff > threshold) + { + max_diff = std::max(max_diff, diff); + tot_diff += diff; + ++diff_cnt; + } + } + float diff_thres = max_val * scale; +#if defined(ENABLE_BF16) + if constexpr (std::is_same_v) + { + // bfloat16 has fewer mantissa digits than float16, so the cumulative error will be larger. + diff_thres *= 2.f; + } + else +#endif + { + diff_thres *= 1.5f; + } + printf("max diff %f (diff threshold %f), avg diff %f, diff cnt %d/%d\n", max_diff, diff_thres, tot_diff / diff_cnt, + diff_cnt, size); + return max_diff <= diff_thres; +} + +template +void random_fill(std::vector& vec, T2 minv, T2 maxv) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(static_cast(minv), static_cast(maxv)); + for (auto& v : vec) + { + v = static_cast(dis(gen)); + } +} + +template +bool benchmark(int m, int n, int k, int group_size, int warmup, int iter) +{ + printf("benchmark mnk (%d, %d, %d) ", m, n, k); + if (AFlag == WeightOnlyActivationType::FP16) + { + printf("FP16 Activation "); + } + else + { + printf("BF16 Activation "); + } + if (BFlag == WeightOnlyQuantType::Int8b) + { + printf("Int8b "); + } + else + { + printf("Int4b "); + } + if (group_size == 0) + { + printf("PerChannel Weight Only\n"); + } + else + { + printf("GroupWise%d Weight Only\n", group_size); + } + using AT = typename AType::CudaKernelAType; + using BT = typename BType::CudaKernelBType; + constexpr int elem_per_byte = BType::elemsPerByte; + CudaBuffer d_act(m * k * sizeof(AT)); + CudaBuffer d_weight(k * n * sizeof(uint8_t) / elem_per_byte); + CudaBuffer d_scales(n * k * sizeof(AT)); + CudaBuffer d_zeros(n * k * sizeof(AT)); + CudaBuffer d_bias(n * sizeof(AT)); + CudaBuffer d_out(m * n * sizeof(AT)); + std::vector h_act(m * k); + std::vector h_weight(k * n); + std::vector h_scales(n * k), h_zeros(n * k), h_bias(n); + std::vector h_out1(m * n), h_out2(m * n); + + random_fill(h_act, -1.f, 1.f); + random_fill(h_scales, -1.f, 1.f); + + for (uint8_t& v : h_weight) + { + v = rand() % 256; + } + + d_act.copy_from(h_act.data()); + d_weight.copy_from(h_weight.data()); + d_scales.copy_from(h_scales.data()); + d_zeros.copy_from(h_zeros.data()); + d_bias.copy_from(h_bias.data()); + + void* p_zeros = nullptr; + void* p_bias = nullptr; + if (group_size == 64 || group_size == 128) + { + p_zeros = d_zeros.data(); + p_bias = d_bias.data(); + } + + float time1, time2; + time1 = benchmark_perchannel( + d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k, 0, warmup, iter); + d_out.copy_to(h_out1.data()); + time2 = benchmark_perchannel( + d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k, 0, warmup, iter); + d_out.copy_to(h_out2.data()); + float quant_scale = 1.f / (1 << (8 / elem_per_byte - 1)); + bool pass = compare(h_out1.data(), h_out2.data(), m * n, quant_scale); + printf( + "cuda kernel cost time %.6f, cutlass kernel cost time %.6f, cuda speedup %.3f\n", time1, time2, time2 / time1); + return pass; +} + +TEST(Kernel, WeightOnly) +{ + bool pass; + int warmup = 10, iter = 30; + std::vector ms{1, 2, 4}; + std::vector ns{512, 1024, 2048, 4096}; + std::vector ks{512, 1024, 2048, 4096}; + std::vector gss{0, 64, 128}; + for (auto m : ms) + { + for (auto n : ns) + { + for (auto k : ks) + { + for (auto gs : gss) + { + pass = benchmark( + m, n, k, gs, warmup, iter); + EXPECT_TRUE(pass); + pass = benchmark( + m, n, k, gs, warmup, iter); + EXPECT_TRUE(pass); +#if defined(ENABLE_BF16) + pass = benchmark( + m, n, k, gs, warmup, iter); + EXPECT_TRUE(pass); + pass = benchmark( + m, n, k, gs, warmup, iter); + EXPECT_TRUE(pass); +#endif + } + } + } + } +} diff --git a/cpp/tests/runtime/gptDecoderBatchTest.cpp b/cpp/tests/runtime/gptDecoderBatchTest.cpp index b4537e73b..34b0c14ae 100644 --- a/cpp/tests/runtime/gptDecoderBatchTest.cpp +++ b/cpp/tests/runtime/gptDecoderBatchTest.cpp @@ -120,8 +120,10 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector con SizeType constexpr maxInputLength{8}; SizeType constexpr maxNewTokens{2}; auto constexpr maxSeqLength = maxInputLength + maxNewTokens; + // We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache). + auto const maxKvCacheLength = maxSeqLength; - decoder.setup(batchSize, maxBeamWidth, maxSeqLength, modelConfig.getDataType()); + decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, modelConfig.getDataType()); std::vector const inputLengths{4, 5, 6, 7}; std::vector tiledInputLengths; @@ -240,8 +242,10 @@ void testDecoderWavefront( SizeType constexpr maxInputLength{8}; SizeType constexpr maxNewTokens{8}; auto constexpr maxSeqLength = maxInputLength + maxNewTokens; + // We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache). + auto const maxKvCacheLength = maxSeqLength; - decoder.setup(batchSize, maxBeamWidth, maxSeqLength, modelConfig.getDataType()); + decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, modelConfig.getDataType()); std::vector const inputLengths{4, 5, 6, 7}; std::vector tiledInputLengths; diff --git a/cpp/tests/runtime/gptDecoderTest.cpp b/cpp/tests/runtime/gptDecoderTest.cpp index 63d688ac4..3d1b96f2f 100644 --- a/cpp/tests/runtime/gptDecoderTest.cpp +++ b/cpp/tests/runtime/gptDecoderTest.cpp @@ -71,7 +71,7 @@ void testDecoder(nvinfer1::DataType const dtype, SamplingConfig const& samplingC auto endIds = std::shared_ptr(manager.copyFrom(endIdsVec, ITensor::makeShape({batchSize, beamWidth}), MemoryType::kGPU)); - DecodingInput inputs{maxInputLength, batchSize, logits, endIds}; + DecodingInput inputs{maxInputLength, maxSeqLength, batchSize, logits, endIds}; std::vector sequenceLimitLengthsVec(batchSize, maxSeqLength); inputs.sequenceLimitLength = manager.copyFrom(sequenceLimitLengthsVec, ITensor::makeShape({batchSize}), MemoryType::kGPU); diff --git a/cpp/tests/runtime/gptSessionTest.cpp b/cpp/tests/runtime/gptSessionTest.cpp index 01b53a389..9c4c69755 100644 --- a/cpp/tests/runtime/gptSessionTest.cpp +++ b/cpp/tests/runtime/gptSessionTest.cpp @@ -739,6 +739,7 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName, samplingConfig.randomSeed = std::vector{1ull}; samplingConfig.topK = std::vector{1}; samplingConfig.topP = std::vector{1.0f}; + samplingConfig.lengthPenalty = std::vector{1.0f}; auto const padId = modelIds.padId; auto const endId = modelIds.endId; diff --git a/cpp/tests/runtime/iTensorTest.cpp b/cpp/tests/runtime/iTensorTest.cpp index c2be7b5e5..499665bac 100644 --- a/cpp/tests/runtime/iTensorTest.cpp +++ b/cpp/tests/runtime/iTensorTest.cpp @@ -21,12 +21,30 @@ #include "tensorrt_llm/runtime/iTensor.h" using namespace tensorrt_llm::runtime; -using namespace ::testing; +namespace tc = tensorrt_llm::common; -namespace +TEST(ITensorTest, SqueezeTensor) { + auto dims = ITensor::makeShape({16, 1, 4}); + auto constexpr dataType = nvinfer1::DataType::kFLOAT; + ITensor::SharedPtr tensor{BufferManager::cpu(dims, dataType)}; + + auto squeezeDim = 0; + EXPECT_THROW(tensor->squeeze(squeezeDim), std::runtime_error); + squeezeDim = 1; + auto squeezed = ITensor::view(tensor, ITensor::squeeze(dims, squeezeDim)); + + EXPECT_EQ(squeezed->getSize(), tensor->getSize()); + EXPECT_EQ(squeezed->getShape().nbDims, tensor->getShape().nbDims - 1); + EXPECT_EQ(squeezed->getShape().d[0], tensor->getShape().d[0]); + EXPECT_EQ(squeezed->getShape().d[1], tensor->getShape().d[2]); + + EXPECT_NO_THROW(squeezed->release()); + EXPECT_EQ(squeezed->data(), nullptr); + EXPECT_NE(tensor->data(), nullptr); +} -TEST(iTensorTest, UnsqueezeShape) +TEST(ITensorTest, UnsqueezeShape) { auto oldShape = ITensor::makeShape({2, 3, 4, 5}); { @@ -66,7 +84,7 @@ TEST(iTensorTest, UnsqueezeShape) { try { - auto shape = ITensor::unsqueeze(oldShape, invalidDim); + ITensor::unsqueeze(oldShape, invalidDim); FAIL() << "Expected failure"; } catch (tensorrt_llm::common::TllmException const& e) @@ -80,13 +98,12 @@ TEST(iTensorTest, UnsqueezeShape) } } -TEST(iTensorTest, UnsqueezeTensor) +TEST(ITensorTest, UnsqueezeTensor) { auto oldShape = ITensor::makeShape({2, 3, 4, 5}); - BufferManager manager(std::make_shared()); { - auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32); + auto tensor = BufferManager::cpu(oldShape, nvinfer1::DataType::kINT32); tensor->unsqueeze(0); auto shape = tensor->getShape(); @@ -98,7 +115,7 @@ TEST(iTensorTest, UnsqueezeTensor) EXPECT_EQ(shape.d[4], 5); } { - auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32); + auto tensor = BufferManager::cpu(oldShape, nvinfer1::DataType::kINT32); tensor->unsqueeze(1); auto shape = tensor->getShape(); @@ -111,7 +128,7 @@ TEST(iTensorTest, UnsqueezeTensor) } { - auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32); + auto tensor = BufferManager::cpu(oldShape, nvinfer1::DataType::kINT32); tensor->unsqueeze(4); auto shape = tensor->getShape(); @@ -128,7 +145,7 @@ TEST(iTensorTest, UnsqueezeTensor) { try { - auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32); + auto tensor = BufferManager::cpu(oldShape, nvinfer1::DataType::kINT32); tensor->unsqueeze(invalidDim); FAIL() << "Expected failure"; } @@ -143,4 +160,61 @@ TEST(iTensorTest, UnsqueezeTensor) } } -} // namespace +TEST(ITensorTest, TensorView) +{ + auto const dims = ITensor::makeShape({16, 1, 4}); + auto constexpr dataType = nvinfer1::DataType::kFLOAT; + ITensor::SharedPtr tensor = BufferManager::cpu(dims, dataType); + + auto const viewDims = ITensor::makeShape({16, 1, 2}); + + auto view = ITensor::view(tensor, viewDims); + EXPECT_EQ(view->getSize(), tensor->getSize() / 2); + EXPECT_EQ(view->getShape().nbDims, tensor->getShape().nbDims); + EXPECT_EQ(view->getShape().d[2], tensor->getShape().d[2] / 2); + + EXPECT_NO_THROW(view->release()); + EXPECT_EQ(view->data(), nullptr); + EXPECT_NE(tensor->data(), nullptr); +} + +TEST(ITensorTest, TensorSlice) +{ + auto dims = ITensor::makeShape({16, 8, 4}); + auto constexpr dataType = nvinfer1::DataType::kFLOAT; + ITensor::SharedPtr tensor{BufferManager::cpu(dims, dataType)}; + auto offset = dims.d[0] / 4; + auto slice = ITensor::slice(tensor, offset); + auto const sizeSlice = 3 * tensor->getSize() / 4; + EXPECT_EQ(slice->getShape().d[0], dims.d[0] - offset); + EXPECT_EQ(slice->getSize(), sizeSlice); + EXPECT_EQ(slice->getCapacity(), sizeSlice); + EXPECT_EQ(static_cast(slice->data()) - static_cast(tensor->data()), + offset * ITensor::volume(dims) / dims.d[0] * BufferDataType(dataType).getSize()); + + auto dimsNew = ITensor::makeShape({12, 32}); + EXPECT_EQ(ITensor::volume(dimsNew), sizeSlice); + EXPECT_NO_THROW(slice->reshape(dimsNew)); + EXPECT_EQ(slice->getShape().d[1], dimsNew.d[1]); + dimsNew.d[0] = 6; + EXPECT_LT(ITensor::volume(dimsNew), sizeSlice); + EXPECT_NO_THROW(slice->reshape(dimsNew)); + EXPECT_EQ(slice->getShape().d[0], dimsNew.d[0]); + dimsNew.d[0] = 16; + EXPECT_GT(ITensor::volume(dimsNew), sizeSlice); + EXPECT_THROW(slice->reshape(dimsNew), std::runtime_error); + + EXPECT_NO_THROW(slice->resize(sizeSlice)); + EXPECT_NO_THROW(slice->resize(sizeSlice / 2)); + EXPECT_EQ(slice->getShape().d[0], sizeSlice / 2); + EXPECT_THROW(slice->resize(sizeSlice * 2), std::runtime_error); + EXPECT_NO_THROW(slice->release()); + EXPECT_EQ(slice->data(), nullptr); + EXPECT_NE(tensor->data(), nullptr); + + std::shared_ptr constTensor{tensor}; + auto constSlice = ITensor::slice(constTensor, offset); + EXPECT_EQ(constSlice->getShape().d[0], dims.d[0] - offset); + auto uniqueSlice = ITensor::slice(std::move(constSlice), 1); + EXPECT_EQ(uniqueSlice->getShape().d[0], dims.d[0] - offset - 1); +} diff --git a/cpp/tests/runtime/samplingTest.cpp b/cpp/tests/runtime/samplingTest.cpp index 4d554fe55..e2824c545 100644 --- a/cpp/tests/runtime/samplingTest.cpp +++ b/cpp/tests/runtime/samplingTest.cpp @@ -104,7 +104,7 @@ typename tl::DynamicDecodeLayer::OutputParams dynamicDecodeTest(BufferMan ddLayer.setup(batchSize, beamWidth, setupParams); typename tl::DynamicDecodeLayer::ForwardParams forwardParams( - step, ite, maxInputLength, localBatchSize, logits, endIds); + step, ite, maxInputLength, static_cast(maxSeqLength), localBatchSize, logits, endIds); forwardParams.no_repeat_ngram_size = noRepeatNgramSize; typename tl::DynamicDecodeLayer::OutputParams outputParams(outputIds); diff --git a/cpp/tests/runtime/tllmBuffersTest.cpp b/cpp/tests/runtime/tllmBuffersTest.cpp index a3998b3d4..40c207d64 100644 --- a/cpp/tests/runtime/tllmBuffersTest.cpp +++ b/cpp/tests/runtime/tllmBuffersTest.cpp @@ -36,9 +36,6 @@ class TllmBuffersTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-t void SetUp() override { mDeviceCount = tc::getDeviceCount(); - - if (mDeviceCount == 0) - GTEST_SKIP(); } void TearDown() override {} @@ -48,6 +45,9 @@ class TllmBuffersTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-t TEST_F(TllmBuffersTest, Stream) { + if (mDeviceCount == 0) + GTEST_SKIP(); + CudaStream stream{}; EXPECT_NE(stream.get(), nullptr); auto ptr = std::make_shared(); @@ -109,6 +109,9 @@ TEST_F(TllmBuffersTest, HostAllocator) TEST_F(TllmBuffersTest, CudaAllocatorAsync) { + if (mDeviceCount == 0) + GTEST_SKIP(); + auto streamPtr = std::make_shared(); auto constexpr size = 1024; CudaAllocatorAsync allocator{streamPtr}; @@ -171,6 +174,9 @@ void testBuffer(IBuffer& buffer, std::int32_t typeSize) TEST_F(TllmBuffersTest, DeviceBuffer) { + if (mDeviceCount == 0) + GTEST_SKIP(); + auto streamPtr = std::make_shared(); auto constexpr size = 1024; CudaAllocatorAsync allocator{streamPtr}; @@ -186,6 +192,9 @@ TEST_F(TllmBuffersTest, DeviceBuffer) TEST_F(TllmBuffersTest, DeviceTensor) { + if (mDeviceCount == 0) + GTEST_SKIP(); + auto streamPtr = std::make_shared(); nvinfer1::Dims constexpr dims{3, 16, 8, 4}; CudaAllocatorAsync allocator{streamPtr}; @@ -228,91 +237,11 @@ TEST_F(TllmBuffersTest, BufferSlice) EXPECT_EQ(uniqueSlice->getSize(), sizeSlice - 1); } -TEST_F(TllmBuffersTest, TensorSlice) -{ - auto dims = ITensor::makeShape({16, 8, 4}); - HostAllocator allocator{}; - auto constexpr dataType = nvinfer1::DataType::kFLOAT; - auto tensor = std::make_shared(dims, dataType, allocator); - auto offset = dims.d[0] / 4; - auto slice = ITensor::slice(tensor, offset); - auto const sizeSlice = 3 * tensor->getSize() / 4; - EXPECT_EQ(slice->getShape().d[0], dims.d[0] - offset); - EXPECT_EQ(slice->getSize(), sizeSlice); - EXPECT_EQ(slice->getCapacity(), sizeSlice); - EXPECT_EQ(static_cast(slice->data()) - static_cast(tensor->data()), - offset * ITensor::volume(dims) / dims.d[0] * BufferDataType(dataType).getSize()); - - auto dimsNew = ITensor::makeShape({12, 32}); - EXPECT_EQ(ITensor::volume(dimsNew), sizeSlice); - EXPECT_NO_THROW(slice->reshape(dimsNew)); - EXPECT_EQ(slice->getShape().d[1], dimsNew.d[1]); - dimsNew.d[0] = 6; - EXPECT_LT(ITensor::volume(dimsNew), sizeSlice); - EXPECT_NO_THROW(slice->reshape(dimsNew)); - EXPECT_EQ(slice->getShape().d[0], dimsNew.d[0]); - dimsNew.d[0] = 16; - EXPECT_GT(ITensor::volume(dimsNew), sizeSlice); - EXPECT_THROW(slice->reshape(dimsNew), std::runtime_error); - - EXPECT_NO_THROW(slice->resize(sizeSlice)); - EXPECT_NO_THROW(slice->resize(sizeSlice / 2)); - EXPECT_EQ(slice->getShape().d[0], sizeSlice / 2); - EXPECT_THROW(slice->resize(sizeSlice * 2), std::runtime_error); - EXPECT_NO_THROW(slice->release()); - EXPECT_EQ(slice->data(), nullptr); - EXPECT_NE(tensor->data(), nullptr); - - std::shared_ptr constTensor{tensor}; - auto constSlice = ITensor::slice(constTensor, offset); - EXPECT_EQ(constSlice->getShape().d[0], dims.d[0] - offset); - auto uniqueSlice = ITensor::slice(std::move(constSlice), 1); - EXPECT_EQ(uniqueSlice->getShape().d[0], dims.d[0] - offset - 1); -} - -TEST_F(TllmBuffersTest, TensorSqueeze) -{ - auto dims = ITensor::makeShape({16, 1, 4}); - HostAllocator allocator{}; - auto constexpr dataType = nvinfer1::DataType::kFLOAT; - auto tensor = std::make_shared(dims, dataType, allocator); - - auto squeezeDim = 0; - EXPECT_THROW(tensor->squeeze(squeezeDim), std::runtime_error); - squeezeDim = 1; - auto squeezed = ITensor::view(tensor, ITensor::squeeze(dims, squeezeDim)); - - EXPECT_EQ(squeezed->getSize(), tensor->getSize()); - EXPECT_EQ(squeezed->getShape().nbDims, tensor->getShape().nbDims - 1); - EXPECT_EQ(squeezed->getShape().d[0], tensor->getShape().d[0]); - EXPECT_EQ(squeezed->getShape().d[1], tensor->getShape().d[2]); - - EXPECT_NO_THROW(squeezed->release()); - EXPECT_EQ(squeezed->data(), nullptr); - EXPECT_NE(tensor->data(), nullptr); -} - -TEST_F(TllmBuffersTest, TensorView) -{ - auto const dims = ITensor::makeShape({16, 1, 4}); - HostAllocator allocator{}; - auto constexpr dataType = nvinfer1::DataType::kFLOAT; - auto tensor = std::make_shared(dims, dataType, allocator); - - auto const viewDims = ITensor::makeShape({16, 1, 2}); - - auto view = ITensor::view(tensor, viewDims); - EXPECT_EQ(view->getSize(), tensor->getSize() / 2); - EXPECT_EQ(view->getShape().nbDims, tensor->getShape().nbDims); - EXPECT_EQ(view->getShape().d[2], tensor->getShape().d[2] / 2); - - EXPECT_NO_THROW(view->release()); - EXPECT_EQ(view->data(), nullptr); - EXPECT_NE(tensor->data(), nullptr); -} - TEST_F(TllmBuffersTest, BufferOutput) { + if (mDeviceCount == 0) + GTEST_SKIP(); + auto streamPtr = std::make_shared(); CudaAllocatorAsync allocator{streamPtr}; for (std::size_t size : {0, 16}) @@ -331,6 +260,9 @@ TEST_F(TllmBuffersTest, BufferOutput) TEST_F(TllmBuffersTest, TensorOutput) { + if (mDeviceCount == 0) + GTEST_SKIP(); + auto streamPtr = std::make_shared(); nvinfer1::Dims constexpr dims{3, 16, 8, 4}; CudaAllocatorAsync allocator{streamPtr}; diff --git a/docs/source/gpt_attention.md b/docs/source/gpt_attention.md index 4e1a34cfa..d32d12b42 100644 --- a/docs/source/gpt_attention.md +++ b/docs/source/gpt_attention.md @@ -164,6 +164,34 @@ the MHA/MQA kernel. The scaling factor to dequantize those values is stored in the `kv_quant_orig_scale` tensor. That tensor contains a single value (per tensor scaling). + +## Sliding Window Attention, Cyclic (Rolling Buffer) KV Cache + +TensorRT-LLM has a feature called `Cyclic KV Cache`, which treats the kv cache +as a circular buffer. This means that it only stores the kv cache for the last N +tokens, where N is determined by the `max_kv_cache_length` parameter in +`GenerationSession.setup`. You can see examples of this in the `run.py` or +`summarize.py` files. When the cache is full, new tokens’ kv cache will +overwrite the "least recently used" caches. + +In the context phase, if the input length surpasses the `max_kv_cache_length`, +`Sliding Window Attention` will be activated. This serves the same function as +the `sliding window_size`. + +This feature helps to reduce the memory footprint of the kv cache when +dealing with very long sequences. + +_Note that when using beam search, cyclic kv cache may not perform as well as +full kv cache when the current step exceeds `max_kv_cache_length`. +This issue will be addressed in future releases._ + +_The experimental feature, which allows different `max_kv_cache_length` values +for each layer, is also supported. To utilize this feature, simply provide an +`int32 torch.Tensor` with a shape of `[num_layers]` to the `GenerationSession.setup`. +This tensor will serve as the buffer for `max_kv_cache_length`, +setting unique values for each layer. However, it’s important to note that the +memory allocation for the kv cache still relies on the buffer’s maximum value._ + ## Beam-Search The GPT attention operator supports beam-search. In the context phase, a single diff --git a/docs/source/gpt_runtime.md b/docs/source/gpt_runtime.md index 035ed77ae..87782b02e 100644 --- a/docs/source/gpt_runtime.md +++ b/docs/source/gpt_runtime.md @@ -44,7 +44,8 @@ optional object to log information, warnings and errors: using namespace tensorrt_llm::runtime; -GptSession session(modelConfig, // Description of the model, +GptSession session(sessionConfig, // Configuration of the session, + modelConfig, // Description of the model, worldConfig, // Description of the environment, engineBuffer, // The compiled TensorRT engine (const void*), engineSize, // The size in bytes of the TensorRT engine (size_t), @@ -56,6 +57,35 @@ associated size (in bytes) of that buffer. There exist other overloaded versions that take `std::vector` or `std::string` arguments to encapsulate the engine. +#### Session Configuration + +The session configuration is an instance of the +[`GptSession::Config`](source:cpp/include/tensorrt_llm/runtime/gptSession.h) class. +The constructor of this class requires three arguments: + + * `maxBatchSize`, the maximum number of sequences in a batch, + * `maxBeamWidth`, the maximum width of the beams in beam-search, + * `maxSequenceLength`, the length of the longest input sequence, + +Additionally, the class encapsulates the following optional parameters +(they are declared as public member variables and can be accessed directly): + + * `decoderPerRequest`, whether the session will use a different decoder per + request. It must be set to `true` when running in-flight batching, + * `cudaGraphMode`, whether the session will use CUDA graphs for the engine + execution in generation phase, + * `kvCacheConfig` encapsulates parameters to configure paged KV cache, when the paged KV cache is enabled in the engine: + * `maxTokens`, the maximum number of tokens that will have to be + stored in the paged KV cache, + * `freeGpuMemoryFraction`, the fraction of free GPU memory that will be + reserved for paged KV cache, + * `ctxMicroBatchSize`, the micro batch size to be used in context phase. + Batches entered in `GptSession::generation` will be split into smaller + micro batches of this size, + * `genMicroBatchSize`, the micro batch size to be used in generation phase, + Batches entered in `GptSession::generation` will be split into smaller + micro batches of this size. + #### Model Configuration The model configuration is an instance of the @@ -152,7 +182,7 @@ MPI_Comm_rank(MPI_COMM_WORLD, &rank); tensorrt_llm::runtime::WorldConfig worldConfig(tensorParallelism, pipelineParallelism, rank); // Create the GPT session (as shown above). -tensorrt_llm::runtime::GptSession session(modelConfig, worldConfig, ...); +tensorrt_llm::runtime::GptSession session(sessionConfig, modelConfig, worldConfig, ...); ``` For simplicity, TensorRT-LLM provides users with the following simplified API: @@ -169,22 +199,6 @@ installed on the system (talk to your system administrator if needed): mpirun -n 2 ... ``` -### Setup - -***GptSession*** - -The `GptSession::setup` member function must be called to prepare the runtime -to execute the inference on a batch of input sequences. That member function -takes four arguments: - - * `batchSize`, the number of sequences in the batch, - * `beamWidth`, the width of the beams in beam-search, - * `maxSequenceLength`, the length of the longest input sequence, - * `decoderPerRequest`, is the session asked to use a different decoder per - request. It must be set to `true` when running in-flight batching, - * `maxTokensInPagedKvCache`, the maximum number of tokens that will have to be - stored in the KV cache when the paged KV cache is enabled. - ### Generation The `GptSession::generate` member function performs the generation loop. Given @@ -230,10 +244,10 @@ populates an instance of the sequences). It can be set to the same value as `endId`, * `ids`, is the tensor of input IDs. That tensor must be allocated on the GPU. When the input tensor is padded, the shape of `ids` is `[batchSize, - maxInputLength]`, where `batchSize` and `maxInputLength` correspond to the - arguments passed to the `GptSession::setup` member function. When the input - is packed, the shape of `ids` is `[numTokens]`, where `numTokens` is the sum - of the lengths of the different sequences in the batch, + maxInputLength]`, where `batchSize` and `maxInputLength` must respect the + maximum sizes in `sessionConfig` passed to the `GptSession` constructor. + When the input is packed, the shape of `ids` is `[numTokens]`, where + `numTokens` is the sum of the lengths of the different sequences in the batch, * `lengths`, is the tensor of input sequence lengths. That tensor must be allocated on the GPU and contain `batchSize` values, * `packed`, indicates if the `ids` tensor is packed or padded. In this diff --git a/examples/baichuan/build.py b/examples/baichuan/build.py index 45c7f483f..4e2074b75 100644 --- a/examples/baichuan/build.py +++ b/examples/baichuan/build.py @@ -126,6 +126,11 @@ def parse_arguments(): type=str, default='float16', choices=['float32', 'bfloat16', 'float16']) + parser.add_argument('--logits_dtype', + type=str, + default='float32', + choices=['float16', 'float32']) + parser.add_argument( '--timing_cache', type=str, @@ -163,6 +168,14 @@ def parse_arguments(): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--parallel_build', default=False, action='store_true') parser.add_argument('--visualize', default=False, action='store_true') parser.add_argument('--enable_debug_output', @@ -252,6 +265,14 @@ def parse_arguments(): default=None, help='Define the max number of tokens supported by the engine') + parser.add_argument( + '--strongly_typed', + default=False, + action="store_true", + help= + 'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.' + ) + args = parser.parse_args() assert not ( @@ -378,7 +399,8 @@ def build_rank_engine(builder: Builder, dtype=dtype, mlp_hidden_size=args.inter_size, mapping=mapping, - quant_mode=args.quant_mode) + quant_mode=args.quant_mode, + logits_dtype=args.logits_dtype) if args.use_smooth_quant or args.use_weight_only: tensorrt_llm_baichuan = quantize_model(tensorrt_llm_baichuan, args.quant_mode) @@ -408,6 +430,7 @@ def build_rank_engine(builder: Builder, elif args.bin_model_dir is not None: load_from_binary(tensorrt_llm_baichuan, args.bin_model_dir, + args.model_version, mapping, fp16=(args.dtype == 'float16'), multi_query_mode=False) @@ -432,6 +455,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.use_weight_only: network.plugin_config.set_weight_only_quant_matmul_plugin( dtype='float16') @@ -514,7 +539,8 @@ def build(rank, args): max_output_len=args.max_output_len, max_num_tokens=args.max_num_tokens, int8=int8_trt_flag, - quant_mode=args.quant_mode) + quant_mode=args.quant_mode, + strongly_typed=args.strongly_typed) engine_name = get_engine_name(model_name, args.dtype, args.world_size, cur_rank) engine = build_rank_engine(builder, builder_config, engine_name, diff --git a/examples/baichuan/run.py b/examples/baichuan/run.py index 6c05bd11c..2e8b7beee 100644 --- a/examples/baichuan/run.py +++ b/examples/baichuan/run.py @@ -104,6 +104,12 @@ def parse_input(input_text: str, input_file: str, tokenizer, end_id: int, def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--max_output_len', type=int, required=True) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='error') parser.add_argument('--model_version', type=str, @@ -179,6 +185,7 @@ def generate( output_csv: str = None, output_npy: str = None, tokenizer_dir: str = None, + max_kv_cache_len: int = None, num_beams: int = 1, ): tensorrt_llm.logger.set_level(log_level) @@ -231,7 +238,8 @@ def generate( decoder.setup(input_lengths.size(0), max_input_length, max_output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=max_kv_cache_len) output_ids = decoder.decode(input_ids, input_lengths, sampling_config) torch.cuda.synchronize() diff --git a/examples/baichuan/summarize.py b/examples/baichuan/summarize.py index b825123ef..723048aae 100644 --- a/examples/baichuan/summarize.py +++ b/examples/baichuan/summarize.py @@ -182,10 +182,12 @@ def summarize_tensorrt_llm(datapoint): end_id=end_id, pad_id=pad_id, top_k=top_k, num_beams=num_beams) with torch.no_grad(): - tensorrt_llm_baichuan.setup(batch_size, - max_context_length=max_length, - max_new_tokens=output_len, - beam_width=num_beams) + tensorrt_llm_baichuan.setup( + batch_size, + max_context_length=max_length, + max_new_tokens=output_len, + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) if tensorrt_llm_baichuan.remove_input_padding: output_ids = tensorrt_llm_baichuan.decode_batch( line_encoded, sampling_config) @@ -381,6 +383,12 @@ def summarize_hf(datapoint): parser.add_argument('--engine_dir', type=str, default='baichuan_outputs') parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_ite', type=int, default=20) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--check_accuracy', action='store_true') parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, diff --git a/examples/baichuan/weight.py b/examples/baichuan/weight.py index 7c3bc687c..c8e5adc3a 100644 --- a/examples/baichuan/weight.py +++ b/examples/baichuan/weight.py @@ -206,6 +206,7 @@ def gen_suffix(rank, use_smooth_quant, quant_per_channel): def load_from_binary(tensorrt_llm_baichuan: BaichuanForCausalLM, dir_path, + model_version, mapping=Mapping(), fp16=False, multi_query_mode=False): @@ -313,6 +314,12 @@ def set_smoother(module, dir_path, base_name, shape, rank): # share input embedding lm_head_weight = fromfile(dir_path, 'lm_head.weight.bin', [vocab_size, n_embd]) + if model_version.startswith('v2'): + # baichuan v2 models use NormHead + tensorrt_llm.logger.info( + f'Normalizing lm_head.weight for {model_version}') + lm_head_weight = lm_head_weight / np.linalg.norm( + lm_head_weight, axis=1, keepdims=True) if vocab_size % mapping.tp_size != 0: # padding diff --git a/examples/bloom/build.py b/examples/bloom/build.py index ab4cffd80..75ceef728 100644 --- a/examples/bloom/build.py +++ b/examples/bloom/build.py @@ -158,6 +158,14 @@ def parse_arguments(): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument( '--use_layernorm_plugin', nargs='?', @@ -261,6 +269,13 @@ def parse_arguments(): default=False, choices=['float16', 'float32', 'bfloat16'], help="Activates the lookup plugin which enables embedding sharing.") + parser.add_argument( + '--strongly_typed', + default=False, + action="store_true", + help= + 'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.' + ) args = parser.parse_args() logger.set_level(args.log_level) @@ -395,6 +410,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() # Quantization plugins. if args.use_smooth_quant: network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype) @@ -476,7 +493,8 @@ def build(rank, args): max_input_len=args.max_input_len, max_output_len=args.max_output_len, int8=int8_trt_flag, - quant_mode=args.quant_mode) + quant_mode=args.quant_mode, + strongly_typed=args.strongly_typed) builder_config.trt_builder_config.builder_optimization_level = 1 engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size, cur_rank) diff --git a/examples/bloom/run.py b/examples/bloom/run.py index d8acf25f4..45ae793a3 100644 --- a/examples/bloom/run.py +++ b/examples/bloom/run.py @@ -31,6 +31,12 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--max_output_len', type=int, required=True) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='error') parser.add_argument('--engine_dir', type=str, default='bloom_outputs') parser.add_argument('--tokenizer_dir', @@ -91,7 +97,8 @@ def parse_arguments(): runtime_mapping) decoder.setup(input_ids.size(0), max_context_length=input_ids.size(1), - max_new_tokens=args.max_output_len) + max_new_tokens=args.max_output_len, + max_kv_cache_length=args.max_kv_cache_len) output_ids = decoder.decode(input_ids, input_lengths, sampling_config) torch.cuda.synchronize() diff --git a/examples/bloom/summarize.py b/examples/bloom/summarize.py index 3d485b355..0db3b859d 100644 --- a/examples/bloom/summarize.py +++ b/examples/bloom/summarize.py @@ -170,7 +170,8 @@ def summarize_tensorrt_llm(datapoint): tensorrt_llm_bloom.setup(line_encoded.size(0), max_context_length=line_encoded.size(1), max_new_tokens=output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) output_ids = tensorrt_llm_bloom.decode( line_encoded, @@ -358,6 +359,12 @@ def summarize_hf(datapoint): parser.add_argument('--engine_dir', type=str, default='bloom_outputs') parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_ite', type=int, default=20) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--check_accuracy', action='store_true') parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, diff --git a/examples/chatglm/build.py b/examples/chatglm/build.py index 88a8687cb..414b2ea10 100644 --- a/examples/chatglm/build.py +++ b/examples/chatglm/build.py @@ -76,6 +76,10 @@ def parse_arguments(args): type=str, default='float16', choices=['float32', 'float16', 'bfloat16']) + parser.add_argument('--logits_dtype', + type=str, + default='float32', + choices=['float16', 'float32']) parser.add_argument( '--timing_cache', type=str, @@ -141,6 +145,14 @@ def parse_arguments(args): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument('--builder_opt', type=int, default=None) parser.add_argument( @@ -413,6 +425,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.paged_kv_cache: diff --git a/examples/chatglm/summarize.py b/examples/chatglm/summarize.py index 086bc68c3..daf759c79 100644 --- a/examples/chatglm/summarize.py +++ b/examples/chatglm/summarize.py @@ -209,7 +209,8 @@ def eval_tensorrt_llm(datapoint, eval_type='summarize'): tensorrt_llm_gpt.setup(batch_size, max_context_length=max_length, max_new_tokens=output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) if tensorrt_llm_gpt.remove_input_padding: output_ids = tensorrt_llm_gpt.decode_batch( @@ -439,6 +440,12 @@ def eval_hf(datapoint, eval_type='summarize'): parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_ite', type=int, default=20) parser.add_argument('--output_len', type=int, default=100) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--check_accuracy', action='store_true', default=True) parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, diff --git a/examples/enc_dec/build.py b/examples/enc_dec/build.py index ce8a537a5..4f0faeb8b 100644 --- a/examples/enc_dec/build.py +++ b/examples/enc_dec/build.py @@ -130,6 +130,14 @@ def parse_arguments(args, component): choices=['float16', 'float32', 'bfloat16'], help="Activates the lookup plugin which enables embedding sharing.") + parser.add_argument( + '--strongly_typed', + default=False, + action="store_true", + help= + 'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.' + ) + args = parser.parse_args(args) logger.set_level(args.log_level) @@ -325,7 +333,7 @@ def build(rank, args): cross_attention=(args.component == 'decoder'), has_position_embedding=args.has_position_embedding, has_token_type_embedding=args.has_token_type_embedding, - ) + strongly_typed=args.strongly_typed) engine_name = get_engine_name(MODEL_NAME, args.dtype, world_size, cur_rank) diff --git a/examples/falcon/build.py b/examples/falcon/build.py index fd3f4159f..498239eb6 100644 --- a/examples/falcon/build.py +++ b/examples/falcon/build.py @@ -33,6 +33,7 @@ from tensorrt_llm.models import quantize_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.profiler import check_gpt_mem_usage from tensorrt_llm.quantization import QuantMode from weight import get_scaling_factors # isort:skip @@ -219,6 +220,14 @@ def parse_arguments(): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--visualize', default=False, action='store_true') parser.add_argument('--load_by_shard', action='store_true', @@ -458,6 +467,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.world_size > 1: network.plugin_config.set_nccl_plugin(args.dtype, @@ -545,6 +556,26 @@ def build(rank, args): assert engine is not None, \ f'Failed to build engine for rank {cur_rank}' + local_num_kv_heads = (args.n_kv_head + args.world_size - + 1) // args.world_size + kv_dtype = str_dtype_to_trt(args.dtype) + if args.quant_mode.has_int8_kv_cache(): + kv_dtype = str_dtype_to_trt('int8') + elif args.quant_mode.has_fp8_kv_cache(): + kv_dtype = str_dtype_to_trt('fp8') + check_gpt_mem_usage( + engine=engine, + kv_dtype=kv_dtype, + use_gpt_attention_plugin=args.use_gpt_attention_plugin, + paged_kv_cache=args.paged_kv_cache, + max_batch_size=args.max_batch_size, + max_beam_width=args.max_beam_width, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + local_num_kv_heads=local_num_kv_heads, + head_size=args.n_embd / args.n_head, + num_layers=args.n_layer) + if cur_rank == 0: # Use in-memory timing cache for multiple builder passes. if not args.parallel_build: diff --git a/examples/falcon/run.py b/examples/falcon/run.py index ace1a0c26..1d1d248a9 100644 --- a/examples/falcon/run.py +++ b/examples/falcon/run.py @@ -31,6 +31,12 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--max_output_len', type=int, required=True) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='error') parser.add_argument('--engine_dir', type=str, default='falcon_outputs') parser.add_argument('--tokenizer_dir', @@ -216,7 +222,8 @@ def main(): decoder.setup(input_ids.size(0), max_context_length=input_ids.size(1), max_new_tokens=args.max_output_len, - beam_width=args.num_beams) + beam_width=args.num_beams, + max_kv_cache_length=args.max_kv_cache_len) output_ids = decoder.decode(input_ids, input_lengths, sampling_config) torch.cuda.synchronize() diff --git a/examples/falcon/summarize.py b/examples/falcon/summarize.py index 2f06f4d4a..dbd3cf70d 100644 --- a/examples/falcon/summarize.py +++ b/examples/falcon/summarize.py @@ -208,7 +208,8 @@ def summarize_tensorrt_llm(datapoint): tensorrt_llm_falcon.setup(batch_size, max_context_length=max_length, max_new_tokens=output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) if tensorrt_llm_falcon.remove_input_padding: output_ids = tensorrt_llm_falcon.decode_batch( @@ -413,6 +414,12 @@ def summarize_hf(datapoint): parser.add_argument('--engine_dir', type=str, default='falcon_outputs') parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_ite', type=int, default=20) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--check_accuracy', action='store_true') parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, diff --git a/examples/gpt/build.py b/examples/gpt/build.py index 36c2e24c9..7e07cbb0c 100644 --- a/examples/gpt/build.py +++ b/examples/gpt/build.py @@ -29,6 +29,7 @@ from tensorrt_llm.models import quantize_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.profiler import check_gpt_mem_usage from tensorrt_llm.quantization import QuantMode from weight import load_from_ft, parse_ft_config, check_embedding_share # isort:skip @@ -136,6 +137,14 @@ def parse_arguments(args): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument('--builder_opt', type=int, default=None) parser.add_argument( @@ -482,6 +491,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.paged_kv_cache: @@ -555,6 +566,7 @@ def build(rank, args): int8_trt_flag = args.quant_mode.has_act_or_weight_quant() or ( args.paged_kv_cache == False and args.quant_mode.has_int8_kv_cache()) + num_kv_heads = 1 if args.multi_query_mode else args.n_head builder_config = builder.create_builder_config( name=MODEL_NAME, precision=args.dtype, @@ -563,7 +575,7 @@ def build(rank, args): parallel_build=args.parallel_build, num_layers=args.n_layer, num_heads=args.n_head, - num_kv_heads=1 if args.multi_query_mode else args.n_head, + num_kv_heads=num_kv_heads, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, @@ -589,6 +601,26 @@ def build(rank, args): cur_rank, args) assert engine is not None, f'Failed to build engine for rank {cur_rank}' + local_num_kv_heads = (num_kv_heads + args.world_size - + 1) // args.world_size + kv_dtype = str_dtype_to_trt(args.dtype) + if args.quant_mode.has_int8_kv_cache(): + kv_dtype = str_dtype_to_trt('int8') + elif args.quant_mode.has_fp8_kv_cache(): + kv_dtype = str_dtype_to_trt('fp8') + check_gpt_mem_usage( + engine=engine, + kv_dtype=kv_dtype, + use_gpt_attention_plugin=args.use_gpt_attention_plugin, + paged_kv_cache=args.paged_kv_cache, + max_batch_size=args.max_batch_size, + max_beam_width=args.max_beam_width, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + local_num_kv_heads=local_num_kv_heads, + head_size=args.n_embd / args.n_head, + num_layers=args.n_layer) + if cur_rank == 0: # Use in-memory timing cache for multiple builder passes. if not args.parallel_build: diff --git a/examples/gpt/run.py b/examples/gpt/run.py index f94b7acc3..5cf4ebab3 100644 --- a/examples/gpt/run.py +++ b/examples/gpt/run.py @@ -184,6 +184,12 @@ def print_output(output_ids, input_lengths, sequence_lengths, tokenizer, def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--max_output_len', type=int, required=True) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='error') parser.add_argument('--engine_dir', type=str, default='gpt_outputs') parser.add_argument('--input_text', @@ -234,6 +240,7 @@ def generate( output_npy: str = None, tokenizer_path: str = 'gpt2', vocab_file=None, + max_kv_cache_len: int = None, num_beams: int = 1, prompt_table: Path = None, tasks: str = None, @@ -290,7 +297,8 @@ def generate( decoder.setup(input_lengths.size(0), max_input_length, max_output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=max_kv_cache_len) ptuning_args = [] if model_config.max_prompt_embedding_table_size == 0 else ptuning_setup( prompt_table, dtype, model_config.hidden_size, tasks, input_ids, diff --git a/examples/gpt/summarize.py b/examples/gpt/summarize.py index 5182e7ab1..75528adca 100644 --- a/examples/gpt/summarize.py +++ b/examples/gpt/summarize.py @@ -154,6 +154,8 @@ def main(args): model.cuda() if args.data_type == 'fp16': model.half() + elif args.data_type == 'bf16': + model.bfloat16() def eval_tensorrt_llm(datapoint, eval_type='summarize'): batch_size = len(datapoint) @@ -207,7 +209,8 @@ def eval_tensorrt_llm(datapoint, eval_type='summarize'): tensorrt_llm_gpt.setup(batch_size, max_context_length=max_length, max_new_tokens=output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) if tensorrt_llm_gpt.remove_input_padding: outputs = tensorrt_llm_gpt.decode_batch( @@ -503,7 +506,7 @@ def eval_hf(datapoint, eval_type='summarize'): parser.add_argument('--test_trt_llm', action='store_true') parser.add_argument('--data_type', type=str, - choices=['fp32', 'fp16'], + choices=['fp32', 'fp16', 'bf16'], default='fp32') parser.add_argument('--dataset_path', type=str, default='') parser.add_argument('--log_level', type=str, default='info') @@ -511,6 +514,12 @@ def eval_hf(datapoint, eval_type='summarize'): parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_ite', type=int, default=20) parser.add_argument('--output_len', type=int, default=100) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--check_accuracy', action='store_true') parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, diff --git a/examples/gpt/weight.py b/examples/gpt/weight.py index 37296a06f..3d7c9a41b 100644 --- a/examples/gpt/weight.py +++ b/examples/gpt/weight.py @@ -20,8 +20,9 @@ import torch import tensorrt_llm -from tensorrt_llm._utils import (pad_vocab_size, str_dtype_to_np, - str_dtype_to_torch, torch_to_numpy) +from tensorrt_llm._utils import (numpy_to_torch, pad_vocab_size, + str_dtype_to_np, str_dtype_to_torch, + torch_to_numpy) from tensorrt_llm.functional import is_gated_activation from tensorrt_llm.models import GPTLMHeadModel from tensorrt_llm.quantization import QuantMode @@ -259,11 +260,11 @@ def set_smoothquant_scale_factors(module, is_qkv=True) elif use_weight_only: processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() + numpy_to_torch(t), plugin_weight_only_quant_type) + dst.value = torch_to_numpy(processed_torch_weights) scales = tensorrt_llm_gpt.layers[ i].attention.qkv.per_channel_scale - scales.value = torch_weight_scales.numpy() + scales.value = torch_to_numpy(torch_weight_scales) else: dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) if bias: @@ -305,11 +306,11 @@ def set_smoothquant_scale_factors(module, [1, n_embd // tensor_parallel], dtype=np.float32) elif use_weight_only: processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() + numpy_to_torch(t), plugin_weight_only_quant_type) + dst.value = torch_to_numpy(processed_torch_weights) scales = tensorrt_llm_gpt.layers[ i].attention.dense.per_channel_scale - scales.value = torch_weight_scales.numpy() + scales.value = torch_to_numpy(torch_weight_scales) else: dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) @@ -355,10 +356,10 @@ def set_smoothquant_scale_factors(module, elif use_weight_only: dst = tensorrt_llm_gpt.layers[i].mlp.fc.weight processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() + numpy_to_torch(t), plugin_weight_only_quant_type) + dst.value = torch_to_numpy(processed_torch_weights) scales = tensorrt_llm_gpt.layers[i].mlp.fc.per_channel_scale - scales.value = torch_weight_scales.numpy() + scales.value = torch_to_numpy(torch_weight_scales) else: tensorrt_llm_gpt.layers[ i].mlp.fc.weight.value = np.ascontiguousarray( @@ -403,10 +404,10 @@ def set_smoothquant_scale_factors(module, elif use_weight_only: dst = tensorrt_llm_gpt.layers[i].mlp.proj.weight processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() + numpy_to_torch(t), plugin_weight_only_quant_type) + dst.value = torch_to_numpy(processed_torch_weights) scales = tensorrt_llm_gpt.layers[i].mlp.proj.per_channel_scale - scales.value = torch_weight_scales.numpy() + scales.value = torch_to_numpy(torch_weight_scales) else: tensorrt_llm_gpt.layers[i].mlp.proj.weight.value = ( np.ascontiguousarray(np.transpose(t, [1, 0]))) diff --git a/examples/gptj/README.md b/examples/gptj/README.md index a66bbcd8b..1eee8d72a 100644 --- a/examples/gptj/README.md +++ b/examples/gptj/README.md @@ -117,7 +117,6 @@ python build.py --model_dir gptj_model \ --dtype float16 \ --use_gpt_attention_plugin float16 \ --enable_context_fmha \ - --enable_two_optimization_profiles \ --output_dir gptj_engine_fp8_quantized \ --enable_fp8 \ --fp8_kv_cache \ @@ -220,7 +219,7 @@ The linear layer in the AWQ int4 weight only quantized weights should have 3 par To run a TensorRT-LLM GPT-J model: ```bash -python3 run.py --max_output_len=50 --engine_dir=gptj_engine +python3 run.py --max_output_len=50 --engine_dir=gptj_engine --hf_model_location=gptj_model ``` ## Summarization using the GPT-J model diff --git a/examples/gptj/build.py b/examples/gptj/build.py index 9a4c0702b..61e701270 100644 --- a/examples/gptj/build.py +++ b/examples/gptj/build.py @@ -26,12 +26,14 @@ load_from_bin_gpt_j, load_from_hf_gpt_j, parse_config) import tensorrt_llm +from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.builder import Builder from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import quantize_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.profiler import check_gpt_mem_usage from tensorrt_llm.quantization import QuantMode MODEL_NAME = "gptj" @@ -123,6 +125,14 @@ def parse_arguments(args): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument( '--output_dir', @@ -392,6 +402,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.use_weight_only: if args.per_group: network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin( @@ -480,6 +492,26 @@ def build(rank, args): cur_rank, args) assert engine is not None, f'Failed to build engine for rank {cur_rank}' + local_num_kv_heads = (args.n_head + args.world_size - + 1) // args.world_size + kv_dtype = str_dtype_to_trt(args.dtype) + if args.quant_mode.has_int8_kv_cache(): + kv_dtype = str_dtype_to_trt('int8') + elif args.quant_mode.has_fp8_kv_cache(): + kv_dtype = str_dtype_to_trt('fp8') + check_gpt_mem_usage( + engine=engine, + kv_dtype=kv_dtype, + use_gpt_attention_plugin=args.use_gpt_attention_plugin, + paged_kv_cache=args.paged_kv_cache, + max_batch_size=args.max_batch_size, + max_beam_width=args.max_beam_width, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + local_num_kv_heads=local_num_kv_heads, + head_size=args.n_embd / args.n_head, + num_layers=args.n_layer) + if cur_rank == 0: # Use in-memory timing cache for multiple builder passes. if not args.parallel_build: diff --git a/examples/gptj/run.py b/examples/gptj/run.py index f3d78ac94..87a997820 100644 --- a/examples/gptj/run.py +++ b/examples/gptj/run.py @@ -149,6 +149,12 @@ def print_output(output_ids, cum_log_probs, input_lengths, sequence_lengths, def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--max_output_len', type=int, required=True) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='error') parser.add_argument('--engine_dir', type=str, default='gpt_outputs') parser.add_argument('--num_beams', type=int, default=1) @@ -190,6 +196,7 @@ def generate( output_csv: str = None, output_npy: str = None, hf_model_location: str = 'gptj', + max_kv_cache_len: int = None, num_beams: int = 1, min_length: int = 1, ): @@ -232,7 +239,8 @@ def generate( decoder.setup(input_lengths.size(0), max_input_length, max_output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=max_kv_cache_len) outputs = decoder.decode(input_ids, input_lengths, diff --git a/examples/gptj/summarize.py b/examples/gptj/summarize.py index 4066a94be..b490d3bb7 100644 --- a/examples/gptj/summarize.py +++ b/examples/gptj/summarize.py @@ -174,7 +174,8 @@ def summarize_tensorrt_llm(datapoint): tensorrt_llm_gpt.setup(batch_size, max_context_length=max_length, max_new_tokens=output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) if tensorrt_llm_gpt.remove_input_padding: output_ids = tensorrt_llm_gpt.decode_batch( @@ -396,6 +397,12 @@ def summarize_hf(datapoint): parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_ite', type=int, default=20) parser.add_argument('--output_len', type=int, default=100) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--check_accuracy', action='store_true') parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, diff --git a/examples/gptneox/build.py b/examples/gptneox/build.py index 01c55c32f..ed5b4ac44 100644 --- a/examples/gptneox/build.py +++ b/examples/gptneox/build.py @@ -175,6 +175,14 @@ def parse_arguments(): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument( '--output_dir', @@ -204,6 +212,14 @@ def parse_arguments(): 'Note: embedding sharing is only enabled when --embedding_sharding_dim=0' ) + parser.add_argument( + '--strongly_typed', + default=False, + action="store_true", + help= + 'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.' + ) + args = parser.parse_args() logger.set_level(args.log_level) @@ -317,6 +333,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.use_weight_only_quant_matmul_plugin: network.plugin_config.set_weight_only_quant_matmul_plugin( dtype=args.use_weight_only_quant_matmul_plugin) @@ -385,7 +403,8 @@ def build(rank, args): max_input_len=args.max_input_len, int8=args.use_weight_only_quant_matmul_plugin or args.use_weight_only_groupwise_quant_matmul_plugin, - max_output_len=args.max_output_len) + max_output_len=args.max_output_len, + strongly_typed=args.strongly_typed) engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size, cur_rank) diff --git a/examples/gptneox/run.py b/examples/gptneox/run.py index 3d2fb811e..701a0759f 100644 --- a/examples/gptneox/run.py +++ b/examples/gptneox/run.py @@ -28,6 +28,12 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--max_output_len', type=int, required=True) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='error') parser.add_argument('--engine_dir', type=str, default='gptneox_outputs') parser.add_argument('--tokenizer_dir', @@ -95,7 +101,9 @@ def parse_arguments(): runtime_mapping, debug_mode=False) if remove_input_padding: - decoder.setup(1, torch.max(input_lengths).item(), args.max_output_len) + decoder.setup(1, + torch.max(input_lengths).item(), args.max_output_len, + args.max_kv_cache_len) else: decoder.setup(input_ids.size(0), input_ids.size(1), args.max_output_len) output_ids = decoder.decode(input_ids, input_lengths, sampling_config) diff --git a/examples/gptneox/summarize.py b/examples/gptneox/summarize.py index 4664f619e..b9f6d0a45 100644 --- a/examples/gptneox/summarize.py +++ b/examples/gptneox/summarize.py @@ -166,7 +166,8 @@ def summarize_tensorrt_llm(datapoint): tensorrt_llm_gpt.setup(batch_size, max_context_length=max_length, max_new_tokens=output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) if tensorrt_llm_gpt.remove_input_padding: output_ids = tensorrt_llm_gpt.decode_batch( @@ -361,6 +362,12 @@ def summarize_hf(datapoint): parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_ite', type=int, default=20) parser.add_argument('--output_len', type=int, default=100) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--check_accuracy', action='store_true') parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, diff --git a/examples/internlm/build.py b/examples/internlm/build.py index f72706319..7d08b7f36 100644 --- a/examples/internlm/build.py +++ b/examples/internlm/build.py @@ -184,6 +184,14 @@ def parse_arguments(): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--visualize', default=False, action='store_true') parser.add_argument('--enable_debug_output', default=False, @@ -597,6 +605,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.use_weight_only: if args.per_group: network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin( diff --git a/examples/llama/README.md b/examples/llama/README.md index d7df10f86..ba80b9a39 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -431,6 +431,33 @@ mpirun -n 2 --allow-run-as-root \ --engine_dir ./tmp/llama/30B/trt_engines/fp16/2-gpu/ ``` +#### Mistral v1.0 +Mistral v1.0 is compatible with LLaMA interface and can be built and run using the same instructions. +Setting `--max_input_len`, corresponding to the `max_position_embeddings` in the original Mistral config explicitly regulates context size. +The `--max_kv_cache_len` parameter is set to the `sliding_window` value in the config and regulates both sliding window attention in the context phase and rolling buffer cache in the generation phase. + +```bash +# Build Mistral 7B with max input length 32256 +python build.py --model_dir ./tmp/mistral/7B/ \ + --dtype float16 \ + --remove_input_padding \ + --use_gpt_attention_plugin float16 \ + --enable_context_fmha \ + --use_gemm_plugin float16 \ + --output_dir ./tmp/mistral/7B/trt_engines/fp16/1-gpu/ \ + --max_input_len 32256 + +# Run Mistral 7B fp16 inference with sliding window/cache size 4096 +python3 run.py --max_output_len=50 \ + --tokenizer_dir ./tmp/llama/7B/ \ + --engine_dir=./tmp/llama/7B/trt_engines/fp16/1-gpu/ \ + --max_kv_cache_len=4096 +``` + +Note that if you are comparing TRT-LLM with Huggingface, +you should install `transformers` with version >= 3.34.1 in order to have Mistral model supported. +And upgrade `flash-attn` package by `pip install --upgrade flash-attn` or you may see wrong results generated by the huggingface implementation. + ## Running CodeLlama Those examples can be used to build and run the CodeLlama models. All 7b, 13b, and 34b sizes and variants are supported. diff --git a/examples/llama/build.py b/examples/llama/build.py index 5a4b524f5..c460627f3 100644 --- a/examples/llama/build.py +++ b/examples/llama/build.py @@ -35,6 +35,7 @@ from tensorrt_llm.models import quantize_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.profiler import check_gpt_mem_usage from tensorrt_llm.quantization import QuantMode from weight import parse_ft_config # isort:skip @@ -184,6 +185,14 @@ def parse_arguments(): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--visualize', default=False, action='store_true') parser.add_argument('--enable_debug_output', default=False, @@ -411,7 +420,7 @@ def parse_arguments(): args.n_kv_head = hf_config.num_key_value_heads args.n_layer = hf_config.num_hidden_layers args.n_positions = hf_config.max_position_embeddings - args.vocab_size = hf_config.vocab_size + args.vocab_size = hf_config.vocab_size if args.vocab_size is None else args.vocab_size args.hidden_act = hf_config.hidden_act args.rms_norm_eps = hf_config.rms_norm_eps elif args.meta_ckpt_dir is not None: @@ -437,7 +446,7 @@ def parse_arguments(): args.n_head = n_head args.n_layer = n_layer args.n_positions = n_positions - args.vocab_size = vocab_size + args.vocab_size = vocab_size if args.vocab_size is None else args.vocab_size args.hidden_act = hidden_act args.rms_norm_eps = 1e-06 logger.warning("Set rms_norm_eps to 1e-06 directly.") @@ -593,6 +602,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.use_weight_only: if args.per_group: network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin( @@ -697,6 +708,26 @@ def build(rank, args): cur_rank, args) assert engine is not None, f'Failed to build engine for rank {cur_rank}' + local_num_kv_heads = (args.n_kv_head + args.world_size - + 1) // args.world_size + kv_dtype = str_dtype_to_trt(args.dtype) + if args.quant_mode.has_int8_kv_cache(): + kv_dtype = str_dtype_to_trt('int8') + elif args.quant_mode.has_fp8_kv_cache(): + kv_dtype = str_dtype_to_trt('fp8') + check_gpt_mem_usage( + engine=engine, + kv_dtype=kv_dtype, + use_gpt_attention_plugin=args.use_gpt_attention_plugin, + paged_kv_cache=args.paged_kv_cache, + max_batch_size=args.max_batch_size, + max_beam_width=args.max_beam_width, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + local_num_kv_heads=local_num_kv_heads, + head_size=args.n_embd / args.n_head, + num_layers=args.n_layer) + if cur_rank == 0: # Use in-memory timing cache for multiple builder passes. if not args.parallel_build: diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index 926de5f08..a6789a325 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,3 +1,3 @@ -datasets==2.14.5 +datasets==2.14.6 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/llama/run.py b/examples/llama/run.py index 2091a9b20..75512abea 100644 --- a/examples/llama/run.py +++ b/examples/llama/run.py @@ -196,6 +196,12 @@ def print_output(output_ids, input_lengths, max_output_len, tokenizer, def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--max_output_len', type=int, required=True) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='error') parser.add_argument('--engine_dir', type=str, default='llama_outputs') parser.add_argument('--tokenizer_dir', @@ -248,6 +254,7 @@ def generate( output_csv: str = None, output_npy: str = None, tokenizer_dir: str = None, + max_kv_cache_len: int = None, num_beams: int = 1, streaming: bool = False, streaming_interval: int = 5, @@ -293,8 +300,11 @@ def generate( print(input_ids) max_input_length = torch.max(input_lengths).item() - decoder.setup(input_lengths.size(0), max_input_length, max_output_len, - num_beams) + decoder.setup(input_lengths.size(0), + max_input_length, + max_output_len, + num_beams, + max_kv_cache_length=max_kv_cache_len) ptuning_args = [] if model_config.max_prompt_embedding_table_size == 0 else ptuning_setup( prompt_table, dtype, model_config.hidden_size, tasks, input_ids, diff --git a/examples/llama/summarize.py b/examples/llama/summarize.py index 5c3d8d486..9fc883c1d 100644 --- a/examples/llama/summarize.py +++ b/examples/llama/summarize.py @@ -200,7 +200,8 @@ def summarize_tensorrt_llm(datapoint): tensorrt_llm_llama.setup(batch_size, max_context_length=max_length, max_new_tokens=output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) if tensorrt_llm_llama.remove_input_padding: output_ids = tensorrt_llm_llama.decode_batch( @@ -390,6 +391,12 @@ def summarize_hf(datapoint): choices=['fp32', 'fp16'], default='fp16') parser.add_argument('--dataset_path', type=str, default='') + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='info') parser.add_argument('--engine_dir', type=str, default='llama_outputs') parser.add_argument('--batch_size', type=int, default=1) diff --git a/examples/llama/summarize_long.py b/examples/llama/summarize_long.py new file mode 100644 index 000000000..0ca1e41fa --- /dev/null +++ b/examples/llama/summarize_long.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +import torch +from datasets import load_dataset, load_metric +from summarize import TRTLLaMA +from transformers import AutoModelForCausalLM, LlamaTokenizer + +import tensorrt_llm +import tensorrt_llm.profiler as profiler +from tensorrt_llm.logger import logger + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--hf_model_location', + type=str, + default='/code/tensorrt_llm/models/Mistral-7B-v0.1') + parser.add_argument('--test_hf', action='store_true') + parser.add_argument('--test_trt_llm', action='store_true') + parser.add_argument('--data_type', + type=str, + choices=['fp16'], + default='fp16') + parser.add_argument('--dataset_path', + type=str, + default='/code/tensorrt_llm/data') + parser.add_argument('--max_kv_cache_len', + type=int, + default=4096, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') + parser.add_argument( + '--max_input_len', + type=int, + default=6400, + help='The max input length TensorRT-LLM engine was built with') + parser.add_argument('--log_level', type=str, default='info') + parser.add_argument('--max_ite', type=int, default=5) + parser.add_argument( + '--engine_dir', + type=str, + default='/code/tensorrt_llm/mistral_trtllm/llama_style_merge_long_v2') + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--num_beams', type=int, default=1) + parser.add_argument('--top_k', type=int, default=1) + parser.add_argument('--output_len', type=int, default=128) + parser.add_argument('--temperature', type=float, default=1) + parser.add_argument('--check_accuracy', action='store_true') + parser.add_argument('--tensorrt_llm_rouge1_threshold', + type=float, + default=15.0) + + args = parser.parse_args() + return args + + +def get_long_texts(dataset_openweb): + for datapoint in dataset_openweb["train"]: + text = datapoint["text"] + approximate_tokens = len(text.split()) + if (approximate_tokens > args.max_kv_cache_len) and ( + approximate_tokens < args.max_input_len): + yield text + + +def prepare_prompt(text): + text = text.replace("\n", " ") + text = text + '\n TL;DR: ' + text = text.strip() + text = text.replace(" n't", "n't") + return text + + +def summarize_hf(datapoint, tokenizer, hf_model, args): + + line_encoded = tokenizer(datapoint, + return_tensors='pt', + padding=True, + truncation=True)["input_ids"].type(torch.int32) + + line_encoded = line_encoded.cuda() + + with torch.no_grad(): + output = hf_model.generate(line_encoded, + max_new_tokens=args.output_len, + temperature=args.temperature, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + num_beams=args.num_beams, + top_k=args.top_k, + do_sample=True, + early_stopping=True) + + tokens_list = output[:, len(line_encoded[0]):].tolist() + output = output.reshape([args.batch_size, args.num_beams, -1]) + output_lines_list = [ + tokenizer.batch_decode(output[:, i, len(line_encoded[0]):], + skip_special_tokens=True) + for i in range(args.num_beams) + ] + + return output_lines_list, tokens_list + + +def summarize_tensorrt_llm(datapoint, tokenizer, tensorrt_llm_llama, args): + line_encoded = [] + input_id = tokenizer.encode(datapoint, + return_tensors='pt').type(torch.int32) + line_encoded.append(input_id) + input_lengths = [] + input_lengths.append(input_id.shape[-1]) + max_length = max(input_lengths) + + pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0] + end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0] + + if tensorrt_llm_llama.remove_input_padding: + line_encoded = [ + torch.tensor(t, dtype=torch.int32).cuda() for t in line_encoded + ] + else: + # do padding, should move outside the profiling to prevent the overhead + for i in range(args.batch_size): + pad_size = max_length - input_lengths[i] + + pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id + line_encoded[i] = torch.cat( + [torch.tensor(line_encoded[i], dtype=torch.int32), pad], + axis=-1) + + line_encoded = torch.cat(line_encoded, axis=0).cuda() + + input_lengths = torch.tensor(input_lengths, dtype=torch.int32).cuda() + + sampling_config = tensorrt_llm.runtime.SamplingConfig( + end_id=end_id, + pad_id=pad_id, + top_k=args.top_k, + num_beams=args.num_beams) + + with torch.no_grad(): + tensorrt_llm_llama.setup(batch_size=args.batch_size, + max_context_length=max_length, + max_new_tokens=args.output_len, + beam_width=args.num_beams, + max_kv_cache_length=args.max_kv_cache_len) + logger.info(f"Generation session set up with the parameters: \ + batch_size: {tensorrt_llm_llama.batch_size}, \ + max_context_length: {tensorrt_llm_llama.max_context_length}, \ + max_new_tokens: {tensorrt_llm_llama.max_new_tokens}, \ + beam_width: {tensorrt_llm_llama.beam_width}, \ + max_kv_cache_length: {tensorrt_llm_llama.max_kv_cache_length}") + + if tensorrt_llm_llama.remove_input_padding: + output_ids = tensorrt_llm_llama.decode_batch( + line_encoded, sampling_config) + else: + output_ids = tensorrt_llm_llama.decode( + line_encoded, + input_lengths, + sampling_config, + ) + torch.cuda.synchronize() + + logger.info(f"Decoded output of shape{output_ids.shape}") + + # Extract a list of tensors of shape beam_width x output_ids. + if tensorrt_llm_llama.mapping.is_first_pp_rank(): + output_beams_list = [ + tokenizer.batch_decode(output_ids[batch_idx, :, + input_lengths[batch_idx]:], + skip_special_tokens=True) + for batch_idx in range(args.batch_size) + ] + return output_beams_list, output_ids[:, :, max_length:].tolist() + return [], [] + + +def main(args): + runtime_rank = tensorrt_llm.mpi_rank() + logger.set_level(args.log_level) + + profiler.start('load tokenizer') + tokenizer = LlamaTokenizer.from_pretrained(args.hf_model_location, + legacy=False, + padding_side='left') + profiler.stop('load tokenizer') + tensorrt_llm.logger.info( + f'Load tokenizer takes: {profiler.elapsed_time_in_sec("load tokenizer")} sec' + ) + tokenizer.pad_token = tokenizer.eos_token + + dataset_openweb = load_dataset("stas/openwebtext-10k", + cache_dir=args.dataset_path) + long_texts = get_long_texts(dataset_openweb) # generator + + # get datapoints + try: + datapoints = [ + prepare_prompt(next(long_texts)) for i in range(args.max_ite) + ] + except StopIteration: + logger.warning( + f"No test data of sufficient length ({args.max_kv_cache_len}). Try decreasing the max_kv_cache_len parameter" + ) + return + + if args.test_trt_llm: + config_path = os.path.join(args.engine_dir, 'config.json') + with open(config_path, 'r') as f: + config = json.load(f) + tensorrt_llm_llama = TRTLLaMA(args, config) + + trt_llm_summary = [] + for ite in range(args.max_ite): + trt_llm_summary.append( + summarize_tensorrt_llm(datapoints[ite], tokenizer, + tensorrt_llm_llama, args)[0]) + + if runtime_rank == 0: + logger.info( + "---------------------------------------------------------") + logger.info("TRT LLM Generated : ") + logger.info(f" Article : {datapoints[0]}") + logger.info(f"\n Summary : {trt_llm_summary[0]}") + logger.info( + "---------------------------------------------------------") + + del tensorrt_llm_llama + + test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0 + if test_hf: + profiler.start('load HF model') + hf_model = AutoModelForCausalLM.from_pretrained( + args.hf_model_location, + torch_dtype=torch.float16, + use_flash_attention_2=True) + profiler.stop('load HF model') + tensorrt_llm.logger.info( + f'Load HF model takes: {profiler.elapsed_time_in_sec("load HF model")} sec' + ) + hf_model.cuda() + + hf_summary = [] + for ite in range(args.max_ite): + hf_summary.append( + summarize_hf(datapoints[ite], tokenizer, hf_model, args)[0]) + logger.info("---------------------------------------------------------") + logger.info("HF Generated : ") + logger.info(f" Article : {datapoints[0]}") + logger.info(f"\n Summary : {hf_summary[0]}") + logger.info("---------------------------------------------------------") + + # no ground truth, compare with hf + if runtime_rank == 0 and args.test_hf and args.test_trt_llm: + + metric_tensorrt_llm = [ + load_metric("rouge") for _ in range(args.num_beams) + ] + + for i in range(args.num_beams): + metric_tensorrt_llm[i].seed = 0 + + for ite in range(args.max_ite): + for batch_idx in range(len(trt_llm_summary[0])): + for beam_idx in range(args.num_beams): + metric_tensorrt_llm[beam_idx].add_batch( + predictions=[trt_llm_summary[ite][batch_idx][beam_idx]], + references=[hf_summary[ite][beam_idx][batch_idx]]) + + for beam_idx in range(args.num_beams): + logger.info(f"TensorRT-LLM beam {beam_idx} result") + computed_metrics_tensorrt_llm = metric_tensorrt_llm[ + beam_idx].compute() + for key in computed_metrics_tensorrt_llm.keys(): + logger.info( + f' {key} : {computed_metrics_tensorrt_llm[key].mid[2]*100}' + ) + + if args.check_accuracy and beam_idx == 0: + assert computed_metrics_tensorrt_llm['rouge1'].mid[ + 2] * 100 > args.tensorrt_llm_rouge1_threshold + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/examples/llama/weight.py b/examples/llama/weight.py index d347c341e..0211c5f89 100644 --- a/examples/llama/weight.py +++ b/examples/llama/weight.py @@ -14,7 +14,6 @@ # limitations under the License. import configparser import time -from operator import attrgetter from pathlib import Path from typing import Dict, List, Optional, Union @@ -829,19 +828,42 @@ def load_from_gptq_llama(tensorrt_llm_llama, 'Loading weights from groupwise GPTQ LLaMA safetensors...') tik = time.time() - if quant_ckpt_path.endswith(".safetensors"): - groupwise_qweight_safetensors = safe_open(quant_ckpt_path, - framework="pt", - device=0) - model_params = { - key: groupwise_qweight_safetensors.get_tensor(key) - for key in groupwise_qweight_safetensors.keys() - } - elif quant_ckpt_path.endswith(".pt"): - model_params = torch.load(quant_ckpt_path, - map_location=torch.device('cpu')) - else: - assert False, "Quantized checkpoint format not supported!" + gptq_llama = safe_open(quant_ckpt_path, framework="pt", device=0) + gptq_prefix = "model." + gptq_suffix_list = [".qweight", ".qzeros", ".scales"] + gptq_key_list = [ + "embed_tokens.weight", # vocab_embedding + "lm_head.weight", # lm_head + "norm.weight", # ln_f + "self_attn.", # attention.qkv + "_proj", # qkv suffix + "self_attn.o_proj", # attention.dense + "mlp.up_proj", # mlp.gate + "mlp.down_proj", # mlp.proj + "mlp.gate_proj", # mlp.fc + "input_layernorm.weight", # input_layernorm + "post_attention_layernorm.weight", # post_layernorm + ] + split_sym = "." + + packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4 + preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm + torch_dtype = str_dtype_to_torch(dtype) + + def load(key, no_prefix=0): + if no_prefix: + return gptq_llama.get_tensor(key) + else: + return gptq_llama.get_tensor(gptq_prefix + key) + + def torch_split(v, dim): + if v.shape[dim] % mapping.tp_size != 0: + tensorrt_llm.logger.error( + "Current weight shape is invalid for mapping.tp_size=" + + str(mapping.tp_size)) + assert False, "Invalid TP size" + return v.split(v.shape[dim] // mapping.tp_size, + dim=dim)[mapping.tp_rank] def unpack_int32_into_int8(w_packed): # Unpack inputs packed in int32/float32 into uint4 and store them in int8 format @@ -853,19 +875,18 @@ def unpack_int32_into_int8(w_packed): w_unpacked[:, 1::2] = w_packed_int4x2 // 16 return w_unpacked.contiguous() - def preprocess_groupwise_weight_params(weight_name, - qweight_int32=None, - qzeros_int32=None, - scales_fp16=None): - if weight_name is not None: - qweight_int32 = model_params[weight_name].cpu() - qzeros_int32 = model_params[weight_name[:-7] + 'qzeros'].cpu() - scales_fp16 = model_params[weight_name[:-7] + 'scales'].cpu() + def process_and_assign_weight(mOp, v, tp_dim=0): + if tp_dim == -1: + qweight_int32, qzeros_int32, scales_fp16 = [ + item.cpu() for item in v + ] + else: + qweight_int32, qzeros_int32, scales_fp16 = [ + torch_split(item, tp_dim).cpu() for item in v + ] - UINT4_TO_INT4_FLAG = 1 - GPTQ_FLAG = 1 - packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4 - preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm + USE_UINT4_INPUT = 1 # Set to true if checkpoint store UINT4 weights + USE_GPTQ_FOR_LLAMA = 1 # GPTQ-for-LLaMA added 1 to zeros qweight_unpacked_int8 = unpack_int32_into_int8( qweight_int32.T).T.contiguous() - 8 @@ -873,148 +894,87 @@ def preprocess_groupwise_weight_params(weight_name, torch.quint4x2).view(torch.int8) # zeros = zeros * scales qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32) - zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * UINT4_TO_INT4_FLAG - - GPTQ_FLAG) * scales_fp16 + if not USE_UINT4_INPUT: + # Correcting UINT4 values back to INT4 order + mask_negative = qzeros_unpacked_int32[qzeros_unpacked_int32 < 0] + mask_positive = qzeros_unpacked_int32[qzeros_unpacked_int32 >= 0] + qzeros_unpacked_int32 = qzeros_unpacked_int32 + 16 * mask_negative - 16 * mask_positive + zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * USE_UINT4_INPUT - + USE_GPTQ_FOR_LLAMA) * scales_fp16 zeros_x_scales_fp16 = zeros_x_scales_fp16.half() # return processed interleaved weight, original scales and zeros * scales - return qweight_interleaved.contiguous(), scales_fp16.contiguous( - ), zeros_x_scales_fp16.contiguous() + mOp.qweight.value = qweight_interleaved.cpu().numpy() + mOp.scale.value = scales_fp16.cpu().numpy() + mOp.zero.value = zeros_x_scales_fp16.cpu().numpy() - layer_ids = [ - extract_layer_idx(key) for key in groupwise_qweight_safetensors.keys() - ] - layer_ids = [ - int(layer_idx) for layer_idx in layer_ids if layer_idx is not None - ] - num_hidden_layers = max(layer_ids) + 1 - num_kv_heads = tensorrt_llm_llama.num_kv_heads - mha_mode = (num_kv_heads == tensorrt_llm_llama.num_heads) - suffixs = ['qweight', 'qzeros', 'scales'] + # Load weights from GPTQ checkpoint into TRT-LLM module + # 1. vocab_embedding + v = load(gptq_key_list[0]) + if mapping.is_first_pp_rank(): + tensorrt_llm_llama.vocab_embedding.weight.value = v.to( + torch_dtype).cpu().numpy() + + # 2. lm_head + v = load(gptq_key_list[1], "no_prefix") + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.lm_head.weight.value = torch_split( + v, 0).to(torch_dtype).cpu().numpy() + + # 3. ln_f + v = load(gptq_key_list[2]) + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() + # 4. Weights inside each layer + num_hidden_layers = tensorrt_llm_llama.num_layers layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size layers_range = list( range(mapping.pp_rank * layers_per_pipeline_stage, (mapping.pp_rank + 1) * layers_per_pipeline_stage, 1)) for l in layers_range: - prefix = f'model.layers.{l}.self_attn.' - split_qkv_suf = [] - - for suf in suffixs: - q_part = model_params[prefix + 'q_proj.' + suf].cpu() - k_part = model_params[prefix + 'k_proj.' + suf].cpu() - v_part = model_params[prefix + 'v_proj.' + suf].cpu() - q_part = q_part.split(q_part.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - k_part = k_part.split(k_part.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - v_part = v_part.split(v_part.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - split_qkv = torch.cat([q_part, k_part, v_part], dim=1) - split_qkv_suf.append(split_qkv) - - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_qkv_suf[0], split_qkv_suf[1], split_qkv_suf[2]) - - idx = l - mapping.pp_rank * layers_per_pipeline_stage - tensorrt_llm_llama.layers[ - idx].attention.qkv.qweight.value = th_qweight.numpy() - tensorrt_llm_llama.layers[ - idx].attention.qkv.scale.value = th_zero.numpy() - tensorrt_llm_llama.layers[ - idx].attention.qkv.zero.value = th_scale.numpy() + layer_idx = l - mapping.pp_rank * layers_per_pipeline_stage + prefix = "layers" + split_sym + str(layer_idx) + split_sym + tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}') + layer = tensorrt_llm_llama.layers[layer_idx] - torch_dtype = str_dtype_to_torch(dtype) + # 4.1 attention.qkv + qkv_weight_list = [] + for suf in gptq_suffix_list: + qkv_list = [] + for comp in ["q", "k", "v"]: + comp_part = load(prefix + gptq_key_list[3] + comp + + gptq_key_list[4] + suf) + comp_part = torch_split(comp_part, 1) + qkv_list.append(comp_part) + qkv_weight_list.append(torch.cat(qkv_list, dim=1)) - for k, v in model_params.items(): - if isinstance(v, list): - v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v] - else: - v = torch_to_numpy(v.to(torch_dtype).detach().cpu()) - if 'model.embed_tokens.weight' in k: - if mapping.is_first_pp_rank(): - tensorrt_llm_llama.vocab_embedding.weight.value = v - elif 'model.norm.weight' in k: - if mapping.is_last_pp_rank(): - tensorrt_llm_llama.ln_f.weight.value = v - elif 'lm_head.weight' in k: - if mapping.is_last_pp_rank(): - tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray( - split(v, mapping.tp_size, mapping.tp_rank)) - else: - layer_idx = extract_layer_idx(k) - if layer_idx is None: - continue - idx = int(layer_idx) - if idx not in layers_range: - continue - idx = idx - mapping.pp_rank * layers_per_pipeline_stage + process_and_assign_weight(layer.attention.qkv, qkv_weight_list) - if 'input_layernorm.weight' in k: - tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v - elif 'post_attention_layernorm.weight' in k: - tensorrt_llm_llama.layers[idx].post_layernorm.weight.value = v - elif 'self_attn.o_proj.qweight' in k: - split_v_suf = [] - for suf in suffixs: - v = model_params[k[:-7] + suf].cpu() - split_v = v.split(v.shape[0] // mapping.tp_size, - dim=0)[mapping.tp_rank] - split_v_suf.append(split_v) - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_v_suf[0], split_v_suf[1], split_v_suf[2]) - tensorrt_llm_llama.layers[ - idx].attention.dense.qweight.value = th_qweight.numpy() - tensorrt_llm_llama.layers[ - idx].attention.dense.scale.value = th_zero.numpy() - tensorrt_llm_llama.layers[ - idx].attention.dense.zero.value = th_scale.numpy() - elif 'mlp.up_proj.qweight' in k: - split_v_suf = [] - for suf in suffixs: - v = model_params[k[:-7] + suf].cpu() - split_v = v.split(v.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - split_v_suf.append(split_v) - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_v_suf[0], split_v_suf[1], split_v_suf[2]) - tensorrt_llm_llama.layers[ - idx].mlp.gate.qweight.value = th_qweight.numpy() - tensorrt_llm_llama.layers[ - idx].mlp.gate.scale.value = th_zero.numpy() - tensorrt_llm_llama.layers[ - idx].mlp.gate.zero.value = th_scale.numpy() - elif 'mlp.down_proj.qweight' in k: - split_v_suf = [] - for suf in suffixs: - v = model_params[k[:-7] + suf].cpu() - split_v = v.split(v.shape[0] // mapping.tp_size, - dim=0)[mapping.tp_rank] - split_v_suf.append(split_v) - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_v_suf[0], split_v_suf[1], split_v_suf[2]) - tensorrt_llm_llama.layers[ - idx].mlp.proj.qweight.value = th_qweight.numpy() - tensorrt_llm_llama.layers[ - idx].mlp.proj.scale.value = th_zero.numpy() - tensorrt_llm_llama.layers[ - idx].mlp.proj.zero.value = th_scale.numpy() - elif 'mlp.gate_proj.qweight' in k: - split_v_suf = [] - for suf in suffixs: - v = model_params[k[:-7] + suf].cpu() - split_v = v.split(v.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - split_v_suf.append(split_v) - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_v_suf[0], split_v_suf[1], split_v_suf[2]) - tensorrt_llm_llama.layers[ - idx].mlp.fc.qweight.value = th_qweight.numpy() - tensorrt_llm_llama.layers[ - idx].mlp.fc.scale.value = th_zero.numpy() - tensorrt_llm_llama.layers[ - idx].mlp.fc.zero.value = th_scale.numpy() + # 4.2 attention.dense + v = [load(prefix + gptq_key_list[5] + suf) for suf in gptq_suffix_list] + process_and_assign_weight(layer.attention.dense, v, 0) + + # 4.3 mlp.gate + v = [load(prefix + gptq_key_list[6] + suf) for suf in gptq_suffix_list] + process_and_assign_weight(layer.mlp.gate, v, 1) + + # 4.4 mlp.proj + v = [load(prefix + gptq_key_list[7] + suf) for suf in gptq_suffix_list] + process_and_assign_weight(layer.mlp.proj, v, 0) + + # 4.5 mlp.fc + v = [load(prefix + gptq_key_list[8] + suf) for suf in gptq_suffix_list] + process_and_assign_weight(layer.mlp.fc, v, 1) + + # 4.6 input_layernorm + v = load(prefix + gptq_key_list[9]) + layer.input_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() + + # 4.7 post_layernorm + v = load(prefix + gptq_key_list[10]) + layer.post_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) @@ -1028,36 +988,74 @@ def load_from_awq_llama(tensorrt_llm_llama: LLaMAForCausalLM, dtype="float16", ft_model_dir=None): tensorrt_llm.logger.info( - 'Loading weights from groupwise AWQ LLaMA safetensors...') + 'Loading weights from groupwise AWQ LLaMA checkpoint...') tik = time.time() - if quant_ckpt_path.endswith(".safetensors"): - groupwise_qweight_safetensors = safe_open(quant_ckpt_path, - framework="pt", - device=0) - awq_llama = { - key: groupwise_qweight_safetensors.get_tensor(key) - for key in groupwise_qweight_safetensors.keys() - } - elif quant_ckpt_path.endswith(".pt"): - awq_llama = torch.load(quant_ckpt_path, - map_location=torch.device('cpu')) + if quant_ckpt_path.endswith(".pt"): + awq_llama = torch.load(quant_ckpt_path) + awq_prefix = "model." + awq_suffix_list = [ + ".weight", + ".weight_quantizer._amax", + ".input_quantizer._pre_quant_scale", + ] + awq_key_list = [ + "embed_tokens.weight", # vocab_embedding + "lm_head", # lm_head + "norm.weight", # ln_f + "self_attn.", # attention.qkv + "_proj", # qkv suffix + "self_attn.o_proj", # attention.dense + "mlp.up_proj", # mlp.gate + "mlp.down_proj", # mlp.proj + "mlp.gate_proj", # mlp.fc + "input_layernorm.weight", # input_layernorm + "post_attention_layernorm.weight", # post_layernorm + ] + split_sym = "." + + def load(key): + if "lm_head" in key: + v = awq_llama[key] + else: + v = awq_llama[awq_prefix + key] + return v + + group_size = load("layers.0.self_attn.o_proj.weight").numel() // load( + "layers.0.self_attn.o_proj.weight_quantizer._amax").numel() + elif quant_ckpt_path.endswith(".npz"): + awq_llama = np.load(quant_ckpt_path) + awq_prefix = "_np:" + awq_suffix_list = [ + ":weight", + ":weights_scaling_factor", + ":prequant_scaling_factor", + ] + awq_key_list = [ + "vocab_embedding:weight", # vocab_embedding + "lm_head", # lm_head + "final_layernorm:weight", # ln_f + "attention:qkv:", # attention.qkv + "", # qkv suffix + "attention:dense", # attention.dense + "mlp:gate", # mlp.gate + "mlp:proj", # mlp.proj + "mlp:fc", # mlp.fc + "input_layernorm:weight", # input_layernorm + "post_layernorm:weight", # post_layernorm + ] + split_sym = ":" + + def load(key): + v = torch.from_numpy(awq_llama[awq_prefix + key]) + if "weights_scaling_factor" in key: + v *= 7 # For AMMO *.npz checkpoints + return v + + group_size = load("layers:0:attention:dense:weight").numel() // load( + "layers:0:attention:dense:weights_scaling_factor").numel() else: - assert False, "Quantized checkpoint format not supported!" - - group_size = awq_llama["model.layers.0.self_attn.o_proj.weight"].numel( - ) // awq_llama[ - "model.layers.0.self_attn.o_proj.weight_quantizer._amax"].numel() - - awq_llama_block_names = [ - "input_layernorm.weight", - "post_attention_layernorm.weight", - ] - - tensorrt_llm_llama_block_names = [ - "input_layernorm.weight", - "post_layernorm.weight", - ] + assert False, "Unsupported AWQ quantized checkpoint format" quant_mode = getattr(tensorrt_llm_llama, 'quant_mode', QuantMode(0)) # Int8 KV cache @@ -1076,86 +1074,72 @@ def fromfile(dir_path, name, shape=None, dtype=None): return t return None + def torch_split(v, dim): + if v.shape[dim] % mapping.tp_size != 0: + tensorrt_llm.logger.error( + "Current weight shape is invalid for mapping.tp_size=" + + str(mapping.tp_size)) + assert False, "Invalid TP size" + return v.split(v.shape[dim] // mapping.tp_size, + dim=dim)[mapping.tp_rank] + def AWQ_quantize_pack_preprocess(weight, scale): - scale = scale.repeat_interleave(group_size, dim=0) - weight = weight / scale + weight /= scale.repeat_interleave(group_size, dim=0) qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7) - int4_weight = packer(qweight_int8.cpu()) - int4_weight = preprocessor(int4_weight, torch.quint4x2) + int4_weight = preprocessor(packer(qweight_int8.cpu()), torch.quint4x2) return int4_weight.view(torch.int8).cpu().numpy() - def process_and_assign_weight(awq_llama, mPrefix, mOp, tp_dim=0): - weight = awq_llama[mPrefix + ".weight"].T.contiguous() + def process_and_assign_weight(mOp, v, tp_dim=0): + weight = v[0].T.contiguous() [k, n] = weight.shape - weight = weight.split(weight.shape[tp_dim] // mapping.tp_size, - dim=tp_dim)[mapping.tp_rank] - amax = awq_llama[mPrefix + ".weight_quantizer._amax"].reshape( - (n, int(k / group_size))).T.contiguous() - amax = amax.split(amax.shape[tp_dim] // mapping.tp_size, - dim=tp_dim)[mapping.tp_rank] - pre_quant_scale = awq_llama[ - mPrefix + ".input_quantizer._pre_quant_scale"].reshape((1, k)) + weight = torch_split(weight, tp_dim) + amax = v[1].reshape((n, k // group_size)).T.contiguous() + amax = torch_split(amax, tp_dim) + pre_quant_scale = v[2].reshape((1, k)) if tp_dim == 0: - pre_quant_scale = pre_quant_scale.split(k // mapping.tp_size, - dim=1)[mapping.tp_rank] + pre_quant_scale = torch_split(pre_quant_scale, 1) scale = amax / 8.0 mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale) mOp.scale.value = scale.to(torch_dtype).cpu().numpy() mOp.pre_quant_scale.value = pre_quant_scale.to( torch_dtype).cpu().numpy() - def deSmooth(weight, pre_quant_scale): - [k, n] = weight.shape - pre_quant_scale = pre_quant_scale.repeat( - (n, 1)).transpose(1, 0).contiguous() - weight = weight * pre_quant_scale - return weight - - def reSmooth(weight, pre_quant_scale): - [k, n] = weight.shape - pre_quant_scale = pre_quant_scale.repeat( - (n, 1)).transpose(1, 0).contiguous() - weight = weight / pre_quant_scale - return weight - - def get_scale(weight): - weight = weight.T.contiguous() - [n, k] = weight.shape - weight = weight.reshape(n, int(k / group_size), group_size) - weight = torch.abs(weight.reshape(-1, group_size)) - amax, idx = weight.max(1) - amax = amax.reshape(n, int(k / group_size)).T.contiguous() - return amax / 8 - def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale): - weight = deSmooth(weight, pre_quant_scale) - weight = reSmooth(weight, avg_pre_quant_scale) - scale = get_scale(weight) + # deSmooth and reSmooth + [k, n] = weight.shape + if quant_ckpt_path.endswith("pt"): + # NPZ files are already re-smoothed + weight *= pre_quant_scale.repeat((n, 1)).transpose(1, + 0).contiguous() + weight /= avg_pre_quant_scale.repeat( + (n, 1)).transpose(1, 0).contiguous() + + # Get scale + weight_t = weight.T.contiguous() + weight_t = weight_t.reshape(n, k // group_size, group_size) + weight_t = torch.abs(weight_t.reshape(-1, group_size)) + amax, idx = weight_t.max(1) + amax = amax.reshape(n, k // group_size).T.contiguous() + scale = amax / 8 return weight, scale - def process_and_assign_qkv_weight(awq_llama, prefix, mOp): - q_weight = awq_llama[prefix + "self_attn.q_proj.weight"].T.contiguous() - k_weight = awq_llama[prefix + "self_attn.k_proj.weight"].T.contiguous() - v_weight = awq_llama[prefix + "self_attn.v_proj.weight"].T.contiguous() - k = q_weight.shape[0] - - q_weight = q_weight.split(q_weight.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - k_weight = k_weight.split(k_weight.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - v_weight = v_weight.split(v_weight.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - - q_pre_quant_scale = awq_llama[ - prefix + - "self_attn.q_proj.input_quantizer._pre_quant_scale"].reshape((1, k)) - k_pre_quant_scale = awq_llama[ - prefix + - "self_attn.k_proj.input_quantizer._pre_quant_scale"].reshape((1, k)) - v_pre_quant_scale = awq_llama[ - prefix + - "self_attn.v_proj.input_quantizer._pre_quant_scale"].reshape((1, k)) - + def process_and_assign_qkv_weight(prefix, mOp): + q_weight = load(prefix + "q" + awq_key_list[4] + + awq_suffix_list[0]).T.contiguous() + k_weight = load(prefix + "k" + awq_key_list[4] + + awq_suffix_list[0]).T.contiguous() + v_weight = load(prefix + "v" + awq_key_list[4] + + awq_suffix_list[0]).T.contiguous() + dim_k = q_weight.shape[0] + q_weight = torch_split(q_weight, 1) + k_weight = torch_split(k_weight, 1) + v_weight = torch_split(v_weight, 1) + q_pre_quant_scale = load(prefix + "q" + awq_key_list[4] + + awq_suffix_list[2]).reshape((1, dim_k)) + k_pre_quant_scale = load(prefix + "k" + awq_key_list[4] + + awq_suffix_list[2]).reshape((1, dim_k)) + v_pre_quant_scale = load(prefix + "v" + awq_key_list[4] + + awq_suffix_list[2]).reshape((1, dim_k)) qkv_pre_quant_scale = (q_pre_quant_scale + k_pre_quant_scale + v_pre_quant_scale) / 3.0 q_weight, q_scale = reSmooth_and_get_scale(q_weight, q_pre_quant_scale, @@ -1164,7 +1148,6 @@ def process_and_assign_qkv_weight(awq_llama, prefix, mOp): qkv_pre_quant_scale) v_weight, v_scale = reSmooth_and_get_scale(v_weight, v_pre_quant_scale, qkv_pre_quant_scale) - qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1) qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1) @@ -1173,68 +1156,72 @@ def process_and_assign_qkv_weight(awq_llama, prefix, mOp): mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale) mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy() - # Check if we need to pad vocab - v = awq_llama.get('model.embed_tokens.weight') - [vocab_size, k] = v.shape - pad_vocab = False - pad_vocab_size = vocab_size - if vocab_size % 64 != 0: - pad_vocab = True - pad_vocab_size = int((vocab_size + 63) / 64) * 64 - if pad_vocab: - new_v = torch.zeros([pad_vocab_size, k]) - new_v[:vocab_size, :] = v - v = new_v + # Load weights from AWQ checkpoint into TRT-LLM module + # 1. vocab_embedding + v = load(awq_key_list[0]) + # TRT-LLM requires vocab_size to be multiple of 64 for successful GEMM + if v[0].shape[0] % 64 != 0: + v = torch.nn.functional.pad(v, [0, 0, 0, 64 - v.shape[0] % 64]) if mapping.is_first_pp_rank(): tensorrt_llm_llama.vocab_embedding.weight.value = v.to( torch_dtype).cpu().numpy() - layer_ids = [extract_layer_idx(key) for key in awq_llama.keys()] - layer_ids = [ - int(layer_idx) for layer_idx in layer_ids if layer_idx is not None - ] + # 2. lm_head + v = [load(awq_key_list[1] + suf) for suf in awq_suffix_list] + if v[0].shape[0] % 64 != 0: + v[0] = torch.nn.functional.pad(v[0], [0, 0, 0, 64 - v[0].shape[0] % 64]) + v[1] = torch.nn.functional.pad(v[1], [0, 0, 0, 64 - v[1].shape[0] % 64], + value=1) + if mapping.is_last_pp_rank(): + process_and_assign_weight(tensorrt_llm_llama.lm_head, v, 1) + + # 3. ln_f + v = load(awq_key_list[2]) + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() - num_hidden_layers = max(layer_ids) + 1 + # 4. Weights inside each layer + num_hidden_layers = tensorrt_llm_llama.num_layers layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size layers_range = list( range(mapping.pp_rank * layers_per_pipeline_stage, (mapping.pp_rank + 1) * layers_per_pipeline_stage, 1)) - for layer_idx in layers_range: - prefix = "model.layers." + str(layer_idx) + "." + for l in layers_range: + layer_idx = l - mapping.pp_rank * layers_per_pipeline_stage + prefix = "layers" + split_sym + str(layer_idx) + split_sym tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}') - for idx, awq_attr in enumerate(awq_llama_block_names): - v = awq_llama[prefix + awq_attr] - layer = attrgetter(tensorrt_llm_llama_block_names[idx])( - tensorrt_llm_llama.layers[layer_idx]) - setattr(layer, 'value', v.to(torch_dtype).cpu().numpy()) - - # Attention QKV Linear - # concatenate the Q, K, V layers weights. - process_and_assign_qkv_weight( - awq_llama, prefix, - tensorrt_llm_llama.layers[layer_idx].attention.qkv) - - # Attention Dense (out_proj) Linear - mPrefix = prefix + "self_attn.o_proj" - mOp = tensorrt_llm_llama.layers[layer_idx].attention.dense - process_and_assign_weight(awq_llama, mPrefix, mOp, 0) - - # MLP up_proj (mlp.gate) Linear - mPrefix = prefix + "mlp.up_proj" - mOp = tensorrt_llm_llama.layers[layer_idx].mlp.gate - process_and_assign_weight(awq_llama, mPrefix, mOp, 1) - - # MLP down_proj (mlp.proj) Linear - mPrefix = prefix + "mlp.down_proj" - mOp = tensorrt_llm_llama.layers[layer_idx].mlp.proj - process_and_assign_weight(awq_llama, mPrefix, mOp, 0) - - # MLP gate_proj (mlp.fc) Linear - mPrefix = prefix + "mlp.gate_proj" - mOp = tensorrt_llm_llama.layers[layer_idx].mlp.fc - process_and_assign_weight(awq_llama, mPrefix, mOp, 1) + layer = tensorrt_llm_llama.layers[layer_idx] + + # 4.1 attention.qkv + process_and_assign_qkv_weight(prefix + awq_key_list[3], + layer.attention.qkv) + + # 4.2 attention.dense + v = [load(prefix + awq_key_list[5] + suf) for suf in awq_suffix_list] + process_and_assign_weight(layer.attention.dense, v, 0) + + # 4.3 mlp.gate + v = [load(prefix + awq_key_list[6] + suf) for suf in awq_suffix_list] + process_and_assign_weight(layer.mlp.gate, v, 1) + + # 4.4 mlp.proj + v = [load(prefix + awq_key_list[7] + suf) for suf in awq_suffix_list] + process_and_assign_weight(layer.mlp.proj, v, 0) + # 4.5 mlp.fc + v = [load(prefix + awq_key_list[8] + suf) for suf in awq_suffix_list] + process_and_assign_weight(layer.mlp.fc, v, 1) + + # 4.6 input_layernorm + v = load(prefix + awq_key_list[9]) + layer.input_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() + + # 4.7 post_layernorm + v = load(prefix + awq_key_list[10]) + layer.post_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() + + # 4.8 attention.kv_quant_orig_scale / kv_quant_orig_scale if use_int8_kv_cache: assert ft_model_dir, "You must pass --ft_model_dir to tell TRT-LLM where to look for scales of INT8 kv cache." t = fromfile( @@ -1242,40 +1229,8 @@ def process_and_assign_qkv_weight(awq_llama, prefix, mOp): '.attention.query_key_value.scale_y_quant_orig.bin', [1], np.float32) assert t is not None, f"{ft_model_dir} does not contain model.layers.{layer_idx}.attention.query_key_value.scale_y_quant_orig.bin" - tensorrt_llm_llama.layers[ - layer_idx].attention.kv_orig_quant_scale.value = 1.0 / t - tensorrt_llm_llama.layers[ - layer_idx].attention.kv_quant_orig_scale.value = t - - v = awq_llama['model.norm.weight'] - if mapping.is_last_pp_rank(): - tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() - - #lm_head - if pad_vocab: - weight = awq_llama['lm_head.weight'] - [vocab_size, k] = weight.shape - new_weight = torch.zeros([pad_vocab_size, k]) - new_weight[:vocab_size, :] = weight - new_weight = new_weight.T.contiguous() - amax = awq_llama['lm_head.weight_quantizer._amax'].reshape( - [vocab_size, k // group_size]) - new_amax = torch.ones([pad_vocab_size, k // group_size]) - new_amax[:vocab_size, :] = amax - new_amax = new_amax.T.contiguous() - new_scale = new_amax / 8 - tensorrt_llm_llama.lm_head.qweight.value = AWQ_quantize_pack_preprocess( - new_weight, new_scale) - tensorrt_llm_llama.lm_head.scale.value = new_scale.to( - torch_dtype).cpu().numpy() - tensorrt_llm_llama.lm_head.pre_quant_scale.value = awq_llama[ - 'lm_head.input_quantizer._pre_quant_scale'].to( - torch_dtype).cpu().numpy() - else: - mPrefix = "lm_head" - mOp = tensorrt_llm_llama.lm_head - if mapping.is_last_pp_rank(): - process_and_assign_weight(awq_llama, mPrefix, mOp, 1) + layer.attention.kv_orig_quant_scale.value = 1.0 / t + layer.attention.kv_quant_orig_scale.value = t tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) diff --git a/examples/mpt/build.py b/examples/mpt/build.py index 2f88cd320..f214f7aee 100644 --- a/examples/mpt/build.py +++ b/examples/mpt/build.py @@ -136,6 +136,14 @@ def parse_arguments(args): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument('--builder_opt', type=int, default=None) parser.add_argument( @@ -478,6 +486,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.paged_kv_cache: diff --git a/examples/opt/build.py b/examples/opt/build.py index 7f974800d..809361798 100644 --- a/examples/opt/build.py +++ b/examples/opt/build.py @@ -108,6 +108,14 @@ def parse_arguments(): parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') + parser.add_argument( + '--multi_block_mode', + default=False, + action='store_true', + help= + 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ + It is beneifical when batchxnum_heads cannot fully utilize GPU.' + ) parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument( '--output_dir', @@ -184,6 +192,13 @@ def parse_arguments(): help= "Activates the lookup plugin which enables embedding sharing. It is also required for language modeling embedding weight sharing." ) + parser.add_argument( + '--strongly_typed', + default=False, + action="store_true", + help= + 'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.' + ) args = parser.parse_args() if args.use_weight_only: args.quant_mode = QuantMode.use_weight_only( @@ -282,6 +297,8 @@ def build_rank_engine(builder: Builder, if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) + if args.multi_block_mode: + network.plugin_config.enable_mmha_multi_block_mode() if args.use_weight_only: assert (args.dtype == 'float16') network.plugin_config.set_weight_only_quant_matmul_plugin( @@ -348,7 +365,8 @@ def build(rank, args): max_output_len=args.max_output_len, use_prompt_tuning=args.max_prompt_embedding_table_size > 0, int8=(args.quant_mode.has_act_or_weight_quant() - or args.quant_mode.has_int8_kv_cache())) + or args.quant_mode.has_int8_kv_cache()), + strongly_typed=args.strongly_typed) engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size, cur_rank) diff --git a/examples/opt/summarize.py b/examples/opt/summarize.py index fb21bda77..6f81a269d 100644 --- a/examples/opt/summarize.py +++ b/examples/opt/summarize.py @@ -166,7 +166,8 @@ def summarize_tensorrt_llm(datapoint): tensorrt_llm_gpt.setup(batch_size, max_context_length=max_length, max_new_tokens=output_len, - beam_width=num_beams) + beam_width=num_beams, + max_kv_cache_length=args.max_kv_cache_len) if tensorrt_llm_gpt.remove_input_padding: output_ids = tensorrt_llm_gpt.decode_batch( @@ -358,6 +359,12 @@ def summarize_hf(datapoint): parser.add_argument('--engine_dir', type=str, default='gpt_outputs') parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_ite', type=int, default=20) + parser.add_argument('--max_kv_cache_len', + type=int, + default=None, + help='The max kv cache length. \ + If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ + If it is set to None, we will use the max sequence length.') parser.add_argument('--check_accuracy', action='store_true') parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index d2af07edd..a93098d42 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -160,6 +160,7 @@ def main(build_type: str = "Release", if cpp_only: assert not install, "Installing is not supported for cpp_only builds" + assert not python_bindings, "Python bindings are not supported for cpp_only builds" return pkg_dir = project_dir / "tensorrt_llm" diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index b5f0ae419..3540466e7 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -34,6 +34,12 @@ def torch_to_numpy(x: torch.Tensor): return x.view(torch.int16).cpu().numpy().view(np_bfloat16) +def numpy_to_torch(x): + if x.dtype != np_bfloat16: + return torch.tensor(x) + return torch.tensor(x.view(np.int16)).view(torch.bfloat16) + + fp32_array = partial(np.array, dtype=np.float32) fp16_array = partial(np.array, dtype=np.float16) int32_array = partial(np.array, dtype=np.int32) diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 23dd159d4..0779639f9 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3051,6 +3051,7 @@ def gpt_attention( past_key_value: Tensor, sequence_length: Tensor, host_past_key_value_lengths: Tensor, + host_max_kv_cache_lengths: Tensor, context_lengths: Tensor, cache_indirection: Tensor, host_request_types: Tensor, @@ -3065,7 +3066,6 @@ def gpt_attention( rotary_embedding_max_positions: int = 1024, position_embedding_type: PositionEmbeddingType = PositionEmbeddingType. learned_absolute, - multi_block_mode: bool = False, kv_orig_quant_scale: Tensor = None, kv_quant_orig_scale: Tensor = None, kv_cache_quant_mode: QuantMode = None, @@ -3114,6 +3114,11 @@ def gpt_attention( host past_key_value_length: Tensor An INT32 tensor of shape [batch_size]. + host max_kv_cache_lengths: Tensor + An INT32 tensor of shape [1]. + by default, the max_kv_cache_length is determined by the shape of cache_indir_table. + And we support flexible max_kv_cache_length (or max_past_length) for each layer. + context_lengths: Tensor The tensor that stores the context-phase sequence length of each request. Its shape is [batch_size]. See QKV Input in doc/functional.py, @@ -3171,10 +3176,6 @@ def gpt_attention( * PositionEmbeddingType.alibi * PositionEmbeddingType.alibi_with_scale - multi_block_mode: bool - Do we enable multi-block for the masked MHA. See Generation Phase - in docs/gpt_attention.md, - kv_orig_quant_scale: Tensor The tensor to store the scaling factor for quantization to INT8/FP8 in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache in @@ -3252,6 +3253,7 @@ def gpt_attention( assert host_context_lengths is not None or not default_net( ).plugin_config.remove_input_padding assert isinstance(max_context_length, int) + assert host_max_kv_cache_lengths is not None paged_kv_cache_flag = default_net().plugin_config.paged_kv_cache @@ -3308,8 +3310,9 @@ def gpt_attention( np.int32), trt.PluginFieldType.INT32) multi_block_mode = trt.PluginField( - "multi_block_mode", np.array(np.int8(multi_block_mode), dtype=np.int8), - trt.PluginFieldType.INT8) + "multi_block_mode", + np.array(np.int8(default_net().plugin_config.multi_block_mode), + dtype=np.int8), trt.PluginFieldType.INT8) tp_size = trt.PluginField("tp_size", np.array(tp_size, dtype=np.int32), trt.PluginFieldType.INT32) tp_rank = trt.PluginField("tp_rank", np.array(tp_rank, dtype=np.int32), @@ -3360,6 +3363,7 @@ def gpt_attention( tensor, sequence_length, host_past_key_value_lengths, + host_max_kv_cache_lengths, context_lengths, cache_indirection, host_request_types, @@ -3404,7 +3408,7 @@ def gpt_attention( if kv_cache_quant_mode.has_int8_kv_cache() and not paged_kv_cache_flag: # past key value - layer.get_input(6).set_dynamic_range(-127, 127) + layer.get_input(7).set_dynamic_range(-127, 127) # present key value layer.get_output(1).set_dynamic_range(-127, 127) diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index c7cddac87..7745a8b24 100644 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -19,7 +19,7 @@ import tensorrt as trt from .._common import default_net, precision -from .._utils import numpy_fp32_to_bf16 +from .._utils import numpy_fp32_to_bf16, trt_dtype_to_np from ..functional import (AttentionMaskType, PositionEmbeddingType, RotaryScalingType, Tensor, bert_attention, cast, clip, concat, constant, embedding, expand_dims, expand_mask, @@ -36,17 +36,15 @@ class RopeEmbeddingUtils: @staticmethod - def create_sinusoidal_positions( - num_pos: int, - dim: int, - theta: float = 10000.0, - ): - inv_freq = 1.0 / (theta**(np.arange(0, dim, 2) / dim)).astype( - np.float32) + def create_sinusoidal_positions(num_pos: int, + dim: int, + theta: float = 10000.0, + dtype=np.float32): + inv_freq = 1.0 / (theta**(np.arange(0, dim, 2) / dim)).astype(dtype) sinusoid_inp = np.einsum("i , j -> i j", - np.arange(num_pos, dtype=np.float32), + np.arange(num_pos, dtype=dtype), inv_freq, - dtype=np.float32) + dtype=dtype) concat = np.concatenate((np.sin(sinusoid_inp), np.cos(sinusoid_inp)), axis=1) return np.expand_dims(concat, axis=0).astype(np.float32) @@ -64,7 +62,9 @@ def rotate_every_two(tensor: Tensor) -> Tensor: x2 = slice(tensor, [0, 0, 0, 1], shape_tensor, [1, 1, 1, 2]) x1 = expand_dims(x1, 4) x2 = expand_dims(x2, 4) - zero = constant(np.ascontiguousarray(np.zeros([1], dtype=np.float32))) + zero = constant( + np.ascontiguousarray(np.zeros([1], + dtype=trt_dtype_to_np(x2.dtype)))) x2 = zero - x2 x = concat([x2, x1], 4) return view( @@ -86,7 +86,9 @@ def rotate_half(tensor: Tensor) -> Tensor: x1 = slice(tensor, [0, 0, 0, 0], shape_tensor, [1, 1, 1, 1]) x2 = slice(tensor, concat([0, 0, 0, last_dim]), shape_tensor, [1, 1, 1, 1]) - zero = constant(np.ascontiguousarray(np.zeros([1], dtype=np.float32))) + zero = constant( + np.ascontiguousarray(np.zeros([1], + dtype=trt_dtype_to_np(x2.dtype)))) x2 = zero - x2 x = concat([x2, x1], 3) return x @@ -273,11 +275,13 @@ class KeyValueCacheParams: def __init__(self, past_key_value: List[Tensor] = None, host_past_key_value_lengths: Tensor = None, + host_max_kv_cache_lengths: List[Tensor] = None, kv_cache_block_pointers: List[Tensor] = None, cache_indirection: Tensor = None, past_key_value_length: Tensor = None): self.past_key_value = past_key_value self.host_past_key_value_lengths = host_past_key_value_lengths + self.host_max_kv_cache_lengths = host_max_kv_cache_lengths self.kv_cache_block_pointers = kv_cache_block_pointers self.cache_indirection = cache_indirection # self.past_key_value_length = past_key_value_length @@ -292,10 +296,18 @@ def get_first_kv_cache_block_pointers(self): return None return self.kv_cache_block_pointers[0] + def fill_none_tensor_list(self, list_size): + if self.past_key_value is None: + self.past_key_value = tuple([None] * list_size) + if self.host_max_kv_cache_lengths is None: + self.host_max_kv_cache_lengths = tuple([None] * list_size) + def is_valid(self, gpt_attention_plugin): if gpt_attention_plugin: if self.host_past_key_value_lengths is None: return False + if self.host_max_kv_cache_lengths is None: + return False if self.cache_indirection is None: return False @@ -323,7 +335,6 @@ def __init__( tp_group=None, tp_size=1, tp_rank=0, - multi_block_mode=False, quant_mode: QuantMode = QuantMode(0), q_scaling=1.0, cross_attention=False, @@ -365,7 +376,6 @@ def __init__( # - True, inv_sqrt_Dh * Q*K^T + inv_sqrt_Dh * alibi_bias self.scale_alibi_bias = position_embedding_type == PositionEmbeddingType.alibi_with_scale self.position_embedding_type = position_embedding_type - self.multi_block_mode = multi_block_mode self.relative_attention = relative_attention self.max_distance = max_distance self.rotary_embedding_base = rotary_embedding_base @@ -387,7 +397,9 @@ def __init__( rotary_embedding_percentage) self.rotary_enabled = True self.embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( - self.max_position_embeddings, self.rotary_embedding_dim) + self.max_position_embeddings, + self.rotary_embedding_dim, + ) self.quant_mode = quant_mode if use_int8_kv_cache: @@ -529,6 +541,8 @@ def forward( sequence_length=attention_params.sequence_length, host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=kv_cache_params. + host_max_kv_cache_lengths, context_lengths=attention_params.context_lengths, cache_indirection=kv_cache_params.cache_indirection, host_request_types=attention_params.host_request_types, @@ -542,7 +556,6 @@ def forward( rotary_embedding_scale=self.rotary_embedding_scale, rotary_embedding_max_positions=self.max_position_embeddings, position_embedding_type=self.position_embedding_type, - multi_block_mode=self.multi_block_mode, kv_orig_quant_scale=kv_orig_quant_scale, kv_quant_orig_scale=kv_quant_orig_scale, kv_cache_quant_mode=self.quant_mode, @@ -613,6 +626,10 @@ def transpose_for_scores(x, else: embed_positions = constant(self.embed_positions) + if default_net().strongly_typed and (embed_positions.dtype != + value.dtype): + embed_positions = cast(embed_positions, value.dtype) + if self.rotary_embedding_dim is not None: # When shape(hidden_states, 1) > 1(Context phase), the embedding start from 0, # otherwise (Generation phase) move start to position @@ -808,6 +825,10 @@ def quantize_tensor(x, scale): attention_probs = softmax(attention_scores, dim=-1) + if default_net().strongly_typed and (attention_probs.dtype != + value.dtype): + attention_probs = cast(attention_probs, value.dtype) + context = matmul(attention_probs, value).permute([0, 2, 1, 3]) context = context.view( concat([shape(context, 0), diff --git a/tensorrt_llm/models/baichuan/model.py b/tensorrt_llm/models/baichuan/model.py index a9cf262e2..2bd7d1f17 100644 --- a/tensorrt_llm/models/baichuan/model.py +++ b/tensorrt_llm/models/baichuan/model.py @@ -164,15 +164,15 @@ def forward(self, hidden_states = self.vocab_embedding(input_ids) - if kv_cache_params.past_key_value is None: - kv_cache_params.past_key_value = tuple([None] * len(self.layers)) + kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] - for layer, past, pointer in zip( + for layer, past, pointer, max_kv_cache_length in zip( self.layers, kv_cache_params.past_key_value, - kv_cache_params.kv_cache_block_pointers): + kv_cache_params.kv_cache_block_pointers, + kv_cache_params.host_max_kv_cache_lengths): hidden_states = layer( hidden_states, use_cache=use_cache, @@ -181,6 +181,7 @@ def forward(self, past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, kv_cache_block_pointers=[pointer], cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params) @@ -208,6 +209,7 @@ def __init__(self, max_position_embeddings, position_embedding_type, dtype, + logits_dtype='float32', mlp_hidden_size=None, mapping=Mapping(), quant_mode=QuantMode(0)): @@ -232,6 +234,12 @@ def __init__(self, elif quant_mode.has_fp8_kv_cache(): self.kv_dtype = str_dtype_to_trt('fp8') + if isinstance(logits_dtype, str): + self._logits_dtype = str_dtype_to_trt(logits_dtype) + else: + assert isinstance(logits_dtype, trt.DataType) + self._logits_dtype = logits_dtype + self.quant_mode = quant_mode super().__init__(num_layers, num_heads, num_kv_heads, hidden_size, @@ -268,7 +276,7 @@ def forward(self, # [batch_size, hidden_size] -> [batch_size, vocab_size] lm_logits = self.lm_head(hidden_states) - lm_logits.mark_output('logits', self.dtype) + lm_logits.mark_output('logits', self._logits_dtype) if use_cache and default_net().plugin_config.paged_kv_cache == False: for i, present in enumerate(presents): @@ -324,6 +332,8 @@ def prepare_inputs(self, past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], kv_cache_block_pointers=model_inputs[ 'kv_cache_block_pointers_list'], cache_indirection=model_inputs['cache_indirection'], diff --git a/tensorrt_llm/models/bloom/model.py b/tensorrt_llm/models/bloom/model.py index 880c1c223..c12eed21d 100644 --- a/tensorrt_llm/models/bloom/model.py +++ b/tensorrt_llm/models/bloom/model.py @@ -185,13 +185,14 @@ def forward(self, hidden_states = self.embedding(input_ids) hidden_states = self.ln_embed(hidden_states) - if kv_cache_params.past_key_value is None: - kv_cache_params.past_key_value = tuple([None] * len(self.layers)) + kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] - for layer, past in zip(self.layers, kv_cache_params.past_key_value): + for layer, past, max_kv_cache_length in zip( + self.layers, kv_cache_params.past_key_value, + kv_cache_params.host_max_kv_cache_lengths): hidden_states = layer( hidden_states, use_cache=use_cache, @@ -200,6 +201,7 @@ def forward(self, past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params) @@ -347,6 +349,8 @@ def prepare_inputs(self, past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], cache_indirection=model_inputs['cache_indirection'], ), AttentionParams( diff --git a/tensorrt_llm/models/chatglm/model.py b/tensorrt_llm/models/chatglm/model.py index 5c0d18fd7..9eea959fc 100644 --- a/tensorrt_llm/models/chatglm/model.py +++ b/tensorrt_llm/models/chatglm/model.py @@ -70,7 +70,6 @@ def __init__(self, layer_id, args): tp_group=args.mapping.tp_group, tp_size=args.mapping.tp_size, tp_rank=args.mapping.rank, - multi_block_mode=args.multi_block_mode, quant_mode=args.quant_mode, q_scaling=1.0, cross_attention=False, @@ -193,20 +192,24 @@ def forward( hidden_states = self.embedding(input_ids) + kv_cache_params.fill_none_tensor_list(len(self.layers)) + if self.use_cache: presents = [] - for layer, past_key_value, kv_cache_block_pointers in zip( + for layer, past, pointer, max_kv_cache_length in zip( self.layers, kv_cache_params.past_key_value, - kv_cache_params.kv_cache_block_pointers): + kv_cache_params.kv_cache_block_pointers, + kv_cache_params.host_max_kv_cache_lengths): layer_output = layer( hidden_states, position_ids, kv_cache_params=KeyValueCacheParams( - past_key_value=[past_key_value], - kv_cache_block_pointers=[kv_cache_block_pointers], + past_key_value=[past], + kv_cache_block_pointers=[pointer], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, cache_indirection=kv_cache_params.cache_indirection, ), attention_params=attention_params, @@ -232,7 +235,6 @@ def __init__(self, **args): argNamespace.__setattr__(key, value) assert "model_version" in args.keys(), "model_version not set" # Other default values - argNamespace.multi_block_mode = False argNamespace.norm_epsilon = 1.0e-5 argNamespace.tokens_per_block = 64 argNamespace.use_cache = True @@ -273,6 +275,12 @@ def init(self, args): elif args.quant_mode.has_fp8_kv_cache(): self.kv_dtype = str_dtype_to_trt('fp8') + if isinstance(args.logits_dtype, str): + self._logits_dtype = str_dtype_to_trt(args.logits_dtype) + else: + assert isinstance(args.logits_dtype, trt.DataType) + self._logits_dtype = args.logits_dtype + self.hidden_size = args.hidden_size self.mapping = args.mapping self.max_num_tokens = args.max_output_len + args.max_input_len @@ -317,7 +325,7 @@ def forward( default_net().plugin_config.remove_input_padding) lm_logits = self.lm_head(hidden_states) - lm_logits.mark_output('logits', self.dtype) + lm_logits.mark_output('logits', self._logits_dtype) if self.use_cache and default_net( ).plugin_config.paged_kv_cache == False: @@ -373,6 +381,8 @@ def prepare_inputs( past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], kv_cache_block_pointers=model_inputs[ 'kv_cache_block_pointers_list'], cache_indirection=model_inputs['cache_indirection'], diff --git a/tensorrt_llm/models/falcon/model.py b/tensorrt_llm/models/falcon/model.py index 178a79ee4..d363cbf1f 100644 --- a/tensorrt_llm/models/falcon/model.py +++ b/tensorrt_llm/models/falcon/model.py @@ -244,8 +244,7 @@ def forward(self, hidden_states=None, all_reduce_workspace=None): - if kv_cache_params.past_key_value is None: - kv_cache_params.past_key_value = tuple([None] * len(self.layers)) + kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] @@ -255,9 +254,10 @@ def forward(self, else: hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) - for layer, past, pointer in zip( + for layer, past, pointer, max_kv_cache_length in zip( self.layers, kv_cache_params.past_key_value, - kv_cache_params.kv_cache_block_pointers): + kv_cache_params.kv_cache_block_pointers, + kv_cache_params.host_max_kv_cache_lengths): hidden_states = layer( hidden_states, use_cache=use_cache, @@ -266,6 +266,7 @@ def forward(self, past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, kv_cache_block_pointers=[pointer], cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params, @@ -448,6 +449,8 @@ def prepare_inputs(self, past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], kv_cache_block_pointers=model_inputs[ 'kv_cache_block_pointers_list'], cache_indirection=model_inputs['cache_indirection']), diff --git a/tensorrt_llm/models/generation_mixin.py b/tensorrt_llm/models/generation_mixin.py index 61be1f3b6..06962b95b 100644 --- a/tensorrt_llm/models/generation_mixin.py +++ b/tensorrt_llm/models/generation_mixin.py @@ -318,6 +318,7 @@ def prepare_basic_inputs(self, context_lengths = None host_context_lengths = None host_past_key_value_lengths = None + host_max_kv_cache_lengths = None attention_mask = None cache_indirection = None host_request_types = None @@ -378,6 +379,19 @@ def prepare_basic_inputs(self, ]), ) + if use_gpt_attention_plugin: + host_max_kv_cache_lengths = [] + for i in range(num_layers): + host_kv_cache_length_tensor = Tensor( + name=f'host_max_kv_cache_length_{i}', + dtype=trt.int32, + shape=[1], + dim_range=OrderedDict([ + ('scalar', + [1, 1] if enable_two_optimization_profiles else [1]) + ])) + host_max_kv_cache_lengths.append(host_kv_cache_length_tensor) + cache_indirection = Tensor( name='cache_indirection', dtype=trt.int32, @@ -460,6 +474,7 @@ def prepare_basic_inputs(self, 'attention_mask': attention_mask, 'sequence_length': sequence_length, 'host_past_key_value_lengths': host_past_key_value_lengths, + 'host_max_kv_cache_lengths': host_max_kv_cache_lengths, 'past_key_value': past_key_value, 'last_token_ids': last_token_ids, 'cache_indirection': cache_indirection, diff --git a/tensorrt_llm/models/gpt/model.py b/tensorrt_llm/models/gpt/model.py index 98580565a..be2678390 100644 --- a/tensorrt_llm/models/gpt/model.py +++ b/tensorrt_llm/models/gpt/model.py @@ -287,15 +287,15 @@ def forward(self, prompt_vocab_size, workspace=workspace) - if kv_cache_params.past_key_value is None: - kv_cache_params.past_key_value = tuple([None] * len(self.layers)) + kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] - for layer, past, pointer in zip( + for layer, past, pointer, max_kv_cache_length in zip( self.layers, kv_cache_params.past_key_value, - kv_cache_params.kv_cache_block_pointers): + kv_cache_params.kv_cache_block_pointers, + kv_cache_params.host_max_kv_cache_lengths): hidden_states = layer( hidden_states, use_cache=use_cache, @@ -304,6 +304,7 @@ def forward(self, past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, kv_cache_block_pointers=[pointer], cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params, @@ -491,6 +492,8 @@ def prepare_inputs(self, past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], kv_cache_block_pointers=model_inputs[ 'kv_cache_block_pointers_list'], cache_indirection=model_inputs['cache_indirection'], diff --git a/tensorrt_llm/models/gptj/model.py b/tensorrt_llm/models/gptj/model.py index 4fdeee9db..8c61d8a12 100644 --- a/tensorrt_llm/models/gptj/model.py +++ b/tensorrt_llm/models/gptj/model.py @@ -146,15 +146,15 @@ def forward(self, hidden_states = self.embedding(input_ids) - if kv_cache_params.past_key_value is None: - kv_cache_params.past_key_value = tuple([None] * len(self.layers)) + kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] - for layer, past, pointer in zip( + for layer, past, pointer, max_kv_cache_length in zip( self.layers, kv_cache_params.past_key_value, - kv_cache_params.kv_cache_block_pointers): + kv_cache_params.kv_cache_block_pointers, + kv_cache_params.host_max_kv_cache_lengths): hidden_states = layer( hidden_states, use_cache=use_cache, @@ -162,6 +162,7 @@ def forward(self, past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, kv_cache_block_pointers=[pointer], cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params) @@ -304,6 +305,8 @@ def prepare_inputs(self, past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], kv_cache_block_pointers=model_inputs[ 'kv_cache_block_pointers_list'], cache_indirection=model_inputs['cache_indirection'], diff --git a/tensorrt_llm/models/gptneox/model.py b/tensorrt_llm/models/gptneox/model.py index 0202cc59d..2c9592a3b 100644 --- a/tensorrt_llm/models/gptneox/model.py +++ b/tensorrt_llm/models/gptneox/model.py @@ -153,13 +153,14 @@ def forward(self, attention_params=None): hidden_states = self.embedding(input_ids) - if kv_cache_params.past_key_value is None: - kv_cache_params.past_key_value = tuple([None] * len(self.layers)) + kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] - for layer, past in zip(self.layers, kv_cache_params.past_key_value): + for layer, past, max_kv_cache_length in zip( + self.layers, kv_cache_params.past_key_value, + kv_cache_params.host_max_kv_cache_lengths): hidden_states = layer( hidden_states, use_cache=use_cache, @@ -167,6 +168,7 @@ def forward(self, past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params) @@ -297,6 +299,8 @@ def prepare_inputs(self, max_batch_size, max_input_len, max_new_tokens, past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], cache_indirection=model_inputs['cache_indirection'], ), AttentionParams( diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index 896ca77fa..548aa947c 100644 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -221,8 +221,7 @@ def forward( prompt_vocab_size: Optional[Tensor] = None, ): - if kv_cache_params.past_key_value is None: - tuple([None] * len(self.layers)) + kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] @@ -239,9 +238,10 @@ def forward( hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) self.register_network_output(f"embd", hidden_states) - for layer, past, pointer in zip( + for layer, past, pointer, max_kv_cache_length in zip( self.layers, kv_cache_params.past_key_value, - kv_cache_params.kv_cache_block_pointers): + kv_cache_params.kv_cache_block_pointers, + kv_cache_params.host_max_kv_cache_lengths): hidden_states = layer( hidden_states, use_cache=use_cache, @@ -250,6 +250,7 @@ def forward( past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, kv_cache_block_pointers=[pointer], cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params, @@ -449,6 +450,8 @@ def prepare_inputs( past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], kv_cache_block_pointers=model_inputs[ 'kv_cache_block_pointers_list'], cache_indirection=model_inputs['cache_indirection'], diff --git a/tensorrt_llm/models/opt/model.py b/tensorrt_llm/models/opt/model.py index f2469c972..08ebe208a 100644 --- a/tensorrt_llm/models/opt/model.py +++ b/tensorrt_llm/models/opt/model.py @@ -166,13 +166,14 @@ def forward(self, prompt_embedding_table, prompt_tasks, prompt_vocab_size) - if kv_cache_params.past_key_value is None: - kv_cache_params.past_key_value = tuple([None] * len(self.layers)) + kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] - for layer, past in zip(self.layers, kv_cache_params.past_key_value): + for layer, past, max_kv_cache_length in zip( + self.layers, kv_cache_params.past_key_value, + kv_cache_params.host_max_kv_cache_lengths): hidden_states = layer( hidden_states, use_cache=use_cache, @@ -181,6 +182,7 @@ def forward(self, past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=max_kv_cache_length, cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params) if use_cache: @@ -332,6 +334,8 @@ def prepare_inputs(self, past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], + host_max_kv_cache_lengths=model_inputs[ + 'host_max_kv_cache_lengths'], cache_indirection=model_inputs['cache_indirection'], ), AttentionParams( diff --git a/tensorrt_llm/models/quantized/ammo.py b/tensorrt_llm/models/quantized/ammo.py index cce07f427..fecdd34eb 100644 --- a/tensorrt_llm/models/quantized/ammo.py +++ b/tensorrt_llm/models/quantized/ammo.py @@ -107,16 +107,13 @@ def quantize_and_export(model: torch.nn.Module, if export_path: with torch.inference_mode(): - if qformat == "int4_awq": - torch.save(model.state_dict(), export_path) - else: - export_model_config( - model, - model_type, - torch.float16, - quantization=qformat, - export_dir=export_path, - inference_tensor_parallel=tensor_parallel_size, - ) + export_model_config( + model, + model_type, + torch.float16, + quantization=qformat, + export_dir=export_path, + inference_tensor_parallel=tensor_parallel_size, + ) logger.info(f"Quantized model exported to :{export_path}") return model diff --git a/tensorrt_llm/parameter.py b/tensorrt_llm/parameter.py index b14a2953f..f78bfa57a 100644 --- a/tensorrt_llm/parameter.py +++ b/tensorrt_llm/parameter.py @@ -46,18 +46,18 @@ def __init__(self, else: v_range = 0.1 - # value ~ U[-1, 1] if dtype == trt.DataType.INT8: - value = torch.randint(-128, - 128, (shape), + value = torch.randint(int(-128 * v_range), + int(128 * v_range), (shape), dtype=trt_dtype_to_torch(dtype), device='cuda') + # value ~ U[int(-128 * v_range), int(128 * v_range)] else: value = torch.randn( (shape), dtype=trt_dtype_to_torch(dtype), device='cuda') * 2 - 1 - # value ~ U[-v_range, v_range] - value = value * v_range + # value ~ N[-v_range, v_range] + value = value * v_range self._value = self._regularize_value(value) @property diff --git a/tensorrt_llm/plugin/plugin.py b/tensorrt_llm/plugin/plugin.py index 7f0048e52..91a1e9eab 100644 --- a/tensorrt_llm/plugin/plugin.py +++ b/tensorrt_llm/plugin/plugin.py @@ -59,6 +59,7 @@ def __init__(self) -> None: def init(self): self.bert_attention_plugin = False self.gpt_attention_plugin = False + self.multi_block_mode = False self.identity_plugin = False self.gemm_plugin = False self.smooth_quant_gemm_plugin = False @@ -111,6 +112,11 @@ def set_gpt_attention_plugin(self, dtype='float16'): self.gpt_attention_plugin = dtype return self + def enable_mmha_multi_block_mode(self): + self.multi_block_mode = True + logger.info(f"Generation Multi Block Mode Enabled") + return self + def set_bert_attention_plugin(self, dtype='float16'): self.bert_attention_plugin = dtype return self diff --git a/tensorrt_llm/profiler.py b/tensorrt_llm/profiler.py index 25c296f42..b9ff6176c 100644 --- a/tensorrt_llm/profiler.py +++ b/tensorrt_llm/profiler.py @@ -16,6 +16,8 @@ from functools import partial from typing import Literal, Optional, Union +import tensorrt as trt + try: import psutil except ImportError: @@ -24,8 +26,11 @@ import pynvml except ImportError: pynvml = None +import traceback + import torch +from tensorrt_llm.builder import _is_building from tensorrt_llm.logger import logger @@ -165,7 +170,7 @@ def device_memory_info( self, device: Optional[Union[torch.device, int]] = None, ) -> int: - index = torch._utils._get_device_index(device, optional=True) + index = device.index if isinstance(device, torch.device) else device if index not in self.device_handles: handle = pynvml.nvmlDeviceGetHandleByIndex(index) self.device_handles[index] = handle @@ -267,3 +272,53 @@ def print_memory_usage(tag: Optional[str] = None, def print_peak_memory_usage(unit: MemoryMonitor.UnitType = 'GiB'): if _default_memory_monitor is not None: _default_memory_monitor.print_peak_memory_usage(unit=unit) + + +@_is_building +def check_gpt_mem_usage(engine, kv_dtype, use_gpt_attention_plugin, + paged_kv_cache, max_batch_size, max_beam_width, + max_input_len, max_output_len, local_num_kv_heads, + head_size, num_layers): + # Get the amount of memory + runtime = trt.Runtime(logger.trt_logger) + activation_size = 0 + try: + cuda_engine = runtime.deserialize_cuda_engine(engine) + assert cuda_engine is not None + activation_size = cuda_engine.device_memory_size / 1024 / 1024 + del cuda_engine + except Exception: + logger.warning( + f'Exception when deserializing engine: {traceback.format_exc()}') + logger.warning(f'Activation memory size will be regarded as 0.') + logger.info(f'Activation memory size: {activation_size:.2f} MiB') + weights_size = engine.nbytes / 1024 / 1024 + logger.info(f'Weights memory size: {weights_size:.2f} MiB') + kv_cache_size = max_batch_size * max_beam_width * 2 * local_num_kv_heads * ( + max_input_len + + max_output_len) * (head_size) * num_layers * kv_dtype.itemsize + # without plugin, we need two set of kv cache buffers, + # one for inputs, and the other for outputs. + if not use_gpt_attention_plugin: + kv_cache_size *= 2 + kv_cache_size = kv_cache_size / 1024 / 1024 + logger.info(f'Max KV Cache memory size: {kv_cache_size:.2f} MiB') + est_memory_size = activation_size + weights_size + kv_cache_size + logger.info( + f'Estimated max memory usage on runtime: {est_memory_size:.2f} MiB') + _, _, total_mem = device_memory_info(torch.cuda.current_device()) + if est_memory_size > total_mem: + logger.warning( + f'Engine is successfully built, but GPU Memory ({total_mem:.2f} GB)', + ' may not be enough when inferencing on max shape.') + if paged_kv_cache: + logger.warning( + f'Since paged_kv_cache is enabled, the max KV Cache ', + 'memory size is a estimate for very extreme cases, ', + 'it\'s possible that most cases won\'t meet OOM.') + else: + logger.warning( + f'Enabling `--paged_kv_cache` could help reduce the ', + 'GPU memory usage on runtime.') + + return est_memory_size diff --git a/tensorrt_llm/quantization/layers.py b/tensorrt_llm/quantization/layers.py index d482aa986..1a34790be 100644 --- a/tensorrt_llm/quantization/layers.py +++ b/tensorrt_llm/quantization/layers.py @@ -634,6 +634,9 @@ def __init__(self, def forward(self, hidden_states, workspace=None): inter = self.fc(hidden_states) inter = ACT2FN[self.hidden_act](inter) + if default_net( + ).strongly_typed and inter.dtype != self.proj.smoother.value: + inter = cast(inter, self.proj.smoother.value) inter = inter / self.proj.smoother.value if self.quant_mode.has_act_and_weight_quant(): if self.quant_mode.has_act_static_scaling(): @@ -966,6 +969,9 @@ def forward(self, hidden_states, workspace=None): inter = ACT2FN[self.hidden_act](inter) gate = self.gate(hidden_states) inter_x_gate = inter * gate + if default_net( + ).strongly_typed and inter_x_gate.dtype != self.proj.smoother.value.dtype: + inter_x_gate = cast(inter_x_gate, self.proj.smoother.value.dtype) inter_x_gate = inter_x_gate / self.proj.smoother.value if self.quant_mode.has_act_and_weight_quant(): if self.quant_mode.has_act_static_scaling(): @@ -997,7 +1003,6 @@ def __init__(self, tp_group=None, tp_size=1, tp_rank=0, - multi_block_mode=False, scale_alibi_bias=False, paged_kv_cache=False, quant_mode=QuantMode(0)): @@ -1027,7 +1032,6 @@ def __init__(self, self.scale_alibi_bias = scale_alibi_bias self.position_embedding_type = position_embedding_type - self.multi_block_mode = multi_block_mode self.paged_kv_cache = paged_kv_cache self.rotary_embedding_dim = 0 @@ -1087,8 +1091,6 @@ def forward(self, qkv = self.qkv(hidden_states) else: raise ValueError("smooth_quant_gemm_plugin is not set") - if not default_net().plugin_config.gpt_attention_plugin: - raise ValueError("gpt_attention_plugin is not set") alibi_slopes = None if self.position_embedding_type == PositionEmbeddingType.alibi: @@ -1125,6 +1127,8 @@ def forward(self, sequence_length=attention_params.sequence_length, host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, + host_max_kv_cache_lengths=kv_cache_params. + host_max_kv_cache_lengths, context_lengths=attention_params.context_lengths, cache_indirection=kv_cache_params.cache_indirection, host_request_types=attention_params.host_request_types, @@ -1134,7 +1138,6 @@ def forward(self, q_scaling=self.q_scaling, rotary_embedding_dim=self.rotary_embedding_dim, position_embedding_type=self.position_embedding_type, - multi_block_mode=self.multi_block_mode, kv_orig_quant_scale=kv_quant_scale, kv_quant_orig_scale=kv_dequant_scale, kv_cache_quant_mode=self.quant_mode, @@ -1238,7 +1241,9 @@ def merge_caches(): if use_cache and self.quant_mode.has_int8_kv_cache(): past_key_value = quantize_tensor( past_key_value, self.kv_quantization_scale.value) - + if default_net( + ).strongly_typed and context.dtype != self.dense.smoother.value.dtype: + context = cast(context, self.dense.smoother.value.dtype) context = context / self.dense.smoother.value if self.quant_mode.has_act_and_weight_quant(): if self.quant_mode.has_act_static_scaling(): diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 6606afdc6..244ed7227 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -390,10 +390,11 @@ def __init__(self, if model_config.gpt_attention_plugin: expected_tensor_names += [ - 'sequence_length', - 'context_lengths', - 'host_request_types', - 'host_past_key_value_lengths', + 'sequence_length', 'context_lengths', 'host_request_types', + 'host_past_key_value_lengths' + ] + expected_tensor_names += [ + f'host_max_kv_cache_length_{i}' for i in range(self.num_layers) ] if model_config.remove_input_padding: expected_tensor_names.append('host_context_lengths') @@ -698,6 +699,7 @@ def __setup_decoder(self, input_ids: torch.Tensor, device=padded_input_ids.device)), axis=-1) + # Note: we still allocate max_seq_length size of parent ids (not max_kv_cache_length). self.parent_ids = torch.zeros( (batch_size, scfg.num_beams, self.max_seq_length), dtype=torch.int32, @@ -784,6 +786,7 @@ def setup(self, max_context_length: int, max_new_tokens: int, beam_width: int = 1, + max_kv_cache_length: Optional[int] = None, encoder_max_input_length: Optional[int] = None): # Store these params related to buffer size to check against # the input shape with the params given in decode() @@ -793,6 +796,50 @@ def setup(self, self.max_seq_length = max_context_length + max_new_tokens self.beam_width = beam_width self.encoder_max_input_length = encoder_max_input_length + if max_kv_cache_length is None: + self.max_kv_cache_length = self.max_seq_length + logger.info( + "The max_kv_cache_length is not set, we will use max_seq_length by default." + ) + self.host_max_kv_cache_lengths = [ + torch.ones((1, ), dtype=torch.int32) * self.max_kv_cache_length + for i in range(self.num_layers) + ] + elif isinstance(max_kv_cache_length, int): + if max_kv_cache_length > self.max_seq_length: + logger.warning( + "The value of max_kv_cache_length should ideally not exceed max_seq_length. " + "Therefore, it has been adjusted to match the value of max_seq_length." + ) + self.max_kv_cache_length = min(max_kv_cache_length, + self.max_seq_length) + self.host_max_kv_cache_lengths = [ + torch.ones((1, ), dtype=torch.int32) * self.max_kv_cache_length + for i in range(self.num_layers) + ] + elif isinstance(max_kv_cache_length, torch.Tensor): + self.max_kv_cache_length = int( + torch.max(max_kv_cache_length).item()) + if self.max_kv_cache_length > self.max_seq_length: + logger.warning( + "The value of max_kv_cache_length should ideally not exceed max_seq_length. " + "Therefore, it has been adjusted to match the value of max_seq_length." + ) + self.max_kv_cache_length = min(self.max_kv_cache_length, + self.max_seq_length) + if max_kv_cache_length.shape[0] != self.num_layers: + logger.error( + "max_kv_cache_length tensor's size is not equal to num_layers!" + ) + assert False + self.host_max_kv_cache_lengths = [ + torch.minimum( + max_kv_cache_length.to(torch.int32)[i], + torch.IntTensor([self.max_seq_length])) + for i in range(self.num_layers) + ] + else: + assert False, "invalid max_kv_cache_length!" self.buffer = {} if self.mapping.is_last_pp_rank(): @@ -810,7 +857,7 @@ def setup(self, if self.paged_kv_cache: blocks = batch_size * beam_width * math.ceil( - self.max_seq_length / self.tokens_per_block) + self.max_kv_cache_length / self.tokens_per_block) cache_shape = ( blocks, 2, @@ -823,7 +870,7 @@ def setup(self, batch_size, 2, self.num_heads_kv, - self.max_seq_length, + self.max_kv_cache_length, self.head_size, ) if self.cross_attention: @@ -1021,7 +1068,7 @@ def _get_context_shape_buffer(self, f'past_key_value_{idx}': key_value_cache, f'present_key_value_{idx}': - key_value_cache + key_value_cache, }) if self.cross_attention: cross_cache_shape = self.buffer[ @@ -1047,6 +1094,14 @@ def _get_context_shape_buffer(self, 'host_past_key_value_lengths': (batch_size, ), 'host_request_types': host_request_types.shape, }) + for idx in range(self.first_layer, self.last_layer): + ctx_shape.update({ + f'host_max_kv_cache_length_{idx}': (1, ), + }) + ctx_buffer.update({ + f'host_max_kv_cache_length_{idx}': + self.host_max_kv_cache_lengths[idx], + }) ctx_buffer.update({ 'sequence_length': self.sequence_length_buffer, @@ -1229,6 +1284,14 @@ def _get_next_step_shape_buffer(self, 'host_request_types': host_request_types.shape }) + for idx in range(self.first_layer, self.last_layer): + next_step_shape.update({ + f'host_max_kv_cache_length_{idx}': (1, ), + }) + next_step_buffer.update({ + f'host_max_kv_cache_length_{idx}': + self.host_max_kv_cache_lengths[idx], + }) next_step_buffer.update({ # Sequence lengths are not used in the context phase actually. 'sequence_length': @@ -1595,14 +1658,15 @@ def handle_per_step( (batch_size, beam_width, -1)).to(self.decoder_logits_dtype) decode_step = step + max_context_length should_stop = self.dynamic_decoder.forward( - next_token_logits, decode_step, max_context_length, ite, - batch_size, self.end_ids, self.embedding_bias_opt, - context_lengths, sequence_limit_lengths, stop_words_list, - bad_words_list, no_repeat_ngram_size, - this_src_cache_indirection, self.output_ids, - self.new_tokens, self.finished, self.sequence_length_buffer, - self.cum_log_probs, self.log_probs, self.parent_ids, - this_tgt_cache_indirection, self.beam_hyps_output_ids_tgt, + next_token_logits, decode_step, max_context_length, + self.max_kv_cache_length, ite, batch_size, self.end_ids, + self.embedding_bias_opt, context_lengths, + sequence_limit_lengths, stop_words_list, bad_words_list, + no_repeat_ngram_size, this_src_cache_indirection, + self.output_ids, self.new_tokens, self.finished, + self.sequence_length_buffer, self.cum_log_probs, + self.log_probs, self.parent_ids, this_tgt_cache_indirection, + self.beam_hyps_output_ids_tgt, self.beam_hyps_sequence_lengths_tgt, self.beam_hyps_cum_log_probs, self.beam_hyps_normed_scores, self.beam_hyps_log_probs, self.beam_hyps_min_normed_scores, @@ -1851,7 +1915,7 @@ def decode(self, torch.full(( batch_size, beam_width, - self.max_seq_length, + self.max_kv_cache_length, ), 0, dtype=torch.int32, @@ -1859,7 +1923,7 @@ def decode(self, torch.full(( batch_size, beam_width, - self.max_seq_length, + self.max_kv_cache_length, ), 0, dtype=torch.int32, @@ -1875,7 +1939,7 @@ def decode(self, # Init KV cache block manager if self.paged_kv_cache: - max_blocks_per_seq = math.ceil(self.max_seq_length / + max_blocks_per_seq = math.ceil(self.max_kv_cache_length / self.tokens_per_block) blocks = batch_size * beam_width * max_blocks_per_seq memory_pools = [ @@ -1885,6 +1949,7 @@ def decode(self, self.kv_cache_manager = KVCacheManager(memory_pools, blocks, self.tokens_per_block, max_blocks_per_seq, + self.max_kv_cache_length, beam_width) # Add sequences to the manager diff --git a/tensorrt_llm/runtime/kv_cache_manager.py b/tensorrt_llm/runtime/kv_cache_manager.py index 1a7d0bb5b..e1ace98f5 100644 --- a/tensorrt_llm/runtime/kv_cache_manager.py +++ b/tensorrt_llm/runtime/kv_cache_manager.py @@ -238,6 +238,7 @@ def __init__(self, blocks: int, tokens_per_block: int, max_blocks_per_seq: int, + max_kv_cache_len: int, beam_width: int = 1): self.blocks_manager = BlocksManager( @@ -247,6 +248,7 @@ def __init__(self, beam_width=beam_width) self.num_pools = len(memory_pools) self.tokens_per_block = tokens_per_block + self.max_kv_cache_len = max_kv_cache_len self.beam_width = beam_width self.lens = [] @@ -259,6 +261,9 @@ def step(self, finished: List[bool]): """ for seq in self.sequences: batch_idx = seq.get_batch_idx() + # Enable cyclic kv cache when it exceeds the max_kv_cache_len + if self.lens[batch_idx] == self.max_kv_cache_len: + continue if not finished[batch_idx] and self.lens[ batch_idx] % self.tokens_per_block == self.tokens_per_block - 1: self.blocks_manager.allocate(seq) @@ -285,7 +290,7 @@ def add_sequence(self, sequence: GenerationSequence, context_len: int): """ Add sequence to the manager and allocate minimum amount of blocks for context """ - self.lens.append(context_len) + self.lens.append(min(context_len, self.max_kv_cache_len)) self.sequences.append(sequence) # With beam_width > 1 we share context blocks between beams. @@ -298,7 +303,8 @@ def add_sequence(self, sequence: GenerationSequence, context_len: int): self.blocks_manager.allocate(sequence, share_across_beam=True) # Get one extra block for each beam. This is always one extra block # because we need space for context_len + 1 tokens. - self.blocks_manager.allocate(sequence, share_across_beam=False) + if context_len < self.max_kv_cache_len or context_len % self.tokens_per_block > 0: + self.blocks_manager.allocate(sequence, share_across_beam=False) def get_pointer_arrays(self, beam_width: int) -> List[torch.Tensor]: """ diff --git a/tensorrt_llm/runtime/session.py b/tensorrt_llm/runtime/session.py index 58659fd9a..3f9b2de53 100644 --- a/tensorrt_llm/runtime/session.py +++ b/tensorrt_llm/runtime/session.py @@ -16,7 +16,7 @@ import contextlib from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import tensorrt as trt @@ -144,9 +144,11 @@ def _print_io_info(self): logger.info( f"Tensor:{name=:}, {mode=:}, {shape=:}, {dtype=:}, {tformat=:}") - def infer_shapes(self, - inputs: List[TensorInfo], - context=None) -> List[TensorInfo]: + def infer_shapes( + self, + inputs: List[TensorInfo], + context: Optional[trt.IExecutionContext] = None + ) -> List[TensorInfo]: ''' @brief: Set input shapes to given context, and infer the output shapes from the given input shapes. This function should be called every time when the input shapes are changed before calling run(). @@ -163,7 +165,10 @@ def infer_shapes(self, raise ValueError(f"Tensor:{i.name} is not an input tensor") if self.engine.get_tensor_dtype(i.name) != i.dtype: raise ValueError(f"Tensor:{i.name} has wrong dtype") - context.set_input_shape(i.name, i.shape) + if not context.set_input_shape(i.name, i.shape): + raise RuntimeError( + f"Could not set shape {i.shape} for tensor {i.name}. Please check the profile range for which your model was build." + ) outputs = [] for i in range(self.engine.num_io_tensors): diff --git a/tests/attention/test_gpt_attention.py b/tests/attention/test_gpt_attention.py index a706e8ee0..961b97bb6 100644 --- a/tests/attention/test_gpt_attention.py +++ b/tests/attention/test_gpt_attention.py @@ -242,10 +242,11 @@ def test_gpt_attention(self, def _construct_execution( session, input_tensor, weight, bias, past_key_value, pointer_array, sequence_length, host_past_key_value_lengths, - context_lengths, host_context_lengths, cache_indirection, - host_request_types, num_heads, hidden_size, num_kv_heads, - output, dtype, max_context_length, shape_dict, - kv_int8_quant_scale, kv_int8_dequant_scale, configuration): + host_max_kv_cache_lengths, context_lengths, + host_context_lengths, cache_indirection, host_request_types, + num_heads, hidden_size, num_kv_heads, output, dtype, + max_context_length, shape_dict, kv_int8_quant_scale, + kv_int8_dequant_scale, configuration): head_size = hidden_size // num_heads # construct trt network builder = tensorrt_llm.Builder() @@ -256,6 +257,8 @@ def _construct_execution( net.plugin_config.enable_remove_input_padding() if paged_kv_cache: net.plugin_config.enable_paged_kv_cache(tokens_per_block) + if enable_multi_block_mmha: + net.plugin_config.enable_mmha_multi_block_mode() with tensorrt_llm.net_guard(net): x_tensor = Tensor(name='input', @@ -269,6 +272,10 @@ def _construct_execution( name='host_past_key_value_lengths', shape=tuple(host_past_key_value_lengths.shape), dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_max_kv_cache_lengths_tensor = Tensor( + name='host_max_kv_cache_lengths', + shape=tuple(host_max_kv_cache_lengths.shape), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) context_lengths_tensor = Tensor( name='context_lengths', shape=tuple(context_lengths.shape), @@ -366,6 +373,7 @@ def _construct_execution( sequence_length=sequence_length_tensor, host_past_key_value_lengths= host_past_key_value_lengths_tensor, + host_max_kv_cache_lengths=host_max_kv_cache_lengths_tensor, context_lengths=context_lengths_tensor, cache_indirection=cache_indirection_tensor, host_request_types=host_request_types_tensor, @@ -380,7 +388,6 @@ def _construct_execution( rotary_embedding_max_positions=configuration. max_position_embeddings, position_embedding_type=position_embedding_type, - multi_block_mode=enable_multi_block_mmha, kv_orig_quant_scale=kv_quant_scale_tensor, kv_quant_orig_scale=kv_dequant_scale_tensor, host_context_lengths=host_context_lengths_tensor, @@ -404,6 +411,7 @@ def _construct_execution( 'input': input_tensor, 'sequence_length': sequence_length, 'host_past_key_value_lengths': host_past_key_value_lengths, + 'host_max_kv_cache_lengths': host_max_kv_cache_lengths, 'context_lengths': context_lengths, 'cache_indirection': cache_indirection, 'host_request_types': host_request_types @@ -465,6 +473,7 @@ def _construct_execution( 'weight': (hidden_size, qkv_hidden_size), 'bias': (qkv_hidden_size, ), 'host_past_key_value_lengths': (batch_size, ), + 'host_max_kv_cache_lengths': (1, ), 'sequence_length': (batch_size, ), 'context_lengths': (batch_size, ), 'kv_quant_scale': (1, ), @@ -496,6 +505,7 @@ def _construct_execution( blocks, tokens_per_block, max_blocks_per_seq, + max_seq_len, beam_width=beam_width) # Add sequences to the manager @@ -807,6 +817,8 @@ def verify_kv_cache(torch_present): shape_dict['output'] = shape_dict['input'] host_past_key_value_lengths = torch.tensor([0] * batch_size, dtype=torch.int32) + host_max_kv_cache_lengths = torch.tensor([max_seq_len], + dtype=torch.int32) input_tensor = torch.randn(shape_dict['input'], dtype=str_dtype_to_torch(dtype), @@ -877,11 +889,11 @@ def verify_kv_cache(torch_present): session, output, present_key_value = _construct_execution( session, input_tensor, weight_plugin, bias_plugin, present_key_value, pointer_array, sequence_length, - host_past_key_value_lengths, input_lengths, - host_context_lengths, cache_indirection, host_request_types, - num_heads, hidden_size, num_kv_heads, output, dtype, - max_context_length, shape_dict, kv_quant_scale, - kv_dequant_scale, configuration) + host_past_key_value_lengths, host_max_kv_cache_lengths, + input_lengths, host_context_lengths, cache_indirection, + host_request_types, num_heads, hidden_size, num_kv_heads, + output, dtype, max_context_length, shape_dict, + kv_quant_scale, kv_dequant_scale, configuration) del session session = None @@ -905,6 +917,8 @@ def verify_kv_cache(torch_present): # Generation stage shape_dict['input'] = (batch_size, 1, hidden_size) host_past_key_value_lengths = sequence_length.cpu() - 1 + host_max_kv_cache_lengths = torch.tensor([max_seq_len], + dtype=torch.int32) input_tensor = torch.randn(shape_dict['input'], dtype=str_dtype_to_torch(dtype), device='cuda') * 1e-3 @@ -1022,11 +1036,11 @@ def tile_beam_width(tensor: torch.Tensor, num_beams: int): session, tiled_input_tensor, weight_plugin, bias_plugin, tiled_present_key_value, pointer_array, tiled_sequence_length, tiled_host_past_key_value_lengths, - tiled_input_lengths, tiled_host_context_lengths, - cache_indirection, tiled_host_request_types, num_heads, - hidden_size, num_kv_heads, tiled_output, dtype, - max_context_length, shape_dict, kv_quant_scale, - kv_dequant_scale, configuration) + host_max_kv_cache_lengths, tiled_input_lengths, + tiled_host_context_lengths, cache_indirection, + tiled_host_request_types, num_heads, hidden_size, + num_kv_heads, tiled_output, dtype, max_context_length, + shape_dict, kv_quant_scale, kv_dequant_scale, configuration) del session session = None diff --git a/tests/attention/test_gpt_attention_IFB.py b/tests/attention/test_gpt_attention_IFB.py index 55c324572..2d2679479 100644 --- a/tests/attention/test_gpt_attention_IFB.py +++ b/tests/attention/test_gpt_attention_IFB.py @@ -192,14 +192,13 @@ def test_gpt_attention_IFB(self, remove_input_padding = True - def _construct_execution(session, input_tensor, weight, bias, - pointer_array, sequence_length, - host_past_key_value_lengths, context_lengths, - max_context_length, cache_indirection, - num_heads, hidden_size, num_kv_heads, output, - dtype, kv_int8_quant_scale, - kv_int8_dequant_scale, host_context_lengths, - host_request_types): + def _construct_execution( + session, input_tensor, weight, bias, pointer_array, + sequence_length, host_past_key_value_lengths, + host_max_kv_cache_lengths, context_lengths, max_context_length, + cache_indirection, num_heads, hidden_size, num_kv_heads, output, + dtype, kv_int8_quant_scale, kv_int8_dequant_scale, + host_context_lengths, host_request_types): head_size = hidden_size // num_heads # construct trt network builder = tensorrt_llm.Builder() @@ -208,6 +207,8 @@ def _construct_execution(session, input_tensor, weight, bias, net.plugin_config.set_context_fmha(context_fmha_type) net.plugin_config.enable_remove_input_padding() net.plugin_config.enable_paged_kv_cache(tokens_per_block) + if enable_multi_block_mmha: + net.plugin_config.enable_mmha_multi_block_mode() with tensorrt_llm.net_guard(net): x_tensor = Tensor(name='input', @@ -221,6 +222,10 @@ def _construct_execution(session, input_tensor, weight, bias, name='host_past_key_value_lengths', shape=tuple(host_past_key_value_lengths.shape), dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_max_kv_cache_lengths_tensor = Tensor( + name='host_max_kv_cache_lengths', + shape=tuple(host_max_kv_cache_lengths.shape), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) input_lengths_tensor = Tensor( name='context_lengths', shape=tuple(context_lengths.shape), @@ -309,6 +314,7 @@ def _construct_execution(session, input_tensor, weight, bias, sequence_length=sequence_length_tensor, host_past_key_value_lengths= host_past_key_value_lengths_tensor, + host_max_kv_cache_lengths=host_max_kv_cache_lengths_tensor, context_lengths=input_lengths_tensor, cache_indirection=cache_indirection_tensor, host_request_types=host_request_types_tensor, @@ -323,7 +329,6 @@ def _construct_execution(session, input_tensor, weight, bias, rotary_embedding_max_positions=configuration. max_position_embeddings, position_embedding_type=position_embedding_type, - multi_block_mode=enable_multi_block_mmha, kv_orig_quant_scale=kv_int8_quant_scale_tensor, kv_quant_orig_scale=kv_int8_dequant_scale_tensor, kv_cache_quant_mode=QuantMode.from_description( @@ -341,6 +346,7 @@ def _construct_execution(session, input_tensor, weight, bias, 'input': input_tensor, 'sequence_length': sequence_length, 'host_past_key_value_lengths': host_past_key_value_lengths, + 'host_max_kv_cache_lengths': host_max_kv_cache_lengths, 'context_lengths': context_lengths, 'cache_indirection': cache_indirection, 'host_request_types': host_request_types, @@ -686,6 +692,7 @@ def torch_exec(step: int, blocks, tokens_per_block, max_blocks_per_seq, + max_seq_len, beam_width=beam_width) torch_cache_list = [None] * num_req @@ -759,6 +766,9 @@ def torch_exec(step: int, host_past_key_value_length_list, dtype=torch.int32, device='cpu').reshape(-1) + host_max_kv_cache_lengths = torch.tensor([max_seq_len], + dtype=torch.int32, + device='cpu') total_num_tokens = int(sum(host_input_lengths)) max_context_length = in_len context_lengths = host_context_lengths.cuda() @@ -799,10 +809,11 @@ def torch_exec(step: int, session, output = _construct_execution( session, input_tensor, weight_plugin, bias_plugin, dense_pointer_arrays, sequence_lengths, - host_past_key_value_lengths, context_lengths, - max_context_length, cache_indirection, num_heads, hidden_size, - num_kv_heads, output, dtype, kv_int8_quant_scale, - kv_int8_dequant_scale, host_context_lengths, host_request_types) + host_past_key_value_lengths, host_max_kv_cache_lengths, + context_lengths, max_context_length, cache_indirection, + num_heads, hidden_size, num_kv_heads, output, dtype, + kv_int8_quant_scale, kv_int8_dequant_scale, + host_context_lengths, host_request_types) del session session = None diff --git a/tests/model/test_bloom.py b/tests/model/test_bloom.py index 9907c9f6b..af1bba98b 100644 --- a/tests/model/test_bloom.py +++ b/tests/model/test_bloom.py @@ -209,6 +209,8 @@ def test_bloom(self, use_gpt_attention_plugin, context_fmha_type, dtype=torch.int32).cuda() ctx_host_past_key_value_lengths = torch.tensor([0] * batch_size, dtype=torch.int32) + host_max_kv_cache_lengths = torch.tensor([total_length], + dtype=torch.int32) cache_indirections = [ torch.full(( @@ -259,6 +261,7 @@ def test_bloom(self, use_gpt_attention_plugin, context_fmha_type, device='cuda') ctx_shape.update({ f'past_key_value_{i}': shape, + f'host_max_kv_cache_length_{i}': (1, ), }) shape = (batch_size, 2, bloom_config.n_head, seq_len, bloom_config.hidden_size // bloom_config.n_head) @@ -269,6 +272,8 @@ def test_bloom(self, use_gpt_attention_plugin, context_fmha_type, torch.zeros(shape, dtype=str_dtype_to_torch(dtype), device='cuda'), + f'host_max_kv_cache_length_{i}': + host_max_kv_cache_lengths, }) context = runtime.ctx_context @@ -305,6 +310,8 @@ def test_bloom(self, use_gpt_attention_plugin, context_fmha_type, gen_host_past_key_value_lengths = torch.tensor([seq_len + step - 1] * batch_size, dtype=torch.int32) + gen_host_max_kv_cache_lengths = torch.tensor([total_length], + dtype=torch.int32) step1_buffer = { 'input_ids': gen_id, 'context_lengths': gen_context_lengths.contiguous(), @@ -331,10 +338,13 @@ def test_bloom(self, use_gpt_attention_plugin, context_fmha_type, bloom_config.hidden_size // bloom_config.n_head) step1_shape.update({ f'past_key_value_{i}': shape, + f'host_max_kv_cache_length_{i}': (1, ), }) step1_buffer.update({ f'past_key_value_{i}': ctx_buffer[f'present_key_value_{i}'], + f'host_max_kv_cache_length_{i}': + host_max_kv_cache_lengths, }) context = runtime.context_1 diff --git a/tests/model/test_falcon.py b/tests/model/test_falcon.py index 5809a522a..a8a72a6a2 100644 --- a/tests/model/test_falcon.py +++ b/tests/model/test_falcon.py @@ -347,6 +347,8 @@ def test_falcon(self, query_type, use_alibi, parallel_attention, # past kv length: (length, is_context) host_past_key_value_lengths = torch.tensor([0] * batch_size, dtype=torch.int32) + host_max_kv_cache_lengths = torch.tensor([total_length], + dtype=torch.int32) ctx_buffer = { 'input_ids': ctx_input_ids.contiguous(), @@ -374,6 +376,7 @@ def test_falcon(self, query_type, use_alibi, parallel_attention, head_dim) for i in range(hf_config.num_hidden_layers): ctx_shape[f'past_key_value_{i}'] = past_kv_shape + ctx_shape[f'host_max_kv_cache_length_{i}'] = (1, ) ctx_buffer[f'present_key_value_{i}'] = torch.zeros( present_kv_shape, dtype=str_dtype_to_torch(kv_dtype), @@ -381,6 +384,8 @@ def test_falcon(self, query_type, use_alibi, parallel_attention, if use_gpt_attengion_plugin: ctx_buffer[f'past_key_value_{i}'] = ctx_buffer[ f'present_key_value_{i}'] + ctx_buffer[ + f'host_max_kv_cache_length_{i}'] = host_max_kv_cache_lengths else: ctx_buffer[f'past_key_value_{i}'] = torch.zeros( (1, ), dtype=str_dtype_to_torch(kv_dtype), device=device) @@ -450,6 +455,8 @@ def test_falcon(self, query_type, use_alibi, parallel_attention, if use_gpt_attengion_plugin: # gpt_attention_plugin shares past/present cache. step1_buffer[f'present_key_value_{i}'] = kv_cache + step1_buffer[ + f'host_max_kv_cache_length_{i}'] = host_max_kv_cache_lengths step1_shape = {k: v.shape for k, v in step1_buffer.items()} context = runtime.context_1 diff --git a/tests/model/test_gpt.py b/tests/model/test_gpt.py index 634132a59..cdefaef93 100644 --- a/tests/model/test_gpt.py +++ b/tests/model/test_gpt.py @@ -497,7 +497,8 @@ def test_gpt_plugin(self, use_refit, fast_building, blocks = batch_size * beam_width * max_blocks_per_seq kv_cache_manager = KVCacheManager(key_value_cache_buffers, blocks, tokens_per_block, - max_blocks_per_seq, beam_width) + max_blocks_per_seq, total_length, + beam_width) # Add sequences to the manager for bi in range(batch_size): @@ -513,6 +514,7 @@ def run_engine(context, last_token_ids, cache_indirection, host_past_key_value_lengths, + host_max_kv_cache_lengths, sequence_length=None, host_context_lengths=None): @@ -543,12 +545,16 @@ def run_engine(context, ctx_buffer[ f'kv_cache_block_pointers_{idx}'] = kv_cache_block_pointers[ idx].reshape(shape).contiguous() + ctx_buffer[ + f'host_max_kv_cache_length_{idx}'] = host_max_kv_cache_lengths else: for i in range(gpt_config.n_layer): ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[ i] ctx_buffer[ f'present_key_value_{i}'] = key_value_cache_buffers[i] + ctx_buffer[ + f'host_max_kv_cache_length_{i}'] = host_max_kv_cache_lengths ctx_shape = { key: buffer.shape @@ -599,6 +605,8 @@ def compare_context(run_ref_only=False): host_past_key_value_lengths = torch.tensor([0] * batch_size, dtype=torch.int32) + host_max_kv_cache_lengths = torch.tensor([total_length], + dtype=torch.int32) host_context_lengths = ctx_context_lengths.cpu( ) if enable_remove_input_padding else None @@ -617,6 +625,7 @@ def compare_context(run_ref_only=False): last_token_ids=ctx_last_token_ids, cache_indirection=cache_indirections[0], host_past_key_value_lengths=host_past_key_value_lengths, + host_max_kv_cache_lengths=host_max_kv_cache_lengths, sequence_length=sequence_length, host_context_lengths=host_context_lengths, host_request_types=host_request_types) @@ -667,6 +676,8 @@ def compare_generation(run_ref_only=False): host_past_key_value_lengths = torch.tensor([seq_len + step - 1] * batch_size, dtype=torch.int32) + host_max_kv_cache_lengths = torch.tensor([seq_len + step], + dtype=torch.int32) host_context_lengths = gen_context_lengths.cpu( ) if enable_remove_input_padding else None @@ -684,6 +695,7 @@ def compare_generation(run_ref_only=False): last_token_ids=gen_last_token_ids, cache_indirection=cache_indirections[1], host_past_key_value_lengths=host_past_key_value_lengths, + host_max_kv_cache_lengths=host_max_kv_cache_lengths, sequence_length=sequence_length, host_context_lengths=host_context_lengths, host_request_types=host_request_types) @@ -739,6 +751,9 @@ def compare_mixing_context_and_generation_phases(): [0] * num_context_input + [seq_len] * num_generation_input, dtype=torch.int32) + host_max_kv_cache_lengths = torch.tensor([total_length], + dtype=torch.int32) + context_lengths = torch.tensor([seq_len] * batch_size, dtype=torch.int32).cuda() if enable_remove_input_padding: @@ -761,6 +776,7 @@ def compare_mixing_context_and_generation_phases(): last_token_ids=gen_last_token_ids, cache_indirection=cache_indirections[0], host_past_key_value_lengths=host_past_key_value_lengths, + host_max_kv_cache_lengths=host_max_kv_cache_lengths, sequence_length=sequence_length, host_context_lengths=host_context_lengths, host_request_types=host_request_types, diff --git a/tests/model/test_gptj.py b/tests/model/test_gptj.py index 590f27600..21d51102b 100644 --- a/tests/model/test_gptj.py +++ b/tests/model/test_gptj.py @@ -223,6 +223,7 @@ def run_engine(context, last_token_ids, cache_indirection, host_past_key_value_lengths, + host_max_kv_cache_lengths, sequence_length, host_context_lengths=None): @@ -238,6 +239,8 @@ def run_engine(context, } for i in range(gpt_config.n_layer): ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] + ctx_buffer[ + f'host_max_kv_cache_length_{i}'] = host_max_kv_cache_lengths ctx_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[ i] @@ -312,6 +315,9 @@ def compare_context(): dtype=torch.int32).cpu() host_past_key_value_lengths = torch.tensor([0] * batch_size, dtype=torch.int32) + host_max_kv_cache_lengths = torch.tensor([total_seq_len], + dtype=torch.int32) + host_context_lengths = ctx_context_lengths.cpu( ) if enable_remove_input_padding else None @@ -323,6 +329,7 @@ def compare_context(): last_token_ids=ctx_last_token_ids, cache_indirection=cache_indirections[0], host_past_key_value_lengths=host_past_key_value_lengths, + host_max_kv_cache_lengths=host_max_kv_cache_lengths, sequence_length=sequence_length_buffer, host_context_lengths=host_context_lengths, host_request_types=host_request_types) @@ -387,6 +394,9 @@ def compare_generation(): host_past_key_value_lengths = torch.tensor([seq_len] * batch_size, dtype=torch.int32) + host_max_kv_cache_lengths = torch.tensor([total_seq_len], + dtype=torch.int32) + host_request_types = torch.tensor([1] * batch_size, dtype=torch.int32).cpu() host_context_lengths = gen_context_lengths.cpu( @@ -404,6 +414,7 @@ def compare_generation(): last_token_ids=gen_last_token_ids, cache_indirection=cache_indirections[1], host_past_key_value_lengths=host_past_key_value_lengths, + host_max_kv_cache_lengths=host_max_kv_cache_lengths, sequence_length=sequence_length_buffer, host_context_lengths=host_context_lengths, host_request_types=host_request_types) diff --git a/tests/model/test_gptneox.py b/tests/model/test_gptneox.py index 014c99119..ad4fc2e87 100644 --- a/tests/model/test_gptneox.py +++ b/tests/model/test_gptneox.py @@ -293,6 +293,9 @@ def test_gptneox_plugin(self, context_fmha_flag, ctx_shape[f'past_key_value_{i}'] = shape ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] ctx_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] + ctx_buffer[f'host_max_kv_cache_length_{i}'] = torch.tensor( + [total_seq_len], dtype=torch.int32) + ctx_shape[f'host_max_kv_cache_length_{i}'] = (1, ) sequence_length_buffer = torch.add(sequence_length_buffer, step) ctx_buffer['sequence_length'] = sequence_length_buffer ctx_shape['sequence_length'] = ctx_buffer['sequence_length'].shape @@ -384,11 +387,14 @@ def test_gptneox_plugin(self, context_fmha_flag, step1_shape = {k: v.shape for k, v in step1_buffer.items()} for i in range(gpt_config.num_hidden_layers): step1_shape[f'past_key_value_{i}'] = shape + step1_shape[f'host_max_kv_cache_length_{i}'] = (1, ) step1_shape['sequence_length'] = (batch_size, ) step1_shape['host_past_key_value_lengths'] = (batch_size, ) for i in range(gpt_config.num_hidden_layers): step1_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] step1_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] + step1_buffer[f'host_max_kv_cache_length_{i}'] = torch.tensor( + [total_seq_len], dtype=torch.int32) # For step 1, the sequence_lengths = context_lengths + 1. sequence_length_buffer = torch.add(sequence_length_buffer, step) step1_buffer['sequence_length'] = sequence_length_buffer diff --git a/tests/model/test_llama.py b/tests/model/test_llama.py index 8a1c05915..e4aee3ba8 100644 --- a/tests/model/test_llama.py +++ b/tests/model/test_llama.py @@ -318,6 +318,9 @@ def test_llama(self, use_refit, fast_building, context_fmha_flag, ctx_shape[f'past_key_value_{i}'] = kv_shape ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] ctx_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] + ctx_buffer[f'host_max_kv_cache_length_{i}'] = torch.tensor( + [max_seq_len], dtype=torch.int32) + ctx_shape[f'host_max_kv_cache_length_{i}'] = (1, ) ctx_buffer['sequence_length'] = sequence_length_buffer ctx_shape['sequence_length'] = ctx_buffer['sequence_length'].shape ctx_shape['host_past_key_value_lengths'] = (batch_size, ) @@ -374,11 +377,14 @@ def test_llama(self, use_refit, fast_building, context_fmha_flag, for i in range(llama_config.num_hidden_layers): step1_shape[f'past_key_value_{i}'] = kv_shape + step1_shape[f'host_max_kv_cache_length_{i}'] = (1, ) step1_shape['sequence_length'] = (batch_size, ) step1_shape['host_past_key_value_lengths'] = (batch_size, ) for i in range(llama_config.num_hidden_layers): step1_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] step1_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] + step1_buffer[f'host_max_kv_cache_length_{i}'] = torch.tensor( + [max_seq_len], dtype=torch.int32) step1_buffer[ 'host_past_key_value_lengths'] = sequence_length_buffer.cpu() sequence_length_buffer = torch.add(sequence_length_buffer, step) diff --git a/tests/model/test_mistral.py b/tests/model/test_mistral.py new file mode 100644 index 000000000..9899541d7 --- /dev/null +++ b/tests/model/test_mistral.py @@ -0,0 +1,577 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import sys +import tempfile +import unittest +from itertools import product +from pathlib import Path + +import numpy as np +import pytest +import torch +from parameterized import parameterized +from transformers import MistralConfig, MistralForCausalLM + +import tensorrt_llm +from tensorrt_llm import Builder +from tensorrt_llm._utils import str_dtype_to_trt +from tensorrt_llm.layers import PositionEmbeddingType +from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType + +sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) +from examples.llama.weight import load_from_hf_llama + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import getSMVersion + + +class TestMistral(unittest.TestCase): + EOS_TOKEN = 2 + PAD_TOKEN = 2 + + def _gen_tensorrt_llm_network(self, network, hf_mistral, + mistral_config: MistralConfig, batch_size, + beam_width, input_len, output_len, dtype, + rank, tensor_parallel): + list(range(tensor_parallel)) + + with net_guard(network): + kv_dtype = str_dtype_to_trt(dtype) + + # Initialize model + tensorrt_llm_mistral = tensorrt_llm.models.LLaMAForCausalLM( + num_layers=mistral_config.num_hidden_layers, + num_heads=mistral_config.num_attention_heads, + num_kv_heads=mistral_config.num_key_value_heads, + hidden_size=mistral_config.hidden_size, + vocab_size=mistral_config.vocab_size, + hidden_act=mistral_config.hidden_act, + max_position_embeddings=mistral_config.max_position_embeddings, + dtype=kv_dtype, + mlp_hidden_size=mistral_config.intermediate_size, + position_embedding_type=PositionEmbeddingType.rope_gpt_neox, + mapping=tensorrt_llm.Mapping(world_size=tensor_parallel, + tp_size=tensor_parallel), + ) + load_from_hf_llama(tensorrt_llm_mistral, + hf_mistral, + dtype=dtype, + mapping=tensorrt_llm.Mapping( + world_size=tensor_parallel, + rank=rank, + tp_size=tensor_parallel)) + # Prepare + network.set_named_parameters( + tensorrt_llm_mistral.named_parameters()) + inputs = tensorrt_llm_mistral.prepare_inputs( + batch_size, input_len, output_len, True, beam_width) + # Forward + tensorrt_llm_mistral(*inputs) + + return network + + def _gen_tensorrt_llm_engine(self, + dtype, + rank, + world_size, + llama_config, + hf_llama, + model_name, + use_plugin, + batch_size, + beam_width, + input_len, + output_len, + use_refit, + fast_building=False, + context_fmha_flag=ContextFMHAType.disabled, + enable_remove_input_padding=False): + + builder = Builder() + + with tempfile.TemporaryDirectory() as tmpdirname: + network = builder.create_network() + if use_plugin: + network.plugin_config.set_gpt_attention_plugin(dtype) + if fast_building: + network.plugin_config.set_gemm_plugin(dtype) + if enable_remove_input_padding: + network.plugin_config.enable_remove_input_padding() + network.plugin_config.set_context_fmha(context_fmha_flag) + + self._gen_tensorrt_llm_network(network, hf_llama, llama_config, + batch_size, beam_width, input_len, + output_len, dtype, rank, world_size) + + builder_config = builder.create_builder_config( + name=model_name, + precision=dtype, + timing_cache='model.cache', + tensor_parallel=world_size, # TP only + use_refit=use_refit, + ) + engine_buffer = builder.build_engine(network, builder_config) + return engine_buffer + + def _gen_tensorrt_llm_runtime(self, + log_level, + dtype, + world_size, + rank, + llama_config, + hf_llama, + model_name, + use_plugin, + batch_size, + beam_width, + input_len, + output_len, + use_refit, + fast_building=False, + context_fmha_flag=ContextFMHAType.disabled, + enable_remove_input_padding=False): + tensorrt_llm.logger.set_level(log_level) + mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size) + engine_buffer = self._gen_tensorrt_llm_engine( + dtype, rank, world_size, llama_config, hf_llama, model_name, + use_plugin, batch_size, beam_width, input_len, output_len, + use_refit, fast_building, context_fmha_flag, + enable_remove_input_padding) + runtime = tensorrt_llm.runtime.generation._Runtime( + engine_buffer, mapping) + return runtime, engine_buffer + + def load_test_cases(): + test_cases = list( + product([False], [False, True], + [ContextFMHAType.disabled, ContextFMHAType.enabled], + [False, True], ['float16', 'bfloat16'], [2, 4])) + return test_cases + + def custom_name_func(testcase_func, param_num, param): + return "%s_%s" % ( + testcase_func.__name__, + parameterized.to_safe_name("_".join(str(x) for x in param.args)), + ) + + @parameterized.expand(load_test_cases, name_func=custom_name_func) + def test_mistral(self, use_refit, fast_building, context_fmha_flag, + enable_remove_input_padding, dtype, num_kv_heads): + + # Skip tests that are not supported in pre-ampere architecture + if getSMVersion() < 80: + if context_fmha_flag == ContextFMHAType.enabled: + pytest.skip( + "ContextFMHAType is not supported in pre-ampere architecture" + ) + elif context_fmha_flag == ContextFMHAType.enabled_with_fp32_acc: + pytest.skip( + "ContextFMHAType with fp32 acc is not supported in pre-ampere architecture" + ) + elif dtype == 'bfloat16': + pytest.skip( + "bfloat16 is not supported in pre-ampere architecture") + + PRECHECKED_GOOD_RANDOM_SEEDS = [1, 4, 5, 8] + model = 'llama' + log_level = 'error' + use_plugin = True # gpt plugin + batch_size = 4 + beam_width = 1 + input_len = 4 + output_len = 2 + max_seq_len = input_len + output_len + world_size = 1 + head_size = 32 + rank = 0 + mistral_config = MistralConfig() + mistral_config.hidden_act = 'silu' + mistral_config.num_hidden_layers = 2 + mistral_config.max_position_embeddings = 64 + mistral_config.vocab_size = 128 + mistral_config.num_attention_heads = 2 * num_kv_heads + mistral_config.hidden_size = mistral_config.num_attention_heads * head_size + mistral_config.intermediate_size = (( + (mistral_config.hidden_size * 4 * 2 // 3) + head_size - 1) // + head_size) * head_size + mistral_config.num_key_value_heads = num_kv_heads + assert (mistral_config.num_attention_heads % + mistral_config.num_key_value_heads) == 0 + mistral_config.pad_token_id = self.PAD_TOKEN + mistral_config.eos_token_id = self.EOS_TOKEN + seed_idx = random.randint(0, len(PRECHECKED_GOOD_RANDOM_SEEDS) - 1) + torch.manual_seed(PRECHECKED_GOOD_RANDOM_SEEDS[seed_idx]) + hf_mistral = MistralForCausalLM(mistral_config).cuda() + runtime, _ = self._gen_tensorrt_llm_runtime( + log_level, dtype, world_size, rank, mistral_config, hf_mistral, + model, use_plugin, batch_size, beam_width, input_len, output_len, + use_refit, fast_building, context_fmha_flag, + enable_remove_input_padding) + key_value_cache_buffers = [] + head_size = mistral_config.hidden_size // mistral_config.num_attention_heads + for i in range(mistral_config.num_hidden_layers): + key_value_cache_buffers.append( + torch.zeros(( + batch_size, + 2, + mistral_config.num_key_value_heads, + max_seq_len, + head_size, + ), + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device='cuda')) + + # compare context + step = 0 + ctx_ids = torch.randint(100, (batch_size, input_len)).int().cuda() + ctx_context_lengths = input_len * torch.ones( + (batch_size), dtype=torch.int32, device='cuda') + ctx_position_ids = torch.tensor(range(input_len), + dtype=torch.int32).reshape([ + 1, input_len + ]).expand([batch_size, + input_len]).cuda() + ctx_last_token_ids = ctx_context_lengths.clone() + ctx_host_request_types = torch.tensor([0] * batch_size, + dtype=torch.int32) + + # We need sequence_lengths start as context_lengths for step 0, + # and it will be added one after each step. + sequence_length_buffer = ctx_context_lengths.detach().clone() + + with torch.no_grad(): + hf_outputs = hf_mistral.forward(ctx_ids) + torch.cuda.synchronize() + ref = hf_outputs.logits[:, -1, :] + + if enable_remove_input_padding: + ctx_ids = ctx_ids.view([1, batch_size * input_len]) + ctx_position_ids = ctx_position_ids.view( + [1, batch_size * input_len]) + ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, dim=0).int() + + cache_indirections = [ + torch.full(( + batch_size, + beam_width, + max_seq_len, + ), + 0, + dtype=torch.int32, + device='cuda'), + torch.full(( + batch_size, + beam_width, + max_seq_len, + ), + 0, + dtype=torch.int32, + device='cuda') + ] # ping-pong buffers + + ctx_buffer = { + 'input_ids': ctx_ids, + 'context_lengths': ctx_context_lengths, + 'position_ids': ctx_position_ids, + 'last_token_ids': ctx_last_token_ids, + 'cache_indirection': cache_indirections[0], + 'host_request_types': ctx_host_request_types, + } + if enable_remove_input_padding: + ctx_buffer['host_context_lengths'] = ctx_context_lengths.cpu() + + ctx_shape = {k: v.shape for k, v in ctx_buffer.items()} + + kv_shape = (batch_size, 2, mistral_config.num_key_value_heads, + max_seq_len, head_size) + for i in range(mistral_config.num_hidden_layers): + ctx_shape[f'past_key_value_{i}'] = kv_shape + ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] + ctx_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] + ctx_buffer[f'host_max_kv_cache_length_{i}'] = torch.tensor( + [max_seq_len], dtype=torch.int32) + ctx_shape[f'host_max_kv_cache_length_{i}'] = (1, ) + ctx_buffer['sequence_length'] = sequence_length_buffer + ctx_shape['sequence_length'] = ctx_buffer['sequence_length'].shape + ctx_shape['host_past_key_value_lengths'] = (batch_size, ) + ctx_buffer['host_past_key_value_lengths'] = torch.tensor( + [0] * batch_size, dtype=torch.int32) + + context = runtime.ctx_context + runtime._set_shape(context, ctx_shape) + runtime._set_buffer(context, ctx_buffer) + runtime._run(context) + torch.cuda.synchronize() + res = ctx_buffer['logits'] + + np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), + res.to(torch.float32).cpu().numpy(), + atol=0.12) + + # compare generation + step = 1 + step1_id = torch.randint(100, (batch_size, 1)).int().cuda() + gen_context_lengths = ctx_context_lengths.clone() + gen_position_ids = torch.ones_like(step1_id).int().cuda() * input_len + gen_last_token_ids = torch.zeros_like(gen_context_lengths).int().cuda() + gen_host_request_types = torch.tensor([1] * batch_size, + dtype=torch.int32) + + with torch.no_grad(): + hf_outputs = hf_mistral.forward( + step1_id, + past_key_values=hf_outputs.past_key_values, + use_cache=True) + torch.cuda.synchronize() + ref = hf_outputs.logits[:, -1, :] + + if enable_remove_input_padding: + step1_id = step1_id.view([1, batch_size]) + gen_position_ids = gen_position_ids.view([1, batch_size]) + gen_last_token_ids = torch.ones_like( + gen_context_lengths).int().cuda() + gen_last_token_ids = torch.cumsum(gen_last_token_ids, dim=0).int() + + step1_buffer = { + 'input_ids': step1_id, + 'context_lengths': gen_context_lengths, + 'position_ids': gen_position_ids, + 'last_token_ids': gen_last_token_ids, + 'host_request_types': gen_host_request_types, + 'cache_indirection': cache_indirections[1], + } + if enable_remove_input_padding: + step1_buffer['host_context_lengths'] = gen_context_lengths.cpu() + + step1_shape = {k: v.shape for k, v in step1_buffer.items()} + + for i in range(mistral_config.num_hidden_layers): + step1_shape[f'past_key_value_{i}'] = kv_shape + step1_shape[f'host_max_kv_cache_length_{i}'] = (1, ) + step1_shape['sequence_length'] = (batch_size, ) + step1_shape['host_past_key_value_lengths'] = (batch_size, ) + for i in range(mistral_config.num_hidden_layers): + step1_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] + step1_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] + step1_buffer[f'host_max_kv_cache_length_{i}'] = torch.tensor( + [max_seq_len], dtype=torch.int32) + step1_buffer[ + 'host_past_key_value_lengths'] = sequence_length_buffer.cpu() + sequence_length_buffer = torch.add(sequence_length_buffer, step) + step1_buffer['sequence_length'] = sequence_length_buffer + + context = runtime.context_1 + runtime._set_shape(context, step1_shape) + runtime._set_buffer(context, step1_buffer) + runtime._run(context) + torch.cuda.synchronize() + res = step1_buffer['logits'] + + np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), + res.to(torch.float32).cpu().numpy(), + atol=0.12) + + def get_loader_test_cases(): + test_cases = [] + test_cases.extend( + list( + product([ + ("mistral-7b-hf", "mistral-7b"), + ], [ + (1, 0), + (2, 0), + (2, 1), + ], [ + -1, + 0, + 1, + ]))) + return test_cases + + def loader_name_func(testcase_func, param_num, param): + expand_params = lambda params: '_'.join([ + expand_params(x) if isinstance(x, (list, tuple)) else str(x) + for x in params + ]) + name = expand_params(param.args) + return "%s_%s" % ( + testcase_func.__name__, + parameterized.to_safe_name(name), + ) + + @parameterized.expand(get_loader_test_cases, name_func=loader_name_func) + def test_loaders(self, paths, tp_info, emb_sharding_dim): + model_root = os.getenv("LLM_MODELS_ROOT") + if model_root is None: + pytest.skip("Skipping since real weights are unavailable.") + hf_path = Path(model_root, paths[0]) + mistralai_path = Path(model_root, paths[1]) + if not hf_path.exists(): + pytest.skip(f"Skipping since the path {hf_path} does not exist.") + if not mistralai_path.exists(): + pytest.skip( + f"Skipping since the path {mistralai_path} does not exist.") + + def print_corner(name, t: np.ndarray): + if len(t.shape) == 1: + tl = t[:2] + br = t[-2:] + elif len(t.shape) == 2: + tl = t[:2, :2] + br = t[-2:, -2:] + print(name, np.concatenate([tl, br]).flatten()) + + def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): + print_corner("vocab", m.vocab_embedding.weight._value) + print_corner("lm_head", m.lm_head.weight._value) + print_corner("ln_f", m.ln_f.weight._value) + print_corner("qkv", m.layers[0].attention.qkv.weight._value) + print_corner("gate", m.layers[0].mlp.gate.weight._value) + print_corner("inorm", m.layers[0].input_layernorm.weight._value) + print(flush=True) + return + + import tensorrt as trt + + from examples.llama.weight import load_from_meta_llama + tp_size = tp_info[0] + rank = tp_info[1] + dtype = "float16" + use_parallel_embedding = (emb_sharding_dim >= 0) + embedding_sharding_dim = abs(emb_sharding_dim) + hf_mistral = MistralForCausalLM.from_pretrained( + hf_path, + device_map={ + "model": "cpu", + "lm_head": "cpu" + }, # Load to CPU memory + torch_dtype="auto") + assert hf_mistral.config.torch_dtype == torch.float16 + kv_dtype = trt.float16 if hf_mistral.config.torch_dtype == torch.float16 else trt.float32 + max_context_length = 128 # for loader tests this value does not matter + tensorrt_llm_mistral_wHF = tensorrt_llm.models.LLaMAForCausalLM( + num_layers=hf_mistral.config.num_hidden_layers, + num_heads=hf_mistral.config.num_attention_heads, + num_kv_heads=hf_mistral.config.num_key_value_heads, + hidden_size=hf_mistral.config.hidden_size, + vocab_size=hf_mistral.config.vocab_size, + hidden_act=hf_mistral.config.hidden_act, + max_position_embeddings=hf_mistral.config.max_position_embeddings, + dtype=kv_dtype, + mlp_hidden_size=hf_mistral.config.intermediate_size, + position_embedding_type=PositionEmbeddingType.rope_gpt_neox, + mapping=tensorrt_llm.Mapping(world_size=tp_size, + rank=rank, + tp_size=tp_size), + use_parallel_embedding=use_parallel_embedding, + embedding_sharding_dim=embedding_sharding_dim) + # print_layers(tensorrt_llm_mistral_wHF) + load_from_hf_llama(tensorrt_llm_mistral_wHF, + hf_mistral, + mapping=tensorrt_llm.Mapping(world_size=tp_size, + rank=rank, + tp_size=tp_size), + dtype=dtype) + # print_layers(tensorrt_llm_mistral_wHF) + + tensorrt_llm_mistral_wMAI = tensorrt_llm.models.LLaMAForCausalLM( + num_layers=hf_mistral.config.num_hidden_layers, + num_heads=hf_mistral.config.num_attention_heads, + num_kv_heads=hf_mistral.config.num_key_value_heads, + hidden_size=hf_mistral.config.hidden_size, + vocab_size=hf_mistral.config.vocab_size, + hidden_act=hf_mistral.config.hidden_act, + max_position_embeddings=hf_mistral.config.max_position_embeddings, + dtype=kv_dtype, + mlp_hidden_size=hf_mistral.config.intermediate_size, + position_embedding_type=PositionEmbeddingType.rope_gpt_neox, + mapping=tensorrt_llm.Mapping(world_size=tp_size, + rank=rank, + tp_size=tp_size), + use_parallel_embedding=use_parallel_embedding, + embedding_sharding_dim=embedding_sharding_dim) + # print_layers(tensorrt_llm_mistral_wMAI) + load_from_meta_llama(tensorrt_llm_mistral_wMAI, + mistralai_path, + mapping=tensorrt_llm.Mapping(world_size=tp_size, + rank=rank, + tp_size=tp_size), + dtype=dtype) + # print_layers(tensorrt_llm_mistral_wMAI) + # token embedding + np.testing.assert_allclose( + tensorrt_llm_mistral_wHF.vocab_embedding.weight._value, + tensorrt_llm_mistral_wMAI.vocab_embedding.weight._value, + atol=1e-3) + # output + np.testing.assert_allclose( + tensorrt_llm_mistral_wHF.lm_head.weight._value, + tensorrt_llm_mistral_wMAI.lm_head.weight._value, + atol=1e-3) + # norm + np.testing.assert_allclose(tensorrt_llm_mistral_wHF.ln_f.weight._value, + tensorrt_llm_mistral_wMAI.ln_f.weight._value, + atol=1e-3) + # Checking all of the layers takes too much time, just check one random layer + l = np.random.randint(0, tensorrt_llm_mistral_wHF.num_layers) + # for l in range(tensorrt_llm_mistral_wHF.num_layers): + if l >= 0: + print(f"Checking Layer-{l} weights ...", flush=True) + # layer{l}.input_layernorm + np.testing.assert_allclose(tensorrt_llm_mistral_wHF.layers[l]. + input_layernorm.weight._value, + tensorrt_llm_mistral_wMAI.layers[l]. + input_layernorm.weight._value, + atol=1e-3) + # layer{l}.post_layernorm + np.testing.assert_allclose( + tensorrt_llm_mistral_wHF.layers[l].post_layernorm.weight._value, + tensorrt_llm_mistral_wMAI.layers[l].post_layernorm.weight. + _value, + atol=1e-3) + # layer{l}.mlp.gate + np.testing.assert_allclose( + tensorrt_llm_mistral_wHF.layers[l].mlp.gate.weight._value, + tensorrt_llm_mistral_wMAI.layers[l].mlp.gate.weight._value, + atol=1e-3) + # layer{l}.mlp.proj + np.testing.assert_allclose( + tensorrt_llm_mistral_wHF.layers[l].mlp.proj.weight._value, + tensorrt_llm_mistral_wMAI.layers[l].mlp.proj.weight._value, + atol=1e-3) + # layer{l}.mlp.fc + np.testing.assert_allclose( + tensorrt_llm_mistral_wHF.layers[l].mlp.fc.weight._value, + tensorrt_llm_mistral_wMAI.layers[l].mlp.fc.weight._value, + atol=1e-3) + # layer{l}.dense + np.testing.assert_allclose(tensorrt_llm_mistral_wHF.layers[l]. + attention.dense.weight._value, + tensorrt_llm_mistral_wMAI.layers[l]. + attention.dense.weight._value, + atol=1e-3) + # layer{l}.qkv + np.testing.assert_allclose( + tensorrt_llm_mistral_wHF.layers[l].attention.qkv.weight._value, + tensorrt_llm_mistral_wMAI.layers[l].attention.qkv.weight._value, + atol=1e-3) + return + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_graph_rewriter.py b/tests/test_graph_rewriter.py index e920bf484..0b5f9761d 100644 --- a/tests/test_graph_rewriter.py +++ b/tests/test_graph_rewriter.py @@ -72,6 +72,10 @@ def _construct_execution( name='host_past_key_value_lengths', shape=shape_dict['host_past_key_value_lengths'], dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_max_kv_cache_lengths_tensor = Tensor( + name='host_max_kv_cache_lengths', + shape=shape_dict['host_max_kv_cache_lengths'], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) context_lengths_tensor = Tensor( name='context_lengths', shape=shape_dict['context_lengths'], @@ -119,6 +123,7 @@ def _construct_execution( past_key_value=past_key_value_tensor, sequence_length=sequence_length_tensor, host_past_key_value_lengths=host_past_key_value_lengths_tensor, + host_max_kv_cache_lengths=host_max_kv_cache_lengths_tensor, context_lengths=context_lengths_tensor, cache_indirection=cache_indirection_tensor, host_request_types=host_request_types_tensor, @@ -129,7 +134,6 @@ def _construct_execution( max_context_length=in_len, rotary_embedding_dim=rotary_embedding_dim, position_embedding_type=position_embedding_type, - multi_block_mode=False, kv_orig_quant_scale=None, kv_quant_orig_scale=None, kv_cache_quant_mode=QuantMode.from_description( @@ -169,7 +173,8 @@ def _construct_execution( 'cache_indirection': (batch_size, 1, max_seq_len), 'input': (batch_size, in_len, hidden_size), 'output': (batch_size, in_len, hidden_size), - 'host_past_key_value_lengths': (batch_size, ) + 'host_past_key_value_lengths': (batch_size, ), + 'host_max_kv_cache_lengths': (1, ) } weight = torch.randn(shape_dict['weight'], diff --git a/tests/test_kv_cache_manager.py b/tests/test_kv_cache_manager.py index a65a29aae..649233ce6 100644 --- a/tests/test_kv_cache_manager.py +++ b/tests/test_kv_cache_manager.py @@ -282,6 +282,8 @@ def test_kv_cache_manager(self): manager = KVCacheManager(memory_pools=[memory_pool_1, memory_pool_2], blocks=blocks, tokens_per_block=tokens_per_block, + max_kv_cache_len=max_blocks_per_seq * + tokens_per_block, max_blocks_per_seq=max_blocks_per_seq) manager.add_sequence(GenerationSequence(seq_idx=0, batch_idx=0), 30) manager.add_sequence(GenerationSequence(seq_idx=1, batch_idx=1), 35) diff --git a/tests/test_layer.py b/tests/test_layer.py index a13c92fcb..c67819b36 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -691,6 +691,11 @@ def test_attention(self, host_past_key_value_lengths = torch.tensor([0] * batch_size, dtype=torch.int32) + # the max kv cache length for each layer. + # single tensor since we only have 1 layer here. + host_max_kv_cache_lengths = torch.tensor([max_seq_len], + dtype=torch.int32) + sequence_length = torch.full([batch_size], seq_len, dtype=torch.int32, @@ -762,6 +767,10 @@ def test_attention(self, name='host_past_key_value_lengths', shape=tuple(host_past_key_value_lengths.shape), dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_max_kv_cache_lengths_tensor = Tensor( + name='host_max_kv_cache_lengths', + shape=tuple(host_max_kv_cache_lengths.shape), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) cache_indirection_tensor = Tensor( name='cache_indirection', shape=tuple(cache_indirection.shape), @@ -792,6 +801,8 @@ def test_attention(self, past_key_value=[past_key_value_tensor], host_past_key_value_lengths= host_past_key_value_lengths_tensor, + host_max_kv_cache_lengths= + host_max_kv_cache_lengths_tensor, cache_indirection=cache_indirection_tensor), attention_params=AttentionParams( sequence_length=sequence_length_tensor, @@ -820,6 +831,7 @@ def test_attention(self, 'past_key_value': past_key_value, 'sequence_length': sequence_length, 'host_past_key_value_lengths': host_past_key_value_lengths, + 'host_max_kv_cache_lengths': host_max_kv_cache_lengths, 'context_lengths': context_lengths, 'host_request_types': host_request_types, 'cache_indirection': cache_indirection From a995c1d4066403779b17dc893cd80ee96e6bb13d Mon Sep 17 00:00:00 2001 From: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Date: Fri, 10 Nov 2023 05:34:26 -0800 Subject: [PATCH 2/2] update --- .../aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a | 3 +++ .../libtensorrt_llm_batch_manager_static.pre_cxx11.a | 3 +++ cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt | 3 +++ 3 files changed, 9 insertions(+) create mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a create mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a create mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a new file mode 100644 index 000000000..d62103f67 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a3ec9a8760d7b8ace53e420572aeb1b3607effc92fd56e13351fa4cbddbbb37 +size 1646420 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a new file mode 100644 index 000000000..fd17c5bc6 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:114348de9f6d1b3fa147f4fbccede10b7dbe13da6c5c86e968bb56bf05f9ec5a +size 1657852 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt new file mode 100644 index 000000000..f381e5401 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -0,0 +1,3 @@ +0776a4d41c06192c4ca0409ad8b837de libtensorrt_llm_batch_manager_static.a +c901725d5d278fd8d41f524f81fe5170 libtensorrt_llm_batch_manager_static.pre_cxx11.a +b3330c65d9b23d4f20c2b8d5a7c24cd45c910cd4 commit