Skip to content

Commit 880854b

Browse files
committed
add tests for scheduler
Signed-off-by: Tianer Zhou <ezhoureal@gmail.com>
1 parent 98708c4 commit 880854b

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed

lib/llm/src/kv_router/scheduler.rs

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,155 @@ impl WorkerSelector for DefaultWorkerSelector {
349349
})
350350
}
351351
}
352+
353+
#[cfg(test)]
354+
mod tests {
355+
use super::*;
356+
use crate::kv_router::indexer::OverlapScores;
357+
use std::collections::HashMap;
358+
359+
// Helper to create a worker endpoint
360+
fn create_endpoint(
361+
worker_id: i64,
362+
gpu_cache_usage_perc: f32,
363+
num_requests_waiting: u64,
364+
) -> Endpoint {
365+
Endpoint {
366+
name: format!("worker-{}", worker_id),
367+
subject: format!("worker-subject-{:x}", worker_id),
368+
data: ForwardPassMetrics {
369+
gpu_cache_usage_perc,
370+
num_requests_waiting,
371+
// Other fields can be default initialized for this test
372+
..Default::default()
373+
},
374+
}
375+
}
376+
377+
// Helper to create ProcessedEndpoints
378+
fn create_workers(workers: Vec<(i64, f32, u64)>) -> ProcessedEndpoints {
379+
let mut endpoints = HashMap::new();
380+
for (id, usage, waiting) in workers {
381+
endpoints.insert(id, create_endpoint(id, usage, waiting));
382+
}
383+
ProcessedEndpoints { endpoints, load_avg: 0.0, load_std: 0.0 }
384+
}
385+
386+
// Helper to create a scheduling request
387+
fn create_request(overlaps: Vec<(i64, u32)>, isl_tokens: usize) -> SchedulingRequest {
388+
SchedulingRequest {
389+
isl_tokens,
390+
overlap: OverlapScores {
391+
scores: overlaps.into_iter().collect(),
392+
frequencies: vec![],
393+
},
394+
resp_tx: tokio::sync::oneshot::channel().0,
395+
}
396+
}
397+
398+
#[test]
399+
fn test_select_worker_basic() {
400+
// Setup workers
401+
let workers = create_workers(vec![
402+
(1, 0.50, 1), // worker_id, gpu_usage%, waiting_requests
403+
(2, 0.80, 0),
404+
]);
405+
406+
// Setup request: 100 tokens, block_size=20 (5 blocks)
407+
let request = create_request(vec![(1, 3), (2, 4)], 100);
408+
let selector = DefaultWorkerSelector::new(None);
409+
let block_size = 20;
410+
411+
// Execute selection
412+
let result = selector
413+
.select_worker(&workers, &request, block_size)
414+
.expect("Should select a worker");
415+
// Worker 2 should win because:
416+
// Worker1: 2.0 * 0.600 - 1.0 * 0.500 - 1.0 * 1.000 = -0.3
417+
// Worker2: 2.0 * 0.800 - 1.0 * 0.800 - 1.0 * 0.000 = 0.8
418+
assert_eq!(result.worker_id, 2);
419+
assert_eq!(result.required_blocks, 5); // 100 tokens / 20 block_size
420+
assert_eq!(result.overlap_blocks, 4);
421+
}
422+
423+
#[test]
424+
fn test_no_endpoints() {
425+
let workers = create_workers(vec![]);
426+
let request = create_request(vec![], 100);
427+
let selector = DefaultWorkerSelector::new(None);
428+
let block_size = 20;
429+
430+
match selector.select_worker(&workers, &request, block_size) {
431+
Err(KvSchedulerError::NoEndpoints) => {} // Expected
432+
_ => panic!("Should return NoEndpoints error"),
433+
}
434+
}
435+
436+
#[test]
437+
fn test_no_overlap_scores() {
438+
// Workers exist but request has no overlap scores
439+
let workers = create_workers(vec![(1, 50.0, 1)]);
440+
let request = create_request(vec![], 100); // No overlaps
441+
let selector = DefaultWorkerSelector::new(None);
442+
let block_size = 20;
443+
444+
let result = selector
445+
.select_worker(&workers, &request, block_size)
446+
.expect("Should fallback to selecting worker");
447+
448+
// Worker1 should be selected with 0 overlap
449+
assert_eq!(result.worker_id, 1);
450+
assert_eq!(result.overlap_blocks, 0);
451+
}
452+
453+
#[test]
454+
fn test_tie_breaker_randomness() {
455+
// Two identical workers
456+
let workers = create_workers(vec![
457+
(1, 50.0, 1),
458+
(2, 50.0, 1),
459+
]);
460+
461+
// Both have same overlap
462+
let request = create_request(vec![(1, 3), (2, 3)], 100);
463+
let selector = DefaultWorkerSelector::new(None);
464+
let block_size = 20;
465+
466+
// Run multiple times to verify randomness
467+
let mut results = Vec::new();
468+
for _ in 0..10 {
469+
let result = selector
470+
.select_worker(&workers, &request, block_size)
471+
.expect("Should select worker");
472+
results.push(result.worker_id);
473+
}
474+
println!("{:?}", results);
475+
// Should have selected both workers at least once
476+
assert!(results.contains(&1));
477+
assert!(results.contains(&2));
478+
}
479+
480+
#[test]
481+
fn test_custom_weights() {
482+
// Setup workers
483+
let workers = create_workers(vec![
484+
(1, 0.50, 1),
485+
(2, 0.80, 0),
486+
]);
487+
488+
// Custom config with high priority on GPU usage
489+
let config = KvRouterConfig {
490+
gpu_cache_usage_weight: 10.0, // Very high weight
491+
..Default::default()
492+
};
493+
let selector = DefaultWorkerSelector::new(Some(config));
494+
let request = create_request(vec![(1, 3), (2, 4)], 100);
495+
let block_size = 20;
496+
497+
let result = selector
498+
.select_worker(&workers, &request, block_size)
499+
.expect("Should select worker");
500+
501+
assert_eq!(result.worker_id, 1);
502+
}
503+
}

0 commit comments

Comments
 (0)