-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
open source 4dbf696ae9b74a26829d120b67ab8443d70c8e58 (#2297)
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvanesh.sridharan@sprinklr.com> Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
- Loading branch information
1 parent
48686bc
commit 8681b3a
Showing
205 changed files
with
5,553 additions
and
1,792 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
187 changes: 187 additions & 0 deletions
187
cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
/* | ||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "common.h" | ||
#include "tensorrt_llm/batch_manager/llmRequest.h" | ||
#include "tensorrt_llm/common/algorithm.h" | ||
#include "tensorrt_llm/runtime/common.h" | ||
#include <variant> | ||
|
||
namespace tensorrt_llm::batch_manager | ||
{ | ||
namespace kv_cache_manager | ||
{ | ||
class KVCacheManager; | ||
} | ||
class BasePeftCacheManager; | ||
} // namespace tensorrt_llm::batch_manager | ||
|
||
namespace tensorrt_llm::batch_manager | ||
{ | ||
|
||
using tensorrt_llm::runtime::SizeType32; | ||
|
||
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity. | ||
/// Depending on the CapacitySchedulerPolicy it will schedule already started and new requests, | ||
/// or even pause previously started requests. | ||
class BaseCapacityScheduler | ||
{ | ||
public: | ||
explicit BaseCapacityScheduler(LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) | ||
: mNoScheduleUntilState(noScheduleUntilState) | ||
, mNoScheduleAfterState(noScheduleAfterState) | ||
{ | ||
} | ||
|
||
[[nodiscard]] LlmRequestState constexpr getNoScheduleUntilState() const noexcept | ||
{ | ||
return mNoScheduleUntilState; | ||
} | ||
|
||
[[nodiscard]] LlmRequestState constexpr getNoScheduleAfterState() const noexcept | ||
{ | ||
return mNoScheduleAfterState; | ||
} | ||
|
||
private: | ||
/// The state until/after which the scheduler should not schedule requests | ||
LlmRequestState mNoScheduleUntilState; | ||
LlmRequestState mNoScheduleAfterState; | ||
}; | ||
|
||
/// @brief Schedule up to maxNumRequests requests | ||
class MaxRequestsScheduler : public BaseCapacityScheduler | ||
{ | ||
public: | ||
explicit MaxRequestsScheduler(SizeType32 maxNumRequests, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager, | ||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, | ||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE); | ||
|
||
/// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests | ||
/// to update for this current iteration, and a map of requests to pause | ||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const; | ||
|
||
private: | ||
SizeType32 mMaxNumRequests; | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr}; | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr}; | ||
}; | ||
|
||
/// @brief Schedule requests using the MAX_UTILIZATION policy | ||
/// @details Try reserving resources to advance requests by one step, | ||
/// may pause previously started requests. | ||
class MaxUtilizationScheduler : public BaseCapacityScheduler | ||
{ | ||
public: | ||
MaxUtilizationScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager, | ||
std::shared_ptr<BasePeftCacheManager> peftCacheManager, bool manyMicroBatches, | ||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, | ||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE); | ||
|
||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const; | ||
|
||
private: | ||
/// @return {fitsKvCache, fitsPeft} | ||
std::pair<bool, bool> trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req, | ||
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages, | ||
std::unordered_set<uint64_t>& seenTaskIds) const; | ||
|
||
SizeType32 mMaxNumRequests; | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr}; | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr}; | ||
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr}; | ||
/// @brief Boolean that indicates if multiple micro batches might be in flight | ||
bool mManyMicroBatches; | ||
}; | ||
|
||
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy | ||
class GuaranteedNoEvictScheduler : public BaseCapacityScheduler | ||
{ | ||
public: | ||
GuaranteedNoEvictScheduler(SizeType32 maxNumRequests, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager, | ||
std::shared_ptr<BasePeftCacheManager> peftCacheManager, | ||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, | ||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE); | ||
|
||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const; | ||
|
||
protected: | ||
[[nodiscard]] std::tuple<RequestVector, RequestVector> forwardImpl( | ||
RequestList const& activeRequests, bool staticBatchScheduling) const; | ||
|
||
private: | ||
SizeType32 mMaxNumRequests; | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr}; | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr}; | ||
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr}; | ||
}; | ||
|
||
/// @brief Schedule requests using the STATIC_BATCH policy | ||
class StaticBatchScheduler : public GuaranteedNoEvictScheduler | ||
{ | ||
public: | ||
StaticBatchScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager, | ||
std::shared_ptr<BasePeftCacheManager> peftCacheManager, | ||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, | ||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE); | ||
|
||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const; | ||
}; | ||
|
||
class CapacityScheduler : public Algorithm | ||
{ | ||
public: | ||
constexpr static auto name{"CapacityScheduler"}; | ||
|
||
CapacityScheduler() = default; | ||
|
||
CapacityScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager, | ||
std::shared_ptr<BasePeftCacheManager> peftCacheManager, | ||
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false, | ||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, | ||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE); | ||
|
||
static CapacityScheduler make(SizeType32 maxNumRequests, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager, | ||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager, | ||
std::shared_ptr<BasePeftCacheManager> peftCacheManager, | ||
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false, | ||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT, | ||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE) | ||
{ | ||
return CapacityScheduler{maxNumRequests, std::move(kvCacheManager), std::move(crossKvCacheManager), | ||
std::move(peftCacheManager), capacitySchedulerPolicy, manyMicroBatches, noScheduleUntilState, | ||
noScheduleAfterState}; | ||
} | ||
|
||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const; | ||
|
||
private: | ||
std::variant<std::monostate, MaxRequestsScheduler, MaxUtilizationScheduler, GuaranteedNoEvictScheduler, | ||
StaticBatchScheduler> | ||
mScheduler; | ||
}; | ||
|
||
} // namespace tensorrt_llm::batch_manager |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
/* | ||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "tensorrt_llm/runtime/common.h" | ||
#include <cstdint> | ||
#include <list> | ||
#include <memory> | ||
#include <unordered_set> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace tensorrt_llm::executor | ||
{ | ||
class RequestWithId; | ||
} | ||
|
||
namespace tensorrt_llm::batch_manager | ||
{ | ||
class LlmRequest; | ||
|
||
using RequestList = std::list<std::shared_ptr<LlmRequest>>; | ||
using RequestIdType = std::uint64_t; | ||
using RequestVector = std::vector<std::shared_ptr<LlmRequest>>; | ||
using ReqIdsSet = std::unordered_set<RequestIdType>; | ||
|
||
class ScheduledRequests | ||
{ | ||
public: | ||
/// @brief context phase requests (for decoder-only models) or encoder phase requests (for encoder-decoder models | ||
/// and encoder-only models) | ||
RequestVector contextRequests; | ||
|
||
/// @brief generation phase requests (for decoder-only models) or empty for others | ||
RequestVector generationRequests; | ||
|
||
ScheduledRequests() = default; | ||
|
||
explicit ScheduledRequests(RequestVector contextRequests, RequestVector generationRequests) | ||
: contextRequests{std::move(contextRequests)} | ||
, generationRequests{std::move(generationRequests)} | ||
{ | ||
} | ||
|
||
[[nodiscard]] bool empty() const | ||
{ | ||
return contextRequests.empty() && generationRequests.empty(); | ||
} | ||
|
||
[[nodiscard]] std::size_t size() const | ||
{ | ||
return contextRequests.size() + generationRequests.size(); | ||
} | ||
}; | ||
|
||
class BatchState | ||
{ | ||
public: | ||
BatchState() = default; | ||
|
||
BatchState(runtime::SizeType32 numCtxRequests, runtime::SizeType32 numGenRequests, runtime::SizeType32 numTokens, | ||
runtime::SizeType32 maxKvCacheLength) | ||
: mNumCtxRequests{numCtxRequests} | ||
, mNumGenRequests{numGenRequests} | ||
, mNumTokens{numTokens} | ||
, mMaxKvCacheLength{maxKvCacheLength} | ||
{ | ||
} | ||
|
||
bool isAnyContext() const | ||
{ | ||
return mNumCtxRequests > 0; | ||
} | ||
|
||
bool operator==(BatchState const& other) const | ||
{ | ||
return mNumCtxRequests == other.mNumCtxRequests && mNumGenRequests == other.mNumGenRequests | ||
&& mNumTokens == other.mNumTokens && mMaxKvCacheLength == other.mMaxKvCacheLength; | ||
} | ||
|
||
size_t hash() const | ||
{ | ||
size_t h1 = std::hash<runtime::SizeType32>{}(mNumCtxRequests); | ||
size_t h2 = std::hash<runtime::SizeType32>{}(mNumGenRequests); | ||
size_t h3 = std::hash<runtime::SizeType32>{}(mNumTokens); | ||
size_t h4 = std::hash<runtime::SizeType32>{}(mMaxKvCacheLength); | ||
return h1 ^ h2 ^ h3 ^ h4; | ||
} | ||
|
||
runtime::SizeType32 mNumCtxRequests; | ||
runtime::SizeType32 mNumGenRequests; | ||
runtime::SizeType32 mNumTokens; | ||
runtime::SizeType32 mMaxKvCacheLength; | ||
}; | ||
|
||
struct BatchStateHash | ||
{ | ||
size_t operator()(BatchState const& bs) const | ||
{ | ||
return bs.hash(); | ||
} | ||
}; | ||
|
||
} // namespace tensorrt_llm::batch_manager |
Oops, something went wrong.