Skip to content

Commit 0cb8227

Browse files
committed
Restructure block registration
1 parent 34f3fc6 commit 0cb8227

File tree

8 files changed

+100
-46
lines changed

8 files changed

+100
-46
lines changed

lib/llm/src/block_manager/block.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ pub mod view;
2121
pub use crate::tokens::TokenBlockError;
2222
pub use anyhow::Result;
2323
use nixl_sys::NixlDescriptor;
24+
25+
pub use registry::{GlobalRegistry, RegistrationHandle};
2426
pub use state::{BlockState, BlockStateInvalid};
2527

2628
use crate::block_manager::{
@@ -172,7 +174,7 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
172174
pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> {
173175
match self.state() {
174176
BlockState::Complete(state) => Ok(state.token_block().sequence_hash()),
175-
BlockState::Registered(state) => Ok(state.sequence_hash()),
177+
BlockState::Registered(state, _) => Ok(state.sequence_hash()),
176178
_ => Err(BlockError::InvalidState(
177179
"Block is not complete".to_string(),
178180
)),
@@ -248,14 +250,14 @@ pub(crate) trait PrivateBlockExt {
248250
fn register(
249251
&mut self,
250252
registry: &mut registry::BlockRegistry,
251-
) -> Result<PublishHandle, registry::BlockRegistationError>;
253+
) -> Result<Option<PublishHandle>, registry::BlockRegistationError>;
252254
}
253255

254256
impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> {
255257
fn register(
256258
&mut self,
257259
registry: &mut registry::BlockRegistry,
258-
) -> Result<PublishHandle, registry::BlockRegistationError> {
260+
) -> Result<Option<PublishHandle>, registry::BlockRegistationError> {
259261
registry.register_block(&mut self.state)
260262
}
261263
}

lib/llm/src/block_manager/block/registry.rs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
use std::{
1717
collections::HashMap,
18-
sync::{Arc, Weak},
18+
sync::{Arc, Mutex, Weak},
1919
};
2020

2121
use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
@@ -39,17 +39,21 @@ pub enum BlockRegistationError {
3939
#[error("Failed to unregister block: {0}")]
4040
pub struct UnregisterFailure(SequenceHash);
4141

42+
pub type GlobalRegistry = Arc<Mutex<HashMap<SequenceHash, Weak<RegistrationHandle>>>>;
43+
4244
#[derive()]
4345
pub struct BlockRegistry {
44-
blocks: HashMap<SequenceHash, Weak<RegistrationHandle>>,
46+
blocks: HashMap<SequenceHash, Weak<()>>,
4547
event_manager: Arc<dyn EventManager>,
48+
global_pool: GlobalRegistry,
4649
}
4750

4851
impl BlockRegistry {
49-
pub fn new(event_manager: Arc<dyn EventManager>) -> Self {
52+
pub fn new(event_manager: Arc<dyn EventManager>, global_pool: GlobalRegistry) -> Self {
5053
Self {
5154
blocks: HashMap::new(),
5255
event_manager,
56+
global_pool,
5357
}
5458
}
5559

@@ -65,7 +69,7 @@ impl BlockRegistry {
6569
pub fn register_block(
6670
&mut self,
6771
block_state: &mut BlockState,
68-
) -> Result<PublishHandle, BlockRegistationError> {
72+
) -> Result<Option<PublishHandle>, BlockRegistationError> {
6973
match block_state {
7074
BlockState::Reset => Err(BlockRegistationError::InvalidState(
7175
"Block is in Reset state".to_string(),
@@ -82,21 +86,42 @@ impl BlockRegistry {
8286
}
8387
}
8488

85-
// Create the [RegistrationHandle] and [PublishHandle]
86-
let publish_handle =
87-
Self::create_publish_handle(state.token_block(), self.event_manager.clone());
88-
let reg_handle = publish_handle.remove_handle();
89+
let mut publish_handle = None;
90+
91+
let block_handle = Arc::new(());
92+
93+
let reg_handle = 'reg_block: {
94+
let mut global_pool = self.global_pool.lock().unwrap();
95+
96+
if let Some(handle) = global_pool.get(&sequence_hash) {
97+
if let Some(handle) = handle.upgrade() {
98+
break 'reg_block handle;
99+
}
100+
}
101+
102+
publish_handle = Some(Self::create_publish_handle(
103+
state.token_block(),
104+
self.event_manager.clone(),
105+
));
106+
let reg_handle = publish_handle.as_ref().unwrap().remove_handle();
107+
108+
global_pool.insert(sequence_hash, Arc::downgrade(&reg_handle));
109+
110+
reg_handle
111+
};
89112

90-
// Insert the [RegistrationHandle] into the registry
91113
self.blocks
92-
.insert(sequence_hash, Arc::downgrade(&reg_handle));
114+
.insert(sequence_hash, Arc::downgrade(&block_handle));
93115

94116
// Update the [BlockState] to [BlockState::Registered]
95-
let _ = std::mem::replace(block_state, BlockState::Registered(reg_handle));
117+
let _ = std::mem::replace(
118+
block_state,
119+
BlockState::Registered(reg_handle, block_handle),
120+
);
96121

97122
Ok(publish_handle)
98123
}
99-
BlockState::Registered(registered) => Err(
124+
BlockState::Registered(registered, _) => Err(
100125
BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()),
101126
),
102127
}

lib/llm/src/block_manager/block/state.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub enum BlockState {
3030
Reset,
3131
Partial(PartialState),
3232
Complete(CompleteState),
33-
Registered(Arc<RegistrationHandle>),
33+
Registered(Arc<RegistrationHandle>, Arc<()>),
3434
}
3535

3636
impl BlockState {
@@ -109,7 +109,7 @@ impl BlockState {
109109
BlockState::Reset => Some(0),
110110
BlockState::Partial(state) => Some(state.block.len()),
111111
BlockState::Complete(state) => Some(state.token_block.tokens().len()),
112-
BlockState::Registered(_) => None,
112+
BlockState::Registered(_, _) => None,
113113
}
114114
}
115115

@@ -126,15 +126,15 @@ impl BlockState {
126126
match self {
127127
BlockState::Reset => true,
128128
BlockState::Partial(state) => state.block.is_empty(),
129-
BlockState::Complete(_) => false, // Always full
130-
BlockState::Registered(_) => false, // Always full
129+
BlockState::Complete(_) => false, // Always full
130+
BlockState::Registered(_, _) => false, // Always full
131131
}
132132
}
133133

134134
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
135135
pub fn tokens(&self) -> Option<&Tokens> {
136136
match self {
137-
BlockState::Reset | BlockState::Registered(_) => None,
137+
BlockState::Reset | BlockState::Registered(_, _) => None,
138138
BlockState::Partial(state) => Some(state.block.tokens()),
139139
BlockState::Complete(state) => Some(state.token_block.tokens()),
140140
}
@@ -147,12 +147,12 @@ impl BlockState {
147147

148148
/// Returns true if the block is in the complete or registered state
149149
pub fn is_complete(&self) -> bool {
150-
matches!(self, BlockState::Complete(_) | BlockState::Registered(_))
150+
matches!(self, BlockState::Complete(_) | BlockState::Registered(_, _))
151151
}
152152

153153
/// Returns true if the block is in the registered state
154154
pub fn is_registered(&self) -> bool {
155-
matches!(self, BlockState::Registered(_state))
155+
matches!(self, BlockState::Registered(_state, _))
156156
}
157157
}
158158

lib/llm/src/block_manager/offload.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
9090
target: &mut MutableBlock<Target, Metadata>,
9191
) -> Result<()> {
9292
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
93-
if let BlockState::Registered(reg_handle) = source.state() {
93+
if let BlockState::Registered(reg_handle, _) = source.state() {
9494
// Bring the block back to the 'Reset' state.
9595
target.reset();
9696
// Transfer metadata.
@@ -250,7 +250,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
250250
priority: u64,
251251
) -> core::result::Result<(), BlockPoolError> {
252252
match block.state() {
253-
BlockState::Registered(_) => {}
253+
BlockState::Registered(_, _) => {}
254254
_ => {
255255
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
256256
"Block is not registered.".to_string(),
@@ -294,7 +294,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
294294
) -> core::result::Result<Vec<ImmutableBlock<DeviceStorage, Metadata>>, BlockPoolError> {
295295
for block in &blocks {
296296
match block.state() {
297-
BlockState::Registered(_) => {}
297+
BlockState::Registered(_, _) => {}
298298
_ => {
299299
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
300300
"Block is not registered.".to_string(),
@@ -610,7 +610,7 @@ mod tests {
610610
// Check that the block is registered.
611611
assert!(matches!(
612612
onboarded_blocks[0].state(),
613-
BlockState::Registered(_)
613+
BlockState::Registered(_, _)
614614
));
615615

616616
compare_block_contents(&onboarded_blocks[0], &immutable_host_block).await?;
@@ -693,7 +693,7 @@ mod tests {
693693
);
694694
assert!(matches!(
695695
onboarded_blocks[0].state(),
696-
BlockState::Registered(_)
696+
BlockState::Registered(_, _)
697697
));
698698

699699
compare_block_contents(&onboarded_blocks[0], &immutable_host_block).await?;

lib/llm/src/block_manager/pool.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ use priority_key::PriorityKey;
6969
pub use super::block::{ImmutableBlock, MutableBlock};
7070

7171
use super::block::{
72-
nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata,
72+
nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata, GlobalRegistry
7373
};
7474
use super::events::{EventManager, NullEventManager};
7575
use super::storage::Storage;
@@ -116,15 +116,18 @@ pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> {
116116

117117
#[builder(default)]
118118
blocks: Vec<Block<S, M>>,
119+
120+
#[builder(default)]
121+
global_pool: GlobalRegistry,
119122
}
120123

121124
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> {
122125
pub fn build(self) -> anyhow::Result<BlockPool<S, M>> {
123126
let args = self.build_internal()?;
124-
let (event_manager, cancel_token, blocks) = args.dissolve();
127+
let (event_manager, cancel_token, blocks, global_pool) = args.dissolve();
125128

126129
tracing::info!("building block pool");
127-
let pool = BlockPool::new(event_manager, cancel_token, blocks);
130+
let pool = BlockPool::new(event_manager, cancel_token, blocks, global_pool);
128131

129132
Ok(pool)
130133
}
@@ -200,9 +203,10 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
200203
event_manager: Arc<dyn EventManager>,
201204
cancel_token: CancellationToken,
202205
blocks: Vec<Block<S, M>>,
206+
global_pool: GlobalRegistry,
203207
) -> Self {
204208
let (pool, progress_engine) =
205-
Self::with_progress_engine(event_manager, cancel_token, blocks);
209+
Self::with_progress_engine(event_manager, cancel_token, blocks, global_pool);
206210

207211
// pool.runtime.handle().spawn(async move {
208212
// let mut progress_engine = progress_engine;
@@ -239,12 +243,19 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
239243
event_manager: Arc<dyn EventManager>,
240244
cancel_token: CancellationToken,
241245
blocks: Vec<Block<S, M>>,
246+
global_pool: GlobalRegistry,
242247
) -> (Self, ProgressEngine<S, M>) {
243248
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
244249
let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel();
245250

246-
let progress_engine =
247-
ProgressEngine::<S, M>::new(event_manager, priority_rx, ctrl_rx, cancel_token, blocks);
251+
let progress_engine = ProgressEngine::<S, M>::new(
252+
event_manager,
253+
priority_rx,
254+
ctrl_rx,
255+
cancel_token,
256+
blocks,
257+
global_pool,
258+
);
248259

249260
(
250261
Self {
@@ -468,9 +479,9 @@ mod tests {
468479
self,
469480
) -> anyhow::Result<(BlockPool<S, M>, ProgressEngine<S, M>)> {
470481
let args = self.build_internal()?;
471-
let (event_manager, cancel_token, blocks) = args.dissolve();
482+
let (event_manager, cancel_token, blocks, global_pool) = args.dissolve();
472483
let (pool, progress_engine) =
473-
BlockPool::with_progress_engine(event_manager, cancel_token, blocks);
484+
BlockPool::with_progress_engine(event_manager, cancel_token, blocks, global_pool);
474485

475486
Ok((pool, progress_engine))
476487
}

lib/llm/src/block_manager/pool/inactive.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
138138
block.reset();
139139
self.uninitialized_set.push_back(block);
140140
}
141-
BlockState::Registered(state) => {
141+
BlockState::Registered(state, _) => {
142142
let sequence_hash = state.sequence_hash();
143143
self.insert_with_sequence_hash(block, sequence_hash);
144144
}
@@ -499,6 +499,8 @@ pub(crate) mod tests {
499499

500500
use super::*;
501501

502+
use std::sync::Mutex;
503+
502504
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
503505
pub struct TestMetadata {
504506
priority: u32,
@@ -610,7 +612,10 @@ pub(crate) mod tests {
610612
let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap();
611613

612614
let event_manager = NullEventManager::new();
613-
let mut registry = BlockRegistry::new(event_manager);
615+
let mut registry = BlockRegistry::new(
616+
event_manager,
617+
GlobalRegistry::new(Mutex::new(HashMap::new())),
618+
);
614619

615620
// Iterate through the generated TokenBlocks and the template Blocks,
616621
// setting the state and registering each one.
@@ -652,7 +657,7 @@ pub(crate) mod tests {
652657
let matched_block_count = matched_blocks.len();
653658

654659
let event_manager = NullEventManager::new();
655-
let mut registry = BlockRegistry::new(event_manager);
660+
let mut registry = BlockRegistry::new(event_manager, Arc::new(Mutex::new(HashMap::new())));
656661

657662
// all matched blocks should be in the complete or registered state
658663
for block in &mut matched_blocks {

0 commit comments

Comments
 (0)