From 0e3c8fbc78277cfacde2d9da4be801b5c361dcaf Mon Sep 17 00:00:00 2001 From: jothomson Date: Fri, 16 May 2025 14:03:07 -0700 Subject: [PATCH 1/6] Restructure block registration --- lib/llm/src/block_manager/block.rs | 8 ++-- lib/llm/src/block_manager/block/registry.rs | 49 +++++++++++++++----- lib/llm/src/block_manager/block/state.rs | 14 +++--- lib/llm/src/block_manager/offload.rs | 8 ++-- lib/llm/src/block_manager/offload/pending.rs | 2 +- lib/llm/src/block_manager/pool.rs | 26 ++++++++--- lib/llm/src/block_manager/pool/inactive.rs | 11 +++-- lib/llm/src/block_manager/pool/state.rs | 17 ++++--- lib/llm/src/block_manager/state.rs | 11 ++++- 9 files changed, 101 insertions(+), 45 deletions(-) diff --git a/lib/llm/src/block_manager/block.rs b/lib/llm/src/block_manager/block.rs index 479e51f129..9dd2881f11 100644 --- a/lib/llm/src/block_manager/block.rs +++ b/lib/llm/src/block_manager/block.rs @@ -21,6 +21,8 @@ pub mod view; pub use crate::tokens::TokenBlockError; pub use anyhow::Result; use nixl_sys::NixlDescriptor; + +pub use registry::{GlobalRegistry, RegistrationHandle}; pub use state::{BlockState, BlockStateInvalid}; use crate::block_manager::{ @@ -176,7 +178,7 @@ impl Block { pub fn sequence_hash(&self) -> Result { match self.state() { BlockState::Complete(state) => Ok(state.token_block().sequence_hash()), - BlockState::Registered(state) => Ok(state.sequence_hash()), + BlockState::Registered(state, _) => Ok(state.sequence_hash()), _ => Err(BlockError::InvalidState( "Block is not complete".to_string(), )), @@ -250,14 +252,14 @@ pub(crate) trait PrivateBlockExt { fn register( &mut self, registry: &mut registry::BlockRegistry, - ) -> Result; + ) -> Result, registry::BlockRegistationError>; } impl PrivateBlockExt for Block { fn register( &mut self, registry: &mut registry::BlockRegistry, - ) -> Result { + ) -> Result, registry::BlockRegistationError> { registry.register_block(&mut self.state) } } diff --git a/lib/llm/src/block_manager/block/registry.rs b/lib/llm/src/block_manager/block/registry.rs index 1e0d9619ae..652d86bc52 100644 --- a/lib/llm/src/block_manager/block/registry.rs +++ b/lib/llm/src/block_manager/block/registry.rs @@ -15,7 +15,7 @@ use std::{ collections::HashMap, - sync::{Arc, Weak}, + sync::{Arc, Mutex, Weak}, }; use super::super::events::{EventManager, EventReleaseManager, PublishHandle}; @@ -39,17 +39,21 @@ pub enum BlockRegistationError { #[error("Failed to unregister block: {0}")] pub struct UnregisterFailure(SequenceHash); +pub type GlobalRegistry = Arc>>>; + #[derive()] pub struct BlockRegistry { - blocks: HashMap>, + blocks: HashMap>, event_manager: Arc, + global_pool: GlobalRegistry, } impl BlockRegistry { - pub fn new(event_manager: Arc) -> Self { + pub fn new(event_manager: Arc, global_pool: GlobalRegistry) -> Self { Self { blocks: HashMap::new(), event_manager, + global_pool, } } @@ -65,7 +69,7 @@ impl BlockRegistry { pub fn register_block( &mut self, block_state: &mut BlockState, - ) -> Result { + ) -> Result, BlockRegistationError> { match block_state { BlockState::Reset => Err(BlockRegistationError::InvalidState( "Block is in Reset state".to_string(), @@ -82,21 +86,42 @@ impl BlockRegistry { } } - // Create the [RegistrationHandle] and [PublishHandle] - let publish_handle = - Self::create_publish_handle(state.token_block(), self.event_manager.clone()); - let reg_handle = publish_handle.remove_handle(); + let mut publish_handle = None; + + let block_handle = Arc::new(()); + + let reg_handle = 'reg_block: { + let mut global_pool = self.global_pool.lock().unwrap(); + + if let Some(handle) = global_pool.get(&sequence_hash) { + if let Some(handle) = handle.upgrade() { + break 'reg_block handle; + } + } + + publish_handle = Some(Self::create_publish_handle( + state.token_block(), + self.event_manager.clone(), + )); + let reg_handle = publish_handle.as_ref().unwrap().remove_handle(); + + global_pool.insert(sequence_hash, Arc::downgrade(®_handle)); + + reg_handle + }; - // Insert the [RegistrationHandle] into the registry self.blocks - .insert(sequence_hash, Arc::downgrade(®_handle)); + .insert(sequence_hash, Arc::downgrade(&block_handle)); // Update the [BlockState] to [BlockState::Registered] - let _ = std::mem::replace(block_state, BlockState::Registered(reg_handle)); + let _ = std::mem::replace( + block_state, + BlockState::Registered(reg_handle, block_handle), + ); Ok(publish_handle) } - BlockState::Registered(registered) => Err( + BlockState::Registered(registered, _) => Err( BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()), ), } diff --git a/lib/llm/src/block_manager/block/state.rs b/lib/llm/src/block_manager/block/state.rs index c22bb3952e..51dce7e128 100644 --- a/lib/llm/src/block_manager/block/state.rs +++ b/lib/llm/src/block_manager/block/state.rs @@ -30,7 +30,7 @@ pub enum BlockState { Reset, Partial(PartialState), Complete(CompleteState), - Registered(Arc), + Registered(Arc, Arc<()>), } impl BlockState { @@ -109,7 +109,7 @@ impl BlockState { BlockState::Reset => Some(0), BlockState::Partial(state) => Some(state.block.len()), BlockState::Complete(state) => Some(state.token_block.tokens().len()), - BlockState::Registered(_) => None, + BlockState::Registered(_, _) => None, } } @@ -126,15 +126,15 @@ impl BlockState { match self { BlockState::Reset => true, BlockState::Partial(state) => state.block.is_empty(), - BlockState::Complete(_) => false, // Always full - BlockState::Registered(_) => false, // Always full + BlockState::Complete(_) => false, // Always full + BlockState::Registered(_, _) => false, // Always full } } /// Returns a reference to the underlying TokenBlock if the state is Complete or Registered. pub fn tokens(&self) -> Option<&Tokens> { match self { - BlockState::Reset | BlockState::Registered(_) => None, + BlockState::Reset | BlockState::Registered(_, _) => None, BlockState::Partial(state) => Some(state.block.tokens()), BlockState::Complete(state) => Some(state.token_block.tokens()), } @@ -147,12 +147,12 @@ impl BlockState { /// Returns true if the block is in the complete or registered state pub fn is_complete(&self) -> bool { - matches!(self, BlockState::Complete(_) | BlockState::Registered(_)) + matches!(self, BlockState::Complete(_) | BlockState::Registered(_, _)) } /// Returns true if the block is in the registered state pub fn is_registered(&self) -> bool { - matches!(self, BlockState::Registered(_state)) + matches!(self, BlockState::Registered(_state, _)) } } diff --git a/lib/llm/src/block_manager/offload.rs b/lib/llm/src/block_manager/offload.rs index b57366cffc..15aa570594 100644 --- a/lib/llm/src/block_manager/offload.rs +++ b/lib/llm/src/block_manager/offload.rs @@ -334,7 +334,7 @@ impl OffloadManager { priority: u64, ) -> core::result::Result<(), BlockPoolError> { match block.state() { - BlockState::Registered(_) => {} + BlockState::Registered(_, _) => {} _ => { return Err(BlockPoolError::BlockError(BlockError::InvalidState( "Block is not registered.".to_string(), @@ -397,7 +397,7 @@ impl OffloadManager { ) -> BlockResult { for block in &blocks { match block.state() { - BlockState::Registered(_) => {} + BlockState::Registered(_, _) => {} _ => { return Err(BlockPoolError::BlockError(BlockError::InvalidState( "Block is not registered.".to_string(), @@ -857,7 +857,7 @@ mod tests { // Check that the block is registered. assert!(matches!( onboarded_blocks[0].state(), - BlockState::Registered(_) + BlockState::Registered(_, _) )); check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?; @@ -940,7 +940,7 @@ mod tests { ); assert!(matches!( onboarded_blocks[0].state(), - BlockState::Registered(_) + BlockState::Registered(_, _) )); check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?; diff --git a/lib/llm/src/block_manager/offload/pending.rs b/lib/llm/src/block_manager/offload/pending.rs index ba2792d056..41bc60c3a0 100644 --- a/lib/llm/src/block_manager/offload/pending.rs +++ b/lib/llm/src/block_manager/offload/pending.rs @@ -118,7 +118,7 @@ fn transfer_metadata( target: &mut MutableBlock, ) -> Result<()> { // Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail. - if let BlockState::Registered(reg_handle) = source.state() { + if let BlockState::Registered(reg_handle, _) = source.state() { // Bring the block back to the 'Reset' state. target.reset(); // Transfer metadata. diff --git a/lib/llm/src/block_manager/pool.rs b/lib/llm/src/block_manager/pool.rs index 7f3340237f..bfd488a90b 100644 --- a/lib/llm/src/block_manager/pool.rs +++ b/lib/llm/src/block_manager/pool.rs @@ -70,6 +70,7 @@ pub use super::block::{ImmutableBlock, MutableBlock}; use super::block::{ nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata, + GlobalRegistry, }; use super::events::{EventManager, NullEventManager}; use super::storage::Storage; @@ -116,15 +117,18 @@ pub struct BlockPoolArgs { #[builder(default)] blocks: Vec>, + + #[builder(default)] + global_pool: GlobalRegistry, } impl BlockPoolArgsBuilder { pub fn build(self) -> anyhow::Result> { let args = self.build_internal()?; - let (event_manager, cancel_token, blocks) = args.dissolve(); + let (event_manager, cancel_token, blocks, global_pool) = args.dissolve(); tracing::info!("building block pool"); - let pool = BlockPool::new(event_manager, cancel_token, blocks); + let pool = BlockPool::new(event_manager, cancel_token, blocks, global_pool); Ok(pool) } @@ -200,9 +204,10 @@ impl BlockPool { event_manager: Arc, cancel_token: CancellationToken, blocks: Vec>, + global_pool: GlobalRegistry, ) -> Self { let (pool, progress_engine) = - Self::with_progress_engine(event_manager, cancel_token, blocks); + Self::with_progress_engine(event_manager, cancel_token, blocks, global_pool); // pool.runtime.handle().spawn(async move { // let mut progress_engine = progress_engine; @@ -239,12 +244,19 @@ impl BlockPool { event_manager: Arc, cancel_token: CancellationToken, blocks: Vec>, + global_pool: GlobalRegistry, ) -> (Self, ProgressEngine) { let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel(); let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel(); - let progress_engine = - ProgressEngine::::new(event_manager, priority_rx, ctrl_rx, cancel_token, blocks); + let progress_engine = ProgressEngine::::new( + event_manager, + priority_rx, + ctrl_rx, + cancel_token, + blocks, + global_pool, + ); ( Self { @@ -468,9 +480,9 @@ mod tests { self, ) -> anyhow::Result<(BlockPool, ProgressEngine)> { let args = self.build_internal()?; - let (event_manager, cancel_token, blocks) = args.dissolve(); + let (event_manager, cancel_token, blocks, global_pool) = args.dissolve(); let (pool, progress_engine) = - BlockPool::with_progress_engine(event_manager, cancel_token, blocks); + BlockPool::with_progress_engine(event_manager, cancel_token, blocks, global_pool); Ok((pool, progress_engine)) } diff --git a/lib/llm/src/block_manager/pool/inactive.rs b/lib/llm/src/block_manager/pool/inactive.rs index 9c7e503a00..c4ed7fdf28 100644 --- a/lib/llm/src/block_manager/pool/inactive.rs +++ b/lib/llm/src/block_manager/pool/inactive.rs @@ -138,7 +138,7 @@ impl InactiveBlockPool { block.reset(); self.uninitialized_set.push_back(block); } - BlockState::Registered(state) => { + BlockState::Registered(state, _) => { let sequence_hash = state.sequence_hash(); self.insert_with_sequence_hash(block, sequence_hash); } @@ -499,6 +499,8 @@ pub(crate) mod tests { use super::*; + use std::sync::Mutex; + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] pub struct TestMetadata { priority: u32, @@ -615,7 +617,10 @@ pub(crate) mod tests { let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap(); let event_manager = NullEventManager::new(); - let mut registry = BlockRegistry::new(event_manager); + let mut registry = BlockRegistry::new( + event_manager, + GlobalRegistry::new(Mutex::new(HashMap::new())), + ); // Iterate through the generated TokenBlocks and the template Blocks, // setting the state and registering each one. @@ -657,7 +662,7 @@ pub(crate) mod tests { let matched_block_count = matched_blocks.len(); let event_manager = NullEventManager::new(); - let mut registry = BlockRegistry::new(event_manager); + let mut registry = BlockRegistry::new(event_manager, Arc::new(Mutex::new(HashMap::new()))); // all matched blocks should be in the complete or registered state for block in &mut matched_blocks { diff --git a/lib/llm/src/block_manager/pool/state.rs b/lib/llm/src/block_manager/pool/state.rs index 91bd19409b..6c6a0ae38a 100644 --- a/lib/llm/src/block_manager/pool/state.rs +++ b/lib/llm/src/block_manager/pool/state.rs @@ -24,11 +24,12 @@ impl State { fn new( event_manager: Arc, return_tx: tokio::sync::mpsc::UnboundedSender>, + global_pool: GlobalRegistry, ) -> Self { Self { active: ActiveBlockPool::new(), inactive: InactiveBlockPool::new(), - registry: BlockRegistry::new(event_manager.clone()), + registry: BlockRegistry::new(event_manager.clone(), global_pool), return_tx, event_manager, } @@ -88,7 +89,7 @@ impl State { return_rx: &mut tokio::sync::mpsc::UnboundedReceiver>, ) -> Block { while let Some(block) = return_rx.recv().await { - if matches!(block.state(), BlockState::Registered(handle) if handle.sequence_hash() == sequence_hash) + if matches!(block.state(), BlockState::Registered(handle, _) if handle.sequence_hash() == sequence_hash) { return block; } @@ -151,7 +152,7 @@ impl State { let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash) { - assert!(matches!(raw_block.state(), BlockState::Registered(_))); + assert!(matches!(raw_block.state(), BlockState::Registered(_, _))); MutableBlock::new(raw_block, self.return_tx.clone()) } else { // Attempt to register the block @@ -161,7 +162,10 @@ impl State { match result { Ok(handle) => { - publish_handles.take_handle(handle); + // Only create our publish handle if this block is new, and not transfered. + if let Some(handle) = handle { + publish_handles.take_handle(handle); + } block } Err(BlockRegistationError::BlockAlreadyRegistered(_)) => { @@ -222,7 +226,7 @@ impl State { }; // this assert allows us to skip the error checking on the active pool registration step - assert!(matches!(raw_block.state(), BlockState::Registered(_))); + assert!(matches!(raw_block.state(), BlockState::Registered(_, _))); let mutable = MutableBlock::new(raw_block, self.return_tx.clone()); @@ -255,9 +259,10 @@ impl ProgressEngine { ctrl_rx: tokio::sync::mpsc::UnboundedReceiver>, cancel_token: CancellationToken, blocks: Vec>, + global_pool: GlobalRegistry, ) -> Self { let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut state = State::::new(event_manager, return_tx); + let mut state = State::::new(event_manager, return_tx, global_pool); tracing::debug!(count = blocks.len(), "adding blocks to inactive pool"); state.inactive.add_blocks(blocks); diff --git a/lib/llm/src/block_manager/state.rs b/lib/llm/src/block_manager/state.rs index ea1920eef0..47cdc1109d 100644 --- a/lib/llm/src/block_manager/state.rs +++ b/lib/llm/src/block_manager/state.rs @@ -17,11 +17,11 @@ use super::*; use super::offload::OffloadManager; use super::{ - block::{Block, ImmutableBlock}, + block::{Block, GlobalRegistry, ImmutableBlock}, config::NixlOptions, }; use cudarc::driver::CudaStream; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use tokio::runtime::Handle; pub struct TransferContext { @@ -76,6 +76,8 @@ impl KvBlockManagerState { // Create a map of NIXL backends let mut nixl_backends: HashMap> = HashMap::new(); + let global_pool = Arc::new(Mutex::new(HashMap::new())); + // Create a NIXL agent if NIXL is enabled and instantiate requested backends // TODO: Build a map of NIXL backends to block pools/sets let nixl_agent = Arc::new(match config.runtime.nixl { @@ -138,6 +140,7 @@ impl KvBlockManagerState { next_block_set_idx, cancellation_token.clone(), worker_id, + global_pool.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } @@ -158,6 +161,7 @@ impl KvBlockManagerState { next_block_set_idx, cancellation_token.clone(), worker_id, + global_pool.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } else { @@ -177,6 +181,7 @@ impl KvBlockManagerState { next_block_set_idx, cancellation_token.clone(), worker_id, + global_pool.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } else { @@ -484,10 +489,12 @@ fn create_block_pool( block_set_idx: usize, cancellation_token: CancellationToken, worker_id: WorkerID, + global_pool: GlobalRegistry, ) -> Result<(BlockPool, Vec>)> { let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?; let pool = BlockPool::::builder() .cancel_token(cancellation_token) + .global_pool(global_pool) .build()?; Ok((pool, blocks)) } From 730459d40c0519fcb056a13ff893ed0a814ee7e2 Mon Sep 17 00:00:00 2001 From: jothomson Date: Fri, 16 May 2025 14:20:09 -0700 Subject: [PATCH 2/6] docs --- lib/llm/src/block_manager/block/registry.rs | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/lib/llm/src/block_manager/block/registry.rs b/lib/llm/src/block_manager/block/registry.rs index 652d86bc52..4a3499eaec 100644 --- a/lib/llm/src/block_manager/block/registry.rs +++ b/lib/llm/src/block_manager/block/registry.rs @@ -13,6 +13,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! # KV Cache Block Registration +//! +//! - This module is responsible for maintaining a registry of all blocks currently within a pool. +//! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks. +//! - The global registry is a mapping of sequences hashes to registration handles. If two blocks in different pools +//! have the same sequence hash, then they will share the same registration handle. The global registry is shared across all pools. +//! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are +//! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime. +//! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle. +//! +//! ## Workflow +//! +//! 1. When a block is registered into a pool, we create a unique block handle. +//! 2. We then check the global registry to see if the block already exists in any other pool. +//! 3. If it does, we use the existing registration handle. Otherwise, we create a new one. +//! 4. When the block handle is dropped, it means that the block is no longer in the pool. +//! 5. When the registration handle is dropped, it means that the block is no longer in any pool. + use std::{ collections::HashMap, sync::{Arc, Mutex, Weak}, @@ -80,6 +98,7 @@ impl BlockRegistry { BlockState::Complete(state) => { let sequence_hash = state.token_block().sequence_hash(); + // If an identical block already exists in this pool, return an error. if let Some(handle) = self.blocks.get(&sequence_hash) { if let Some(_handle) = handle.upgrade() { return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash)); @@ -91,25 +110,30 @@ impl BlockRegistry { let block_handle = Arc::new(()); let reg_handle = 'reg_block: { + // Now, check the global registry. let mut global_pool = self.global_pool.lock().unwrap(); + // If an identical block exists in other pool, use the same registration handle. if let Some(handle) = global_pool.get(&sequence_hash) { if let Some(handle) = handle.upgrade() { break 'reg_block handle; } } + // Otherwise, create a new registration handle. publish_handle = Some(Self::create_publish_handle( state.token_block(), self.event_manager.clone(), )); let reg_handle = publish_handle.as_ref().unwrap().remove_handle(); + // Insert the registration handle into the global registry. global_pool.insert(sequence_hash, Arc::downgrade(®_handle)); reg_handle }; + // Insert our block handle into the per-pool registry. self.blocks .insert(sequence_hash, Arc::downgrade(&block_handle)); From 2840aff7ec3df6b59fb7cb22ca9800b19b84af45 Mon Sep 17 00:00:00 2001 From: jothomson Date: Wed, 21 May 2025 17:00:05 -0700 Subject: [PATCH 3/6] Mild refactor + enable block deregistration from local and global pools --- lib/llm/src/block_manager/block/registry.rs | 118 ++++++++++++++------ lib/llm/src/block_manager/block/state.rs | 4 +- lib/llm/src/block_manager/pool.rs | 24 ++-- lib/llm/src/block_manager/pool/inactive.rs | 9 +- lib/llm/src/block_manager/pool/state.rs | 8 +- lib/llm/src/block_manager/state.rs | 14 +-- 6 files changed, 110 insertions(+), 67 deletions(-) diff --git a/lib/llm/src/block_manager/block/registry.rs b/lib/llm/src/block_manager/block/registry.rs index 4a3499eaec..a99a487f5f 100644 --- a/lib/llm/src/block_manager/block/registry.rs +++ b/lib/llm/src/block_manager/block/registry.rs @@ -42,6 +42,9 @@ use super::state::BlockState; use crate::tokens::{BlockHash, SequenceHash, TokenBlock}; use derive_getters::Getters; +use tokio::sync::mpsc; + +pub type GlobalRegistry = Arc>>>; #[derive(Debug, thiserror::Error)] pub enum BlockRegistationError { @@ -52,31 +55,78 @@ pub enum BlockRegistationError { InvalidState(String), } -/// Error returned when an attempt is made to unregister a block that is still active. -#[derive(Debug, thiserror::Error)] -#[error("Failed to unregister block: {0}")] -pub struct UnregisterFailure(SequenceHash); +/// A block entry is a handle to a block that is registered in the pool. +/// On drop, we need to notify the pool that the block has been unregistered. +/// This is different than the registration handle, which is only dropped when the block is no longer in ANY pool. +#[derive(Debug)] +pub struct BlockHandle { + sequence_hash: SequenceHash, + unregister_tx: mpsc::UnboundedSender, +} -pub type GlobalRegistry = Arc>>>; +impl BlockHandle { + pub fn new( + sequence_hash: SequenceHash, + unregister_tx: mpsc::UnboundedSender, + ) -> Self { + Self { + sequence_hash, + unregister_tx, + } + } +} + +impl Drop for BlockHandle { + fn drop(&mut self) { + let _ = self.unregister_tx.send(self.sequence_hash); + } +} -#[derive()] pub struct BlockRegistry { - blocks: HashMap>, + blocks: Arc>>>, event_manager: Arc, - global_pool: GlobalRegistry, + global_registry: GlobalRegistry, + unregister_tx: mpsc::UnboundedSender, } impl BlockRegistry { - pub fn new(event_manager: Arc, global_pool: GlobalRegistry) -> Self { + pub fn new(event_manager: Arc, global_registry: GlobalRegistry) -> Self { + let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel(); + + let blocks = Arc::new(Mutex::new(HashMap::new())); + + let blocks_clone = blocks.clone(); + let global_registry_clone = global_registry.clone(); + tokio::spawn(async move { + let blocks = blocks_clone; + let global_registry = global_registry_clone; + while let Some(sequence_hash) = unregister_rx.recv().await { + { + let mut blocks = blocks.lock().unwrap(); + blocks.remove(&sequence_hash); + } + + let mut global_registry = global_registry.lock().unwrap(); + + if let Some(entry) = global_registry.get(&sequence_hash) { + if entry.upgrade().is_none() { + global_registry.remove(&sequence_hash); + } + } + } + }); + Self { - blocks: HashMap::new(), + blocks, event_manager, - global_pool, + global_registry, + unregister_tx, } } pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool { - if let Some(handle) = self.blocks.get(&sequence_hash) { + let blocks = self.blocks.lock().unwrap(); + if let Some(handle) = blocks.get(&sequence_hash) { if let Some(_handle) = handle.upgrade() { return true; } @@ -99,22 +149,29 @@ impl BlockRegistry { BlockState::Complete(state) => { let sequence_hash = state.token_block().sequence_hash(); // If an identical block already exists in this pool, return an error. - if let Some(handle) = self.blocks.get(&sequence_hash) { - if let Some(_handle) = handle.upgrade() { - return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash)); + + { + let blocks = self.blocks.lock().unwrap(); + if let Some(handle) = blocks.get(&sequence_hash) { + if let Some(_handle) = handle.upgrade() { + return Err(BlockRegistationError::BlockAlreadyRegistered( + sequence_hash, + )); + } } } let mut publish_handle = None; - let block_handle = Arc::new(()); + let block_handle = + Arc::new(BlockHandle::new(sequence_hash, self.unregister_tx.clone())); let reg_handle = 'reg_block: { // Now, check the global registry. - let mut global_pool = self.global_pool.lock().unwrap(); + let mut global_registry = self.global_registry.lock().unwrap(); // If an identical block exists in other pool, use the same registration handle. - if let Some(handle) = global_pool.get(&sequence_hash) { + if let Some(handle) = global_registry.get(&sequence_hash) { if let Some(handle) = handle.upgrade() { break 'reg_block handle; } @@ -128,14 +185,16 @@ impl BlockRegistry { let reg_handle = publish_handle.as_ref().unwrap().remove_handle(); // Insert the registration handle into the global registry. - global_pool.insert(sequence_hash, Arc::downgrade(®_handle)); + global_registry.insert(sequence_hash, Arc::downgrade(®_handle)); reg_handle }; - // Insert our block handle into the per-pool registry. - self.blocks - .insert(sequence_hash, Arc::downgrade(&block_handle)); + { + let mut blocks = self.blocks.lock().unwrap(); + // Insert our block handle into the per-pool registry. + blocks.insert(sequence_hash, Arc::downgrade(&block_handle)); + } // Update the [BlockState] to [BlockState::Registered] let _ = std::mem::replace( @@ -151,21 +210,6 @@ impl BlockRegistry { } } - pub fn unregister_block( - &mut self, - sequence_hash: SequenceHash, - ) -> Result<(), UnregisterFailure> { - if let Some(handle) = self.blocks.get(&sequence_hash) { - if handle.upgrade().is_none() { - self.blocks.remove(&sequence_hash); - return Ok(()); - } else { - return Err(UnregisterFailure(sequence_hash)); - } - } - Ok(()) - } - fn create_publish_handle( token_block: &TokenBlock, event_manager: Arc, diff --git a/lib/llm/src/block_manager/block/state.rs b/lib/llm/src/block_manager/block/state.rs index 51dce7e128..5eb6ff5bff 100644 --- a/lib/llm/src/block_manager/block/state.rs +++ b/lib/llm/src/block_manager/block/state.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use derive_getters::Getters; -use super::registry::RegistrationHandle; +use super::registry::{BlockHandle, RegistrationHandle}; use super::Result; use crate::tokens::{PartialTokenBlock, SaltHash, Token, TokenBlock, Tokens}; @@ -30,7 +30,7 @@ pub enum BlockState { Reset, Partial(PartialState), Complete(CompleteState), - Registered(Arc, Arc<()>), + Registered(Arc, Arc), } impl BlockState { diff --git a/lib/llm/src/block_manager/pool.rs b/lib/llm/src/block_manager/pool.rs index bfd488a90b..09fbe5af7b 100644 --- a/lib/llm/src/block_manager/pool.rs +++ b/lib/llm/src/block_manager/pool.rs @@ -119,16 +119,16 @@ pub struct BlockPoolArgs { blocks: Vec>, #[builder(default)] - global_pool: GlobalRegistry, + global_registry: GlobalRegistry, } impl BlockPoolArgsBuilder { pub fn build(self) -> anyhow::Result> { let args = self.build_internal()?; - let (event_manager, cancel_token, blocks, global_pool) = args.dissolve(); + let (event_manager, cancel_token, blocks, global_registry) = args.dissolve(); tracing::info!("building block pool"); - let pool = BlockPool::new(event_manager, cancel_token, blocks, global_pool); + let pool = BlockPool::new(event_manager, cancel_token, blocks, global_registry); Ok(pool) } @@ -204,10 +204,10 @@ impl BlockPool { event_manager: Arc, cancel_token: CancellationToken, blocks: Vec>, - global_pool: GlobalRegistry, + global_registry: GlobalRegistry, ) -> Self { let (pool, progress_engine) = - Self::with_progress_engine(event_manager, cancel_token, blocks, global_pool); + Self::with_progress_engine(event_manager, cancel_token, blocks, global_registry); // pool.runtime.handle().spawn(async move { // let mut progress_engine = progress_engine; @@ -244,7 +244,7 @@ impl BlockPool { event_manager: Arc, cancel_token: CancellationToken, blocks: Vec>, - global_pool: GlobalRegistry, + global_registry: GlobalRegistry, ) -> (Self, ProgressEngine) { let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel(); let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel(); @@ -255,7 +255,7 @@ impl BlockPool { ctrl_rx, cancel_token, blocks, - global_pool, + global_registry, ); ( @@ -480,9 +480,13 @@ mod tests { self, ) -> anyhow::Result<(BlockPool, ProgressEngine)> { let args = self.build_internal()?; - let (event_manager, cancel_token, blocks, global_pool) = args.dissolve(); - let (pool, progress_engine) = - BlockPool::with_progress_engine(event_manager, cancel_token, blocks, global_pool); + let (event_manager, cancel_token, blocks, global_registry) = args.dissolve(); + let (pool, progress_engine) = BlockPool::with_progress_engine( + event_manager, + cancel_token, + blocks, + global_registry, + ); Ok((pool, progress_engine)) } diff --git a/lib/llm/src/block_manager/pool/inactive.rs b/lib/llm/src/block_manager/pool/inactive.rs index c4ed7fdf28..7d887e5a73 100644 --- a/lib/llm/src/block_manager/pool/inactive.rs +++ b/lib/llm/src/block_manager/pool/inactive.rs @@ -499,8 +499,6 @@ pub(crate) mod tests { use super::*; - use std::sync::Mutex; - #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] pub struct TestMetadata { priority: u32, @@ -617,10 +615,7 @@ pub(crate) mod tests { let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap(); let event_manager = NullEventManager::new(); - let mut registry = BlockRegistry::new( - event_manager, - GlobalRegistry::new(Mutex::new(HashMap::new())), - ); + let mut registry = BlockRegistry::new(event_manager, GlobalRegistry::default()); // Iterate through the generated TokenBlocks and the template Blocks, // setting the state and registering each one. @@ -662,7 +657,7 @@ pub(crate) mod tests { let matched_block_count = matched_blocks.len(); let event_manager = NullEventManager::new(); - let mut registry = BlockRegistry::new(event_manager, Arc::new(Mutex::new(HashMap::new()))); + let mut registry = BlockRegistry::new(event_manager, GlobalRegistry::default()); // all matched blocks should be in the complete or registered state for block in &mut matched_blocks { diff --git a/lib/llm/src/block_manager/pool/state.rs b/lib/llm/src/block_manager/pool/state.rs index 6c6a0ae38a..46bf18df75 100644 --- a/lib/llm/src/block_manager/pool/state.rs +++ b/lib/llm/src/block_manager/pool/state.rs @@ -24,12 +24,12 @@ impl State { fn new( event_manager: Arc, return_tx: tokio::sync::mpsc::UnboundedSender>, - global_pool: GlobalRegistry, + global_registry: GlobalRegistry, ) -> Self { Self { active: ActiveBlockPool::new(), inactive: InactiveBlockPool::new(), - registry: BlockRegistry::new(event_manager.clone(), global_pool), + registry: BlockRegistry::new(event_manager.clone(), global_registry), return_tx, event_manager, } @@ -259,10 +259,10 @@ impl ProgressEngine { ctrl_rx: tokio::sync::mpsc::UnboundedReceiver>, cancel_token: CancellationToken, blocks: Vec>, - global_pool: GlobalRegistry, + global_registry: GlobalRegistry, ) -> Self { let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut state = State::::new(event_manager, return_tx, global_pool); + let mut state = State::::new(event_manager, return_tx, global_registry); tracing::debug!(count = blocks.len(), "adding blocks to inactive pool"); state.inactive.add_blocks(blocks); diff --git a/lib/llm/src/block_manager/state.rs b/lib/llm/src/block_manager/state.rs index 47cdc1109d..51de6bd87a 100644 --- a/lib/llm/src/block_manager/state.rs +++ b/lib/llm/src/block_manager/state.rs @@ -21,7 +21,7 @@ use super::{ config::NixlOptions, }; use cudarc::driver::CudaStream; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use tokio::runtime::Handle; pub struct TransferContext { @@ -76,7 +76,7 @@ impl KvBlockManagerState { // Create a map of NIXL backends let mut nixl_backends: HashMap> = HashMap::new(); - let global_pool = Arc::new(Mutex::new(HashMap::new())); + let global_registry = GlobalRegistry::default(); // Create a NIXL agent if NIXL is enabled and instantiate requested backends // TODO: Build a map of NIXL backends to block pools/sets @@ -140,7 +140,7 @@ impl KvBlockManagerState { next_block_set_idx, cancellation_token.clone(), worker_id, - global_pool.clone(), + global_registry.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } @@ -161,7 +161,7 @@ impl KvBlockManagerState { next_block_set_idx, cancellation_token.clone(), worker_id, - global_pool.clone(), + global_registry.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } else { @@ -181,7 +181,7 @@ impl KvBlockManagerState { next_block_set_idx, cancellation_token.clone(), worker_id, - global_pool.clone(), + global_registry.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } else { @@ -489,12 +489,12 @@ fn create_block_pool( block_set_idx: usize, cancellation_token: CancellationToken, worker_id: WorkerID, - global_pool: GlobalRegistry, + global_registry: GlobalRegistry, ) -> Result<(BlockPool, Vec>)> { let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?; let pool = BlockPool::::builder() .cancel_token(cancellation_token) - .global_pool(global_pool) + .global_registry(global_registry) .build()?; Ok((pool, blocks)) } From b1528ef01fbfccf3aff27ae90502d931fb3305f1 Mon Sep 17 00:00:00 2001 From: jothomson Date: Thu, 22 May 2025 10:30:41 -0700 Subject: [PATCH 4/6] Integrate with async runtime arg --- lib/llm/src/block_manager/block/registry.rs | 10 ++++-- lib/llm/src/block_manager/pool.rs | 38 +++++++++++++++++---- lib/llm/src/block_manager/pool/inactive.rs | 28 ++++++++++++--- lib/llm/src/block_manager/pool/state.rs | 7 ++-- lib/llm/src/block_manager/state.rs | 23 ++++++++----- 5 files changed, 81 insertions(+), 25 deletions(-) diff --git a/lib/llm/src/block_manager/block/registry.rs b/lib/llm/src/block_manager/block/registry.rs index a99a487f5f..60337d1a6b 100644 --- a/lib/llm/src/block_manager/block/registry.rs +++ b/lib/llm/src/block_manager/block/registry.rs @@ -42,7 +42,7 @@ use super::state::BlockState; use crate::tokens::{BlockHash, SequenceHash, TokenBlock}; use derive_getters::Getters; -use tokio::sync::mpsc; +use tokio::{runtime::Handle, sync::mpsc}; pub type GlobalRegistry = Arc>>>; @@ -90,14 +90,18 @@ pub struct BlockRegistry { } impl BlockRegistry { - pub fn new(event_manager: Arc, global_registry: GlobalRegistry) -> Self { + pub fn new( + event_manager: Arc, + global_registry: GlobalRegistry, + async_runtime: Handle, + ) -> Self { let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel(); let blocks = Arc::new(Mutex::new(HashMap::new())); let blocks_clone = blocks.clone(); let global_registry_clone = global_registry.clone(); - tokio::spawn(async move { + async_runtime.spawn(async move { let blocks = blocks_clone; let global_registry = global_registry_clone; while let Some(sequence_hash) = unregister_rx.recv().await { diff --git a/lib/llm/src/block_manager/pool.rs b/lib/llm/src/block_manager/pool.rs index 09fbe5af7b..f138026071 100644 --- a/lib/llm/src/block_manager/pool.rs +++ b/lib/llm/src/block_manager/pool.rs @@ -81,6 +81,7 @@ use std::{ collections::{BTreeSet, HashMap, VecDeque}, sync::{Arc, Weak}, }; +use tokio::runtime::Handle; use tokio_util::sync::CancellationToken; use dynamo_runtime::Result; @@ -120,15 +121,24 @@ pub struct BlockPoolArgs { #[builder(default)] global_registry: GlobalRegistry, + + #[builder(default = "Handle::current()")] + async_runtime: Handle, } impl BlockPoolArgsBuilder { pub fn build(self) -> anyhow::Result> { let args = self.build_internal()?; - let (event_manager, cancel_token, blocks, global_registry) = args.dissolve(); + let (event_manager, cancel_token, blocks, global_registry, async_runtime) = args.dissolve(); tracing::info!("building block pool"); - let pool = BlockPool::new(event_manager, cancel_token, blocks, global_registry); + let pool = BlockPool::new( + event_manager, + cancel_token, + blocks, + global_registry, + async_runtime, + ); Ok(pool) } @@ -205,9 +215,15 @@ impl BlockPool { cancel_token: CancellationToken, blocks: Vec>, global_registry: GlobalRegistry, + async_runtime: Handle, ) -> Self { - let (pool, progress_engine) = - Self::with_progress_engine(event_manager, cancel_token, blocks, global_registry); + let (pool, progress_engine) = Self::with_progress_engine( + event_manager, + cancel_token, + blocks, + global_registry, + async_runtime, + ); // pool.runtime.handle().spawn(async move { // let mut progress_engine = progress_engine; @@ -245,6 +261,7 @@ impl BlockPool { cancel_token: CancellationToken, blocks: Vec>, global_registry: GlobalRegistry, + async_runtime: Handle, ) -> (Self, ProgressEngine) { let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel(); let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel(); @@ -256,6 +273,7 @@ impl BlockPool { cancel_token, blocks, global_registry, + async_runtime, ); ( @@ -480,12 +498,14 @@ mod tests { self, ) -> anyhow::Result<(BlockPool, ProgressEngine)> { let args = self.build_internal()?; - let (event_manager, cancel_token, blocks, global_registry) = args.dissolve(); + let (event_manager, cancel_token, blocks, global_registry, async_runtime) = + args.dissolve(); let (pool, progress_engine) = BlockPool::with_progress_engine( event_manager, cancel_token, blocks, global_registry, + async_runtime, ); Ok((pool, progress_engine)) @@ -576,8 +596,14 @@ mod tests { .into_blocks() .unwrap(); + let async_runtime = tokio::runtime::Runtime::new().unwrap(); + // Create the BlockPool and add the blocks - let pool = BlockPool::builder().blocks(blocks).build().unwrap(); + let pool = BlockPool::builder() + .blocks(blocks) + .async_runtime(async_runtime.handle().clone()) + .build() + .unwrap(); // All blocks should be in the Reset/Empty state // No blocks should match the expected sequence hash diff --git a/lib/llm/src/block_manager/pool/inactive.rs b/lib/llm/src/block_manager/pool/inactive.rs index 7d887e5a73..3b3276e571 100644 --- a/lib/llm/src/block_manager/pool/inactive.rs +++ b/lib/llm/src/block_manager/pool/inactive.rs @@ -603,6 +603,7 @@ pub(crate) mod tests { pub fn create_blocks( tokens: Tokens, block_size: usize, + async_runtime: Handle, ) -> Vec> { let (token_blocks, _partial_token_block) = tokens.into_sequence(block_size, None).into_parts(); @@ -615,7 +616,8 @@ pub(crate) mod tests { let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap(); let event_manager = NullEventManager::new(); - let mut registry = BlockRegistry::new(event_manager, GlobalRegistry::default()); + let mut registry = + BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime); // Iterate through the generated TokenBlocks and the template Blocks, // setting the state and registering each one. @@ -645,6 +647,7 @@ pub(crate) mod tests { tokens: Tokens, block_size: usize, pool: &mut InactiveBlockPool, + async_runtime: Handle, ) -> (Vec>, usize) { let (mut token_blocks, _partial_token_block) = tokens.into_sequence(block_size, None).into_parts(); @@ -657,7 +660,8 @@ pub(crate) mod tests { let matched_block_count = matched_blocks.len(); let event_manager = NullEventManager::new(); - let mut registry = BlockRegistry::new(event_manager, GlobalRegistry::default()); + let mut registry = + BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime); // all matched blocks should be in the complete or registered state for block in &mut matched_blocks { @@ -697,6 +701,8 @@ pub(crate) mod tests { fn test_block_pool_lifecycle() { dynamo_runtime::logging::init(); + let async_runtime = tokio::runtime::Runtime::new().unwrap(); + const PAGE_SIZE: usize = 2; let mut pool = create_block_pool(10); @@ -715,7 +721,12 @@ pub(crate) mod tests { let tokens = create_token_sequence(&[1, 2, 3, 4]); - let (blocks, matched_block_count) = acquire_blocks(tokens.clone(), PAGE_SIZE, &mut pool); + let (blocks, matched_block_count) = acquire_blocks( + tokens.clone(), + PAGE_SIZE, + &mut pool, + async_runtime.handle().clone(), + ); assert_eq!(blocks.len(), 2); assert_eq!(matched_block_count, 0); assert_eq!(pool.available_blocks(), 8); @@ -725,7 +736,12 @@ pub(crate) mod tests { assert_eq!(pool.total_blocks(), 10); assert_eq!(pool.available_blocks(), 10); - let (blocks, matched_block_count) = acquire_blocks(tokens.clone(), PAGE_SIZE, &mut pool); + let (blocks, matched_block_count) = acquire_blocks( + tokens.clone(), + PAGE_SIZE, + &mut pool, + async_runtime.handle().clone(), + ); assert_eq!(blocks.len(), 2); assert_eq!(matched_block_count, 2); assert_eq!(pool.available_blocks(), 8); @@ -745,9 +761,11 @@ pub(crate) mod tests { fn test_basic_sequence_matching() { let mut pool = InactiveBlockPool::new(); + let async_runtime = tokio::runtime::Runtime::new().unwrap(); + // Create a sequence of 4 tokens split into blocks of 2 let sequence = create_token_sequence(&[1, 2, 3, 4]); - let blocks = create_blocks(sequence, 2); + let blocks = create_blocks(sequence, 2, async_runtime.handle().clone()); assert_eq!(blocks.len(), 2); // Match the blocks in sequence diff --git a/lib/llm/src/block_manager/pool/state.rs b/lib/llm/src/block_manager/pool/state.rs index 46bf18df75..8197d0ac8f 100644 --- a/lib/llm/src/block_manager/pool/state.rs +++ b/lib/llm/src/block_manager/pool/state.rs @@ -25,11 +25,12 @@ impl State { event_manager: Arc, return_tx: tokio::sync::mpsc::UnboundedSender>, global_registry: GlobalRegistry, + async_runtime: Handle, ) -> Self { Self { active: ActiveBlockPool::new(), inactive: InactiveBlockPool::new(), - registry: BlockRegistry::new(event_manager.clone(), global_registry), + registry: BlockRegistry::new(event_manager.clone(), global_registry, async_runtime), return_tx, event_manager, } @@ -260,9 +261,11 @@ impl ProgressEngine { cancel_token: CancellationToken, blocks: Vec>, global_registry: GlobalRegistry, + async_runtime: Handle, ) -> Self { let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut state = State::::new(event_manager, return_tx, global_registry); + let mut state = + State::::new(event_manager, return_tx, global_registry, async_runtime); tracing::debug!(count = blocks.len(), "adding blocks to inactive pool"); state.inactive.add_blocks(blocks); diff --git a/lib/llm/src/block_manager/state.rs b/lib/llm/src/block_manager/state.rs index 51de6bd87a..2271fb041c 100644 --- a/lib/llm/src/block_manager/state.rs +++ b/lib/llm/src/block_manager/state.rs @@ -125,6 +125,14 @@ impl KvBlockManagerState { let mut next_block_set_idx = 0; let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id); + let async_rt_handle = match config.runtime.async_runtime { + Some(rt) => rt.handle().clone(), + None => match Handle::try_current() { + Ok(handle) => handle, + Err(e) => anyhow::bail!(e), + }, + }; + let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout { if nixl_agent.is_none() { tracing::warn!("NIXL is disabled; will not allocate disk blocks."); @@ -141,6 +149,7 @@ impl KvBlockManagerState { cancellation_token.clone(), worker_id, global_registry.clone(), + async_rt_handle.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } @@ -162,6 +171,7 @@ impl KvBlockManagerState { cancellation_token.clone(), worker_id, global_registry.clone(), + async_rt_handle.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } else { @@ -182,6 +192,7 @@ impl KvBlockManagerState { cancellation_token.clone(), worker_id, global_registry.clone(), + async_rt_handle.clone(), )?; (Some(Arc::new(pool)), Some(blocks)) } else { @@ -195,20 +206,12 @@ impl KvBlockManagerState { local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?); } - let offload_async_rt_handle = match config.runtime.async_runtime { - Some(rt) => rt.handle().clone(), - None => match Handle::try_current() { - Ok(handle) => handle, - Err(e) => anyhow::bail!(e), - }, - }; - let offload_manager = OffloadManager::new( disk_pool.clone(), host_pool.clone(), device_pool.clone(), nixl_agent.clone(), - offload_async_rt_handle, + async_rt_handle, )?; let state = Arc::new(Self { @@ -490,11 +493,13 @@ fn create_block_pool( cancellation_token: CancellationToken, worker_id: WorkerID, global_registry: GlobalRegistry, + async_runtime: Handle, ) -> Result<(BlockPool, Vec>)> { let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?; let pool = BlockPool::::builder() .cancel_token(cancellation_token) .global_registry(global_registry) + .async_runtime(async_runtime) .build()?; Ok((pool, blocks)) } From 403dd25973f55a6cea08d0b3157008af527403e8 Mon Sep 17 00:00:00 2001 From: jothomson Date: Tue, 27 May 2025 13:45:28 -0700 Subject: [PATCH 5/6] Fix race conditions --- lib/llm/src/block_manager/block/registry.rs | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/lib/llm/src/block_manager/block/registry.rs b/lib/llm/src/block_manager/block/registry.rs index 60337d1a6b..9b826d4b1b 100644 --- a/lib/llm/src/block_manager/block/registry.rs +++ b/lib/llm/src/block_manager/block/registry.rs @@ -152,16 +152,12 @@ impl BlockRegistry { BlockState::Complete(state) => { let sequence_hash = state.token_block().sequence_hash(); - // If an identical block already exists in this pool, return an error. + let mut blocks = self.blocks.lock().unwrap(); - { - let blocks = self.blocks.lock().unwrap(); - if let Some(handle) = blocks.get(&sequence_hash) { - if let Some(_handle) = handle.upgrade() { - return Err(BlockRegistationError::BlockAlreadyRegistered( - sequence_hash, - )); - } + // If an identical block already exists in this pool, return an error. + if let Some(handle) = blocks.get(&sequence_hash) { + if let Some(_handle) = handle.upgrade() { + return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash)); } } @@ -194,11 +190,7 @@ impl BlockRegistry { reg_handle }; - { - let mut blocks = self.blocks.lock().unwrap(); - // Insert our block handle into the per-pool registry. - blocks.insert(sequence_hash, Arc::downgrade(&block_handle)); - } + blocks.insert(sequence_hash, Arc::downgrade(&block_handle)); // Update the [BlockState] to [BlockState::Registered] let _ = std::mem::replace( From fcedaa318cc70a1607761c3f9dc2ba73fe6d5305 Mon Sep 17 00:00:00 2001 From: jothomson Date: Thu, 29 May 2025 08:58:19 -0700 Subject: [PATCH 6/6] little fix --- lib/llm/src/block_manager/block/registry.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/llm/src/block_manager/block/registry.rs b/lib/llm/src/block_manager/block/registry.rs index 9b826d4b1b..a9a915d98a 100644 --- a/lib/llm/src/block_manager/block/registry.rs +++ b/lib/llm/src/block_manager/block/registry.rs @@ -97,7 +97,8 @@ impl BlockRegistry { ) -> Self { let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel(); - let blocks = Arc::new(Mutex::new(HashMap::new())); + let blocks: Arc>>> = + Arc::new(Mutex::new(HashMap::new())); let blocks_clone = blocks.clone(); let global_registry_clone = global_registry.clone(); @@ -107,7 +108,12 @@ impl BlockRegistry { while let Some(sequence_hash) = unregister_rx.recv().await { { let mut blocks = blocks.lock().unwrap(); - blocks.remove(&sequence_hash); + + if let Some(handle) = blocks.get(&sequence_hash) { + if handle.upgrade().is_none() { + blocks.remove(&sequence_hash); + } + } } let mut global_registry = global_registry.lock().unwrap();