@@ -349,3 +349,155 @@ impl WorkerSelector for DefaultWorkerSelector {
349349 } )
350350 }
351351}
352+
353+ #[ cfg( test) ]
354+ mod tests {
355+ use super :: * ;
356+ use crate :: kv_router:: indexer:: OverlapScores ;
357+ use std:: collections:: HashMap ;
358+
359+ // Helper to create a worker endpoint
360+ fn create_endpoint (
361+ worker_id : i64 ,
362+ gpu_cache_usage_perc : f32 ,
363+ num_requests_waiting : u64 ,
364+ ) -> Endpoint {
365+ Endpoint {
366+ name : format ! ( "worker-{}" , worker_id) ,
367+ subject : format ! ( "worker-subject-{:x}" , worker_id) ,
368+ data : ForwardPassMetrics {
369+ gpu_cache_usage_perc,
370+ num_requests_waiting,
371+ // Other fields can be default initialized for this test
372+ ..Default :: default ( )
373+ } ,
374+ }
375+ }
376+
377+ // Helper to create ProcessedEndpoints
378+ fn create_workers ( workers : Vec < ( i64 , f32 , u64 ) > ) -> ProcessedEndpoints {
379+ let mut endpoints = HashMap :: new ( ) ;
380+ for ( id, usage, waiting) in workers {
381+ endpoints. insert ( id, create_endpoint ( id, usage, waiting) ) ;
382+ }
383+ ProcessedEndpoints { endpoints, load_avg : 0.0 , load_std : 0.0 }
384+ }
385+
386+ // Helper to create a scheduling request
387+ fn create_request ( overlaps : Vec < ( i64 , u32 ) > , isl_tokens : usize ) -> SchedulingRequest {
388+ SchedulingRequest {
389+ isl_tokens,
390+ overlap : OverlapScores {
391+ scores : overlaps. into_iter ( ) . collect ( ) ,
392+ frequencies : vec ! [ ] ,
393+ } ,
394+ resp_tx : tokio:: sync:: oneshot:: channel ( ) . 0 ,
395+ }
396+ }
397+
398+ #[ test]
399+ fn test_select_worker_basic ( ) {
400+ // Setup workers
401+ let workers = create_workers ( vec ! [
402+ ( 1 , 0.50 , 1 ) , // worker_id, gpu_usage%, waiting_requests
403+ ( 2 , 0.80 , 0 ) ,
404+ ] ) ;
405+
406+ // Setup request: 100 tokens, block_size=20 (5 blocks)
407+ let request = create_request ( vec ! [ ( 1 , 3 ) , ( 2 , 4 ) ] , 100 ) ;
408+ let selector = DefaultWorkerSelector :: new ( None ) ;
409+ let block_size = 20 ;
410+
411+ // Execute selection
412+ let result = selector
413+ . select_worker ( & workers, & request, block_size)
414+ . expect ( "Should select a worker" ) ;
415+ // Worker 2 should win because:
416+ // Worker1: 2.0 * 0.600 - 1.0 * 0.500 - 1.0 * 1.000 = -0.3
417+ // Worker2: 2.0 * 0.800 - 1.0 * 0.800 - 1.0 * 0.000 = 0.8
418+ assert_eq ! ( result. worker_id, 2 ) ;
419+ assert_eq ! ( result. required_blocks, 5 ) ; // 100 tokens / 20 block_size
420+ assert_eq ! ( result. overlap_blocks, 4 ) ;
421+ }
422+
423+ #[ test]
424+ fn test_no_endpoints ( ) {
425+ let workers = create_workers ( vec ! [ ] ) ;
426+ let request = create_request ( vec ! [ ] , 100 ) ;
427+ let selector = DefaultWorkerSelector :: new ( None ) ;
428+ let block_size = 20 ;
429+
430+ match selector. select_worker ( & workers, & request, block_size) {
431+ Err ( KvSchedulerError :: NoEndpoints ) => { } // Expected
432+ _ => panic ! ( "Should return NoEndpoints error" ) ,
433+ }
434+ }
435+
436+ #[ test]
437+ fn test_no_overlap_scores ( ) {
438+ // Workers exist but request has no overlap scores
439+ let workers = create_workers ( vec ! [ ( 1 , 50.0 , 1 ) ] ) ;
440+ let request = create_request ( vec ! [ ] , 100 ) ; // No overlaps
441+ let selector = DefaultWorkerSelector :: new ( None ) ;
442+ let block_size = 20 ;
443+
444+ let result = selector
445+ . select_worker ( & workers, & request, block_size)
446+ . expect ( "Should fallback to selecting worker" ) ;
447+
448+ // Worker1 should be selected with 0 overlap
449+ assert_eq ! ( result. worker_id, 1 ) ;
450+ assert_eq ! ( result. overlap_blocks, 0 ) ;
451+ }
452+
453+ #[ test]
454+ fn test_tie_breaker_randomness ( ) {
455+ // Two identical workers
456+ let workers = create_workers ( vec ! [
457+ ( 1 , 50.0 , 1 ) ,
458+ ( 2 , 50.0 , 1 ) ,
459+ ] ) ;
460+
461+ // Both have same overlap
462+ let request = create_request ( vec ! [ ( 1 , 3 ) , ( 2 , 3 ) ] , 100 ) ;
463+ let selector = DefaultWorkerSelector :: new ( None ) ;
464+ let block_size = 20 ;
465+
466+ // Run multiple times to verify randomness
467+ let mut results = Vec :: new ( ) ;
468+ for _ in 0 ..10 {
469+ let result = selector
470+ . select_worker ( & workers, & request, block_size)
471+ . expect ( "Should select worker" ) ;
472+ results. push ( result. worker_id ) ;
473+ }
474+ println ! ( "{:?}" , results) ;
475+ // Should have selected both workers at least once
476+ assert ! ( results. contains( & 1 ) ) ;
477+ assert ! ( results. contains( & 2 ) ) ;
478+ }
479+
480+ #[ test]
481+ fn test_custom_weights ( ) {
482+ // Setup workers
483+ let workers = create_workers ( vec ! [
484+ ( 1 , 0.50 , 1 ) ,
485+ ( 2 , 0.80 , 0 ) ,
486+ ] ) ;
487+
488+ // Custom config with high priority on GPU usage
489+ let config = KvRouterConfig {
490+ gpu_cache_usage_weight : 10.0 , // Very high weight
491+ ..Default :: default ( )
492+ } ;
493+ let selector = DefaultWorkerSelector :: new ( Some ( config) ) ;
494+ let request = create_request ( vec ! [ ( 1 , 3 ) , ( 2 , 4 ) ] , 100 ) ;
495+ let block_size = 20 ;
496+
497+ let result = selector
498+ . select_worker ( & workers, & request, block_size)
499+ . expect ( "Should select worker" ) ;
500+
501+ assert_eq ! ( result. worker_id, 1 ) ;
502+ }
503+ }
0 commit comments