Skip to content

Commit e80b982

Browse files
committed
block-manager
It's a pain because it needs NIXL so I can't run it. All manual.
1 parent cc166b6 commit e80b982

File tree

14 files changed

+98
-96
lines changed

14 files changed

+98
-96
lines changed

lib/llm/src/block_manager/block.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -657,10 +657,10 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> std::fmt::Debug for Muta
657657
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Drop for MutableBlock<S, L, M> {
658658
fn drop(&mut self) {
659659
tracing::debug!("drop: {:?}", self);
660-
if let Some(block) = self.block.take() {
661-
if self.return_tx.send(block).is_err() {
662-
tracing::warn!("block pool shutdown before block was returned");
663-
}
660+
if let Some(block) = self.block.take()
661+
&& self.return_tx.send(block).is_err()
662+
{
663+
tracing::warn!("block pool shutdown before block was returned");
664664
}
665665
}
666666
}

lib/llm/src/block_manager/block/data.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ pub trait BlockDataExt<S: Storage>: Send + Sync + 'static + std::fmt::Debug {
4646
fn is_local_mut(&mut self) -> Option<&mut dyn BlockDataViews<S>>;
4747

4848
/// Get a read-only view of this block's storage for a layer
49-
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
49+
fn layer_view(
50+
&self,
51+
layer_idx: usize,
52+
outer_idx: usize,
53+
) -> BlockResult<view::LayerView<'_, S>> {
5054
match self.is_local() {
5155
Some(views) => views.local_layer_view(layer_idx, outer_idx),
5256
None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks),
@@ -58,23 +62,23 @@ pub trait BlockDataExt<S: Storage>: Send + Sync + 'static + std::fmt::Debug {
5862
&mut self,
5963
layer_idx: usize,
6064
outer_idx: usize,
61-
) -> BlockResult<view::LayerViewMut<S>> {
65+
) -> BlockResult<view::LayerViewMut<'_, S>> {
6266
match self.is_local_mut() {
6367
Some(views) => views.local_layer_view_mut(layer_idx, outer_idx),
6468
None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks),
6569
}
6670
}
6771

6872
/// Get a read-only view of this block's storage
69-
fn block_view(&self) -> BlockResult<view::BlockView<S>> {
73+
fn block_view(&self) -> BlockResult<view::BlockView<'_, S>> {
7074
match self.is_local() {
7175
Some(views) => views.local_block_view(),
7276
None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks),
7377
}
7478
}
7579

7680
/// Get a mutable view of this block's storage
77-
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
81+
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<'_, S>> {
7882
match self.is_local_mut() {
7983
Some(views) => views.local_block_view_mut(),
8084
None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks),
@@ -88,20 +92,20 @@ pub trait BlockDataViews<S: Storage> {
8892
&self,
8993
layer_idx: usize,
9094
outer_idx: usize,
91-
) -> BlockResult<view::LayerView<S>>;
95+
) -> BlockResult<view::LayerView<'_, S>>;
9296

9397
/// Get a mutable view of this block's storage for a layer
9498
fn local_layer_view_mut(
9599
&mut self,
96100
layer_idx: usize,
97101
outer_idx: usize,
98-
) -> BlockResult<view::LayerViewMut<S>>;
102+
) -> BlockResult<view::LayerViewMut<'_, S>>;
99103

100104
/// Get a read-only view of this block's storage
101-
fn local_block_view(&self) -> BlockResult<view::BlockView<S>>;
105+
fn local_block_view(&self) -> BlockResult<view::BlockView<'_, S>>;
102106

103107
/// Get a mutable view of this block's storage
104-
fn local_block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>>;
108+
fn local_block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<'_, S>>;
105109
}
106110

107111
pub trait BlockDataProvider: StorageTypeProvider {

lib/llm/src/block_manager/block/data/local.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ impl<S: Storage> BlockDataViews<S> for LocalBlockData<S> {
101101
&self,
102102
layer_idx: usize,
103103
outer_idx: usize,
104-
) -> BlockResult<view::LayerView<S>> {
104+
) -> BlockResult<view::LayerView<'_, S>> {
105105
let mr = self
106106
.layout
107107
.memory_region(self.block_idx, layer_idx, outer_idx)?;
@@ -113,14 +113,14 @@ impl<S: Storage> BlockDataViews<S> for LocalBlockData<S> {
113113
&mut self,
114114
layer_idx: usize,
115115
outer_idx: usize,
116-
) -> BlockResult<view::LayerViewMut<S>> {
116+
) -> BlockResult<view::LayerViewMut<'_, S>> {
117117
let mr = self
118118
.layout
119119
.memory_region(self.block_idx, layer_idx, outer_idx)?;
120120
unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size(), mr.storage_type()) }
121121
}
122122

