Skip to content

Commit faa74d6

Browse files
feats: add get_offloaded_computed_blocks to get reusable g2 g3 blocks (#1709)
Co-authored-by: jthomson04 <jwillthomson19@gmail.com>
1 parent 800aeec commit faa74d6

File tree

24 files changed

+656
-240
lines changed

24 files changed

+656
-240
lines changed

lib/bindings/python/rust/llm/block_manager.rs

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,6 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
4444
Ok(())
4545
}
4646

47-
pub fn map_dtype(dtype: &str) -> anyhow::Result<dynamo_llm::common::dtype::DType> {
48-
Ok(match dtype {
49-
"fp8" | "FP8" => dynamo_llm::common::dtype::DType::FP8,
50-
"fp16" | "FP16" => dynamo_llm::common::dtype::DType::FP16,
51-
"bf16" | "BF16" => dynamo_llm::common::dtype::DType::BF16,
52-
"fp32" | "FP32" => dynamo_llm::common::dtype::DType::FP32,
53-
"u8" | "U8" => dynamo_llm::common::dtype::DType::U8,
54-
"u16" | "U16" => dynamo_llm::common::dtype::DType::U16,
55-
"u32" | "U32" => dynamo_llm::common::dtype::DType::U32,
56-
"u64" | "U64" => dynamo_llm::common::dtype::DType::U64,
57-
"i8" | "I8" => dynamo_llm::common::dtype::DType::I8,
58-
"i16" | "I16" => dynamo_llm::common::dtype::DType::I16,
59-
"i32" | "I32" => dynamo_llm::common::dtype::DType::I32,
60-
"i64" | "I64" => dynamo_llm::common::dtype::DType::I64,
61-
_ => return Err(anyhow::anyhow!("Unsupported dtype: {}", dtype)),
62-
})
63-
}
64-
6547
type VllmBlockManager = dynamo_llm::block_manager::KvBlockManager<
6648
Logical<DistributedLeaderWorkerResources>,
6749
BasicMetadata,
@@ -71,6 +53,7 @@ type VllmBlockManager = dynamo_llm::block_manager::KvBlockManager<
7153
#[derive(Clone)]
7254
pub struct BlockManager {
7355
inner: Arc<VllmBlockManager>,
56+
_rt: Arc<tokio::runtime::Runtime>,
7457
}
7558

