Skip to content

Commit 7677f74

Browse files
authored
feat: KVBM async Python bindings and Layer class (#1141)
1 parent a0512bd commit 7677f74

File tree

7 files changed

+784
-245
lines changed

7 files changed

+784
-245
lines changed

lib/bindings/python/rust/llm/block_manager.rs

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414
// limitations under the License.
1515

1616
#![cfg(feature = "block-manager")]
17-
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
18-
#![allow(deprecated)]
1917

2018
use super::*;
2119
use pyo3::PyResult;
22-
use tokio;
2320

2421
mod block;
2522
mod block_list;
23+
mod dlpack;
24+
mod layer;
2625

2726
/// Add bingings from this crate to the provided module
2827
pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
28+
m.add_class::<layer::Layer>()?;
2929
m.add_class::<block::Block>()?;
3030
m.add_class::<block_list::BlockList>()?;
3131
m.add_class::<BlockManager>()?;
@@ -34,9 +34,6 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
3434

3535
#[pyclass]
3636
pub struct BlockManager {
37-
// TODO: Can this be implicitly created and referenced?
38-
tokio_runtime: tokio::runtime::Runtime,
39-
// Block manager
4037
inner: Arc<dynamo_llm::block_manager::ReferenceBlockManager>,
4138
// TODO: Metadata should be stored in the block manager?
4239
dtype: dynamo_llm::common::dtype::DType,
@@ -62,7 +59,7 @@ impl BlockManager {
6259
dynamo_llm::block_manager::KvManagerRuntimeConfig::builder()
6360
.worker_id(worker_id)
6461
.build()
65-
.unwrap(),
62+
.map_err(to_pyerr)?,
6663
);
6764
let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
6865
.num_layers(num_layer)
@@ -93,14 +90,17 @@ impl BlockManager {
9390
};
9491
}
9592
model_config = model_config.dtype(dtype_.clone());
96-
config = config.model(model_config.build().unwrap());
93+
config = config.model(model_config.build().map_err(to_pyerr)?);
9794
if let Some(host_num_blocks) = host_num_blocks {
9895
config = config.host_layout(
9996
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
10097
.num_blocks(host_num_blocks)
101-
.allocator(dynamo_llm::block_manager::storage::PinnedAllocator::new().unwrap())
98+
.allocator(
99+
dynamo_llm::block_manager::storage::PinnedAllocator::new()
100+
.map_err(to_pyerr)?,
101+
)
102102
.build()
103-
.unwrap(),
103+
.map_err(to_pyerr)?,
104104
);
105105
}
106106
if let Some(device_num_blocks) = device_num_blocks {
@@ -109,23 +109,22 @@ impl BlockManager {
109109
.num_blocks(device_num_blocks)
110110
.allocator(
111111
dynamo_llm::block_manager::storage::DeviceAllocator::new(device_id)
112-
.unwrap(),
112+
.map_err(to_pyerr)?,
113113
)
114114
.build()
115-
.unwrap(),
115+
.map_err(to_pyerr)?,
116116
);
117117
}
118-
let config = config.build().unwrap();
119-
let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
120-
.enable_all()
121-
.build()
122-
.unwrap();
123-
let block_manager = tokio_runtime.block_on(async {
124-
dynamo_llm::block_manager::ReferenceBlockManager::new(config).unwrap()
125-
});
118+
let config = config.build().map_err(to_pyerr)?;
119+
let tokio_runtime = pyo3_async_runtimes::tokio::get_runtime();
126120
Ok(BlockManager {
127-
tokio_runtime: tokio_runtime,
128-
inner: Arc::from(block_manager),
121+
inner: Arc::from(
122+
tokio_runtime
123+
.block_on(async {
124+
dynamo_llm::block_manager::ReferenceBlockManager::new(config)
125+
})
126+
.map_err(to_pyerr)?,
127+
),
129128
dtype: dtype_,
130129
device_id: device_id,
131130
})
@@ -135,9 +134,11 @@ impl BlockManager {
135134
let blocks = self
136135
.inner
137136
.host()
138-
.unwrap()
137+
.ok_or_else(|| {
138+
pyo3::exceptions::PyRuntimeError::new_err("Host allocator not available")
139+
})?
139140
.allocate_blocks_blocking(count)
140-
.unwrap();
141+
.map_err(to_pyerr)?;
141142
// Wrap each block in an enum accounting for Pinned & Device block
142143
let blocks = blocks
143144
.into_iter()
@@ -150,13 +151,42 @@ impl BlockManager {
150151
))
151152
}
152153

154+
#[pyo3(signature = (count))]
155+
fn allocate_host_blocks<'py>(
156+
&self,
157+
py: Python<'py>,
158+
count: usize,
159+
) -> PyResult<Bound<'py, PyAny>> {
160+
let inner = self.inner.clone();
161+
let dtype = self.dtype.clone();
162+
let device_id = self.device_id;
163+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
164+
let blocks = inner
165+
.host()
166+
.ok_or_else(|| {
167+
pyo3::exceptions::PyRuntimeError::new_err("Host allocator not available")
168+
})?
169+
.allocate_blocks(count)
170+
.await
171+
.map_err(to_pyerr)?;
172+
// Wrap each block in an enum accounting for Pinned & Device block
173+
let blocks = blocks
174+
.into_iter()
175+
.map(|b| block::BlockType::Pinned(b))
176+
.collect();
177+
Ok(block_list::BlockList::from_rust(blocks, dtype, device_id))
178+
})
179+
}
180+
153181
fn allocate_device_blocks_blocking(&self, count: usize) -> PyResult<block_list::BlockList> {
154182
let blocks = self
155183
.inner
156184
.device()
157-
.unwrap()
185+
.ok_or_else(|| {
186+
pyo3::exceptions::PyRuntimeError::new_err("Device allocator not available")
187+
})?
158188
.allocate_blocks_blocking(count)
159-
.unwrap();
189+
.map_err(to_pyerr)?;
160190
// Wrap each block in an enum accounting for Pinned & Device block
161191
let blocks = blocks
162192
.into_iter()
@@ -168,4 +198,31 @@ impl BlockManager {
168198
self.device_id,
169199
))
170200
}
201+
202+
#[pyo3(signature = (count))]
203+
fn allocate_device_blocks<'py>(
204+
&self,
205+
py: Python<'py>,
206+
count: usize,
207+
) -> PyResult<Bound<'py, PyAny>> {
208+
let inner = self.inner.clone();
209+
let dtype = self.dtype.clone();
210+
let device_id = self.device_id;
211+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
212+
let blocks = inner
213+
.device()
214+
.ok_or_else(|| {
215+
pyo3::exceptions::PyRuntimeError::new_err("Device allocator not available")
216+
})?
217+
.allocate_blocks(count)
218+
.await
219+
.map_err(to_pyerr)?;
220+
// Wrap each block in an enum accounting for Pinned & Device block
221+
let blocks = blocks
222+
.into_iter()
223+
.map(|b| block::BlockType::Device(b))
224+
.collect();
225+
Ok(block_list::BlockList::from_rust(blocks, dtype, device_id))
226+
})
227+
}
171228
}

0 commit comments

Comments
 (0)