123-
fn local_block_view(&self) -> BlockResult<view::BlockView<S>> {
123+
fn local_block_view(&self) -> BlockResult<view::BlockView<'_, S>> {
124124
if self.is_fully_contiguous() {
125125
let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
126126
let offset = mr.addr();
@@ -134,7 +134,7 @@ impl<S: Storage> BlockDataViews<S> for LocalBlockData<S> {
134134
}
135135
}
136136

137-
fn local_block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
137+
fn local_block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<'_, S>> {
138138
if self.is_fully_contiguous() {
139139
let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
140140
let offset = mr.addr();

lib/llm/src/block_manager/block/registry.rs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -109,19 +109,19 @@ impl BlockRegistry {
109109
{
110110
let mut blocks = blocks.lock().unwrap();
111111

112-
if let Some(handle) = blocks.get(&sequence_hash) {
113-
if handle.upgrade().is_none() {
114-
blocks.remove(&sequence_hash);
115-
}
112+
if let Some(handle) = blocks.get(&sequence_hash)
113+
&& handle.upgrade().is_none()
114+
{
115+
blocks.remove(&sequence_hash);
116116
}
117117
}
118118

119119
let mut global_registry = global_registry.lock().unwrap();
120120

121-
if let Some(entry) = global_registry.get(&sequence_hash) {
122-
if entry.upgrade().is_none() {
123-
global_registry.remove(&sequence_hash);
124-
}
121+
if let Some(entry) = global_registry.get(&sequence_hash)
122+
&& entry.upgrade().is_none()
123+
{
124+
global_registry.remove(&sequence_hash);
125125
}
126126
}
127127
});
@@ -136,10 +136,10 @@ impl BlockRegistry {
136136

137137
pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool {
138138
let blocks = self.blocks.lock().unwrap();
139-
if let Some(handle) = blocks.get(&sequence_hash) {
140-
if let Some(_handle) = handle.upgrade() {
141-
return true;
142-
}
139+
if let Some(handle) = blocks.get(&sequence_hash)
140+
&& let Some(_handle) = handle.upgrade()
141+
{
142+
return true;
143143
}
144144
false
145145
}
@@ -161,12 +161,12 @@ impl BlockRegistry {
161161
let mut blocks = self.blocks.lock().unwrap();
162162

163163
// If an identical block already exists in this pool, return an error.
164-
if let Some(handle) = blocks.get(&sequence_hash) {
165-
if let Some(_handle) = handle.upgrade() {
166-
return Err(BlockRegistrationError::BlockAlreadyRegistered(
167-
sequence_hash,
168-
));
169-
}
164+
if let Some(handle) = blocks.get(&sequence_hash)
165+
&& let Some(_handle) = handle.upgrade()
166+
{
167+
return Err(BlockRegistrationError::BlockAlreadyRegistered(
168+
sequence_hash,
169+
));
170170
}
171171

172172
let mut publish_handle = None;
@@ -179,10 +179,10 @@ impl BlockRegistry {
179179
let mut global_registry = self.global_registry.lock().unwrap();
180180

181181
// If an identical block exists in other pool, use the same registration handle.
182-
if let Some(handle) = global_registry.get(&sequence_hash) {
183-
if let Some(handle) = handle.upgrade() {
184-
break 'reg_block handle;
185-
}
182+
if let Some(handle) = global_registry.get(&sequence_hash)
183+
&& let Some(handle) = handle.upgrade()
184+
{
185+
break 'reg_block handle;
186186
}
187187

188188
// Otherwise, create a new registration handle.

lib/llm/src/block_manager/block/transfer/context.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ impl TransferContext {
107107
impl Drop for TransferContext {
108108
fn drop(&mut self) {
109109
self.cancel_token.cancel();
110-
if let Some(handle) = self.cuda_event_worker.take() {
111-
if let Err(e) = handle.join() {
112-
tracing::error!("Error joining CUDA event worker: {:?}", e);
113-
}
110+
if let Some(handle) = self.cuda_event_worker.take()
111+
&& let Err(e) = handle.join()
112+
{
113+
tracing::error!("Error joining CUDA event worker: {:?}", e);
114114
}
115115
}
116116
}

lib/llm/src/block_manager/block/transfer/cuda.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,11 @@ unsafe fn cuda_memcpy_h2d(
177177
"Source and destination device memory regions must not overlap for D2D copy"
178178
);
179179

180-
let src_slice = std::slice::from_raw_parts(src_ptr, size);
181-
cuda_result::memcpy_htod_async(dst_ptr as u64, src_slice, stream.cu_stream())
182-
.map_err(|e| TransferError::ExecutionError(format!("CUDA H2D memcpy failed: {}", e)))?;
180+
unsafe {
181+
let src_slice = std::slice::from_raw_parts(src_ptr, size);
182+
cuda_result::memcpy_htod_async(dst_ptr as u64, src_slice, stream.cu_stream())
183+
.map_err(|e| TransferError::ExecutionError(format!("CUDA H2D memcpy failed: {}", e)))?
184+
};
183185
Ok(())
184186
}
185187

@@ -199,9 +201,11 @@ unsafe fn cuda_memcpy_d2h(
199201
"Source and destination device memory regions must not overlap for D2D copy"
200202
);
201203

202-
let dst_slice = std::slice::from_raw_parts_mut(dst_ptr, size);
203-
cuda_result::memcpy_dtoh_async(dst_slice, src_ptr as u64, stream.cu_stream())
204-
.map_err(|e| TransferError::ExecutionError(format!("CUDA D2H memcpy failed: {}", e)))?;
204+
unsafe {
205+
let dst_slice = std::slice::from_raw_parts_mut(dst_ptr, size);
206+
cuda_result::memcpy_dtoh_async(dst_slice, src_ptr as u64, stream.cu_stream())
207+
.map_err(|e| TransferError::ExecutionError(format!("CUDA D2H memcpy failed: {}", e)))?;
208+
}
205209
Ok(())
206210
}
207211

lib/llm/src/block_manager/block/transfer/memcpy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,5 @@ unsafe fn memcpy(src_ptr: *const u8, dst_ptr: *mut u8, size: usize) {
7878
"Source and destination memory regions must not overlap for copy_nonoverlapping"
7979
);
8080

81-
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, size);
81+
unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, size) };
8282
}

lib/llm/src/block_manager/connector/protocol.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ impl TransferCompletionHandle for ImmediateTransferCompletionHandle {
256256
let mut guard = self.completion_tx.lock().unwrap();
257257
guard.take()
258258
};
259-
if let Some(completion_tx) = completion_tx {
260-
if completion_tx
259+
if let Some(completion_tx) = completion_tx
260+
&& completion_tx
261261
.send(TransferToSchedulerMessage::ImmediateResult(
262262
ImmediateTransferResult {
263263
request_id: self.request_id.clone(),
@@ -267,9 +267,8 @@ impl TransferCompletionHandle for ImmediateTransferCompletionHandle {
267267
))
268268
.await
269269
.is_err()
270-
{
271-
tracing::error!(DISCONNECTED_WARNING);
272-
}
270+
{
271+
tracing::error!(DISCONNECTED_WARNING);
273272
}
274273
}
275274
}

lib/llm/src/block_manager/distributed/transfer.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,13 @@ impl BlockTransferHandler {
113113
.collect();
114114

115115
// Perform the transfer, and return the notifying channel.
116-
let channel = match sources.write_to(&mut targets, self.context.clone()) {
116+
match sources.write_to(&mut targets, self.context.clone()) {
117117
Ok(channel) => Ok(channel),
118118
Err(e) => {
119119
tracing::error!("Failed to write to blocks: {:?}", e);
120120
Err(e.into())
121121
}
122-
};
123-
124-
channel
122+
}
125123
}
126124

127125
pub async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {

lib/llm/src/block_manager/offload.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -321,17 +321,16 @@ impl<Locality: LocalityProvider + 'static, Metadata: BlockMetadata>
321321
if let Ok(blocks) = target_pool
322322
.match_sequence_hashes(vec![request.sequence_hash].as_slice())
323323
.await
324+
&& !blocks.is_empty()
324325
{
325-
if !blocks.is_empty() {
326-
continue;
327-
}
326+
continue;
328327
}
329328

330329
let target_block = 'target_block: {
331-
if let Ok(blocks) = target_pool.allocate_blocks(1).await {
332-
if let Some(block) = blocks.into_iter().next() {
333-
break 'target_block Some(block);
334-
}
330+
if let Ok(blocks) = target_pool.allocate_blocks(1).await
331+
&& let Some(block) = blocks.into_iter().next()
332+
{
333+
break 'target_block Some(block);
335334
}
336335

337336
tracing::warn!(
@@ -507,14 +506,14 @@ impl<Locality: LocalityProvider + 'static, Metadata: BlockMetadata>
507506
}
508507
}
509508

510-
if let Some(targets) = targets.as_ref() {
511-
if targets.len() != blocks.len() {
512-
tx.send(Err(BlockPoolError::BlockError(BlockError::Other(
513-
anyhow::anyhow!("Number of targets does not match number of blocks."),
514-
))))
515-
.unwrap();
516-
return rx;
517-
}
509+
if let Some(targets) = targets.as_ref()
510+
&& targets.len() != blocks.len()
511+
{
512+
tx.send(Err(BlockPoolError::BlockError(BlockError::Other(
513+
anyhow::anyhow!("Number of targets does not match number of blocks."),
514+
))))
515+
.unwrap();
516+
return rx;
518517
}
519518

520519
if blocks.is_empty() {

0 commit comments

Comments
 (0)