Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Avoid ref capture in prefix cache contruction #2391

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,13 @@ class EngineImpl : public Engine {
}
EngineConfig engine_config = engine_config_res.Unwrap();
{
EngineState estate = n->estate_;
Array<Model> models = n->models_;
if (engine_config->prefix_cache_mode == PrefixCacheMode::kRadix) {
n->estate_->prefix_cache = PrefixCache::CreateRadixPrefixCache(
static_cast<size_t>(engine_config->prefix_cache_max_num_recycling_seqs),
std::function<void(int64_t)>([estate, models](int64_t seq_id) {
RemoveRequestFromModel(estate, seq_id, models);
estate->id_manager.RecycleId(seq_id);
}));
[engine_ptr = n.get()](int64_t seq_id) {
RemoveRequestFromModel(engine_ptr->estate_, seq_id, engine_ptr->models_);
engine_ptr->estate_->id_manager.RecycleId(seq_id);
});
} else if (engine_config->prefix_cache_mode == PrefixCacheMode::kDisable) {
n->estate_->prefix_cache = PrefixCache::CreateNoPrefixCache();
} else {
Expand Down
16 changes: 8 additions & 8 deletions cpp/serve/prefix_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ using namespace tvm::runtime;
class PrefixCacheImpl : public PrefixCacheObj {
public:
/*!
* \brief Contructor of paged radix tree.
* \brief Constructor of paged radix tree.
* \param max_num_recycling_seqs The maximum number of sequences in prefix cache.
* \param remove_callback The optional callback function to call when removing a sequence.
*/
explicit PrefixCacheImpl(size_t max_num_recycling_seqs, PrefixCacheRemoveCallback remove_callback)
: radix_tree_(PagedRadixTree::Create()),
max_num_recycling_seqs_(max_num_recycling_seqs),
remove_callback_(remove_callback) {
remove_callback_(std::move(remove_callback)) {
recycling_seq_lrus_.clear();
reversed_recycling_seq_lrus_.clear();
seq_states_.clear();
Expand Down Expand Up @@ -64,7 +64,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
// The reusage of recycling sequences logic is different between with/without sliding window
// enabled.
if (sliding_window_size != -1) {
// If sliding window enabled, the reusage of recycling sequences should be limitted to exactly
// If sliding window enabled, the reusage of recycling sequences should be limited to exactly
// matched. And no rolling back is allowed due to the sliding window.
for (int64_t matched_seq_id : matched_seqs) {
if (seq_states_.at(matched_seq_id) == SequenceState::kRecycling &&
Expand Down Expand Up @@ -105,7 +105,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
shortest_recycling_seq_length - matched_offset};
}
// No reusage of recycling sequence, fallback to forking matched sequence. Currently, we only
// fork from sequence without sliding window, due to current paged KVCache implmentation.
// fork from sequence without sliding window, due to current paged KVCache implementation.
size_t longest_forking_offset = 0;
int64_t longest_forking_seq_id = -1;
for (int64_t matched_seq_id : matched_seqs) {
Expand Down Expand Up @@ -137,7 +137,7 @@ class PrefixCacheImpl : public PrefixCacheObj {

/*!
* \brief Extend a sequence with new tokenized sequence suffix.
* \param seq_id The sequence to be extneded.
* \param seq_id The sequence to be extended.
* \param tokens The tokens of tokenized sequence suffix to extend.
* \throw Error if the given sequence id is not valid or active.
*/
Expand Down Expand Up @@ -271,7 +271,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
*/
std::unordered_map<int64_t, size_t> recycling_seq_lrus_;
/*!
* \brief The map from LRU time stamps to sequence, used to find the sequence with earlist LRU
* \brief The map from LRU time stamps to sequence, used to find the sequence with earliest LRU
* time stamp.
*/
std::unordered_map<size_t, int64_t> reversed_recycling_seq_lrus_;
Expand Down Expand Up @@ -326,7 +326,7 @@ class NoPrefixCache : public PrefixCacheObj {

/*!
* \brief Extend a sequence with new tokenized sequence suffix.
* \param seq_id The sequence to be extneded.
* \param seq_id The sequence to be extended.
* \param tokens The tokens of tokenized sequence suffix to extend.
* \throw Error if called since this should never be called.
*/
Expand Down Expand Up @@ -390,7 +390,7 @@ TVM_REGISTER_OBJECT_TYPE(NoPrefixCache);
PrefixCache PrefixCache::CreateRadixPrefixCache(size_t max_num_recycling_seqs,
PrefixCacheRemoveCallback remove_callback) {
ObjectPtr<PrefixCacheImpl> n =
make_object<PrefixCacheImpl>(max_num_recycling_seqs, remove_callback);
make_object<PrefixCacheImpl>(max_num_recycling_seqs, std::move(remove_callback));
return PrefixCache(std::move(n));
}

Expand Down
Loading