Skip to content

Commit

Permalink
[Fix] Avoid ref capture in prefix cache contruction (#2391)
Browse files Browse the repository at this point in the history
This PR fixes the prefix cache construction in Engine, which captured
the references of models and thus led to the GPU memory unable to
be freed when the Engine is destructed.
  • Loading branch information
MasterJH5574 authored May 23, 2024
1 parent fbe3b9e commit 9631cc3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
10 changes: 4 additions & 6 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,13 @@ class EngineImpl : public Engine {
}
EngineConfig engine_config = engine_config_res.Unwrap();
{
EngineState estate = n->estate_;
Array<Model> models = n->models_;
if (engine_config->prefix_cache_mode == PrefixCacheMode::kRadix) {
n->estate_->prefix_cache = PrefixCache::CreateRadixPrefixCache(
static_cast<size_t>(engine_config->prefix_cache_max_num_recycling_seqs),
std::function<void(int64_t)>([estate, models](int64_t seq_id) {
RemoveRequestFromModel(estate, seq_id, models);
estate->id_manager.RecycleId(seq_id);
}));
[engine_ptr = n.get()](int64_t seq_id) {
RemoveRequestFromModel(engine_ptr->estate_, seq_id, engine_ptr->models_);
engine_ptr->estate_->id_manager.RecycleId(seq_id);
});
} else if (engine_config->prefix_cache_mode == PrefixCacheMode::kDisable) {
n->estate_->prefix_cache = PrefixCache::CreateNoPrefixCache();
} else {
Expand Down
16 changes: 8 additions & 8 deletions cpp/serve/prefix_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ using namespace tvm::runtime;
class PrefixCacheImpl : public PrefixCacheObj {
public:
/*!
* \brief Contructor of paged radix tree.
* \brief Constructor of paged radix tree.
* \param max_num_recycling_seqs The maximum number of sequences in prefix cache.
* \param remove_callback The optional callback function to call when removing a sequence.
*/
explicit PrefixCacheImpl(size_t max_num_recycling_seqs, PrefixCacheRemoveCallback remove_callback)
: radix_tree_(PagedRadixTree::Create()),
max_num_recycling_seqs_(max_num_recycling_seqs),
remove_callback_(remove_callback) {
remove_callback_(std::move(remove_callback)) {
recycling_seq_lrus_.clear();
reversed_recycling_seq_lrus_.clear();
seq_states_.clear();
Expand Down Expand Up @@ -64,7 +64,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
// The reusage of recycling sequences logic is different between with/without sliding window
// enabled.
if (sliding_window_size != -1) {
// If sliding window enabled, the reusage of recycling sequences should be limitted to exactly
// If sliding window enabled, the reusage of recycling sequences should be limited to exactly
// matched. And no rolling back is allowed due to the sliding window.
for (int64_t matched_seq_id : matched_seqs) {
if (seq_states_.at(matched_seq_id) == SequenceState::kRecycling &&
Expand Down Expand Up @@ -105,7 +105,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
shortest_recycling_seq_length - matched_offset};
}
// No reusage of recycling sequence, fallback to forking matched sequence. Currently, we only
// fork from sequence without sliding window, due to current paged KVCache implmentation.
// fork from sequence without sliding window, due to current paged KVCache implementation.
size_t longest_forking_offset = 0;
int64_t longest_forking_seq_id = -1;
for (int64_t matched_seq_id : matched_seqs) {
Expand Down Expand Up @@ -137,7 +137,7 @@ class PrefixCacheImpl : public PrefixCacheObj {

/*!
* \brief Extend a sequence with new tokenized sequence suffix.
* \param seq_id The sequence to be extneded.
* \param seq_id The sequence to be extended.
* \param tokens The tokens of tokenized sequence suffix to extend.
* \throw Error if the given sequence id is not valid or active.
*/
Expand Down Expand Up @@ -271,7 +271,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
*/
std::unordered_map<int64_t, size_t> recycling_seq_lrus_;
/*!
* \brief The map from LRU time stamps to sequence, used to find the sequence with earlist LRU
* \brief The map from LRU time stamps to sequence, used to find the sequence with earliest LRU
* time stamp.
*/
std::unordered_map<size_t, int64_t> reversed_recycling_seq_lrus_;
Expand Down Expand Up @@ -326,7 +326,7 @@ class NoPrefixCache : public PrefixCacheObj {

/*!
* \brief Extend a sequence with new tokenized sequence suffix.
* \param seq_id The sequence to be extneded.
* \param seq_id The sequence to be extended.
* \param tokens The tokens of tokenized sequence suffix to extend.
* \throw Error if called since this should never be called.
*/
Expand Down Expand Up @@ -390,7 +390,7 @@ TVM_REGISTER_OBJECT_TYPE(NoPrefixCache);
PrefixCache PrefixCache::CreateRadixPrefixCache(size_t max_num_recycling_seqs,
PrefixCacheRemoveCallback remove_callback) {
ObjectPtr<PrefixCacheImpl> n =
make_object<PrefixCacheImpl>(max_num_recycling_seqs, remove_callback);
make_object<PrefixCacheImpl>(max_num_recycling_seqs, std::move(remove_callback));
return PrefixCache(std::move(n));
}

Expand Down

0 comments on commit 9631cc3

Please sign in to comment.