Skip to content

Commit 75decfb

Browse files
authored
feat: Migrate to GDS MT backend (#1734)
1 parent faa74d6 commit 75decfb

File tree

11 files changed

+106
-76
lines changed

11 files changed

+106
-76
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

container/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ SGLANG_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
114114
VLLM_V1_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
115115
VLLM_V1_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
116116

117-
NIXL_COMMIT=16348080f5bdeb9fe6058a23be140cec020ef3f3
117+
NIXL_COMMIT=fa800bcfe3814b08df9cda9c30443de8c19665e5
118118
NIXL_REPO=ai-dynamo/nixl.git
119119

120120
NIXL_UCX_EFA_REF=7ec95b95e524a87e81cac92f5ca8523e3966b16b

lib/bindings/python/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/bindings/python/rust/llm/block_manager/distributed/worker.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use super::*;
55

6+
use std::sync::Arc;
67
use utils::get_barrier_id;
78

89
use llm_rs::block_manager::distributed::{KvbmWorker as KvbmWorkerImpl, KvbmWorkerConfig};
@@ -91,11 +92,11 @@ impl KvbmWorker {
9192
worker_id: usize,
9293
dtype_width_bytes: usize,
9394
) -> PyResult<Self> {
94-
let mut vllm_tensors: Vec<Box<dyn TorchTensor>> = Vec::with_capacity(tensors.len());
95+
let mut vllm_tensors: Vec<Arc<dyn TorchTensor>> = Vec::with_capacity(tensors.len());
9596

9697
for tensor in tensors {
9798
let vllm_tensor = VllmTensor::new(tensor.clone()).map_err(to_pyerr)?;
98-
vllm_tensors.push(Box::new(vllm_tensor));
99+
vllm_tensors.push(Arc::new(vllm_tensor));
99100
}
100101

101102
let barrier_id = get_barrier_id();

lib/llm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ regex = "1"
8484
rayon = "1"
8585

8686
# block_manager
87-
nixl-sys = {git="https://github.com/ai-dynamo/nixl", rev = "a7c654d46a14cd5ce635cc8c02433d71df93dedf", optional = true }
87+
nixl-sys = {git="https://github.com/ai-dynamo/nixl", rev = "fa800bcfe3814b08df9cda9c30443de8c19665e5", optional = true }
8888
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
8989
ndarray = { version = "0.16", optional = true }
9090
nix = { version = "0.26", optional = true }

lib/llm/src/block_manager/distributed.rs

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ mod tests {
3737

3838
use dynamo_runtime::logging::init as init_logging;
3939

40-
const NUM_DEVICE_BLOCKS: usize = 8;
41-
const NUM_HOST_BLOCKS: usize = 8;
40+
const NUM_BLOCKS: usize = 8;
4241

4342
#[derive(Clone, Debug)]
4443
struct MockTensor {
@@ -100,12 +99,12 @@ mod tests {
10099
let barrier_id = get_unique_barrier_id();
101100

102101
for i in 0..num_workers {
103-
let tensors: Vec<Box<dyn TorchTensor>> =
104-
vec![Box::new(MockTensor::new(vec![2, NUM_DEVICE_BLOCKS, 4096]))];
102+
let tensors: Vec<Arc<dyn TorchTensor>> =
103+
vec![Arc::new(MockTensor::new(vec![2, NUM_BLOCKS, 4096]))];
105104

106105
let config = KvbmWorkerConfig::builder()
107106
.barrier_id(barrier_id.clone())
108-
.num_device_blocks(NUM_DEVICE_BLOCKS)
107+
.num_device_blocks(NUM_BLOCKS)
109108
.tensors(tensors)
110109
.worker_id(i)
111110
.build()?;
@@ -117,7 +116,8 @@ mod tests {
117116
let leader_config = KvbmLeaderConfig::builder()
118117
.barrier_id(barrier_id)
119118
.world_size(num_workers)
120-
.num_host_blocks(NUM_HOST_BLOCKS)
119+
.num_host_blocks(NUM_BLOCKS)
120+
.num_disk_blocks(NUM_BLOCKS)
121121
.build()?;
122122

123123
// When/if this returns, we know that all the workers were also successful.
@@ -137,7 +137,9 @@ mod tests {
137137

138138
let (leader, _workers) = build_leader_and_workers(num_workers).await?;
139139

140-
for block_idx in 0..std::cmp::min(NUM_DEVICE_BLOCKS, NUM_HOST_BLOCKS) {
140+
// Do a whole bunch of distributed transfers.
141+
142+
for block_idx in 0..NUM_BLOCKS {
141143
leader
142144
.transfer_blocks_request(utils::BlockTransferRequest::new(
143145
utils::BlockTransferPool::Device,
@@ -148,10 +150,21 @@ mod tests {
148150
.await?;
149151
}
150152

151-
for block_idx in 0..std::cmp::min(NUM_DEVICE_BLOCKS, NUM_HOST_BLOCKS) {
153+
for block_idx in 0..NUM_BLOCKS {
152154
leader
153155
.transfer_blocks_request(utils::BlockTransferRequest::new(
154156
utils::BlockTransferPool::Host,
157+
utils::BlockTransferPool::Disk,
158+
vec![(block_idx, block_idx)],
159+
))
160+
.await?
161+
.await?;
162+
}
163+
164+
for block_idx in 0..NUM_BLOCKS {
165+
leader
166+
.transfer_blocks_request(utils::BlockTransferRequest::new(
167+
utils::BlockTransferPool::Disk,
155168
utils::BlockTransferPool::Device,
156169
vec![(block_idx, block_idx)],
157170
))
@@ -194,13 +207,19 @@ mod tests {
194207
)
195208
.device_layout(
196209
KvManagerLayoutConfig::builder()
197-
.num_blocks(NUM_DEVICE_BLOCKS)
210+
.num_blocks(NUM_BLOCKS)
198211
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
199212
.build()?,
200213
)
201214
.host_layout(
202215
KvManagerLayoutConfig::builder()
203-
.num_blocks(NUM_HOST_BLOCKS)
216+
.num_blocks(NUM_BLOCKS)
217+
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
218+
.build()?,
219+
)
220+
.disk_layout(
221+
KvManagerLayoutConfig::builder()
222+
.num_blocks(NUM_BLOCKS)
204223
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
205224
.build()?,
206225
)
@@ -218,8 +237,9 @@ mod tests {
218237

219238
let device_pool = block_manager.device().unwrap();
220239
let host_pool = block_manager.host().unwrap();
240+
let disk_pool = block_manager.disk().unwrap();
221241

222-
let mut device_blocks = device_pool.allocate_blocks(NUM_DEVICE_BLOCKS).await?;
242+
let mut device_blocks = device_pool.allocate_blocks(NUM_BLOCKS).await?;
223243

224244
let mut sequence_hashes = Vec::new();
225245
for block in &mut device_blocks {
@@ -245,15 +265,21 @@ mod tests {
245265
.match_sequence_hashes(sequence_hashes.as_slice())
246266
.await?;
247267

248-
assert_eq!(host_blocks.len(), NUM_DEVICE_BLOCKS);
268+
assert_eq!(host_blocks.len(), NUM_BLOCKS);
269+
270+
let disk_blocks = disk_pool
271+
.match_sequence_hashes(sequence_hashes.as_slice())
272+
.await?;
273+
274+
assert_eq!(disk_blocks.len(), NUM_BLOCKS);
249275

250276
// Return the device blocks to the pool.
251277
drop(immutable_device_blocks);
252278

253279
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
254280

255281
// Clear out the device pool.
256-
let _ = device_pool.allocate_blocks(NUM_DEVICE_BLOCKS).await?;
282+
let _ = device_pool.allocate_blocks(NUM_BLOCKS).await?;
257283

258284
// Now, all the blocks should be gone.
259285
assert_eq!(
@@ -270,7 +296,7 @@ mod tests {
270296
// Now, onboard them back to the device.
271297
let new_device_blocks = block_manager.onboard_blocks(host_blocks, None).await??;
272298

273-
assert_eq!(new_device_blocks.len(), NUM_DEVICE_BLOCKS);
299+
assert_eq!(new_device_blocks.len(), NUM_BLOCKS);
274300

275301
Ok(())
276302
}

lib/llm/src/block_manager/distributed/leader.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use std::time::Duration;
1717
use tokio::sync::oneshot;
1818
use tokio_util::sync::CancellationToken;
1919

20+
const INIT_TIMEOUT_SECS: u64 = 120;
21+
2022
/// Data that is sent to workers over ETCD to establish a ZMQ connection.
2123
#[derive(Debug, Clone, Serialize, Deserialize)]
2224
pub struct KvbmLeaderData {
@@ -86,7 +88,7 @@ impl KvbmLeader {
8688
let leader_barrier: LeaderBarrier<KvbmLeaderData, ()> = LeaderBarrier::new(
8789
config.barrier_id.clone(),
8890
config.world_size,
89-
Some(Duration::from_secs(30)),
91+
Some(Duration::from_secs(INIT_TIMEOUT_SECS)),
9092
);
9193

9294
let worker_data = leader_barrier
@@ -103,7 +105,7 @@ impl KvbmLeader {
103105
let zmq_leader = ZmqActiveMessageLeader::new(
104106
leader_sockets,
105107
config.world_size,
106-
Duration::from_secs(30),
108+
Duration::from_secs(INIT_TIMEOUT_SECS),
107109
cancel_token.clone(),
108110
)
109111
.await?;

lib/llm/src/block_manager/distributed/worker.rs

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use dynamo_runtime::{
3131
};
3232

3333
fn load_and_validate_tensors(
34-
tensors: Vec<Box<dyn TorchTensor>>,
34+
tensors: &[Arc<dyn TorchTensor>],
3535
device_id: usize,
3636
) -> anyhow::Result<(Vec<DeviceStorage>, Vec<usize>)> {
3737
let mut shape = None;
@@ -67,7 +67,7 @@ fn load_and_validate_tensors(
6767
}
6868

6969
// Build the storage object from the tensor.
70-
let device_tensor = DeviceStorage::new_from_torch(allocator.ctx(), tensor)?;
70+
let device_tensor = DeviceStorage::new_from_torch(allocator.ctx(), tensor.clone())?;
7171

7272
device_tensors.push(device_tensor);
7373
}
@@ -84,7 +84,7 @@ pub struct KvbmWorkerConfig {
8484
page_size: usize,
8585

8686
#[builder(default = "Vec::new()")]
87-
tensors: Vec<Box<dyn TorchTensor>>,
87+
tensors: Vec<Arc<dyn TorchTensor>>,
8888

8989
#[builder(default = "0")]
9090
device_id: usize,
@@ -105,14 +105,13 @@ impl KvbmWorkerConfig {
105105
}
106106
}
107107

108-
fn build_agent(worker_id: usize) -> anyhow::Result<NixlAgent> {
109-
// TODO: Get GDS enabled here.
110-
// There seems to be some issue with NIXL that causes errors if a large amount of GDS backends are instantiated all at once.
111-
108+
fn build_agent(worker_id: usize, use_gds: bool) -> anyhow::Result<NixlAgent> {
112109
let agent = NixlAgent::new(&format!("kvbm-worker-{}", worker_id))?;
113-
// let (_, gds_params) = agent.get_plugin_params("GDS")?;
110+
if use_gds {
111+
let (_, gds_params) = agent.get_plugin_params("GDS_MT")?;
112+
agent.create_backend("GDS_MT", &gds_params)?;
113+
}
114114
let (_, posix_params) = agent.get_plugin_params("POSIX")?;
115-
// agent.create_backend("GDS", &gds_params)?;
116115
agent.create_backend("POSIX", &posix_params)?;
117116

118117
Ok(agent)
@@ -135,7 +134,7 @@ impl KvbmWorker {
135134
return Err(anyhow::anyhow!("num_device_blocks must be greater than 0"));
136135
}
137136

138-
let (device_tensors, shape) = load_and_validate_tensors(config.tensors, config.device_id)?;
137+
let (device_tensors, shape) = load_and_validate_tensors(&config.tensors, config.device_id)?;
139138

140139
if shape.len() < 3 {
141140
return Err(anyhow::anyhow!(format!(
@@ -182,28 +181,14 @@ impl KvbmWorker {
182181

183182
let layout_builder_clone = layout_builder.clone();
184183

185-
let agent = build_agent(config.worker_id)?;
186-
187-
let transfer_context = Arc::new(TransferContext::new(
188-
Arc::new(Some(agent)),
189-
DeviceAllocator::new(config.device_id)
190-
.unwrap()
191-
.ctx()
192-
.new_stream()
193-
.unwrap(),
194-
Handle::current(),
195-
));
196-
197184
let cancel_token = CancellationToken::new();
198185
let task = CriticalTaskExecutionHandle::new(
199186
move |cancel_token| {
200187
KvbmWorker::worker_task(
201188
device_layout,
202189
layout_builder_clone,
203190
layout_type,
204-
config.barrier_id,
205-
config.worker_id,
206-
transfer_context,
191+
config,
207192
cancel_token,
208193
)
209194
},
@@ -235,26 +220,22 @@ impl KvbmWorker {
235220
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage>>,
236221
mut layout_builder: LayoutConfigBuilder,
237222
layout_type: LayoutType,
238-
barrier_id: String,
239-
worker_id: usize,
240-
transfer_context: Arc<TransferContext>,
223+
config: KvbmWorkerConfig,
241224
cancel_token: CancellationToken,
242225
) -> anyhow::Result<()> {
243-
// Build our device, host, and disk block lists.
244-
let device_blocks = Some(Self::make_layout::<_, BasicMetadata>(
245-
device_layout,
246-
transfer_context.nixl_agent().as_ref(),
247-
0,
248-
worker_id,
249-
)?);
250-
251226
let runtime = Runtime::from_current()?;
252227
let drt = DistributedRuntime::from_settings(runtime).await?;
253228

254-
tracing::info!("Worker {} waiting on barrier {}", worker_id, barrier_id);
229+
tracing::info!(
230+
"Worker {} waiting on barrier {}",
231+
config.worker_id,
232+
config.barrier_id
233+
);
255234

256-
let worker_barrier =
257-
WorkerBarrier::<KvbmLeaderData, ()>::new(barrier_id, worker_id.to_string());
235+
let worker_barrier = WorkerBarrier::<KvbmLeaderData, ()>::new(
236+
config.barrier_id,
237+
config.worker_id.to_string(),
238+
);
258239

259240
let leader_data = tokio::select! {
260241
_ = cancel_token.cancelled() => {
@@ -268,10 +249,30 @@ impl KvbmWorker {
268249

269250
tracing::info!(
270251
"Worker {} received leader data: {:?}",
271-
worker_id,
252+
config.worker_id,
272253
leader_data
273254
);
274255

256+
let agent = build_agent(config.worker_id, leader_data.num_disk_blocks > 0)?;
257+
258+
let transfer_context = Arc::new(TransferContext::new(
259+
Arc::new(Some(agent)),
260+
DeviceAllocator::new(config.device_id)
261+
.unwrap()
262+
.ctx()
263+
.new_stream()
264+
.unwrap(),
265+
Handle::current(),
266+
));
267+
268+
// Build our device, host, and disk block lists.
269+
let device_blocks = Some(Self::make_layout::<_, BasicMetadata>(
270+
device_layout,
271+
transfer_context.nixl_agent().as_ref(),
272+
0,
273+
config.worker_id,
274+
)?);
275+
275276
let host_blocks = if leader_data.num_host_blocks > 0 {
276277
let host_allocator = Arc::new(PinnedAllocator::default());
277278
let host_layout = layout_builder
@@ -283,7 +284,7 @@ impl KvbmWorker {
283284
host_layout,
284285
transfer_context.nixl_agent().as_ref(),
285286
1,
286-
worker_id,
287+
config.worker_id,
287288
)?)
288289
} else {
289290
None
@@ -300,7 +301,7 @@ impl KvbmWorker {
300301
disk_layout,
301302
transfer_context.nixl_agent().as_ref(),
302303
2,
303-
worker_id,
304+
config.worker_id,
304305
)?)
305306
} else {
306307
None

0 commit comments

Comments
 (0)