7659
#[pymethods]
@@ -94,7 +77,7 @@ impl BlockManager {
9477

9578
tracing::info!("Using {} device blocks", device_num_blocks);
9679

97-
let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
80+
let model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
9881
.num_layers(1)
9982
.outer_dim(1)
10083
.page_size(page_size)
@@ -110,47 +93,46 @@ impl BlockManager {
11093
.map_err(to_pyerr)?,
11194
);
11295

113-
if leader.inner().num_host_blocks() > 0 {
114-
tracing::info!("Using {} host blocks", leader.inner().num_host_blocks());
96+
let (leader, rt) = leader.dissolve();
97+
98+
if leader.num_host_blocks() > 0 {
99+
tracing::info!("Using {} host blocks", leader.num_host_blocks());
115100
config = config.host_layout(
116101
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
117-
.num_blocks(leader.inner().num_host_blocks())
102+
.num_blocks(leader.num_host_blocks())
118103
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
119104
.build()
120105
.map_err(to_pyerr)?,
121106
);
122107
}
123108

124-
if leader.inner().num_disk_blocks() > 0 {
125-
tracing::info!("Using {} disk blocks", leader.inner().num_disk_blocks());
109+
if leader.num_disk_blocks() > 0 {
110+
tracing::info!("Using {} disk blocks", leader.num_disk_blocks());
126111
config = config.disk_layout(
127112
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
128-
.num_blocks(leader.inner().num_disk_blocks())
113+
.num_blocks(leader.num_disk_blocks())
129114
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
130115
.build()
131116
.map_err(to_pyerr)?,
132117
);
133118
}
134119

135120
let config = config.build().map_err(to_pyerr)?;
136-
let tokio_runtime = pyo3_async_runtimes::tokio::get_runtime();
137121
Ok(BlockManager {
138122
inner: Arc::from(
139-
tokio_runtime
140-
.block_on(async {
141-
let resources = DistributedLeaderWorkerResources::new(
142-
leader.inner(),
143-
cancel_token.child_token(),
144-
)?;
145-
146-
dynamo_llm::block_manager::KvBlockManager::<
147-
Logical<DistributedLeaderWorkerResources>,
148-
BasicMetadata,
149-
>::new(config, resources)
150-
.await
151-
})
152-
.map_err(to_pyerr)?,
123+
rt.block_on(async {
124+
let resources =
125+
DistributedLeaderWorkerResources::new(leader, cancel_token.child_token())?;
126+
127+
dynamo_llm::block_manager::KvBlockManager::<
128+
Logical<DistributedLeaderWorkerResources>,
129+
BasicMetadata,
130+
>::new(config, resources)
131+
.await
132+
})
133+
.map_err(to_pyerr)?,
153134
),
135+
_rt: rt,
154136
})
155137
}
156138

lib/bindings/python/rust/llm/block_manager/distributed.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use super::*;
55

66
mod leader;
7+
mod utils;
78
mod worker;
89

910
pub use leader::KvbmLeader;

lib/bindings/python/rust/llm/block_manager/distributed/leader.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use super::*;
5+
use utils::get_barrier_id;
56

7+
use derive_getters::Dissolve;
68
use llm_rs::block_manager::distributed::{KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig};
79

810
fn compute_num_blocks(env_var: &str, bytes_per_block: usize) -> usize {
@@ -14,26 +16,22 @@ fn compute_num_blocks(env_var: &str, bytes_per_block: usize) -> usize {
1416
}
1517

1618
#[pyclass]
17-
#[derive(Clone)]
19+
#[derive(Clone, Dissolve)]
1820
pub struct KvbmLeader {
1921
leader: Arc<KvbmLeaderImpl>,
20-
_rt: Arc<tokio::runtime::Runtime>,
21-
}
22-
23-
impl KvbmLeader {
24-
pub fn inner(&self) -> Arc<KvbmLeaderImpl> {
25-
self.leader.clone()
26-
}
22+
rt: Arc<tokio::runtime::Runtime>,
2723
}
2824

2925
#[pymethods]
3026
impl KvbmLeader {
3127
#[new]
32-
#[pyo3(signature = (barrier_id, bytes_per_block, world_size))]
33-
fn new(barrier_id: String, bytes_per_block: usize, world_size: usize) -> PyResult<Self> {
28+
#[pyo3(signature = (bytes_per_block, world_size))]
29+
fn new(bytes_per_block: usize, world_size: usize) -> PyResult<Self> {
3430
let num_host_blocks = compute_num_blocks("DYNAMO_KVBM_CPU_CACHE", bytes_per_block);
3531
let num_disk_blocks = compute_num_blocks("DYNAMO_KVBM_DISK_CACHE", bytes_per_block);
3632

33+
let barrier_id = get_barrier_id();
34+
3735
let config = KvbmLeaderConfig::builder()
3836
.barrier_id(barrier_id)
3937
.num_host_blocks(num_host_blocks)
@@ -52,7 +50,7 @@ impl KvbmLeader {
5250

5351
Ok(Self {
5452
leader: Arc::new(leader),
55-
_rt: Arc::new(rt),
53+
rt: Arc::new(rt),
5654
})
5755
}
5856
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
pub fn get_barrier_id() -> String {
5+
std::env::var("DYNAMO_KVBM_BARRIER_ID").unwrap_or("kvbm".to_string())
6+
}

lib/bindings/python/rust/llm/block_manager/distributed/worker.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
use super::*;
55

6+
use utils::get_barrier_id;
7+
68
use llm_rs::block_manager::distributed::{KvbmWorker as KvbmWorkerImpl, KvbmWorkerConfig};
79
use llm_rs::block_manager::storage::torch::{TorchDevice, TorchTensor};
810

@@ -80,32 +82,31 @@ pub struct KvbmWorker {
8082
#[pymethods]
8183
impl KvbmWorker {
8284
#[new]
83-
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, worker_id=0, dtype=None, barrier_id="kvbm".to_string()))]
85+
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, worker_id=0, dtype_width_bytes=2))]
8486
fn new(
8587
num_device_blocks: usize,
8688
page_size: usize,
8789
tensors: Vec<Py<PyAny>>,
8890
device_id: usize,
8991
worker_id: usize,
90-
dtype: Option<String>,
91-
barrier_id: String,
92+
dtype_width_bytes: usize,
9293
) -> PyResult<Self> {
93-
let dtype = map_dtype(&dtype.unwrap_or("fp16".to_string())).map_err(to_pyerr)?;
94-
9594
let mut vllm_tensors: Vec<Box<dyn TorchTensor>> = Vec::with_capacity(tensors.len());
9695

9796
for tensor in tensors {
9897
let vllm_tensor = VllmTensor::new(tensor.clone()).map_err(to_pyerr)?;
9998
vllm_tensors.push(Box::new(vllm_tensor));
10099
}
101100

101+
let barrier_id = get_barrier_id();
102+
102103
let config = KvbmWorkerConfig::builder()
103104
.num_device_blocks(num_device_blocks)
104105
.page_size(page_size)
105106
.tensors(vllm_tensors)
106107
.device_id(device_id)
107108
.worker_id(worker_id)
108-
.dtype(dtype)
109+
.dtype_width_bytes(dtype_width_bytes)
109110
.barrier_id(barrier_id)
110111
.build()
111112
.map_err(to_pyerr)?;

0 commit comments

Comments
 (0)