Skip to content

Commit 3a1ef41

Browse files
committed
Fix types in Python bindings for kv_cache_block_size and context_length
1 parent 78b5c63 commit 3a1ef41

File tree

4 files changed

+20
-21
lines changed

4 files changed

+20
-21
lines changed

lib/bindings/python/rust/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ fn register_llm<'p>(
105105
endpoint: Endpoint,
106106
model_path: &str,
107107
model_name: Option<&str>,
108-
context_length: Option<usize>,
109-
kv_cache_block_size: Option<usize>,
108+
context_length: Option<u32>,
109+
kv_cache_block_size: Option<u32>,
110110
) -> PyResult<Bound<'p, PyAny>> {
111111
let model_type_obj = match model_type {
112112
ModelType::Chat => llm_rs::model_type::ModelType::Chat,

lib/bindings/python/rust/llm/kv.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ impl KvRouter {
4040

4141
let runtime = pyo3_async_runtimes::tokio::get_runtime();
4242
runtime.block_on(async {
43-
let inner =
44-
llm_rs::kv_router::KvRouter::new(component.inner.clone(), kv_block_size, None)
45-
.await
46-
.map_err(to_pyerr)?;
43+
let inner = llm_rs::kv_router::KvRouter::new(
44+
component.inner.clone(),
45+
kv_block_size as u32,
46+
None,
47+
)
48+
.await
49+
.map_err(to_pyerr)?;
4750
Ok(Self {
4851
inner: Arc::new(inner),
4952
})
@@ -73,7 +76,7 @@ pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) ->
7376
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
7477
}
7578

76-
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size);
79+
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32);
7780
Ok(hashes.into_iter().map(|h| h.0).collect())
7881
}
7982

@@ -191,7 +194,7 @@ impl ZmqKvEventPublisher {
191194
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
192195
component.inner,
193196
config.worker_id,
194-
config.kv_block_size,
197+
config.kv_block_size as u32,
195198
Some(KvEventSourceConfig::Zmq {
196199
endpoint: config.zmq_endpoint,
197200
topic: config.zmq_topic,
@@ -232,7 +235,7 @@ impl ZmqKvEventListener {
232235
zmq_topic,
233236
tx,
234237
shutdown_token.clone(),
235-
kv_block_size,
238+
kv_block_size as u32,
236239
));
237240

238241
Ok(Self {
@@ -293,7 +296,7 @@ impl KvEventPublisher {
293296
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
294297
component.inner,
295298
worker_id,
296-
kv_block_size,
299+
kv_block_size as u32,
297300
None,
298301
)
299302
.map_err(to_pyerr)?;
@@ -322,7 +325,7 @@ impl KvEventPublisher {
322325
data: KvCacheEventData::Stored(KvCacheStoreData {
323326
parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
324327
blocks: create_stored_blocks(
325-
self.kv_block_size,
328+
self.kv_block_size as u32,
326329
&token_ids,
327330
&num_block_tokens,
328331
&block_hashes,
@@ -446,7 +449,7 @@ impl KvIndexer {
446449
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
447450
llm_rs::kv_router::indexer::KvIndexer::new(
448451
component.inner.drt().runtime().child_token(),
449-
kv_block_size,
452+
kv_block_size as u32,
450453
)
451454
.into();
452455
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
@@ -478,7 +481,7 @@ impl KvIndexer {
478481
}
479482

480483
fn block_size(&self) -> usize {
481-
self.inner.block_size()
484+
self.inner.block_size() as usize
482485
}
483486

484487
fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> {

lib/engines/llamacpp/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ impl LlamacppEngine {
7878

7979
let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS);
8080
let llama_ctx_params = if model_config.card().context_length > 0 {
81-
let n_ctx = NonZeroU32::new(model_config.card().context_length as u32);
81+
let n_ctx = NonZeroU32::new(model_config.card().context_length);
8282
LlamaContextParams::default().with_n_ctx(n_ctx)
8383
} else {
8484
// Context length defaults to 512 currently

lib/llm/src/kv_router/indexer.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,21 +1315,17 @@ mod tests {
13151315
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
13161316
setup();
13171317
// create a sequence of 64 elements
1318-
let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>();
1318+
let sequence = (0..kv_block_size).collect::<Vec<u32>>();
13191319
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
13201320
assert_eq!(hashes.len(), 1);
13211321

13221322
// create a sequence of 65 elements
1323-
let sequence = (0..(kv_block_size + 1))
1324-
.map(|i| i as u32)
1325-
.collect::<Vec<u32>>();
1323+
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
13261324
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
13271325
assert_eq!(hashes.len(), 1);
13281326

13291327
// create a sequence of 129 elements
1330-
let sequence = (0..(2 * kv_block_size + 1))
1331-
.map(|i| i as u32)
1332-
.collect::<Vec<u32>>();
1328+
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
13331329
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
13341330
assert_eq!(hashes.len(), 2);
13351331
}

0 commit comments

Comments
 (0)