Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ const CPU_CACHE_OVERRIDE: &str = "DYNAMO_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS";
const DISK_CACHE: &str = "DYNAMO_KVBM_DISK_CACHE_GB";
const DISK_CACHE_OVERRIDE: &str = "DYNAMO_KVBM_DISK_CACHE_OVERRIDE_NUM_BLOCKS";

const LEADER_WORKER_INIT_TIMEOUT_SECS: &str = "DYNAMO_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS";
const DEFAULT_INIT_TIMEOUT_SECS: u64 = 120;

fn compute_num_blocks(cache_size_key: &str, override_key: &str, bytes_per_block: usize) -> usize {
if let Ok(override_num_blocks) = std::env::var(override_key) {
override_num_blocks.parse::<usize>().unwrap_or(0)
Expand All @@ -25,6 +28,13 @@ fn compute_num_blocks(cache_size_key: &str, override_key: &str, bytes_per_block:
}
}

fn get_leader_init_timeout_secs(override_key: &str) -> u64 {
std::env::var(override_key)
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_INIT_TIMEOUT_SECS)
}

#[pyclass]
#[derive(Clone, Dissolve)]
pub struct KvbmLeader {
Expand All @@ -41,12 +51,14 @@ impl KvbmLeader {
let num_disk_blocks = compute_num_blocks(DISK_CACHE, DISK_CACHE_OVERRIDE, bytes_per_block);

let barrier_id = get_barrier_id();
let leader_init_timeout_sec: u64 = get_leader_init_timeout_secs(LEADER_WORKER_INIT_TIMEOUT_SECS);

let config = KvbmLeaderConfig::builder()
.barrier_id(barrier_id)
.num_host_blocks(num_host_blocks)
.num_disk_blocks(num_disk_blocks)
.world_size(world_size)
.leader_init_timeout_secs(leader_init_timeout_sec)
.build()
.map_err(to_pyerr)?;

Expand Down
11 changes: 7 additions & 4 deletions lib/llm/src/block_manager/distributed/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ use std::time::Duration;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;

const INIT_TIMEOUT_SECS: u64 = 120;

/// Data that is sent to workers over ETCD to establish a ZMQ connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvbmLeaderData {
Expand All @@ -43,6 +41,10 @@ pub struct KvbmLeaderConfig {
/// The world size.
#[builder(default = "1")]
world_size: usize,

/// The leader-worker init connection timeout seconds.
#[builder(default = "120")]
leader_init_timeout_secs: u64,
}

impl KvbmLeaderConfig {
Expand Down Expand Up @@ -85,10 +87,11 @@ impl KvbmLeader {
});

// Build our leader barrier and publish the data.
// TODO: Use a separate timeout parameter from the ZMQ connection timeout
let leader_barrier: LeaderBarrier<KvbmLeaderData, ()> = LeaderBarrier::new(
config.barrier_id.clone(),
config.world_size,
Some(Duration::from_secs(INIT_TIMEOUT_SECS)),
Some(Duration::from_secs(config.leader_init_timeout_secs)),
);

let worker_data = leader_barrier
Expand All @@ -105,7 +108,7 @@ impl KvbmLeader {
let zmq_leader = ZmqActiveMessageLeader::new(
leader_sockets,
config.world_size,
Duration::from_secs(INIT_TIMEOUT_SECS),
Duration::from_secs(config.leader_init_timeout_secs),
cancel_token.clone(),
)
.await?;
Expand Down
Loading