Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lib/llm/src/block_manager/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub mod view;
pub use crate::tokens::TokenBlockError;
pub use anyhow::Result;
use nixl_sys::NixlDescriptor;

pub use registry::{GlobalRegistry, RegistrationHandle};
pub use state::{BlockState, BlockStateInvalid};

use crate::block_manager::{
Expand Down Expand Up @@ -176,7 +178,7 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> {
match self.state() {
BlockState::Complete(state) => Ok(state.token_block().sequence_hash()),
BlockState::Registered(state) => Ok(state.sequence_hash()),
BlockState::Registered(state, _) => Ok(state.sequence_hash()),
_ => Err(BlockError::InvalidState(
"Block is not complete".to_string(),
)),
Expand Down Expand Up @@ -250,14 +252,14 @@ pub(crate) trait PrivateBlockExt {
fn register(
&mut self,
registry: &mut registry::BlockRegistry,
) -> Result<PublishHandle, registry::BlockRegistationError>;
) -> Result<Option<PublishHandle>, registry::BlockRegistationError>;
}

impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> {
fn register(
&mut self,
registry: &mut registry::BlockRegistry,
) -> Result<PublishHandle, registry::BlockRegistationError> {
) -> Result<Option<PublishHandle>, registry::BlockRegistationError> {
registry.register_block(&mut self.state)
}
}
Expand Down
167 changes: 131 additions & 36 deletions lib/llm/src/block_manager/block/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//! # KV Cache Block Registration
//!
//! - This module is responsible for maintaining a registry of all blocks currently within a pool.
//! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks.
//! - The global registry is a mapping of sequences hashes to registration handles. If two blocks in different pools
//! have the same sequence hash, then they will share the same registration handle. The global registry is shared across all pools.
//! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are
//! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime.
//! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle.
//!
//! ## Workflow
//!
//! 1. When a block is registered into a pool, we create a unique block handle.
//! 2. We then check the global registry to see if the block already exists in any other pool.
//! 3. If it does, we use the existing registration handle. Otherwise, we create a new one.
//! 4. When the block handle is dropped, it means that the block is no longer in the pool.
//! 5. When the registration handle is dropped, it means that the block is no longer in any pool.

use std::{
collections::HashMap,
sync::{Arc, Weak},
sync::{Arc, Mutex, Weak},
};

use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
Expand All @@ -24,6 +42,9 @@ use super::state::BlockState;
use crate::tokens::{BlockHash, SequenceHash, TokenBlock};

use derive_getters::Getters;
use tokio::{runtime::Handle, sync::mpsc};

pub type GlobalRegistry = Arc<Mutex<HashMap<SequenceHash, Weak<RegistrationHandle>>>>;

#[derive(Debug, thiserror::Error)]
pub enum BlockRegistationError {
Expand All @@ -34,27 +55,88 @@ pub enum BlockRegistationError {
InvalidState(String),
}

/// Error returned when an attempt is made to unregister a block that is still active.
#[derive(Debug, thiserror::Error)]
#[error("Failed to unregister block: {0}")]
pub struct UnregisterFailure(SequenceHash);
/// A block entry is a handle to a block that is registered in the pool.
/// On drop, we need to notify the pool that the block has been unregistered.
/// This is different than the registration handle, which is only dropped when the block is no longer in ANY pool.
#[derive(Debug)]
pub struct BlockHandle {
sequence_hash: SequenceHash,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
}

impl BlockHandle {
pub fn new(
sequence_hash: SequenceHash,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
) -> Self {
Self {
sequence_hash,
unregister_tx,
}
}
}

impl Drop for BlockHandle {
fn drop(&mut self) {
let _ = self.unregister_tx.send(self.sequence_hash);
}
}

#[derive()]
pub struct BlockRegistry {
blocks: HashMap<SequenceHash, Weak<RegistrationHandle>>,
blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>>,
event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
}

impl BlockRegistry {
pub fn new(event_manager: Arc<dyn EventManager>) -> Self {
pub fn new(
event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self {
let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel();

let blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>> =
Arc::new(Mutex::new(HashMap::new()));

let blocks_clone = blocks.clone();
let global_registry_clone = global_registry.clone();
async_runtime.spawn(async move {
let blocks = blocks_clone;
let global_registry = global_registry_clone;
while let Some(sequence_hash) = unregister_rx.recv().await {
{
let mut blocks = blocks.lock().unwrap();

if let Some(handle) = blocks.get(&sequence_hash) {
if handle.upgrade().is_none() {
blocks.remove(&sequence_hash);
}
}
}

let mut global_registry = global_registry.lock().unwrap();

if let Some(entry) = global_registry.get(&sequence_hash) {
if entry.upgrade().is_none() {
global_registry.remove(&sequence_hash);
}
}
}
});

Self {
blocks: HashMap::new(),
blocks,
event_manager,
global_registry,
unregister_tx,
}
}

pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool {
if let Some(handle) = self.blocks.get(&sequence_hash) {
let blocks = self.blocks.lock().unwrap();
if let Some(handle) = blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() {
return true;
}
Expand All @@ -65,7 +147,7 @@ impl BlockRegistry {
pub fn register_block(
&mut self,
block_state: &mut BlockState,
) -> Result<PublishHandle, BlockRegistationError> {
) -> Result<Option<PublishHandle>, BlockRegistationError> {
match block_state {
BlockState::Reset => Err(BlockRegistationError::InvalidState(
"Block is in Reset state".to_string(),
Expand All @@ -76,47 +158,60 @@ impl BlockRegistry {

BlockState::Complete(state) => {
let sequence_hash = state.token_block().sequence_hash();
if let Some(handle) = self.blocks.get(&sequence_hash) {
let mut blocks = self.blocks.lock().unwrap();

// If an identical block already exists in this pool, return an error.
if let Some(handle) = blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() {
return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash));
}
}

// Create the [RegistrationHandle] and [PublishHandle]
let publish_handle =
Self::create_publish_handle(state.token_block(), self.event_manager.clone());
let reg_handle = publish_handle.remove_handle();
let mut publish_handle = None;

let block_handle =
Arc::new(BlockHandle::new(sequence_hash, self.unregister_tx.clone()));

let reg_handle = 'reg_block: {
// Now, check the global registry.
let mut global_registry = self.global_registry.lock().unwrap();

// If an identical block exists in other pool, use the same registration handle.
if let Some(handle) = global_registry.get(&sequence_hash) {
if let Some(handle) = handle.upgrade() {
break 'reg_block handle;
}
}

// Otherwise, create a new registration handle.
publish_handle = Some(Self::create_publish_handle(
state.token_block(),
self.event_manager.clone(),
));
let reg_handle = publish_handle.as_ref().unwrap().remove_handle();

// Insert the registration handle into the global registry.
global_registry.insert(sequence_hash, Arc::downgrade(&reg_handle));

// Insert the [RegistrationHandle] into the registry
self.blocks
.insert(sequence_hash, Arc::downgrade(&reg_handle));
reg_handle
};

blocks.insert(sequence_hash, Arc::downgrade(&block_handle));

// Update the [BlockState] to [BlockState::Registered]
let _ = std::mem::replace(block_state, BlockState::Registered(reg_handle));
let _ = std::mem::replace(
block_state,
BlockState::Registered(reg_handle, block_handle),
);

Ok(publish_handle)
}
BlockState::Registered(registered) => Err(
BlockState::Registered(registered, _) => Err(
BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()),
),
}
}

pub fn unregister_block(
&mut self,
sequence_hash: SequenceHash,
) -> Result<(), UnregisterFailure> {
if let Some(handle) = self.blocks.get(&sequence_hash) {
if handle.upgrade().is_none() {
self.blocks.remove(&sequence_hash);
return Ok(());
} else {
return Err(UnregisterFailure(sequence_hash));
}
}
Ok(())
}

fn create_publish_handle(
token_block: &TokenBlock,
event_manager: Arc<dyn EventManager>,
Expand Down
16 changes: 8 additions & 8 deletions lib/llm/src/block_manager/block/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;

use derive_getters::Getters;

use super::registry::RegistrationHandle;
use super::registry::{BlockHandle, RegistrationHandle};
use super::Result;
use crate::tokens::{PartialTokenBlock, SaltHash, Token, TokenBlock, Tokens};

Expand All @@ -30,7 +30,7 @@ pub enum BlockState {
Reset,
Partial(PartialState),
Complete(CompleteState),
Registered(Arc<RegistrationHandle>),
Registered(Arc<RegistrationHandle>, Arc<BlockHandle>),
}

impl BlockState {
Expand Down Expand Up @@ -109,7 +109,7 @@ impl BlockState {
BlockState::Reset => Some(0),
BlockState::Partial(state) => Some(state.block.len()),
BlockState::Complete(state) => Some(state.token_block.tokens().len()),
BlockState::Registered(_) => None,
BlockState::Registered(_, _) => None,
}
}

Expand All @@ -126,15 +126,15 @@ impl BlockState {
match self {
BlockState::Reset => true,
BlockState::Partial(state) => state.block.is_empty(),
BlockState::Complete(_) => false, // Always full
BlockState::Registered(_) => false, // Always full
BlockState::Complete(_) => false, // Always full
BlockState::Registered(_, _) => false, // Always full
}
}

/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
pub fn tokens(&self) -> Option<&Tokens> {
match self {
BlockState::Reset | BlockState::Registered(_) => None,
BlockState::Reset | BlockState::Registered(_, _) => None,
BlockState::Partial(state) => Some(state.block.tokens()),
BlockState::Complete(state) => Some(state.token_block.tokens()),
}
Expand All @@ -147,12 +147,12 @@ impl BlockState {

/// Returns true if the block is in the complete or registered state
pub fn is_complete(&self) -> bool {
matches!(self, BlockState::Complete(_) | BlockState::Registered(_))
matches!(self, BlockState::Complete(_) | BlockState::Registered(_, _))
}

/// Returns true if the block is in the registered state
pub fn is_registered(&self) -> bool {
matches!(self, BlockState::Registered(_state))
matches!(self, BlockState::Registered(_state, _))
}
}

Expand Down
8 changes: 4 additions & 4 deletions lib/llm/src/block_manager/offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
priority: u64,
) -> core::result::Result<(), BlockPoolError> {
match block.state() {
BlockState::Registered(_) => {}
BlockState::Registered(_, _) => {}
_ => {
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
Expand Down Expand Up @@ -397,7 +397,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
) -> BlockResult<DeviceStorage, Metadata> {
for block in &blocks {
match block.state() {
BlockState::Registered(_) => {}
BlockState::Registered(_, _) => {}
_ => {
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
Expand Down Expand Up @@ -857,7 +857,7 @@ mod tests {
// Check that the block is registered.
assert!(matches!(
onboarded_blocks[0].state(),
BlockState::Registered(_)
BlockState::Registered(_, _)
));

check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
Expand Down Expand Up @@ -940,7 +940,7 @@ mod tests {
);
assert!(matches!(
onboarded_blocks[0].state(),
BlockState::Registered(_)
BlockState::Registered(_, _)
));

check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
Expand Down
2 changes: 1 addition & 1 deletion lib/llm/src/block_manager/offload/pending.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
target: &mut MutableBlock<Target, Metadata>,
) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle) = source.state() {
if let BlockState::Registered(reg_handle, _) = source.state() {
// Bring the block back to the 'Reset' state.
target.reset();
// Transfer metadata.
Expand Down
Loading
Loading