Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use slot::{ConnectorSlotManager, SlotError, SlotManager, SlotState};
use crate::llm::block_manager::BlockManager as PyBlockManager;
use crate::llm::block_manager::{
distributed::KvbmLeader as PyKvbmLeader, vllm::KvbmRequest, VllmBlockManager,
vllm::connector::leader::slot::VllmConnectorSlot,
};
use crate::DistributedRuntime as PyDistributedRuntime;

Expand Down Expand Up @@ -139,9 +140,24 @@ impl Leader for KvConnectorLeader {
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;

if slot.state() == SlotState::Prefilling {
tracing::warn!("slot is in the Prefilled state; this seems like we need to reset the slot and start over");
slot.reset();
debug_assert!(
slot.state() != SlotState::Prefilling && slot.state() != SlotState::Decoding,
"slot is in the Prefilled state or Decoding; shouldn't happen"
);

if slot.state() == SlotState::SkippedPrefill || slot.state() == SlotState::SkippedDecode {
tracing::warn!("slot is in the SkippedPrefill or SkippedDecode state; will resume from skipped and return early");
match slot.state() {
SlotState::SkippedPrefill => {
slot.mark_as_prefilling(self.iteration_counter)?;
return Ok((0, false));
}
SlotState::SkippedDecode => {
slot.mark_as_decoding(self.iteration_counter)?;
return Ok((0, false));
}
_ => unreachable!("slot is not in the SkippedPrefill or SkippedDecode state"),
}
}

// early exit if we cannot match full block
Expand Down Expand Up @@ -251,6 +267,11 @@ impl Leader for KvConnectorLeader {
tracing::debug!("adding {} pending onboarding operations", pending_ops.len());
md.add_operations(pending_ops);
}

assert!(
inflight_requests.remove(request_id),
"request_id {request_id} not found in inflight_requests: "
);
}

// vLLM provides us with "new_requests" which are "new" after onboarding, but not before or during.
Expand Down Expand Up @@ -328,14 +349,13 @@ impl Leader for KvConnectorLeader {

// note, we can not trigger onboarding here -- perhaps we are supposed to or perhaps will get another
// pass at `get_num_new_matched_tokens` or `update_state_after_alloc`.
} else {
// note: evicition might trigger this assert
assert!(
inflight_requests.remove(request_id),
"request_id {request_id} not found in inflight_requests: "
);
}

assert!(
inflight_requests.remove(request_id),
"request_id {request_id} not found in inflight_requests: "
);

let shared_slot = self.slot_manager.get_slot(request_id)?;
let mut slot = shared_slot
.lock()
Expand Down Expand Up @@ -363,6 +383,20 @@ impl Leader for KvConnectorLeader {
}
}

for unscheduled_req in inflight_requests.iter() {
let shared_slot = self.slot_manager.get_slot(unscheduled_req)?;
let mut slot_guard = shared_slot
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;

let slot = slot_guard
.as_any_mut()
.downcast_mut::<VllmConnectorSlot>()
.ok_or_else(|| anyhow::anyhow!("Expected VllmConnectorSlot, got different type"))?;

slot.mark_as_skipped()?;
}

