Skip to content

Commit be88e7b

Browse files
committed
Graceful shutdown + batched nixl transfers + fix DiskStorage unregistration
1 parent 93702e4 commit be88e7b

File tree

5 files changed

+134
-136
lines changed

5 files changed

+134
-136
lines changed

lib/llm/src/block_manager/block/transfer.rs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ where
133133
pub trait WriteTo<Target> {
134134
fn write_to(
135135
&self,
136-
dst: &mut Target,
136+
dst: &mut Vec<Target>,
137137
notify: Option<String>,
138138
ctx: Arc<TransferContext>,
139139
) -> Result<(), TransferError>;
@@ -143,49 +143,61 @@ pub trait WriteTo<Target> {
143143
/// Returns a future that will complete when the transfer is complete.
144144
fn nixl_write_to(
145145
&self,
146-
dst: &mut Target,
146+
dst: &mut Vec<Target>,
147147
notify: Option<String>,
148148
ctx: Arc<TransferContext>,
149149
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError>;
150150
}
151151

152-
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB
152+
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for Vec<Arc<RB>>
153153
where
154154
RB: WriteToStrategy<WB> + Local,
155155
{
156156
fn write_to(
157157
&self,
158-
dst: &mut WB,
158+
dst: &mut Vec<WB>,
159159
notify: Option<String>,
160160
ctx: Arc<TransferContext>,
161161
) -> Result<(), TransferError> {
162-
match Self::write_to_strategy() {
163-
TransferStrategy::Memcpy => memcpy::copy_block(self, dst),
162+
match RB::write_to_strategy() {
163+
TransferStrategy::Memcpy => {
164+
for (src, dst) in self.iter().zip(dst.iter_mut()) {
165+
memcpy::copy_block(src.as_ref(), dst)?;
166+
}
167+
Ok(())
168+
}
164169
TransferStrategy::CudaAsyncH2D
165170
| TransferStrategy::CudaAsyncD2H
166171
| TransferStrategy::CudaAsyncD2D => {
167-
cuda::copy_block(self, dst, ctx.stream().as_ref(), RB::write_to_strategy())
172+
for (src, dst) in self.iter().zip(dst.iter_mut()) {
173+
cuda::copy_block(
174+
src.as_ref(),
175+
dst,
176+
ctx.stream().as_ref(),
177+
RB::write_to_strategy(),
178+
)?;
179+
}
180+
Ok(())
168181
}
169182
TransferStrategy::NixlWrite => {
170-
std::mem::drop(nixl::write_block_to(self, dst, ctx, notify)?);
183+
std::mem::drop(nixl::write_blocks_to(self, dst, ctx, notify)?);
171184
Ok(())
172185
}
173186
_ => Err(TransferError::IncompatibleTypes(format!(
174187
"Unsupported copy strategy: {:?}",
175188
RB::write_to_strategy()
176189
))),
177190
}
178-
// dispatch_copy_to(self, dst, self.transfer_context())
179191
}
180192

181193
fn nixl_write_to(
182194
&self,
183-
dst: &mut WB,
195+
dst: &mut Vec<WB>,
184196
notify: Option<String>,
185197
ctx: Arc<TransferContext>,
186198
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
187199
if let TransferStrategy::NixlWrite = RB::write_to_strategy() {
188-
Ok(nixl::write_block_to(self, dst, ctx, notify)?)
200+
Ok(nixl::write_blocks_to(self, dst, ctx, notify)?)
189201
} else {
190202
Err(TransferError::IncompatibleTypes(format!(
191203
"Expected NIXL transfer strategy, got: {:?}",

lib/llm/src/block_manager/block/transfer/nixl.rs

Lines changed: 64 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,14 @@ use super::*;
1818
use anyhow::Result;
1919
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
2020
use std::future::{poll_fn, Future};
21-
use std::ops::Range;
2221
use std::task::Poll;
2322

24-
/// Copy a block from a source to a destination using CUDA memcpy
25-
pub fn write_block_to<'a, Source, Destination>(
26-
src: &'a Source,
27-
dst: &'a mut Destination,
28-
ctx: Arc<TransferContext>,
29-
notify: Option<String>,
30-
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
23+
fn append_xfer_request<Source, Destination>(
24+
src: &Arc<Source>,
25+
dst: &mut Destination,
26+
src_dl: &mut XferDescList,
27+
dst_dl: &mut XferDescList,
28+
) -> Result<()>
3129
where
3230
Source: BlockDataProvider,
3331
Destination: BlockDataProviderMut,
@@ -36,17 +34,6 @@ where
3634
let dst_data = dst.block_data_mut(private::PrivateToken);
3735

3836
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
39-
// Keep the arc to use in the returned future.
40-
let nixl_agent_arc = ctx.as_ref().nixl_agent();
41-
42-
let nixl_agent = nixl_agent_arc
43-
.as_ref()
44-
.as_ref()
45-
.expect("NIXL agent not found");
46-
47-
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
48-
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
49-
5037
let src_desc = src_data.block_view()?.as_nixl_descriptor();
5138
let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();
5239

@@ -64,121 +51,99 @@ where
6451
)?;
6552
}
6653

67-
let xfer_req = nixl_agent
68-
.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)
69-
.unwrap();
70-
71-
let mut xfer_args = OptArgs::new()?;
72-
73-
if let Some(notify) = notify {
74-
xfer_args.set_has_notification(true)?;
75-
xfer_args.set_notification_message(notify.as_bytes())?;
76-
}
77-
78-
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
79-
80-
// Return a future that completes when the transfer is complete.
81-
// TODO: How efficient is this? Can we do better?
82-
Ok(Box::new(poll_fn(move |_cx| {
83-
let nixl_agent = nixl_agent_arc
84-
.as_ref()
85-
.as_ref()
86-
.expect("NIXL agent not found");
87-
88-
// The nixl agent returns true if the transfer is still in progress.
89-
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
90-
Poll::Ready(())
91-
} else {
92-
Poll::Pending
93-
}
94-
})))
54+
Ok(())
9555
} else {
9656
assert_eq!(src_data.num_layers(), dst_data.num_layers());
97-
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)
57+
for layer_idx in 0..src_data.num_layers() {
58+
for outer_idx in 0..src_data.num_outer_dims() {
59+
let src_view = src_data.layer_view(layer_idx, outer_idx)?;
60+
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
61+
62+
debug_assert_eq!(src_view.size(), dst_view.size());
63+
64+
let src_desc = src_view.as_nixl_descriptor();
65+
let dst_desc = dst_view.as_nixl_descriptor_mut();
66+
67+
unsafe {
68+
src_dl.add_desc(
69+
src_desc.as_ptr() as usize,
70+
src_desc.size(),
71+
src_desc.device_id(),
72+
)?;
73+
74+
dst_dl.add_desc(
75+
dst_desc.as_ptr() as usize,
76+
dst_desc.size(),
77+
dst_desc.device_id(),
78+
)?;
79+
}
80+
}
81+
}
82+
Ok(())
9883
}
9984
}
10085

101-
/// Copy a range of layers from a source to a destination using CUDA memcpy
102-
pub fn write_layers_to<'a, Source, Destination>(
103-
layer_range: Range<usize>,
104-
src: &'a Source,
105-
dst: &'a mut Destination,
86+
/// Copy a block from a source to a destination using CUDA memcpy
87+
pub fn write_blocks_to<Source, Destination>(
88+
src: &[Arc<Source>],
89+
dst: &mut [Destination],
10690
ctx: Arc<TransferContext>,
10791
notify: Option<String>,
10892
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
10993
where
11094
Source: BlockDataProvider,
11195
Destination: BlockDataProviderMut,
11296
{
113-
let src_data = src.block_data(private::PrivateToken);
114-
let dst_data = dst.block_data_mut(private::PrivateToken);
97+
if src.is_empty() || dst.is_empty() {
98+
return Ok(Box::new(std::future::ready(())));
99+
}
100+
assert_eq!(src.len(), dst.len());
115101

116102
let nixl_agent_arc = ctx.as_ref().nixl_agent();
117103
let nixl_agent = nixl_agent_arc
118104
.as_ref()
119105
.as_ref()
120106
.expect("NIXL agent not found");
121107

122-
let remote_worker_id = dst_data.worker_id.to_string();
123-
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
124-
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
125-
126-
// #[cfg(debug_assertions)]
127-
// {
128-
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
129-
// Destination::StorageType,
130-
// >>::write_to_strategy();
131-
// assert_eq!(strategy, expected_strategy);
132-
// }
133-
134-
for layer_idx in layer_range {
135-
for outer_idx in 0..src_data.num_outer_dims() {
136-
let src_view = src_data.layer_view(layer_idx, outer_idx)?;
137-
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
138-
139-
debug_assert_eq!(src_view.size(), dst_view.size());
140-
141-
let src_desc = src_view.as_nixl_descriptor();
142-
let dst_desc = dst_view.as_nixl_descriptor_mut();
143-
144-
unsafe {
145-
src_dl.add_desc(
146-
src_desc.as_ptr() as usize,
147-
src_desc.size(),
148-
src_desc.device_id(),
149-
)?;
150-
151-
dst_dl.add_desc(
152-
dst_desc.as_ptr() as usize,
153-
dst_desc.size(),
154-
dst_desc.device_id(),
155-
)?;
156-
}
157-
}
108+
let src_mem_type = src
109+
.first()
110+
.unwrap()
111+
.block_data(private::PrivateToken)
112+
.storage_type()
113+
.nixl_mem_type();
114+
let dst_mem_type = dst
115+
.first()
116+
.unwrap()
117+
.block_data(private::PrivateToken)
118+
.storage_type()
119+
.nixl_mem_type();
120+
121+
let mut src_dl = XferDescList::new(src_mem_type)?;
122+
let mut dst_dl = XferDescList::new(dst_mem_type)?;
123+
124+
for (src, dst) in src.iter().zip(dst.iter_mut()) {
125+
append_xfer_request(src, dst, &mut src_dl, &mut dst_dl)?;
158126
}
159127

128+
let xfer_req =
129+
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)?;
130+
160131
let mut xfer_args = OptArgs::new()?;
161132

162133
if let Some(notify) = notify {
163134
xfer_args.set_has_notification(true)?;
164135
xfer_args.set_notification_message(notify.as_bytes())?;
165136
}
166137

167-
let xfer_req = nixl_agent.create_xfer_req(
168-
XferOp::Write,
169-
&src_dl,
170-
&dst_dl,
171-
&remote_worker_id,
172-
Some(&xfer_args),
173-
)?;
174-
175138
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
176139

177140
Ok(Box::new(poll_fn(move |_cx| {
178141
let nixl_agent = nixl_agent_arc
179142
.as_ref()
180143
.as_ref()
181144
.expect("NIXL agent not found");
145+
146+
// The nixl agent returns true if the transfer is still in progress.
182147
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
183148
Poll::Ready(())
184149
} else {

lib/llm/src/block_manager/offload.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
129129
let device_clone = this.device.clone();
130130
let host_clone = this.host.clone();
131131
async_rt_handle.spawn(async move {
132-
OffloadManager::offload_worker(
132+
let res = OffloadManager::offload_worker(
133133
device_clone,
134134
host_clone,
135135
device_offload_rx,
@@ -138,8 +138,8 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
138138
MAX_OFFLOAD_STREAM_DEPTH,
139139
)),
140140
)
141-
.await
142-
.unwrap()
141+
.await;
142+
tracing::warn!("Offload worker finished: {:?}", res);
143143
});
144144

145145
let transfer_ctx = Arc::new(TransferContext::new(
@@ -152,7 +152,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
152152
let disk_clone = this.disk.clone();
153153
let transfer_ctx_clone = transfer_ctx.clone();
154154
async_rt_handle.spawn(async move {
155-
OffloadManager::offload_worker(
155+
let res = OffloadManager::offload_worker(
156156
host_clone,
157157
disk_clone,
158158
host_offload_rx,
@@ -161,38 +161,38 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
161161
MAX_OFFLOAD_STREAM_DEPTH,
162162
)),
163163
)
164-
.await
165-
.unwrap()
164+
.await;
165+
tracing::warn!("Offload worker finished: {:?}", res);
166166
});
167167

168168
// Host -> Device onboarding
169169
let host_clone = this.host.clone();
170170
let device_clone = this.device.clone();
171171
let transfer_ctx_clone = transfer_ctx.clone();
172172
async_rt_handle.spawn(async move {
173-
OffloadManager::onboard_worker(
173+
let res = OffloadManager::onboard_worker(
174174
host_clone,
175175
device_clone,
176176
host_onboard_rx,
177177
Arc::new(CudaTransferManager::new(transfer_ctx_clone, 16384)),
178178
)
179-
.await
180-
.unwrap()
179+
.await;
180+
tracing::warn!("Onboard worker finished: {:?}", res);
181181
});
182182

183183
// Disk -> Device onboarding
184184
let disk_clone = this.disk.clone();
185185
let device_clone = this.device.clone();
186186
let transfer_ctx_clone = transfer_ctx.clone();
187187
async_rt_handle.spawn(async move {
188-
OffloadManager::onboard_worker(
188+
let res = OffloadManager::onboard_worker(
189189
disk_clone,
190190
device_clone,
191191
disk_onboard_rx,
192192
Arc::new(DiskTransferManager::new(transfer_ctx_clone, 16384)),
193193
)
194-
.await
195-
.unwrap()
194+
.await;
195+
tracing::warn!("Onboard worker terminated: {:?}", res);
196196
});
197197

198198
Ok(this_clone)

0 commit comments

Comments
 (0)