diff --git a/lib/llm/src/lib.rs b/lib/llm/src/lib.rs index 2bf3d804d0..122a1d6e66 100644 --- a/lib/llm/src/lib.rs +++ b/lib/llm/src/lib.rs @@ -28,6 +28,7 @@ pub mod hub; pub mod key_value_store; pub mod kv_router; pub use kv_router::DEFAULT_KV_BLOCK_SIZE; +pub mod mocker; pub mod model_card; pub mod model_type; pub mod preprocessor; diff --git a/lib/llm/src/mocker.rs b/lib/llm/src/mocker.rs new file mode 100644 index 0000000000..2a9e63a9e2 --- /dev/null +++ b/lib/llm/src/mocker.rs @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod evictor; +pub mod kv_manager; +pub mod protocols; +pub mod scheduler; +pub mod sequence; diff --git a/lib/llm/src/mocker/evictor.rs b/lib/llm/src/mocker/evictor.rs new file mode 100644 index 0000000000..47a312eede --- /dev/null +++ b/lib/llm/src/mocker/evictor.rs @@ -0,0 +1,191 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Eq; +use std::collections::{HashMap, VecDeque}; +use std::hash::Hash; +use std::time::Instant; + +/// An LRU evictor that maintains objects and evicts them based on their +/// last accessed time. Implements a "lazy" eviction mechanism where: +/// 1. The priority queue does not immediately reflect updates or removes +/// 2. Objects are pushed to the queue in order of increasing priority (older objects first) +/// 3. The user must ensure objects are added in correct priority (temporal order) +/// 4. Remove and update operations are lazy - entries remain in the queue until +/// they are either evicted or cleaned up during maintenance +#[derive(Debug)] +pub struct LRUEvictor { + free_table: HashMap, + priority_queue: VecDeque<(T, f64)>, + cleanup_threshold: usize, + start_time: Instant, +} + +impl Default for LRUEvictor { + fn default() -> Self { + Self { + free_table: HashMap::new(), + priority_queue: VecDeque::new(), + cleanup_threshold: 50, + start_time: Instant::now(), + } + } +} + +impl LRUEvictor { + /// Create a new LRUEvictor with the default cleanup threshold + pub fn new(cleanup_threshold: usize) -> Self { + Self { + cleanup_threshold, + ..Default::default() + } + } + + /// Get the current timestamp as seconds since initialization + pub fn current_timestamp(&self) -> f64 { + self.start_time.elapsed().as_secs_f64() + } + + /// Get an iterator over the keys in the evictor + pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, f64> { + self.free_table.keys() + } + + /// Insert or update an object in the evictor with current timestamp + pub fn insert(&mut self, object: T) { + let timestamp = self.current_timestamp(); + self._insert(object, timestamp); + } + + /// Check if the evictor contains the given object + pub fn contains(&self, object: &T) -> bool { + self.free_table.contains_key(object) + } + + /// Evict an object based on LRU policy + /// Returns the evicted object or None if no objects are available + pub fn evict(&mut self) -> Option { + if self.free_table.is_empty() { + return None; + } + + while let Some((object, last_accessed)) = self.priority_queue.pop_front() { + let Some(¤t_last_accessed) = self.free_table.get(&object) else { + continue; // entry is already removed + }; + + if current_last_accessed == last_accessed { + self.free_table.remove(&object); + return Some(object); + } // otherwise entry is stale + } + + None + } + + /// Insert or update an object in the evictor + fn _insert(&mut self, object: T, last_accessed: f64) { + self.free_table.insert(object.clone(), last_accessed); + self.priority_queue.push_back((object, last_accessed)); + self.cleanup_if_necessary(); + } + + /// Remove an object from the evictor + /// We don't remove from the priority queue immediately, as that would be inefficient + /// Outdated entries will be filtered out during eviction or cleanup + pub fn remove(&mut self, object: &T) -> bool { + self.free_table.remove(object).is_some() + } + + /// Get the number of objects in the evictor + pub fn len(&self) -> usize { + self.free_table.len() + } + + /// Check if the evictor is empty + pub fn is_empty(&self) -> bool { + self.free_table.is_empty() + } + + /// Check if cleanup is necessary and perform it if needed + fn cleanup_if_necessary(&mut self) { + if self.priority_queue.len() > self.cleanup_threshold * self.free_table.len() { + self.cleanup(); + } + } + + /// Clean up the priority queue by removing outdated entries + fn cleanup(&mut self) { + let mut new_priority_queue = VecDeque::new(); + for (object, timestamp) in self.priority_queue.drain(..) { + let Some(¤t_timestamp) = self.free_table.get(&object) else { + continue; + }; + + if current_timestamp == timestamp { + new_priority_queue.push_back((object, timestamp)); + } + } + self.priority_queue = new_priority_queue; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[rstest] + #[case(1)] + #[case(2)] + #[case(3)] + fn test_lru_evictor_eviction_order(#[case] threshold: usize) { + // Create a new LRUEvictor with the given cleanup threshold + let mut evictor = LRUEvictor::::new(threshold); + + // Add items in the specified order with small delays between each + evictor.insert(4); + std::thread::sleep(std::time::Duration::from_millis(1)); + evictor.insert(3); + std::thread::sleep(std::time::Duration::from_millis(1)); + evictor.insert(2); + std::thread::sleep(std::time::Duration::from_millis(1)); + evictor.insert(1); + std::thread::sleep(std::time::Duration::from_millis(1)); + evictor.insert(5); + std::thread::sleep(std::time::Duration::from_millis(1)); + evictor.insert(1); // Updates timestamp for 1 + std::thread::sleep(std::time::Duration::from_millis(1)); + evictor.insert(4); // Updates timestamp for 4 + std::thread::sleep(std::time::Duration::from_millis(1)); + evictor.insert(2); // Updates timestamp for 2 + + // Verify the eviction order + println!("Testing with threshold {}", threshold); + let evicted = evictor.evict().unwrap(); + assert_eq!(evicted, 3); + let evicted = evictor.evict().unwrap(); + assert_eq!(evicted, 5); + let evicted = evictor.evict().unwrap(); + assert_eq!(evicted, 1); + let evicted = evictor.evict().unwrap(); + assert_eq!(evicted, 4); + let evicted = evictor.evict().unwrap(); + assert_eq!(evicted, 2); + let evicted = evictor.evict(); + assert_eq!(evicted, None); + assert_eq!(evictor.len(), 0); + } +} diff --git a/lib/llm/src/mocker/kv_manager.rs b/lib/llm/src/mocker/kv_manager.rs new file mode 100644 index 0000000000..8a7a8fefed --- /dev/null +++ b/lib/llm/src/mocker/kv_manager.rs @@ -0,0 +1,427 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! # KV Manager +//! A synchronous implementation of a block manager that handles MoveBlock signals for caching KV blocks. +//! +//! ## Block Operations +//! The KV manager processes four types of MoveBlock signals: +//! +//! ### Use +//! - Checks if block exists in active pool → increment reference count +//! - If in inactive pool → move to active pool +//! - If neither → try evicting from inactive pool to make room +//! - If inactive pool is empty → pre-empt the oldest running request +//! +//! ### Destroy +//! - Removes the block from the active pool +//! +//! ### Deref +//! - Decrements reference count of a block in active pool +//! - If count reaches zero → move block to inactive pool +//! +//! ### Promote +//! - Converts a partial block (uuid) into a full block (global block hash) +//! +//! ## Preemption +//! If a Use operation fails (typically due to insufficient space), a false boolean signal +//! is returned to the scheduler for preemption. Initial KV block allocations for new requests +//! should not fail due to the watermark checking. +//! +//! ## NOTE +//! For simplicity (or non-simplicity), reference counting is tracked manually instead of using +//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror +//! implementation of the main block manager. + +use crate::mocker::evictor::LRUEvictor; +use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock}; +use crate::mocker::sequence::ActiveSequence; +use derive_getters::Getters; +use std::collections::{HashMap, HashSet}; + +#[derive(Getters)] +pub struct KvManager { + #[getter(copy)] + max_capacity: usize, + + #[getter(copy)] + block_size: usize, + + active_blocks: HashMap, + + inactive_blocks: LRUEvictor, + + all_blocks: HashSet, +} + +impl KvManager { + pub fn new(max_capacity: usize, block_size: usize) -> Self { + let active_blocks = HashMap::new(); + let inactive_blocks = LRUEvictor::default(); + let all_blocks = HashSet::new(); + + KvManager { + max_capacity, + block_size, + active_blocks, + inactive_blocks, + all_blocks, + } + } + + /// Process a MoveBlock instruction synchronously + pub fn process(&mut self, event: &MoveBlock) -> bool { + match event { + MoveBlock::Use(hashes, _) => { + for hash in hashes { + // First check if it already exists in active blocks + if let Some(ref_count) = self.active_blocks.get_mut(hash) { + // Block already active, just increment reference count + *ref_count += 1; + continue; + } + + // Then check if it exists in inactive and move it to active if found + if self.inactive_blocks.remove(hash) { + // Insert into active with reference count 1 + self.active_blocks.insert(hash.clone(), 1); + continue; + } + + // Get counts for capacity check + let active_count = self.active_blocks.len(); + let inactive_count = self.inactive_blocks.len(); + + // If at max capacity, evict the oldest entry from inactive blocks + if active_count + inactive_count >= self.max_capacity { + if let Some(evicted) = self.inactive_blocks.evict() { + // Remove evicted block from all_blocks + self.all_blocks.remove(&evicted); + } else { + // Cannot evict block, meaning no free blocks left in inactive pool + // Send a signal, scheduler would expect to handle preemption upon receiving this + return false; + } + } + + // Now insert the new block in active blocks with reference count 1 + self.active_blocks.insert(hash.clone(), 1); + // Add to all_blocks as it's a new block + self.all_blocks.insert(hash.clone()); + } + } + MoveBlock::Destroy(hashes) => { + // Loop in inverse direction + for hash in hashes.iter().rev() { + self.active_blocks.remove(hash).unwrap(); + // Remove from all_blocks when destroyed + assert!(self.all_blocks.remove(hash)); + } + } + MoveBlock::Deref(hashes) => { + // Loop in inverse direction + for hash in hashes.iter().rev() { + // Decrement reference count and check if we need to move to inactive + if let Some(ref_count) = self.active_blocks.get_mut(hash) { + if *ref_count == 0 { + panic!("Negative reference count would be encountered after Deref."); + } + *ref_count -= 1; + + // If reference count reaches zero, remove from active and move to inactive + if *ref_count == 0 { + self.active_blocks.remove(hash); + // Use the LRUEvictor's timing functionality + self.inactive_blocks.insert(hash.clone()); + } + } + } + } + MoveBlock::Promote(uuid, hash) => { + let uuid_block = UniqueBlock::PartialBlock(*uuid); + let hash_block = UniqueBlock::FullBlock(*hash); + + let Some(ref_count) = self.active_blocks.remove(&uuid_block) else { + let in_all_blocks = self.all_blocks.contains(&uuid_block); + panic!( + "Missing active block for promotion: {:?}. Block still exists: {}", + uuid_block, in_all_blocks + ); + }; + + // Replace with hash block, keeping the same reference count + self.active_blocks.insert(hash_block.clone(), ref_count); + + // Update all_blocks + assert!(self.all_blocks.remove(&uuid_block)); + self.all_blocks.insert(hash_block); + } + } + + // Return true if we made it this far + true + } + + /// Get the count of blocks in the input list that aren't in all_blocks + pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize { + blocks + .iter() + .filter(|&block| !self.all_blocks.contains(block)) + .count() + } + + /// Get the current capacity (active blocks + inactive blocks) + pub fn current_capacity(&self) -> usize { + let active = self.active_blocks.len(); + let inactive = self.inactive_blocks.len(); + active + inactive + } + + /// Get the current capacity as a percentage of the maximum capacity + pub fn current_capacity_perc(&self) -> f64 { + let current = self.current_capacity() as f64; + current / self.max_capacity as f64 + } + + /// Get the number of active blocks + pub fn num_active_blocks(&self) -> usize { + self.active_blocks.len() + } + + /// Get the number of inactive blocks + pub fn num_inactive_blocks(&self) -> usize { + self.inactive_blocks.len() + } + + /// Get the keys of inactive blocks + pub fn get_inactive_blocks(&self) -> Vec<&UniqueBlock> { + self.inactive_blocks.keys().collect() + } + + /// Get the keys of active blocks + pub fn get_active_blocks(&self) -> Vec<&UniqueBlock> { + self.active_blocks.keys().collect() + } + + /// Check if a sequence can be scheduled and calculate cost if possible + pub fn try_schedule( + &self, + sequence: &ActiveSequence, + watermark: f64, + tokens_budget: usize, + ) -> Option { + // Return None immediately if tokens_budget is 0 + if tokens_budget == 0 { + return None; + } + + // Get unique blocks from the sequence + let unique_blocks = sequence.unique_blocks(); + + // Get the count of new blocks + let new_blocks = self.probe_new_blocks(unique_blocks); + + // Calculate current usage and available capacity + let active_count = self.active_blocks.len(); + + // Check if we can schedule based on the watermark + if (active_count + new_blocks) as f64 > (1.0 - watermark) * self.max_capacity as f64 { + return None; + } + + // Calculate overlap blocks + let overlap_blocks = unique_blocks.len() - new_blocks; + + // Calculate new tokens + let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size; + + // // Print the full equation with actual values substituted + // println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)", + // new_tokens, + // sequence.num_input_tokens(), + // overlap_blocks, + // self.block_size); + + // Return None if new_tokens exceeds tokens_budget + if new_tokens > tokens_budget { + return None; + } + + // Calculate prefill compute + let prefill_compute = + new_tokens as f64 * (new_tokens + overlap_blocks * self.block_size) as f64; + + Some(PrefillCost { + new_tokens, + prefill_compute, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_failure_on_max_capacity() { + // Create a KvManager with 10 blocks capacity + let mut manager = KvManager::new(10, 16); + + // Helper function to use multiple blocks that returns the response + fn use_blocks(manager: &mut KvManager, ids: Vec) -> bool { + let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); + manager.process(&MoveBlock::Use(blocks, None)) + } + + // First use 10 blocks (0 to 9) in a batch + let response = use_blocks(&mut manager, (0..10).collect()); + assert!(response, "Expected success response"); + + // Verify we are at capacity + assert_eq!(manager.current_capacity(), 10); + + // The 11th block should return false, not panic + let response = use_blocks(&mut manager, vec![10]); + assert!( + !response, + "Expected failure response when exceeding max capacity" + ); + } + + #[test] + // This is taken directly from the example in the vllm v1 prefix caching docs + fn test_block_lifecycle_stringent() { + // Create a KvManager with 10 blocks capacity + let mut manager = KvManager::new(10, 16); + + // Helper function to use multiple blocks + fn use_blocks(manager: &mut KvManager, ids: Vec) { + let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); + manager.process(&MoveBlock::Use(blocks, None)); + } + + // Helper function to destroy multiple blocks + fn destroy_blocks(manager: &mut KvManager, ids: Vec) { + let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); + manager.process(&MoveBlock::Destroy(blocks)); + } + + // Helper function to deref multiple blocks + fn deref_blocks(manager: &mut KvManager, ids: Vec) { + let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); + manager.process(&MoveBlock::Deref(blocks)); + } + + // Helper function to check if active blocks contain expected blocks with expected ref counts + fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) { + assert_eq!( + manager.active_blocks().len(), + expected_blocks.len(), + "Active blocks count doesn't match expected" + ); + + for &(id, ref_count) in expected_blocks { + let block = UniqueBlock::FullBlock(id); + assert!( + manager.active_blocks().contains_key(&block), + "Block {} not found in active blocks", + id + ); + assert_eq!( + manager.active_blocks().get(&block), + Some(&ref_count), + "Block {} has wrong reference count", + id + ); + } + } + + // Helper function to check if inactive blocks contain expected blocks + fn assert_inactive_blocks( + manager: &KvManager, + expected_size: usize, + expected_blocks: &[u64], + ) { + let inactive_blocks = manager.get_inactive_blocks(); + let inactive_blocks_count = manager.inactive_blocks().len(); + + assert_eq!( + inactive_blocks_count, expected_size, + "Inactive blocks count doesn't match expected" + ); + + for &id in expected_blocks { + let block = UniqueBlock::FullBlock(id); + assert!( + inactive_blocks.iter().any(|&b| *b == block), + "Block {} not found in inactive blocks", + id + ); + } + } + + // First use blocks 0, 1, 2, 3, 4 in a batch + use_blocks(&mut manager, (0..5).collect()); + + // Then use blocks 0, 1, 5, 6 in a batch + use_blocks(&mut manager, vec![0, 1, 5, 6]); + + // Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2 + assert_active_blocks( + &manager, + &[(0, 2), (1, 2), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)], + ); + + // Now destroy block 4 + destroy_blocks(&mut manager, vec![4]); + + // And deref blocks 3, 2, 1, 0 in this order as a batch + deref_blocks(&mut manager, vec![0, 1, 2, 3]); + + // Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2 + assert_inactive_blocks(&manager, 2, &[3, 2]); + assert_active_blocks(&manager, &[(0, 1), (1, 1), (5, 1), (6, 1)]); + + // Now destroy block 6 + destroy_blocks(&mut manager, vec![6]); + + // And deref blocks 5, 1, 0 as a batch + deref_blocks(&mut manager, vec![0, 1, 5]); + + // Check that the inactive_blocks is size 5, and contains 0, 1, 2, 3, 5 + assert_inactive_blocks(&manager, 5, &[0, 1, 2, 3, 5]); + assert_active_blocks(&manager, &[]); + + // Now use 0, 1, 2, 7, 8, 9 as a batch + use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]); + + // Check that the inactive_blocks is size 2, and contains 3 and 5 + assert_inactive_blocks(&manager, 2, &[3, 5]); + assert_active_blocks(&manager, &[(0, 1), (1, 1), (2, 1), (7, 1), (8, 1), (9, 1)]); + + // Test the new_blocks method - only block 4 should be new out of [0,1,2,3,4] + let blocks_to_check: Vec = vec![0, 1, 2, 3, 4] + .into_iter() + .map(UniqueBlock::FullBlock) + .collect(); + assert_eq!(manager.probe_new_blocks(&blocks_to_check), 1); + + // Now use blocks 10, 11, 12 as a batch + use_blocks(&mut manager, vec![10, 11, 12]); + + // Check that the inactive_blocks is size 1 and contains only 5 + assert_inactive_blocks(&manager, 1, &[5]); + } +} diff --git a/lib/llm/src/mocker/protocols.rs b/lib/llm/src/mocker/protocols.rs new file mode 100644 index 0000000000..2b551db61b --- /dev/null +++ b/lib/llm/src/mocker/protocols.rs @@ -0,0 +1,93 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +pub type Token = u32; +pub type LocalBlockHash = u64; +/// A global hash identifier for blocks +pub type GlobalHash = u64; +pub type NumBlocks = usize; + +/// Represents an active block in the cache with a reference count +#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] +pub enum UniqueBlock { + /// Block identified by UUID + PartialBlock(Uuid), + /// Block identified by hash + FullBlock(GlobalHash), +} + +impl Default for UniqueBlock { + fn default() -> Self { + // Generate a random UUID when default is used + Self::PartialBlock(Uuid::new_v4()) + } +} + +/// Represents different block movement operations in the cache +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum MoveBlock { + Use(Vec, Option), + Destroy(Vec), + Deref(Vec), + Promote(Uuid, GlobalHash), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DirectRequest { + pub tokens: Vec, + pub max_output_tokens: usize, + pub uuid: Option, +} + +/// Represents the cost of prefilling content in the cache +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PrefillCost { + pub new_tokens: usize, + pub prefill_compute: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_unique_block_default_uniqueness() { + // Create 10 default UniqueBlock instances + let blocks: Vec = (0..10).map(|_| UniqueBlock::default()).collect(); + + // Extract UUIDs from each block + let mut uuids = Vec::new(); + for block in blocks { + match block { + UniqueBlock::PartialBlock(uuid) => uuids.push(uuid), + _ => panic!("Expected UuidIdentifier variant"), + } + } + + // Check that all UUIDs are unique by comparing each with every other + for i in 0..uuids.len() { + for j in i + 1..uuids.len() { + assert_ne!( + uuids[i], uuids[j], + "UUID at index {} and {} are identical", + i, j + ); + } + } + } +} diff --git a/lib/llm/src/mocker/scheduler.rs b/lib/llm/src/mocker/scheduler.rs new file mode 100644 index 0000000000..e71647feab --- /dev/null +++ b/lib/llm/src/mocker/scheduler.rs @@ -0,0 +1,585 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Asynchronous Scheduler for LLM Request Management +//! +//! This module implements an asynchronous scheduler that handles three main functions: +//! 1. Receiving new requests and placing them in the waiting queue +//! 2. Scheduling waiting requests against available KV cache resources +//! 3. Simulating the execution of running requests with realistic timing +//! +//! ## Scheduling Process +//! The scheduler uses a watermark-based approach to determine if there's sufficient +//! KV cache space for new requests. It also enforces a batched tokens budget to prevent +//! oversubscription of computational resources. Only requests that can be allocated +//! these resources are moved from waiting to running state. +//! +//! ## Request Simulation +//! The simulation models two key phases: +//! - Prefill phase: Uses a quadratic cost function: (cached_tokens + new_tokens) * new_tokens +//! - Decode phase: Uses a cost function proportional to active KV blocks (linear) +//! +//! ## Resource Management +//! The scheduler communicates with the KvManager through MoveBlock signals at each +//! stage of request processing. When resources become constrained, it employs an +//! LRU-based preemption strategy where the oldest running request is evicted and +//! placed at the back of the waiting queue to be rescheduled later. +//! +//! ## NOTE +//! The current prefill and decoding time simulations are not scientific at all and are WIP + +use crate::kv_router::protocols::ForwardPassMetrics; +use crate::mocker::evictor::LRUEvictor; +use crate::mocker::kv_manager::KvManager; +use crate::mocker::protocols::DirectRequest; +use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock}; +use crate::mocker::sequence::ActiveSequence; +use std::collections::HashMap; +use std::collections::VecDeque; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; +use tokio::time::{interval, Duration}; +use tokio_util::sync::CancellationToken; +use uuid::Uuid; + +/// Enum representing either a direct request or an active sequence +pub enum Request { + Direct(DirectRequest), + Active(ActiveSequence), +} + +#[derive(Default)] +struct SchedulerState { + waiting: VecDeque, + ready: VecDeque, + running: LRUEvictor, + requests: HashMap, + prefill_costs: HashMap>, +} + +impl SchedulerState { + /// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting. + fn receive(&mut self, request: DirectRequest) -> Uuid { + // Use the provided UUID if available, otherwise generate a new one + let uuid = request.uuid.unwrap_or_else(Uuid::new_v4); + + // Add the request to the map and waiting queue + self.requests.insert(uuid, Request::Direct(request)); + self.waiting.push_back(uuid); + uuid + } + + /// Get the next UUID from ready or waiting queue and its associated Request. + /// Returns from ready if not empty, otherwise from waiting, or None if both are empty. + /// Also removes the Request from the requests HashMap. + fn next(&mut self) -> Option<(Uuid, Request)> { + let uuid = self + .ready + .pop_front() + .or_else(|| self.waiting.pop_front())?; + let request = self.requests.remove(&uuid)?; + Some((uuid, request)) + } + + /// Move a UUID and its Request to the ready queue. + fn make_ready(&mut self, uuid: Uuid, active_seq: ActiveSequence) { + self.requests.insert(uuid, Request::Active(active_seq)); + self.ready.push_back(uuid); + } + + /// Schedule the request with the given UUID. + /// Returns the creation signal from the ActiveSequence. + fn run(&mut self, uuid: Uuid, active_seq: ActiveSequence) -> MoveBlock { + // Insert the request into the map + self.requests.insert(uuid, Request::Active(active_seq)); + + // Get the creation signal + let Some(Request::Active(sequence)) = self.requests.get(&uuid) else { + panic!("Failed to get ActiveSequence for UUID"); + }; + let Some(signal) = sequence.creation_signal() else { + panic!("Failed to get creation signal from ActiveSequence"); + }; + + // Add to running requests + self.running.insert(uuid); + signal.clone() + } + + /// Set the prefill cost for a UUID + fn set_prefill_cost(&mut self, uuid: Uuid, cost: Option) { + self.prefill_costs.insert(uuid, cost); + } + + /// Get the prefill compute value for a UUID if available + fn get_prefill_compute(&self, uuid: &Uuid) -> Option { + self.prefill_costs + .get(uuid) + .and_then(|cost| cost.as_ref()) + .map(|cost| cost.prefill_compute) + } + + /// Calculate the current running batched tokens + fn num_batched_tokens(&self) -> usize { + self.prefill_costs + .values() + .map(|cost| match cost { + Some(cost) => cost.new_tokens, + None => 1, + }) + .sum() + } + + /// Remove a UUID and its associated Request from collections. + fn complete(&mut self, uuid: &Uuid) { + // println!("Request {} will complete", uuid); + self.running.remove(uuid); + self.requests.remove(uuid); + self.prefill_costs.remove(uuid); + } + + /// Preempt the oldest running request by evicting it from running, resetting the sequence, + /// and adding it back to the waiting queue. + /// Returns the signal from reset_with_signal or None if no requests are running. + fn preempt(&mut self) -> Option> { + // Evict the oldest UUID from running + let uuid = self.running.evict()?; + eprintln!("Request {} will be preempted", uuid); + + // Remove the request from the requests HashMap and ensure it's an ActiveSequence + let request = self.requests.remove(&uuid)?; + + // Remove the prefill cost to force recomputation + self.prefill_costs.remove(&uuid); + + // Extract the ActiveSequence from the Request enum + let Request::Active(mut active_sequence) = request else { + panic!("Expected ActiveSequence in running queue") + }; + + // Reset the sequence and get the new sequence and signal + let signals = active_sequence.reset_with_signal(); + + // Insert the new sequence back into the requests map and add to waiting queue + self.requests.insert(uuid, Request::Active(active_sequence)); + self.waiting.push_back(uuid); + + Some(signals) + } +} + +/// Manages scheduling of requests using KvManager resources +#[derive(Clone)] +pub struct Scheduler { + state: Arc>, + kv_manager: Arc>, + request_tx: mpsc::Sender, +} + +impl Scheduler { + /// Create a new Scheduler with the given parameters + pub fn new( + kv_capacity: usize, + watermark: f64, + block_size: usize, + chunk_size: Option, + output_tx: Option>, + cancellation_token: Option, + ) -> Self { + // Create KvManager internally + let kv_manager = KvManager::new(kv_capacity, block_size); + + let token_capacity: usize = 8192; + let state = Arc::new(Mutex::new(SchedulerState::default())); + + let kv_manager = Arc::new(Mutex::new(kv_manager)); + let chunk_size = chunk_size.unwrap_or(256); + + // Create channel for request handling + let (request_tx, mut request_rx) = mpsc::channel::(1024); + + // Use provided cancellation token or create new one + let cancellation_token = cancellation_token.unwrap_or_default(); + let token_clone = cancellation_token.clone(); + + // Create a clone for the background task + let state_clone = state.clone(); + let kv_manager_clone = kv_manager.clone(); + let output_tx_clone = output_tx.clone(); + + // Spawn main background task with cancellation token + tokio::spawn(async move { + let mut schedule_interval = interval(Duration::from_millis(5)); + let mut simulate_interval = interval(Duration::from_millis(1)); + + loop { + tokio::select! { + biased; + + // Enqueue new request + Some(request) = request_rx.recv() => { + let mut state = state_clone.lock().await; + state.receive(request); + } + + // Try Scheduling Requests + _ = schedule_interval.tick() => { + let mut state_guard = state_clone.lock().await; + let mut kv_manager_guard = kv_manager_clone.lock().await; + + // Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't + // schedule anymore. + while let Some((uuid, request)) = state_guard.next() { + let active_sequence = get_active_sequence(request, block_size, chunk_size); + + // Calculate token budget using new_tokens from PrefillCost + let total_prefill_tokens = state_guard.num_batched_tokens(); + let tokens_budget = token_capacity.saturating_sub(total_prefill_tokens); + + // Check if it can be scheduled + let Some(prefill_cost) = kv_manager_guard.try_schedule(&active_sequence, watermark, tokens_budget) else { + state_guard.make_ready(uuid, active_sequence); + break; + }; + + // Get creation signal and schedule the request + let signal = state_guard.run(uuid, active_sequence); + kv_manager_guard.process(&signal); + state_guard.set_prefill_cost(uuid, Some(prefill_cost)); + } + } + + // Check for cancellation + _ = token_clone.cancelled() => { + break; + } + + // Simulate running requests (prefill + decode) + _ = simulate_interval.tick() => { + let mut state_guard = state_clone.lock().await; + let mut kv_manager_guard = kv_manager_clone.lock().await; + + // Base time needed for decoding (assumed memory bound on KV cache) + let active_tokens = kv_manager_guard.num_active_blocks() * block_size; + // TODO: 2 is a dummy / magic scaling factor + let mut generation_time = Duration::from_micros((active_tokens / 2) as u64); + + // Process each running request + let uuids: Vec = state_guard.running.keys().cloned().collect(); + for uuid in uuids { + // Check if UUID is still in running_requests, if not skip this iteration + if !state_guard.running.contains(&uuid) { + continue; + } + + // Get prefill compute value first + let prefill_compute = state_guard.get_prefill_compute(&uuid); + + // Get the active sequence for this UUID + let sequence = state_guard.requests.get_mut(&uuid) + .and_then(|req| if let Request::Active(seq) = req { Some(seq) } else { None }) + .expect("UUID in running_requests must have a corresponding active sequence"); + + // Generate token and get signals + let signals = sequence.generate(); + + // Accumulate sleep duration based on prefill_compute if available + // prefill compute = (cached_tokens + new_tokens) * new_tokens + let sleep_ms = if let Some(compute) = prefill_compute { + // TODO: 1024 is a dummy / magic scaling factor + (compute / 1024.0) as u64 + } else { 0 }; + generation_time += Duration::from_micros(sleep_ms); + + // Process all signals with the KvManager + // Handling of preemption on failure + if !process_signals(&mut kv_manager_guard, &signals) { + sequence.pop(); // revert the failed generation op + + // free_signal derefs the preempted blocks + let Some(free_signal) = state_guard.preempt() else { + panic!("Failed to acquire signal to free KV blocks from preemption"); + }; + + for signal in free_signal { + kv_manager_guard.process(&signal); + } + continue; + } + + // Send UUID notification for each generated token + // TODO: hook this up to an AsyncEngine + if let Some(tx) = &output_tx_clone { + let _ = tx.try_send(uuid); + } + + // Check if we're done after generating + if sequence.generated_tokens() >= sequence.max_output_tokens() { + state_guard.complete(&uuid); + continue; + } + + // Transition to decode (no prefill cost) + if sequence.generated_tokens() == 1 { + state_guard.set_prefill_cost(uuid, None); + } + } + + // Sleep once for the accumulated duration + if generation_time.as_millis() > 0 { + tokio::time::sleep(generation_time).await; + } + } + } + } + }); + + Self { + state, + kv_manager, + request_tx, + } + } + + /// Add a new request to the waiting queue + pub async fn receive(&self, request: DirectRequest) { + let _ = self.request_tx.send(request).await; + } + + /// Get the count of waiting requests + pub async fn waiting_count(&self) -> usize { + let state = self.state.lock().await; + state.waiting.len() + } + + /// Get the count of running requests + pub async fn running_count(&self) -> usize { + let state = self.state.lock().await; + state.running.len() + } + + /// Get the current capacity of the KvManager + pub async fn kv_usage_perc(&self) -> f64 { + let kv_manager = self.kv_manager.lock().await; + kv_manager.current_capacity_perc() + } + + /// Returns forward pass metrics for monitoring purposes + pub async fn get_forward_pass_metrics(&self) -> ForwardPassMetrics { + let state = self.state.lock().await; + let kv_manager = self.kv_manager.lock().await; + + // Get the active blocks and total capacity from KvManager + let active_blocks_count = kv_manager.active_blocks().len() as u64; + let total_capacity = kv_manager.max_capacity() as u64; + + // Calculate GPU cache usage percentage + let gpu_cache_usage_perc = if total_capacity > 0 { + active_blocks_count as f32 / total_capacity as f32 + } else { + 0.0 + }; + + ForwardPassMetrics { + request_active_slots: state.running.len() as u64, + request_total_slots: 420, // Dummy value as specified + kv_active_blocks: active_blocks_count, + kv_total_blocks: total_capacity, + num_requests_waiting: state.waiting.len() as u64, + gpu_cache_usage_perc, + gpu_prefix_cache_hit_rate: 0.0, // Placeholder value as specified + } + } +} + +/// Convert a Request to an ActiveSequence +fn get_active_sequence(request: Request, block_size: usize, chunk_size: usize) -> ActiveSequence { + if let Request::Active(active_seq) = request { + return active_seq; + } + + let Request::Direct(direct_request) = request else { + unreachable!("Request must be either Direct or Active"); + }; + + ActiveSequence::new( + direct_request.tokens, + direct_request.max_output_tokens, + Some(block_size), + Some(chunk_size), + ) +} + +/// Processes MoveBlock signals with the KvManager. +/// +/// When a signal fails, this function verifies that the failure is for an expected case: +/// specifically a single signal attempting to create a single partial (generation) block. +/// This validation is important because in normal operation, the only legitimate failure +/// case should be when trying to acquire a new generation block - any other failures would +/// indicate an unexpected state in the system. +fn process_signals( + kv_manager_guard: &mut tokio::sync::MutexGuard<'_, KvManager>, + signals: &[MoveBlock], +) -> bool { + for signal in signals { + if kv_manager_guard.process(signal) { + continue; + } + + // Check we have a Use signal with blocks + let MoveBlock::Use(blocks, _) = signal else { + panic!("Failed signal is Invalid. Has to fail on generation signal."); + }; + + // Verify the signal contains exactly one block + if blocks.len() != 1 { + panic!("Failed signal is Invalid. Can have only one generation signal."); + } + + // Verify the block is a PartialBlock (generation block) + if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) { + panic!("Failed signal is Invalid. Generation block has to be partial."); + } + + return false; + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + use std::time::Duration; + + #[rstest] + #[case::random(false)] + #[case::caching(true)] + #[tokio::test] + async fn test_scheduler_token_generation_patterns(#[case] use_shared_tokens: bool) { + std::env::set_var("RUST_LOG", "debug"); + + let kv_capacity: usize = 500; + let watermark: f64 = 0.01; // 1% watermark + let block_size: usize = 64; + let chunk_size: usize = 256; + let num_requests: usize = 100; + let input_len: usize = 1000; + let max_output_tokens: usize = 100; + + // Create channel for token output + let (output_tx, mut output_rx) = mpsc::channel::(1024); + + // Create scheduler with internal KvManager + let scheduler = Scheduler::new( + kv_capacity, + watermark, + block_size, + Some(chunk_size), + Some(output_tx), + None, + ); + + // Create shared tokens for caching case + let shared_tokens = if use_shared_tokens { + Some( + (0..input_len / 2) + .map(|_| rand::random::() % 50000) + .collect::>(), + ) + } else { + None + }; + + // Create test requests + for _ in 0..num_requests { + let input_tokens = if let Some(ref shared) = shared_tokens { + // For caching case: use shared tokens for first half, random for second half + let mut tokens = shared.clone(); + tokens.extend((0..input_len / 2).map(|_| rand::random::() % 50000)); + tokens + } else { + // For random case: create unique random token vector for each request + (0..input_len) + .map(|_| rand::random::() % 50000) + .collect::>() + }; + + let request = DirectRequest { + tokens: input_tokens, + max_output_tokens, + uuid: None, + }; + scheduler.receive(request).await; + } + + let start_time = std::time::Instant::now(); + + // Collect all generated tokens (should be num_requests * max_output_tokens) + let expected_tokens = num_requests * max_output_tokens; + let mut received_tokens = 0; + + // Set up a timeout that causes the test to panic if no tokens are received for 2 seconds + let timeout = tokio::time::sleep(Duration::from_secs(2)); + tokio::pin!(timeout); + + // Set up debug ticker interval + let mut debug_interval = interval(Duration::from_millis(500)); + + loop { + tokio::select! { + biased; + + // Manual debug ticker that prints forward pass metrics + _ = debug_interval.tick() => { + let _metrics = scheduler.get_forward_pass_metrics().await; + // println!("Forward Pass Metrics: {:#?}", _metrics); + } + + Some(_) = output_rx.recv() => { + received_tokens += 1; + // Reset timeout whenever we receive a token + timeout.set(tokio::time::sleep(Duration::from_secs(2))); + } + + _ = &mut timeout => { + // Break instead of panicking when timeout occurs + break; + } + } + } + + // Calculate and print elapsed time + let elapsed = start_time.elapsed(); + println!( + "Test completed in: {:?} for {} case", + elapsed, + if use_shared_tokens { + "caching" + } else { + "random" + } + ); + + // Assert that we received the expected number of tokens + assert!( + received_tokens > expected_tokens, + "Received {} tokens but expected more than {}", + received_tokens, + expected_tokens + ); + } +} diff --git a/lib/llm/src/mocker/sequence.rs b/lib/llm/src/mocker/sequence.rs new file mode 100644 index 0000000000..d53dd870e1 --- /dev/null +++ b/lib/llm/src/mocker/sequence.rs @@ -0,0 +1,433 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::mocker::protocols::{MoveBlock, UniqueBlock}; +use crate::tokens::{TokenBlockSequence, Tokens}; +use derive_getters::Getters; +use rand::random; +use uuid; + +/// Create unique blocks from a TokenBlockSequence +fn create_unique_blocks_from_sequence( + tokens: &TokenBlockSequence, + uuid: Option, + block_size: usize, +) -> Vec { + let mut unique_blocks: Vec = tokens + .blocks() + .iter() + .map(|block| UniqueBlock::FullBlock(block.sequence_hash())) + .collect(); + + // Only push the partial block if tokens count isn't a multiple of block_size + if tokens.total_tokens() % block_size != 0 { + unique_blocks.push(match uuid { + Some(uuid) => UniqueBlock::PartialBlock(uuid), + None => UniqueBlock::default(), + }); + } + unique_blocks +} + +/// A sequence that is actively being built, with the ability to add tokens and commit to hashes +/// TODO: reuse tokens +#[derive(Debug, Getters)] +pub struct ActiveSequence { + unique_blocks: Vec, + + tokens: TokenBlockSequence, + + #[getter(copy)] + block_size: usize, + + #[getter(copy)] + chunk_size: usize, // TODO: not actually used + + #[getter(copy)] + max_output_tokens: usize, + + #[getter(copy)] + generated_tokens: usize, + + #[getter(copy)] + num_input_tokens: usize, + + creation_signal: Option, +} + +impl ActiveSequence { + /// Create a new ActiveSequence instance with the provided tokens + pub fn new( + tokens: Vec, + max_output_tokens: usize, + block_size: Option, + chunk_size: Option, + ) -> Self { + let block_size = block_size.unwrap_or(64); + assert!(block_size > 1, "block_size must be greater than 1"); + let chunk_size = chunk_size.unwrap_or(256); + let num_input_tokens = tokens.len(); + + let tokens = Tokens::from(tokens).into_sequence(block_size, None); + let unique_blocks = create_unique_blocks_from_sequence(&tokens, None, block_size); + let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), None)); + + Self { + unique_blocks, + tokens, + block_size, + chunk_size, + max_output_tokens, + generated_tokens: 0, + num_input_tokens, + creation_signal, + } + } + + pub fn extra_tokens(&self) -> usize { + self.len() % self.block_size + } + + pub fn len(&self) -> usize { + self.tokens.total_tokens() + } + + pub fn is_empty(&self) -> bool { + self.tokens.total_tokens() == 0 + } + + /// Create a new ActiveSequence instance and return the creation signal + pub fn new_with_signal( + tokens: Vec, + max_output_tokens: usize, + block_size: Option, + chunk_size: Option, + ) -> (Self, Option) { + let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size); + let signal = sequence.creation_signal.take(); + (sequence, signal) + } + + /// Push a token to the sequence + pub fn push(&mut self, token: u32) -> Option> { + self.tokens.append(token).expect("Token push failed."); + self.generated_tokens += 1; + + if self.len() % self.block_size != 1 { + return None; + } + + // Add a partial block for the first token in a new partial sequence + // Send Use signal (to allocate space for this new generation block) + let mut signals = Vec::new(); + + // Replace last partial block with full block if it exists + if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() { + let last_block_hash = self.tokens.last_complete_block().unwrap().sequence_hash(); + self.unique_blocks.pop(); + self.unique_blocks + .push(UniqueBlock::FullBlock(last_block_hash)); + signals.push(MoveBlock::Promote(uuid, last_block_hash)); + } + + let new_partial_block = UniqueBlock::default(); + self.unique_blocks.push(new_partial_block.clone()); + signals.push(MoveBlock::Use(vec![new_partial_block], None)); + Some(signals) + } + + /// Generate a random token, push it to the sequence, and increment generation count. + /// + /// This function: + /// - Generates a random token and adds it to the current sequence + /// - Acquires a new partial block if needed or promotes an existing partial block to a full block + /// - Returns appropriate signals for the KvManager to process + /// + /// # Panics + /// + /// Calling this function when max_output_tokens has already been reached will cause a panic. + /// Always check `generated_tokens < max_output_tokens` before calling this method. + pub fn generate(&mut self) -> Vec { + // Assert that we haven't reached the maximum output tokens + assert!( + self.generated_tokens < self.max_output_tokens, + "Cannot generate more tokens: reached max_output_tokens limit" + ); + + // Generate a random token + let token = random::(); + + // Collect signals + let mut signals = Vec::new(); + + // Push the token to the sequence and collect any signals + if let Some(move_blocks) = self.push(token) { + signals.extend(move_blocks); + } + + // Check if we've reached the limit after pushing + if self.generated_tokens != self.max_output_tokens { + return signals; + } + + // Free all blocks when we reach max tokens + signals.extend(self.free_signal()); + signals + } + + /// Free all blocks, generating appropriate signals for each block type + pub fn free_signal(&self) -> Vec { + self.unique_blocks + .iter() + .rev() + .map(|block| match block { + UniqueBlock::PartialBlock(uuid) => { + MoveBlock::Destroy(vec![UniqueBlock::PartialBlock(*uuid)]) + } + UniqueBlock::FullBlock(hash) => { + MoveBlock::Deref(vec![UniqueBlock::FullBlock(*hash)]) + } + }) + .collect() + } + + /// Reset the sequence to its initial state and return the free signals from freeing current blocks + /// maintaining the uuid of the last partial block + pub fn reset_with_signal(&mut self) -> Vec { + let free_signal = self.free_signal(); + + self.tokens.truncate(self.num_input_tokens).unwrap(); + self.unique_blocks = + create_unique_blocks_from_sequence(&self.tokens, None, self.block_size); + self.generated_tokens = 0; + self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone(), None)); + + free_signal + } + + /// Pops last token in the sequence. + pub fn pop(&mut self) { + self.tokens.pop(); + self.generated_tokens = self.generated_tokens.saturating_sub(1); + + // Reverts to the last full block + if self.tokens.total_tokens() % self.block_size == 0 { + self.unique_blocks.pop(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_active_sequence_push() { + // Create a sequence with block size 16 initialized with tokens [0..15] + let initial_tokens: Vec = (0..15).collect(); + let (mut seq1, signal1) = + ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), Some(256)); + assert_eq!(seq1.num_input_tokens(), 15); + assert_eq!(seq1.len(), 15); + + // Check that we got a Use signal + assert!(signal1.is_some()); + match &signal1 { + Some(MoveBlock::Use(blocks, _)) => { + assert_eq!(blocks.len(), 1); + } + _ => panic!("Expected Use signal"), + } + + // Push token 15 which should complete the block (no signals yet) + let signal_15 = seq1.push(15); + assert!( + signal_15.is_none(), + "Completing a block should not trigger signals" + ); + + // Push token 16 which should trigger both Promote and Use signals + let signal_16 = seq1.push(16); + assert!(signal_16.is_some()); + let signal_16 = signal_16.unwrap(); + assert_eq!(signal_16.len(), 2); + + // Second signal should be Use for new partial block + match &signal_16[1] { + MoveBlock::Use(blocks, _) => { + assert_eq!(blocks.len(), 1); + assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); + } + _ => panic!("Expected Use signal as first signal"), + } + + // First signal should be Promote for the previous block + match &signal_16[0] { + MoveBlock::Promote(uuid, _) => { + // The uuid is generated dynamically, so we just check it exists + let _ = uuid; + } + _ => panic!("Expected Promote signal as second signal"), + } + + // Verify state after pushing tokens + assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block + assert_eq!(seq1.len(), 17); + assert_eq!(seq1.len() % seq1.block_size(), 1); + + // Create another sequence with block size 16 initialized with tokens [0..17] + let extended_tokens: Vec = (0..16).collect(); + let (mut seq2, _) = + ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), Some(256)); + seq2.push(16); + seq2.pop(); + seq2.push(16); + + // Simplified assertions + assert_eq!( + seq1.unique_blocks()[0], + seq2.unique_blocks()[0], + "First blocks should be the same" + ); + + assert_ne!( + seq1.unique_blocks()[1], + seq2.unique_blocks()[1], + "Second blocks should be different" + ); + + // Reset partial block on seq1 and push back token 16 + seq1.push(17); + seq1.pop(); + seq1.pop(); + seq1.push(16); + + // Now push tokens 17..32 to both sequences + for token in 17..33 { + seq1.push(token); + seq2.push(token); + } + + // Both sequences should now have 2 blocks: + // 1. FullBlock for tokens 0-15 + // 2. FullBlock for tokens 16-31 + // 3. No partial block since there are no remaining tokens + assert_eq!( + seq1.unique_blocks().len(), + 3, + "seq1 should have exactly 3 blocks" + ); + assert_eq!( + seq2.unique_blocks().len(), + 3, + "seq2 should have exactly 3 blocks" + ); + assert_eq!( + seq1.len() % seq1.block_size(), + 1, + "seq1 should have 1 partial token" + ); + assert_eq!( + seq2.len() % seq2.block_size(), + 1, + "seq2 should have 1 partial token" + ); + + // Verify that both sequences have identical blocks up to the second position + assert_eq!( + &seq1.unique_blocks()[0..2], + &seq2.unique_blocks()[0..2], + "First two blocks should be identical" + ); + + // Reset seq1 and check that it equals the original clone + let free_signals = seq1.reset_with_signal(); + + // Verify the reset signals include proper cleanup events + assert!(!free_signals.is_empty()); + } + + #[test] + fn test_active_sequence_generate_signals() { + // Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14) + let initial_tokens: Vec = (0..14).collect(); + let (mut seq, signal) = + ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), Some(256)); + + // Initial signal - should have received a Use signal for the partial block + assert!(signal.is_some()); + match signal { + Some(MoveBlock::Use(blocks, _)) => { + assert_eq!(blocks.len(), 1); + assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); + } + _ => panic!("Expected Use signal for the initial partial block"), + } + + // Generate first two tokens - should not trigger new signals + seq.generate(); + let signals_first = seq.generate(); + assert_eq!(signals_first.len(), 0); + + // Generate third token - this fills the block and should trigger both Promote and Use signals + let signals_second = seq.generate(); + assert_eq!(signals_second.len(), 2); + + // First signal should be Use for new partial block + match &signals_second[1] { + MoveBlock::Use(blocks, _) => { + assert_eq!(blocks.len(), 1); + assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); + } + _ => panic!("Expected Use signal as second signal after second token"), + } + + // Second signal should be Promote + match &signals_second[0] { + MoveBlock::Promote(uuid, hash) => { + // The uuid and hash values are generated dynamically, so we just check the event type + let _ = uuid; + let _ = hash; + } + _ => panic!("Expected Promote signal as first signal after second token"), + } + + // Generate fourth token - should not trigger new signals as it's adding to partial block + let signals_third = seq.generate(); + assert_eq!(signals_third.len(), 0); + + // Generate last token - we reach max_output_tokens, should trigger Destroy and Deref signals + let signals_last = seq.generate(); + assert_eq!(signals_last.len(), 2); + + // First signal should be Destroy for the partial block + match &signals_last[0] { + MoveBlock::Destroy(blocks) => { + assert_eq!(blocks.len(), 1); + assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); + } + _ => panic!("Expected Destroy signal for partial block after fourth token"), + } + + // Second signal should be Deref for the full block + match &signals_last[1] { + MoveBlock::Deref(blocks) => { + assert_eq!(blocks.len(), 1); + assert!(matches!(blocks[0], UniqueBlock::FullBlock(_))); + } + _ => panic!("Expected Deref signal for full block after fourth token"), + } + } +} diff --git a/lib/llm/src/tokens.rs b/lib/llm/src/tokens.rs index 92728a25a2..645919d1bd 100644 --- a/lib/llm/src/tokens.rs +++ b/lib/llm/src/tokens.rs @@ -188,7 +188,7 @@ pub enum TokenBlockError { /// /// This structure accumulates tokens until it reaches the specified `block_size`, /// at which point it can be [`commit`](PartialTokenBlock::commit)ted into a full [`TokenBlock`]. -#[derive(Debug)] // No Clone: intended to be unique within a sequence +#[derive(Debug, PartialEq)] // No Clone: intended to be unique within a sequence pub struct PartialTokenBlock { tokens: Tokens, block_size: usize, @@ -478,7 +478,7 @@ impl TokenBlock { /// - [`BlockHash`]: Hash of tokens within a single block (seeded by [`SaltHash`]). /// - [`SequenceHash`]: Hash combining the previous block's [`SequenceHash`] and the current /// block's [`BlockHash`] (also seeded by [`SaltHash`]). -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct TokenBlockSequence { blocks: Vec, current_block: PartialTokenBlock,