Skip to content

Commit 3d40a69

Browse files
authored
feat: Restructure kv manager block registration (#1093)
1 parent 7d0c938 commit 3d40a69

File tree

9 files changed

+261
-84
lines changed

9 files changed

+261
-84
lines changed

lib/llm/src/block_manager/block.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ pub mod view;
2121
pub use crate::tokens::TokenBlockError;
2222
pub use anyhow::Result;
2323
use nixl_sys::NixlDescriptor;
24+
25+
pub use registry::{GlobalRegistry, RegistrationHandle};
2426
pub use state::{BlockState, BlockStateInvalid};
2527

2628
use crate::block_manager::{
@@ -176,7 +178,7 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
176178
pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> {
177179
match self.state() {
178180
BlockState::Complete(state) => Ok(state.token_block().sequence_hash()),
179-
BlockState::Registered(state) => Ok(state.sequence_hash()),
181+
BlockState::Registered(state, _) => Ok(state.sequence_hash()),
180182
_ => Err(BlockError::InvalidState(
181183
"Block is not complete".to_string(),
182184
)),
@@ -250,14 +252,14 @@ pub(crate) trait PrivateBlockExt {
250252
fn register(
251253
&mut self,
252254
registry: &mut registry::BlockRegistry,
253-
) -> Result<PublishHandle, registry::BlockRegistationError>;
255+
) -> Result<Option<PublishHandle>, registry::BlockRegistationError>;
254256
}
255257

256258
impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> {
257259
fn register(
258260
&mut self,
259261
registry: &mut registry::BlockRegistry,
260-
) -> Result<PublishHandle, registry::BlockRegistationError> {
262+
) -> Result<Option<PublishHandle>, registry::BlockRegistationError> {
261263
registry.register_block(&mut self.state)
262264
}
263265
}

lib/llm/src/block_manager/block/registry.rs

Lines changed: 131 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,27 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16+
//! # KV Cache Block Registration
17+
//!
18+
//! - This module is responsible for maintaining a registry of all blocks currently within a pool.
19+
//! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks.
20+
//! - The global registry is a mapping of sequences hashes to registration handles. If two blocks in different pools
21+
//! have the same sequence hash, then they will share the same registration handle. The global registry is shared across all pools.
22+
//! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are
23+
//! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime.
24+
//! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle.
25+
//!
26+
//! ## Workflow
27+
//!
28+
//! 1. When a block is registered into a pool, we create a unique block handle.
29+
//! 2. We then check the global registry to see if the block already exists in any other pool.
30+
//! 3. If it does, we use the existing registration handle. Otherwise, we create a new one.
31+
//! 4. When the block handle is dropped, it means that the block is no longer in the pool.
32+
//! 5. When the registration handle is dropped, it means that the block is no longer in any pool.
33+
1634
use std::{
1735
collections::HashMap,
18-
sync::{Arc, Weak},
36+
sync::{Arc, Mutex, Weak},
1937
};
2038

2139
use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
@@ -24,6 +42,9 @@ use super::state::BlockState;
2442
use crate::tokens::{BlockHash, SequenceHash, TokenBlock};
2543

2644
use derive_getters::Getters;
45+
use tokio::{runtime::Handle, sync::mpsc};
46+
47+
pub type GlobalRegistry = Arc<Mutex<HashMap<SequenceHash, Weak<RegistrationHandle>>>>;
2748

2849
#[derive(Debug, thiserror::Error)]
2950
pub enum BlockRegistationError {
@@ -34,27 +55,88 @@ pub enum BlockRegistationError {
3455
InvalidState(String),
3556
}
3657

37-
/// Error returned when an attempt is made to unregister a block that is still active.
38-
#[derive(Debug, thiserror::Error)]
39-
#[error("Failed to unregister block: {0}")]
40-
pub struct UnregisterFailure(SequenceHash);
58+
/// A block entry is a handle to a block that is registered in the pool.
59+
/// On drop, we need to notify the pool that the block has been unregistered.
60+
/// This is different than the registration handle, which is only dropped when the block is no longer in ANY pool.
61+
#[derive(Debug)]
62+
pub struct BlockHandle {
63+
sequence_hash: SequenceHash,
64+
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
65+
}
66+
67+
impl BlockHandle {
68+
pub fn new(
69+
sequence_hash: SequenceHash,
70+
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
71+
) -> Self {
72+
Self {
73+
sequence_hash,
74+
unregister_tx,
75+
}
76+
}
77+
}
78+
79+
impl Drop for BlockHandle {
80+
fn drop(&mut self) {
81+
let _ = self.unregister_tx.send(self.sequence_hash);
82+
}
83+
}
4184

42-
#[derive()]
4385
pub struct BlockRegistry {
44-
blocks: HashMap<SequenceHash, Weak<RegistrationHandle>>,
86+
blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>>,
4587
event_manager: Arc<dyn EventManager>,
88+
global_registry: GlobalRegistry,
89+
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
4690
}
4791

4892
impl BlockRegistry {
49-
pub fn new(event_manager: Arc<dyn EventManager>) -> Self {
93+
pub fn new(
94+
event_manager: Arc<dyn EventManager>,
95+
global_registry: GlobalRegistry,
96+
async_runtime: Handle,
97+
) -> Self {
98+
let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel();
99+
100+
let blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>> =
101+
Arc::new(Mutex::new(HashMap::new()));
102+
103+
let blocks_clone = blocks.clone();
104+
let global_registry_clone = global_registry.clone();
105+
async_runtime.spawn(async move {
106+
let blocks = blocks_clone;
107+
let global_registry = global_registry_clone;
108+
while let Some(sequence_hash) = unregister_rx.recv().await {
109+
{
110+
let mut blocks = blocks.lock().unwrap();
111+
112+
if let Some(handle) = blocks.get(&sequence_hash) {
113+
if handle.upgrade().is_none() {
114+
blocks.remove(&sequence_hash);
115+
}
116+
}
117+
}
118+
119+
let mut global_registry = global_registry.lock().unwrap();
120+
121+
if let Some(entry) = global_registry.get(&sequence_hash) {
122+
if entry.upgrade().is_none() {
123+
global_registry.remove(&sequence_hash);
124+
}
125+
}
126+
}
127+
});
128+
50129
Self {
51-
blocks: HashMap::new(),
130+
blocks,
52131
event_manager,
132+
global_registry,
133+
unregister_tx,
53134
}
54135
}
55136

56137
pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool {
57-
if let Some(handle) = self.blocks.get(&sequence_hash) {
138+
let blocks = self.blocks.lock().unwrap();
139+
if let Some(handle) = blocks.get(&sequence_hash) {
58140
if let Some(_handle) = handle.upgrade() {
59141
return true;
60142
}
@@ -65,7 +147,7 @@ impl BlockRegistry {
65147
pub fn register_block(
66148
&mut self,
67149
block_state: &mut BlockState,
68-
) -> Result<PublishHandle, BlockRegistationError> {
150+
) -> Result<Option<PublishHandle>, BlockRegistationError> {
69151
match block_state {
70152
BlockState::Reset => Err(BlockRegistationError::InvalidState(
71153
"Block is in Reset state".to_string(),
@@ -76,47 +158,60 @@ impl BlockRegistry {
76158

77159
BlockState::Complete(state) => {
78160
let sequence_hash = state.token_block().sequence_hash();
79-
if let Some(handle) = self.blocks.get(&sequence_hash) {
161+
let mut blocks = self.blocks.lock().unwrap();
162+
163+
// If an identical block already exists in this pool, return an error.
164+
if let Some(handle) = blocks.get(&sequence_hash) {
80165
if let Some(_handle) = handle.upgrade() {
81166
return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash));
82167
}
83168
}
84169

85-
// Create the [RegistrationHandle] and [PublishHandle]
86-
let publish_handle =
87-
Self::create_publish_handle(state.token_block(), self.event_manager.clone());
88-
let reg_handle = publish_handle.remove_handle();
170+
let mut publish_handle = None;
171+
172+
let block_handle =
173+
Arc::new(BlockHandle::new(sequence_hash, self.unregister_tx.clone()));
174+
175+
let reg_handle = 'reg_block: {
176+
// Now, check the global registry.
177+
let mut global_registry = self.global_registry.lock().unwrap();
178+
179+
// If an identical block exists in other pool, use the same registration handle.
180+
if let Some(handle) = global_registry.get(&sequence_hash) {
181+
if let Some(handle) = handle.upgrade() {
182+
break 'reg_block handle;
183+
}
184+
}
185+
186+
// Otherwise, create a new registration handle.
187+
publish_handle = Some(Self::create_publish_handle(
188+
state.token_block(),
189+
self.event_manager.clone(),
190+
));
191+
let reg_handle = publish_handle.as_ref().unwrap().remove_handle();
192+
193+
// Insert the registration handle into the global registry.
194+
global_registry.insert(sequence_hash, Arc::downgrade(&reg_handle));
89195

90-
// Insert the [RegistrationHandle] into the registry
91-
self.blocks
92-
.insert(sequence_hash, Arc::downgrade(&reg_handle));
196+
reg_handle
197+
};
198+
199+
blocks.insert(sequence_hash, Arc::downgrade(&block_handle));
93200

94201
// Update the [BlockState] to [BlockState::Registered]
95-
let _ = std::mem::replace(block_state, BlockState::Registered(reg_handle));
202+
let _ = std::mem::replace(
203+
block_state,
204+
BlockState::Registered(reg_handle, block_handle),
205+
);
96206

97207
Ok(publish_handle)
98208
}
99-
BlockState::Registered(registered) => Err(
209+
BlockState::Registered(registered, _) => Err(
100210
BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()),
101211
),
102212
}
103213
}
104214

105-
pub fn unregister_block(
106-
&mut self,
107-
sequence_hash: SequenceHash,
108-
) -> Result<(), UnregisterFailure> {
109-
if let Some(handle) = self.blocks.get(&sequence_hash) {
110-
if handle.upgrade().is_none() {
111-
self.blocks.remove(&sequence_hash);
112-
return Ok(());
113-
} else {
114-
return Err(UnregisterFailure(sequence_hash));
115-
}
116-
}
117-
Ok(())
118-
}
119-
120215
fn create_publish_handle(
121216
token_block: &TokenBlock,
122217
event_manager: Arc<dyn EventManager>,

lib/llm/src/block_manager/block/state.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use std::sync::Arc;
1717

1818
use derive_getters::Getters;
1919

20-
use super::registry::RegistrationHandle;
20+
use super::registry::{BlockHandle, RegistrationHandle};
2121
use super::Result;
2222
use crate::tokens::{PartialTokenBlock, SaltHash, Token, TokenBlock, Tokens};
2323

@@ -30,7 +30,7 @@ pub enum BlockState {
3030
Reset,
3131
Partial(PartialState),
3232
Complete(CompleteState),
33-
Registered(Arc<RegistrationHandle>),
33+
Registered(Arc<RegistrationHandle>, Arc<BlockHandle>),
3434
}
3535

3636
impl BlockState {
@@ -109,7 +109,7 @@ impl BlockState {
109109
BlockState::Reset => Some(0),
110110
BlockState::Partial(state) => Some(state.block.len()),
111111
BlockState::Complete(state) => Some(state.token_block.tokens().len()),
112-
BlockState::Registered(_) => None,
112+
BlockState::Registered(_, _) => None,
113113
}
114114
}
115115

@@ -126,15 +126,15 @@ impl BlockState {
126126
match self {
127127
BlockState::Reset => true,
128128
BlockState::Partial(state) => state.block.is_empty(),
129-
BlockState::Complete(_) => false, // Always full
130-
BlockState::Registered(_) => false, // Always full
129+
BlockState::Complete(_) => false, // Always full
130+
BlockState::Registered(_, _) => false, // Always full
131131
}
132132
}
133133

134134
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
135135
pub fn tokens(&self) -> Option<&Tokens> {
136136
match self {
137-
BlockState::Reset | BlockState::Registered(_) => None,
137+
BlockState::Reset | BlockState::Registered(_, _) => None,
138138
BlockState::Partial(state) => Some(state.block.tokens()),
139139
BlockState::Complete(state) => Some(state.token_block.tokens()),
140140
}
@@ -147,12 +147,12 @@ impl BlockState {
147147

148148
/// Returns true if the block is in the complete or registered state
149149
pub fn is_complete(&self) -> bool {
150-
matches!(self, BlockState::Complete(_) | BlockState::Registered(_))
150+
matches!(self, BlockState::Complete(_) | BlockState::Registered(_, _))
151151
}
152152

153153
/// Returns true if the block is in the registered state
154154
pub fn is_registered(&self) -> bool {
155-
matches!(self, BlockState::Registered(_state))
155+
matches!(self, BlockState::Registered(_state, _))
156156
}
157157
}
158158

lib/llm/src/block_manager/offload.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
334334
priority: u64,
335335
) -> core::result::Result<(), BlockPoolError> {
336336
match block.state() {
337-
BlockState::Registered(_) => {}
337+
BlockState::Registered(_, _) => {}
338338
_ => {
339339
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
340340
"Block is not registered.".to_string(),
@@ -397,7 +397,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
397397
) -> BlockResult<DeviceStorage, Metadata> {
398398
for block in &blocks {
399399
match block.state() {
400-
BlockState::Registered(_) => {}
400+
BlockState::Registered(_, _) => {}
401401
_ => {
402402
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
403403
"Block is not registered.".to_string(),
@@ -857,7 +857,7 @@ mod tests {
857857
// Check that the block is registered.
858858
assert!(matches!(
859859
onboarded_blocks[0].state(),
860-
BlockState::Registered(_)
860+
BlockState::Registered(_, _)
861861
));
862862

863863
check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
@@ -940,7 +940,7 @@ mod tests {
940940
);
941941
assert!(matches!(
942942
onboarded_blocks[0].state(),
943-
BlockState::Registered(_)
943+
BlockState::Registered(_, _)
944944
));
945945

946946
check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;

lib/llm/src/block_manager/offload/pending.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
118118
target: &mut MutableBlock<Target, Metadata>,
119119
) -> Result<()> {
120120
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
121-
if let BlockState::Registered(reg_handle) = source.state() {
121+
if let BlockState::Registered(reg_handle, _) = source.state() {
122122
// Bring the block back to the 'Reset' state.
123123
target.reset();
124124
// Transfer metadata.

0 commit comments

Comments
 (0)