Skip to content

Commit 0373980

Browse files
committed
Fix types in Python bindings for kv_cache_block_size and context_length
1 parent a4d6bba commit 0373980

File tree

8 files changed

+28
-29
lines changed

8 files changed

+28
-29
lines changed

lib/bindings/python/rust/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ fn register_llm<'p>(
105105
endpoint: Endpoint,
106106
model_path: &str,
107107
model_name: Option<&str>,
108-
context_length: Option<usize>,
109-
kv_cache_block_size: Option<usize>,
108+
context_length: Option<u32>,
109+
kv_cache_block_size: Option<u32>,
110110
) -> PyResult<Bound<'p, PyAny>> {
111111
let model_type_obj = match model_type {
112112
ModelType::Chat => llm_rs::model_type::ModelType::Chat,

lib/bindings/python/rust/llm/kv.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ impl KvRouter {
4040

4141
let runtime = pyo3_async_runtimes::tokio::get_runtime();
4242
runtime.block_on(async {
43-
let inner =
44-
llm_rs::kv_router::KvRouter::new(component.inner.clone(), kv_block_size, None)
45-
.await
46-
.map_err(to_pyerr)?;
43+
let inner = llm_rs::kv_router::KvRouter::new(
44+
component.inner.clone(),
45+
kv_block_size as u32,
46+
None,
47+
)
48+
.await
49+
.map_err(to_pyerr)?;
4750
Ok(Self {
4851
inner: Arc::new(inner),
4952
})
@@ -73,7 +76,7 @@ pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) ->
7376
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
7477
}
7578

76-
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size);
79+
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32);
7780
Ok(hashes.into_iter().map(|h| h.0).collect())
7881
}
7982