tracing::debug!("metadata: {md:#?}");
serde_json::to_vec(&md)
.map_err(|e| anyhow::anyhow!("Failed to serialize connector metadata: {}", e))
Expand All @@ -374,6 +408,13 @@ impl Leader for KvConnectorLeader {
block_ids: Vec<BlockId>,
) -> anyhow::Result<bool> {
tracing::debug!("Request finished: {request_id}; block_ids: {block_ids:?}");

if !self.slot_manager.has_slot(&request_id) {
tracing::warn!("request_finished called for request_id: {request_id} but slot is not found");
self.inflight_requests.remove(&request_id);
return Ok(false);
}

// grab the slot
let shared_slot = self.slot_manager.get_slot(&request_id)?;

Expand All @@ -388,11 +429,14 @@ impl Leader for KvConnectorLeader {
// we would like to inform it to shutdown, then have it signal to the work that is officially gone,
// then we can remove the slot and trigger the worker to clean up as well.

// remove the request from the inflight requests
self.inflight_requests.remove(&request_id);

// remove it from the manager as we will never use it again
self.slot_manager.remove_slot(&request_id)?;

// if the slot has finished, we can return false to vllm, indicating all gpu blocks are free to be reused
// otherwise, we return false, which means there are still outstanding operations on gpu blocks which
// otherwise, we return true, which means there are still outstanding operations on gpu blocks which
// must be awaited before the gpu blocks can be reused. if we return true, then it is the worker side
// of the connector api which will be used to inform vllm that the request is finished.
if let SlotState::Finished = slot.state() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::any::Any;

use dynamo_llm::{
block_manager::{
block::{locality::LocalityProvider, BlockMetadata},
Expand Down Expand Up @@ -63,10 +65,16 @@ pub enum SlotState {
/// The slot is actively prefilling the sequence.
Prefilling,

/// The slot is skipped prefill.
SkippedPrefill,

/// The slot is actively participating in a forward pass which will result in one more more tokens
/// to be applied to the sequence.
Decoding,

/// The slot is skipped decoding.
SkippedDecode,

/// The slot is marked as finished, but not all resources have been released.
Finishing,

Expand Down Expand Up @@ -98,6 +106,9 @@ pub trait Slot: std::fmt::Debug {

fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;

fn mark_as_prefilling(&mut self, iteration: u64) -> Result<(), SlotError>;
fn mark_as_decoding(&mut self, iteration: u64) -> Result<(), SlotError>;

fn mark_as_finished(&mut self, iteration: u64) -> Result<(), SlotError>;

/// The number of device blocks that have been allocated to the slot.
Expand Down Expand Up @@ -131,6 +142,9 @@ pub trait Slot: std::fmt::Debug {

/// Reset the slot.
fn reset(&mut self);

/// Get a mutable reference to the slot as a dynamic Any.
fn as_any_mut(&mut self) -> &mut dyn Any;
}

pub trait ExternallyManagedDeviceSlot: Slot {
Expand Down Expand Up @@ -329,6 +343,41 @@ impl VllmConnectorSlot {
tokens_cached_from_disk: 0,
}
}

fn mark_as_skipped_prefill(&mut self) -> Result<(), SlotError> {
if self.state != SlotState::Prefilling {
return Err(SlotError::InvalidState(format!(
"cannot mark slot as skipped prefill in state {:?}",
self.state
)));
}
self.state = SlotState::SkippedPrefill;
Ok(())
}

fn mark_as_skipped_decode(&mut self) -> Result<(), SlotError> {
if self.state != SlotState::Decoding {
return Err(SlotError::InvalidState(format!(
"cannot mark slot as skipped decode in state {:?}",
self.state
)));
}
self.state = SlotState::SkippedDecode;
Ok(())
}

pub fn mark_as_skipped(&mut self) -> Result<(), SlotError> {
match self.state {
SlotState::Prefilling => self.mark_as_skipped_prefill(),
SlotState::Decoding => self.mark_as_skipped_decode(),
SlotState::SkippedPrefill => Ok(()), // already skipped
SlotState::SkippedDecode => Ok(()), // already skipped
_ => {
tracing::warn!("slot is in the {:?} state; will not explicitly mark as skipped, request_id: {}", self.state, self.request_id);
Ok(())
},
}
}
}

impl std::fmt::Debug for VllmConnectorSlot {
Expand Down Expand Up @@ -370,6 +419,16 @@ impl Slot for VllmConnectorSlot {
self.state = SlotState::Initialized;
}

fn mark_as_prefilling(&mut self, _iteration: u64) -> Result<(), SlotError> {
self.state = SlotState::Prefilling;
Ok(())
}

fn mark_as_decoding(&mut self, _iteration: u64) -> Result<(), SlotError> {
self.state = SlotState::Decoding;
Ok(())
}

fn record_cached_device_tokens(&mut self, num_tokens: usize) {
self.tokens_cached_from_device = num_tokens;
tracing::debug!("recording {} cached device tokens", num_tokens,);
Expand Down Expand Up @@ -542,7 +601,7 @@ impl Slot for VllmConnectorSlot {

if !matches!(self.state(), SlotState::Initialized | SlotState::Preempted) {
return Err(SlotError::InvalidOperation(format!(
"slot must be in the NotScheduled state to acquire local matches; got {:?}",
"slot must be in the NotScheduled or Preempted state to acquire local matches; got {:?}",
self.state()
)));
}
Expand Down Expand Up @@ -729,6 +788,10 @@ impl Slot for VllmConnectorSlot {

Ok(())
}

fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}

impl ExternallyManagedDeviceSlot for VllmConnectorSlot {
Expand Down
38 changes: 28 additions & 10 deletions lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ impl Worker for KvConnectorWorker {
let mut is_finished_offloading = HashSet::new();
let mut is_finished_onboarding = HashSet::new();



// before we process the maybes, add any newly annotated finished requests
// to the maybe finished set
for request_id in finished_requests {
Expand All @@ -317,11 +319,15 @@ impl Worker for KvConnectorWorker {

// visit each request slot in the maybe finished set
for request_id in self.maybe_finished_offloading.iter() {
if self.connector.is_complete(request_id) {
tracing::debug!(request_id, "request slot is finished");
is_finished_offloading.insert(request_id.clone());
if self.connector.has_slot(request_id) {
if self.connector.is_complete(request_id) {
tracing::debug!(request_id, "request slot is finished");
is_finished_offloading.insert(request_id.clone());
} else {
tracing::debug!(request_id, "request slot is not finished");
}
} else {
tracing::debug!(request_id, "request slot is not finished");
tracing::debug!(request_id, "request slot is not found - likely aborted");
}
}

Expand All @@ -331,23 +337,35 @@ impl Worker for KvConnectorWorker {
self.maybe_finished_offloading.remove(request_id);

// currently chomping the error as the engine is closed and we are shutting down
self.connector.remove_slot(request_id);
if self.connector.has_slot(request_id) {
self.connector.remove_slot(request_id);
} else {
tracing::debug!(request_id, "is_finished_offloading: request slot is not found - likely aborted, removing from is finished offloading set");
}
}

// visit each request slot in the maybe finished set to see if it is finished
for request_id in self.maybe_finished_onboarding.iter() {
if self.connector.is_complete(request_id) {
tracing::debug!(request_id, "request slot is finished");
is_finished_onboarding.insert(request_id.clone());
if self.connector.has_slot(request_id) {
if self.connector.is_complete(request_id) {
tracing::debug!(request_id, "request slot is finished");
is_finished_onboarding.insert(request_id.clone());
} else {
tracing::debug!(request_id, "request slot is not finished");
}
} else {
tracing::debug!(request_id, "request slot is not finished");
tracing::debug!(request_id, "request slot is not found - likely aborted");
}
}

// remove the finished requests from the maybe finished set
for request_id in &is_finished_onboarding {
self.maybe_finished_onboarding.remove(request_id);
self.connector.remove_slot(request_id);
if self.connector.has_slot(request_id) {
self.connector.remove_slot(request_id);
} else {
tracing::debug!(request_id, "is_finished_onboarding: request slot is not found - likely aborted, removing from is finished onboarding set");
}
}

(is_finished_offloading, is_finished_onboarding)
Expand Down
11 changes: 9 additions & 2 deletions lib/llm/src/block_manager/connector/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,15 @@ impl WorkerSchedulerClient {
}

pub fn is_complete(&self, request_id: &str) -> bool {
let slot = self.slots.get(request_id).expect("slot does not exist");
slot.completed.load(Ordering::Relaxed) == slot.operations.len() as u64
match self.slots.get(request_id) {
Some(slot) => {
slot.completed.load(Ordering::Relaxed) == slot.operations.len() as u64
}
None => {
tracing::debug!(request_id, "slot not found - likely aborted");
true
}
}
}
}

Expand Down
Loading