From e6a78e0c849a44136775f4a46027874149a8f81b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 23 May 2024 11:21:56 -0400 Subject: [PATCH] [Fix] Avoid ref capture in prefix cache contruction This PR fixes the prefix cache construction in Engine, which captured the references of models and thus led to the GPU memory unable to be freed when the Engine is destructed. --- cpp/serve/engine.cc | 10 ++++------ cpp/serve/prefix_cache.cc | 16 ++++++++-------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index b3c9c29d22..9c721c0813 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -99,15 +99,13 @@ class EngineImpl : public Engine { } EngineConfig engine_config = engine_config_res.Unwrap(); { - EngineState estate = n->estate_; - Array models = n->models_; if (engine_config->prefix_cache_mode == PrefixCacheMode::kRadix) { n->estate_->prefix_cache = PrefixCache::CreateRadixPrefixCache( static_cast(engine_config->prefix_cache_max_num_recycling_seqs), - std::function([estate, models](int64_t seq_id) { - RemoveRequestFromModel(estate, seq_id, models); - estate->id_manager.RecycleId(seq_id); - })); + [engine_ptr = n.get()](int64_t seq_id) { + RemoveRequestFromModel(engine_ptr->estate_, seq_id, engine_ptr->models_); + engine_ptr->estate_->id_manager.RecycleId(seq_id); + }); } else if (engine_config->prefix_cache_mode == PrefixCacheMode::kDisable) { n->estate_->prefix_cache = PrefixCache::CreateNoPrefixCache(); } else { diff --git a/cpp/serve/prefix_cache.cc b/cpp/serve/prefix_cache.cc index 3362a0dbaf..bb067942a8 100644 --- a/cpp/serve/prefix_cache.cc +++ b/cpp/serve/prefix_cache.cc @@ -18,14 +18,14 @@ using namespace tvm::runtime; class PrefixCacheImpl : public PrefixCacheObj { public: /*! - * \brief Contructor of paged radix tree. + * \brief Constructor of paged radix tree. * \param max_num_recycling_seqs The maximum number of sequences in prefix cache. * \param remove_callback The optional callback function to call when removing a sequence. */ explicit PrefixCacheImpl(size_t max_num_recycling_seqs, PrefixCacheRemoveCallback remove_callback) : radix_tree_(PagedRadixTree::Create()), max_num_recycling_seqs_(max_num_recycling_seqs), - remove_callback_(remove_callback) { + remove_callback_(std::move(remove_callback)) { recycling_seq_lrus_.clear(); reversed_recycling_seq_lrus_.clear(); seq_states_.clear(); @@ -64,7 +64,7 @@ class PrefixCacheImpl : public PrefixCacheObj { // The reusage of recycling sequences logic is different between with/without sliding window // enabled. if (sliding_window_size != -1) { - // If sliding window enabled, the reusage of recycling sequences should be limitted to exactly + // If sliding window enabled, the reusage of recycling sequences should be limited to exactly // matched. And no rolling back is allowed due to the sliding window. for (int64_t matched_seq_id : matched_seqs) { if (seq_states_.at(matched_seq_id) == SequenceState::kRecycling && @@ -105,7 +105,7 @@ class PrefixCacheImpl : public PrefixCacheObj { shortest_recycling_seq_length - matched_offset}; } // No reusage of recycling sequence, fallback to forking matched sequence. Currently, we only - // fork from sequence without sliding window, due to current paged KVCache implmentation. + // fork from sequence without sliding window, due to current paged KVCache implementation. size_t longest_forking_offset = 0; int64_t longest_forking_seq_id = -1; for (int64_t matched_seq_id : matched_seqs) { @@ -137,7 +137,7 @@ class PrefixCacheImpl : public PrefixCacheObj { /*! * \brief Extend a sequence with new tokenized sequence suffix. - * \param seq_id The sequence to be extneded. + * \param seq_id The sequence to be extended. * \param tokens The tokens of tokenized sequence suffix to extend. * \throw Error if the given sequence id is not valid or active. */ @@ -271,7 +271,7 @@ class PrefixCacheImpl : public PrefixCacheObj { */ std::unordered_map recycling_seq_lrus_; /*! - * \brief The map from LRU time stamps to sequence, used to find the sequence with earlist LRU + * \brief The map from LRU time stamps to sequence, used to find the sequence with earliest LRU * time stamp. */ std::unordered_map reversed_recycling_seq_lrus_; @@ -326,7 +326,7 @@ class NoPrefixCache : public PrefixCacheObj { /*! * \brief Extend a sequence with new tokenized sequence suffix. - * \param seq_id The sequence to be extneded. + * \param seq_id The sequence to be extended. * \param tokens The tokens of tokenized sequence suffix to extend. * \throw Error if called since this should never be called. */ @@ -390,7 +390,7 @@ TVM_REGISTER_OBJECT_TYPE(NoPrefixCache); PrefixCache PrefixCache::CreateRadixPrefixCache(size_t max_num_recycling_seqs, PrefixCacheRemoveCallback remove_callback) { ObjectPtr n = - make_object(max_num_recycling_seqs, remove_callback); + make_object(max_num_recycling_seqs, std::move(remove_callback)); return PrefixCache(std::move(n)); }