Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/bindings/python/rust/llm/block_manager/vllm/slot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ impl<S: Storage, L: LocalityProvider> Slot<S, L> {
// apply the token blocks to the mutable blocks
for (mut mutable_block, token_block) in zipped_blocks {
mutable_block
.state_mut()
.apply_token_block(token_block.clone())
.map_err(|e| {
SlotError::from_str(&format!("failed to apply token block: {:?}", e))
Expand Down
237 changes: 194 additions & 43 deletions lib/llm/src/block_manager/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ pub struct Block<S: Storage, L: LocalityProvider, M: BlockMetadata> {
metadata: M,
state: BlockState,
manager: Option<Arc<BlockManager<L, M>>>,
/// If this is Some, then the current block is a duplicate holding a reference to the primary
registered_block: Option<Arc<MutableBlock<S, L, M>>>,
}

impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Block<S, L, M> {
Expand All @@ -208,11 +210,12 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Block<S, L, M> {
metadata,
state: BlockState::Reset,
manager: None,
registered_block: None,
})
}

pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> {
match self.state() {
match &self.state {
BlockState::Complete(state) => Ok(state.token_block().sequence_hash()),
BlockState::Registered(state, _) => Ok(state.sequence_hash()),
_ => Err(BlockError::InvalidState(
Expand All @@ -221,8 +224,52 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Block<S, L, M> {
}
}

/// Returns true if this block is a duplicate (holds a reference to a primary block)
pub fn is_duplicate(&self) -> bool {
self.registered_block.is_some()
}

/// Marks this block as a duplicate, storing a reference to the primary block.
/// The block must be in Complete state and have matching sequence hash with the primary.
pub(crate) fn mark_as_duplicate(
&mut self,
primary_block: Arc<MutableBlock<S, L, M>>,
) -> Result<(), BlockError> {
// Validate the block is in Complete state
match &self.state {
BlockState::Complete(state) => {
let sequence_hash = state.token_block().sequence_hash();
if sequence_hash != primary_block.sequence_hash()? {
return Err(BlockError::InvalidState(
"duplicate blocks must have the same sequence hash".to_string(),
));
}
}
_ => {
return Err(BlockError::InvalidState(
"duplicate blocks must be in the complete state on creation".to_string(),
));
}
}

// Ensure this block isn't already a duplicate
if self.is_duplicate() {
return Err(BlockError::InvalidState(
"block is already marked as duplicate".to_string(),
));
}

self.registered_block = Some(primary_block);
Ok(())
}

/// Detaches and returns the registered block reference if this is a duplicate
pub(crate) fn detach_registered_block(&mut self) -> Option<Arc<MutableBlock<S, L, M>>> {
self.registered_block.take()
}

pub fn parent_sequence_hash(&self) -> Result<Option<SequenceHash>, BlockError> {
match self.state() {
match &self.state {
BlockState::Complete(state) => Ok(state.token_block().parent_sequence_hash()),
BlockState::Registered(state, _) => Ok(state.parent_sequence_hash()),
_ => Err(BlockError::InvalidState(
Expand All @@ -235,6 +282,34 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Block<S, L, M> {
pub fn reset(&mut self) {
self.state = BlockState::Reset;
self.metadata.reset_metadata();
self.registered_block = None;
}

/// Returns true if the block is in the reset state
pub fn is_reset(&self) -> bool {
matches!(self.state, BlockState::Reset)
}

/// Returns true if the block is in the complete or registered state
pub fn is_complete(&self) -> bool {
matches!(
self.state,
BlockState::Complete(_) | BlockState::Registered(_, _)
)
}

/// Returns true if the block is in the registered state
pub fn is_registered(&self) -> bool {
matches!(self.state, BlockState::Registered(_, _))
}

/// Get the registration handle if this block is registered
pub fn registration_handle(&self) -> Option<Arc<RegistrationHandle>> {
if let BlockState::Registered(ref handle, _) = self.state {
Some(handle.clone())
} else {
None
}
}

/// Initialize a sequence on the block using a [SaltHash]
Expand Down Expand Up @@ -352,16 +427,6 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Block<S, L, M> {
self.state = state;
}

/// Get a reference to the state of the block
pub fn state(&self) -> &BlockState {
&self.state
}

/// Get a mutable reference to the state of the block
pub fn state_mut(&mut self) -> &mut BlockState {
&mut self.state
}

/// Get the number of blocks in the block
/// todo(ryan): validate this can be removed
pub fn num_blocks(&self) -> usize {
Expand Down Expand Up @@ -635,6 +700,20 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MutableBlock<S, L, M> {
pub fn set_parent(&mut self, parent: Arc<MutableBlock<S, L, M>>) {
self.parent = Some(parent);
}

/// Marks the underlying block as a duplicate with reference to the primary block
pub(crate) fn mark_as_duplicate(
&mut self,
primary_block: Arc<MutableBlock<S, L, M>>,
) -> Result<(), BlockError> {
if let Some(ref mut block) = self.block {
block.mark_as_duplicate(primary_block)
} else {
Err(BlockError::InvalidState(
"MutableBlock has no underlying block".to_string(),
))
}
}
}

impl<S: Storage, L: LocalityProvider, M: BlockMetadata> std::fmt::Debug for MutableBlock<S, L, M> {
Expand Down Expand Up @@ -808,17 +887,25 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ImmutableBlock<S, L, M>
}
}

/// Attempts to add a duplicate block to the ImmutableBlock.
pub(crate) fn with_duplicate(
self,
duplicate: Arc<MutableBlock<S, L, M>>,
/// Creates an ImmutableBlock from a duplicate block attached to an existing primary.
/// This makes the relationship clear: we're creating a block representation from a duplicate
/// by attaching it to the existing primary ImmutableBlock.
pub(crate) fn from_duplicate(
mut duplicate: MutableBlock<S, L, M>,
primary: ImmutableBlock<S, L, M>,
) -> Result<Self, BlockError> {
if self.duplicate.is_some() {
// Validate that the primary is not itself a duplicate
if primary.duplicate.is_some() {
return Err(BlockError::IncompatibleImmutableBlock);
}

// Mark the duplicate block as a duplicate with reference to the primary
duplicate.mark_as_duplicate(primary.block.clone())?;

Ok(Self {
duplicate: Some(duplicate),
..self
block: primary.block,
sequence_hash: primary.sequence_hash,
duplicate: Some(Arc::new(duplicate)),
})
}

Expand All @@ -830,19 +917,65 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ImmutableBlock<S, L, M>
self.sequence_hash
}

/// Returns true if this block is registered (either as primary or duplicate)
pub fn is_registered(&self) -> bool {
// If this is a duplicate, check if it has a registered primary
if let Some(ref duplicate_arc) = self.duplicate
&& let Some(ref block) = duplicate_arc.block
&& let Some(ref registered_block) = block.registered_block
{
// Check the primary's state
return registered_block.is_registered();
}
// Otherwise check the primary block's state
self.block.is_registered()
}

/// Returns true if this is a duplicate block
pub fn is_duplicate(&self) -> bool {
self.duplicate.is_some()
}

/// Returns true if this block is in reset state
pub fn is_reset(&self) -> bool {
// Duplicates are never reset (they're Complete with a registered primary)
if self.is_duplicate() {
false
} else {
self.block.is_reset()
}
}

/// Returns true if this block is complete (including registered)
pub fn is_complete(&self) -> bool {
// Duplicates present as registered (which is a form of complete)
if self.is_duplicate() {
true
} else {
self.block.is_complete()
}
}

/// Get the registration handle if this block is registered
pub fn registration_handle(&self) -> Option<Arc<RegistrationHandle>> {
// If this is a duplicate, check the primary's registration
if let Some(ref duplicate_arc) = self.duplicate
&& let Some(ref block) = duplicate_arc.block
&& let Some(ref registered_block) = block.registered_block
{
return registered_block.registration_handle();
}
// Check if the primary is registered
self.block.registration_handle()
}

/// If the ImmutableBlock is a duplicate, returns the block ID of the duplicate;
/// otherwise, returns the block ID of the primary block.
pub fn block_id(&self) -> BlockId {
self.duplicate
.as_ref()
.map_or(self.block.block_id(), |duplicate| duplicate.block_id())
}

/// Returns true if the ImmutableBlock holds a duplicate block.
#[allow(unused)]
pub(crate) fn is_duplicate(&self) -> bool {
self.duplicate.is_some()
}
}

impl<S: Storage, L: LocalityProvider, M: BlockMetadata> StorageTypeProvider
Expand Down Expand Up @@ -925,20 +1058,38 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MaybeReturnableBlock<S,
}
}

fn try_take_block(mut self, token: private::PrivateToken) -> Option<Vec<Block<S, L, M>>> {
let blocks = [
Arc::try_unwrap(self.block).ok(),
self.duplicate
.take()
.and_then(|duplicate| Arc::try_unwrap(duplicate).ok()),
];
fn try_take_block(mut self, _token: private::PrivateToken) -> Option<Vec<Block<S, L, M>>> {
let mut blocks = Vec::new();

// If we have a duplicate, we need to handle it specially
if let Some(duplicate_arc) = self.duplicate.take() {
// First unwrap the duplicate and detach its reference to the primary
if let Ok(mut duplicate_mutable) = Arc::try_unwrap(duplicate_arc) {
// Extract the duplicate block
if let Some(mut duplicate_block) = duplicate_mutable.block.take() {
// The duplicate block holds a reference to the primary
// Detach it so we can unwrap the primary
duplicate_block.detach_registered_block();
// Add the duplicate block
blocks.push(duplicate_block);
}
}

let blocks = blocks
.into_iter()
.flatten()
.filter_map(|block| block.try_take_block(token))
.flatten()
.collect::<Vec<_>>();
// Now try to unwrap the primary from self.block
// This should work now that the duplicate's reference has been detached
if let Ok(mut primary_mutable) = Arc::try_unwrap(self.block)
&& let Some(primary_block) = primary_mutable.block.take()
{
blocks.push(primary_block);
}
} else {
// No duplicate, just try to unwrap the primary normally
if let Ok(mut primary_mutable) = Arc::try_unwrap(self.block)
&& let Some(primary_block) = primary_mutable.block.take()
{
blocks.push(primary_block);
}
}

if blocks.is_empty() {
None
Expand Down Expand Up @@ -1403,7 +1554,7 @@ mod tests {
#[test]
fn test_block_state_transitions_and_ops() {
let mut block = create_reset_block();
assert!(matches!(block.state(), BlockState::Reset));
assert!(block.is_reset());

// --- Reset State --- //
assert!(block.add_token(1).is_err(), "Append on Reset should fail");
Expand All @@ -1420,7 +1571,7 @@ mod tests {

// --- Reset -> Partial (via init_sequence) --- //
assert!(block.init_sequence(SALT_HASH).is_ok());
assert!(matches!(block.state(), BlockState::Partial(_)));
assert!(!block.is_reset() && !block.is_complete() && !block.is_registered());

// --- Partial State --- //
let invalid_block = create_full_token_block();
Expand Down Expand Up @@ -1469,7 +1620,7 @@ mod tests {

// --- Partial -> Complete (via commit) --- //
assert!(block.commit().is_ok());
assert!(matches!(block.state(), BlockState::Complete(_)));
assert!(block.is_complete());
assert_eq!(block.tokens().unwrap().as_ref(), &[1, 2, 3, 4]);

// --- Complete State --- //
Expand Down Expand Up @@ -1499,12 +1650,12 @@ mod tests {

// --- Complete -> Reset (via reset) --- //
block.reset();
assert!(matches!(block.state(), BlockState::Reset));
assert!(block.is_reset());

// --- Reset -> Complete (via apply_token_block) --- //
let full_block = create_full_token_block();
assert!(block.apply_token_block(full_block.clone()).is_ok());
assert!(matches!(block.state(), BlockState::Complete(_)));
assert!(block.is_complete());
let applied_tokens = block.tokens().unwrap();
assert_eq!(applied_tokens, full_block.tokens());

Expand Down
Loading
Loading