@@ -191,7 +194,7 @@ impl ZmqKvEventPublisher {
191194
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
192195
component.inner,
193196
config.worker_id,
194-
config.kv_block_size,
197+
config.kv_block_size as u32,
195198
Some(KvEventSourceConfig::Zmq {
196199
endpoint: config.zmq_endpoint,
197200
topic: config.zmq_topic,
@@ -232,7 +235,7 @@ impl ZmqKvEventListener {
232235
zmq_topic,
233236
tx,
234237
shutdown_token.clone(),
235-
kv_block_size,
238+
kv_block_size as u32,
236239
));
237240

238241
Ok(Self {
@@ -293,7 +296,7 @@ impl KvEventPublisher {
293296
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
294297
component.inner,
295298
worker_id,
296-
kv_block_size,
299+
kv_block_size as u32,
297300
None,
298301
)
299302
.map_err(to_pyerr)?;
@@ -322,7 +325,7 @@ impl KvEventPublisher {
322325
data: KvCacheEventData::Stored(KvCacheStoreData {
323326
parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
324327
blocks: create_stored_blocks(
325-
self.kv_block_size,
328+
self.kv_block_size as u32,
326329
&token_ids,
327330
&num_block_tokens,
328331
&block_hashes,
@@ -446,7 +449,7 @@ impl KvIndexer {
446449
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
447450
llm_rs::kv_router::indexer::KvIndexer::new(
448451
component.inner.drt().runtime().child_token(),
449-
kv_block_size,
452+
kv_block_size as u32,
450453
)
451454
.into();
452455
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
@@ -478,7 +481,7 @@ impl KvIndexer {
478481
}
479482

480483
fn block_size(&self) -> usize {
481-
self.inner.block_size()
484+
self.inner.block_size() as usize
482485
}
483486

484487
fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> {

lib/engines/llamacpp/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ impl LlamacppEngine {
7878

7979
let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS);
8080
let llama_ctx_params = if model_config.card().context_length > 0 {
81-
let n_ctx = NonZeroU32::new(model_config.card().context_length as u32);
81+
let n_ctx = NonZeroU32::new(model_config.card().context_length);
8282
LlamaContextParams::default().with_n_ctx(n_ctx)
8383
} else {
8484
// Context length defaults to 512 currently

lib/llm/src/block_manager/block.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,7 @@ mod tests {
16051605
use dynamo_runtime::logging::init as init_logging;
16061606
use nixl_sys::Agent as NixlAgent;
16071607

1608-
const BLOCK_SIZE: usize = 4;
1608+
const BLOCK_SIZE: u32 = 4;
16091609
const SALT_HASH: SaltHash = 12345;
16101610

16111611
// Helper to create a default reset block
@@ -1666,7 +1666,7 @@ mod tests {
16661666

16671667
// Extend to fill capacity
16681668
assert!(block.add_tokens(Tokens::from(vec![4])).is_ok()); // 1, 2, 3, 4
1669-
assert_eq!(block.len(), BLOCK_SIZE);
1669+
assert_eq!(block.len(), BLOCK_SIZE as usize);
16701670

16711671
// Append when full (should fail)
16721672
assert!(block.add_token(5).is_err(), "Append on full Partial block");
@@ -1690,7 +1690,7 @@ mod tests {
16901690

16911691
// Fill block again for commit
16921692
assert!(block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).is_ok());
1693-
assert_eq!(block.len(), BLOCK_SIZE);
1693+
assert_eq!(block.len(), BLOCK_SIZE as usize);
16941694

16951695
// --- Partial -> Complete (via commit) --- //
16961696
assert!(block.commit().is_ok());

lib/llm/src/block_manager/block/state.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl BlockState {
4343
return Err(BlockStateInvalid("Block is not reset".to_string()));
4444
}
4545

46-
let block = PartialTokenBlock::create_sequence_root(page_size, salt_hash);
46+
let block = PartialTokenBlock::create_sequence_root(page_size as u32, salt_hash);
4747
*self = BlockState::Partial(PartialState::new(block));
4848
Ok(())
4949
}

lib/llm/src/block_manager/pool/inactive.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ pub(crate) mod tests {
648648
/// Each block is initialized to the Complete state and then Registered.
649649
pub fn create_blocks(
650650
tokens: Tokens,
651-
block_size: usize,
651+
block_size: u32,
652652
async_runtime: Handle,
653653
) -> Vec<Block<NullDeviceStorage, TestMetadata>> {
654654
let (token_blocks, _partial_token_block) =
@@ -691,7 +691,7 @@ pub(crate) mod tests {
691691

692692
pub fn acquire_blocks(
693693
tokens: Tokens,
694-
block_size: usize,
694+
block_size: u32,
695695
pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>,
696696
async_runtime: Handle,
697697
) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) {
@@ -749,7 +749,7 @@ pub(crate) mod tests {
749749

750750
let async_runtime = tokio::runtime::Runtime::new().unwrap();
751751

752-
const PAGE_SIZE: usize = 2;
752+
const PAGE_SIZE: u32 = 2;
753753

754754
let mut pool = create_block_pool(10);
755755
assert_eq!(pool.total_blocks(), 10);

lib/llm/src/kv_router/indexer.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,21 +1315,17 @@ mod tests {
13151315
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
13161316
setup();
13171317
// create a sequence of 64 elements
1318-
let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>();
1318+
let sequence = (0..kv_block_size).collect::<Vec<u32>>();
13191319
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
13201320
assert_eq!(hashes.len(), 1);
13211321

13221322
// create a sequence of 65 elements
1323-
let sequence = (0..(kv_block_size + 1))
1324-
.map(|i| i as u32)
1325-
.collect::<Vec<u32>>();
1323+
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
13261324
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
13271325
assert_eq!(hashes.len(), 1);
13281326

13291327
// create a sequence of 129 elements
1330-
let sequence = (0..(2 * kv_block_size + 1))
1331-
.map(|i| i as u32)
1332-
.collect::<Vec<u32>>();
1328+
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
13331329
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
13341330
assert_eq!(hashes.len(), 2);
13351331
}

lib/llm/src/tokenizers/sp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl Decoder for SentencePieceTokenizer {
8181
/// # Arguments
8282
/// * `token_ids` - The sequence of token IDs to decode
8383
/// * `skip_special_tokens` - Currently unsupported in SentencePieceTokenizer and
84-
/// it will return an error if true
84+
/// it will return an error if true
8585
///
8686
/// # Returns
8787
/// * `Result<String>` - The decoded text

0 commit comments

Comments
 (0)