@@ -49,13 +49,13 @@ impl LatencyRoutingSnapshot {
4949/// Helper function to sample nodes based on their weights.
5050/// Here weight index is selected based on the input number in range [0, 1]
5151#[ inline( always) ]
52- fn weighted_sample ( weights : & [ f64 ] , number : f64 ) -> Option < usize > {
52+ fn weighted_sample ( weighted_nodes : & [ ( f64 , & Node ) ] , number : f64 ) -> Option < usize > {
5353 if !( 0.0 ..=1.0 ) . contains ( & number) {
5454 return None ;
5555 }
56- let sum: f64 = weights . iter ( ) . sum ( ) ;
56+ let sum: f64 = weighted_nodes . iter ( ) . map ( |n| n . 0 ) . sum ( ) ;
5757 let mut weighted_number = number * sum;
58- for ( idx, weight) in weights . iter ( ) . enumerate ( ) {
58+ for ( idx, & ( weight, _ ) ) in weighted_nodes . iter ( ) . enumerate ( ) {
5959 weighted_number -= weight;
6060 if weighted_number <= 0.0 {
6161 return Some ( idx) ;
@@ -69,19 +69,40 @@ impl RoutingSnapshot for LatencyRoutingSnapshot {
6969 !self . weighted_nodes . is_empty ( )
7070 }
7171
72- fn next ( & self ) -> Option < Node > {
73- // We select a node based on it's weight, using a stochastic weighted random sampling approach.
74- let weights = self
72+ fn next_node ( & self ) -> Option < Node > {
73+ self . next_n_nodes ( 1 ) . unwrap_or_default ( ) . into_iter ( ) . next ( )
74+ }
75+
76+ // Uses weighted random sampling algorithm without item replacement n times.
77+ fn next_n_nodes ( & self , n : usize ) -> Option < Vec < Node > > {
78+ if n == 0 {
79+ return Some ( Vec :: new ( ) ) ;
80+ }
81+
82+ let n = std:: cmp:: min ( n, self . weighted_nodes . len ( ) ) ;
83+
84+ let mut nodes = Vec :: with_capacity ( n) ;
85+
86+ let mut weighted_nodes: Vec < _ > = self
7587 . weighted_nodes
7688 . iter ( )
77- . map ( |n| n. weight )
78- . collect :: < Vec < _ > > ( ) ;
79- // Generate a random float in the range [0, 1)
89+ . map ( |n| ( n. weight , & n . node ) )
90+ . collect ( ) ;
91+
8092 let mut rng = rand:: thread_rng ( ) ;
81- let rand_num = rng. gen :: < f64 > ( ) ;
82- // Using this random float and an array of weights we get an index of the node.
83- let idx = weighted_sample ( weights. as_slice ( ) , rand_num) ;
84- idx. map ( |idx| self . weighted_nodes [ idx] . node . clone ( ) )
93+
94+ for _ in 0 ..n {
95+ // Generate a random float in the range [0, 1)
96+ let rand_num = rng. gen :: < f64 > ( ) ;
97+ if let Some ( idx) = weighted_sample ( weighted_nodes. as_slice ( ) , rand_num) {
98+ let node = weighted_nodes[ idx] . 1 ;
99+ nodes. push ( node. clone ( ) ) ;
100+ // Remove the item, so that it can't be selected anymore.
101+ weighted_nodes. swap_remove ( idx) ;
102+ }
103+ }
104+
105+ Some ( nodes)
85106 }
86107
87108 fn sync_nodes ( & mut self , nodes : & [ Node ] ) -> bool {
@@ -143,7 +164,10 @@ impl RoutingSnapshot for LatencyRoutingSnapshot {
143164
144165#[ cfg( test) ]
145166mod tests {
146- use std:: { collections:: HashSet , time:: Duration } ;
167+ use std:: {
168+ collections:: { HashMap , HashSet } ,
169+ time:: Duration ,
170+ } ;
147171
148172 use simple_moving_average:: SMA ;
149173
@@ -166,7 +190,7 @@ mod tests {
166190 assert ! ( snapshot. weighted_nodes. is_empty( ) ) ;
167191 assert ! ( snapshot. existing_nodes. is_empty( ) ) ;
168192 assert ! ( !snapshot. has_nodes( ) ) ;
169- assert ! ( snapshot. next ( ) . is_none( ) ) ;
193+ assert ! ( snapshot. next_node ( ) . is_none( ) ) ;
170194 }
171195
172196 #[ test]
@@ -181,7 +205,7 @@ mod tests {
181205 assert ! ( !is_updated) ;
182206 assert ! ( snapshot. weighted_nodes. is_empty( ) ) ;
183207 assert ! ( !snapshot. has_nodes( ) ) ;
184- assert ! ( snapshot. next ( ) . is_none( ) ) ;
208+ assert ! ( snapshot. next_node ( ) . is_none( ) ) ;
185209 }
186210
187211 #[ test]
@@ -201,7 +225,7 @@ mod tests {
201225 Duration :: from_secs( 1 )
202226 ) ;
203227 assert_eq ! ( weighted_node. weight, 1.0 ) ;
204- assert_eq ! ( snapshot. next ( ) . unwrap( ) , node) ;
228+ assert_eq ! ( snapshot. next_node ( ) . unwrap( ) , node) ;
205229 // Check second update
206230 let health = HealthCheckStatus :: new ( Some ( Duration :: from_secs ( 2 ) ) ) ;
207231 let is_updated = snapshot. update_node ( & node, health) ;
@@ -232,7 +256,7 @@ mod tests {
232256 assert_eq ! ( weighted_node. weight, 1.0 / avg_latency. as_secs_f64( ) ) ;
233257 assert_eq ! ( snapshot. weighted_nodes. len( ) , 1 ) ;
234258 assert_eq ! ( snapshot. existing_nodes. len( ) , 1 ) ;
235- assert_eq ! ( snapshot. next ( ) . unwrap( ) , node) ;
259+ assert_eq ! ( snapshot. next_node ( ) . unwrap( ) , node) ;
236260 }
237261
238262 #[ test]
@@ -307,12 +331,13 @@ mod tests {
307331
308332 #[ test]
309333 fn test_weighted_sample ( ) {
334+ let node = & Node :: new ( "api1.com" ) . unwrap ( ) ;
310335 // Case 1: empty array
311- let arr: & [ f64 ] = & [ ] ;
336+ let arr = & [ ] ;
312337 let idx = weighted_sample ( arr, 0.5 ) ;
313338 assert_eq ! ( idx, None ) ;
314339 // Case 2: single element in array
315- let arr: & [ f64 ] = & [ 1.0 ] ;
340+ let arr = & [ ( 1.0 , node ) ] ;
316341 let idx = weighted_sample ( arr, 0.0 ) ;
317342 assert_eq ! ( idx, Some ( 0 ) ) ;
318343 let idx = weighted_sample ( arr, 1.0 ) ;
@@ -323,7 +348,7 @@ mod tests {
323348 let idx = weighted_sample ( arr, 1.1 ) ;
324349 assert_eq ! ( idx, None ) ;
325350 // Case 3: two elements in array (second element has twice the weight of the first)
326- let arr: & [ f64 ] = & [ 1.0 , 2.0 ] ; // prefixed_sum = [1.0, 3.0]
351+ let arr = & [ ( 1.0 , node ) , ( 2.0 , node ) ] ; // // prefixed_sum = [1.0, 3.0]
327352 let idx = weighted_sample ( arr, 0.0 ) ; // 0.0 * 3.0 < 1.0
328353 assert_eq ! ( idx, Some ( 0 ) ) ;
329354 let idx = weighted_sample ( arr, 0.33 ) ; // 0.33 * 3.0 < 1.0
@@ -338,7 +363,7 @@ mod tests {
338363 let idx = weighted_sample ( arr, 1.1 ) ;
339364 assert_eq ! ( idx, None ) ;
340365 // Case 4: four elements in array
341- let arr: & [ f64 ] = & [ 1.0 , 2.0 , 1.5 , 2.5 ] ; // prefixed_sum = [1.0, 3.0, 4.5, 7.0]
366+ let arr = & [ ( 1.0 , node ) , ( 2.0 , node ) , ( 1.5 , node ) , ( 2.5 , node ) ] ; // prefixed_sum = [1.0, 3.0, 4.5, 7.0]
342367 let idx = weighted_sample ( arr, 0.14 ) ; // 0.14 * 7 < 1.0
343368 assert_eq ! ( idx, Some ( 0 ) ) ; // probability ~0.14
344369 let idx = weighted_sample ( arr, 0.15 ) ; // 0.15 * 7 > 1.0
@@ -359,4 +384,69 @@ mod tests {
359384 let idx = weighted_sample ( arr, 1.1 ) ;
360385 assert_eq ! ( idx, None ) ;
361386 }
387+
388+ #[ test]
389+ // #[ignore]
390+ // This test is for manual runs to see the statistics for nodes selection probability.
391+ fn test_stats_for_next_n_nodes ( ) {
392+ // Arrange
393+ let mut snapshot = LatencyRoutingSnapshot :: new ( ) ;
394+ let node_1 = Node :: new ( "api1.com" ) . unwrap ( ) ;
395+ let node_2 = Node :: new ( "api2.com" ) . unwrap ( ) ;
396+ let node_3 = Node :: new ( "api3.com" ) . unwrap ( ) ;
397+ let node_4 = Node :: new ( "api4.com" ) . unwrap ( ) ;
398+ let node_5 = Node :: new ( "api5.com" ) . unwrap ( ) ;
399+ let node_6 = Node :: new ( "api6.com" ) . unwrap ( ) ;
400+ let latency_mov_avg = LatencyMovAvg :: from_zero ( Duration :: ZERO ) ;
401+ snapshot. weighted_nodes = vec ! [
402+ WeightedNode {
403+ node: node_2. clone( ) ,
404+ latency_mov_avg: latency_mov_avg. clone( ) ,
405+ weight: 8.0 ,
406+ } ,
407+ WeightedNode {
408+ node: node_3. clone( ) ,
409+ latency_mov_avg: latency_mov_avg. clone( ) ,
410+ weight: 4.0 ,
411+ } ,
412+ WeightedNode {
413+ node: node_1. clone( ) ,
414+ latency_mov_avg: latency_mov_avg. clone( ) ,
415+ weight: 16.0 ,
416+ } ,
417+ WeightedNode {
418+ node: node_6. clone( ) ,
419+ latency_mov_avg: latency_mov_avg. clone( ) ,
420+ weight: 2.0 ,
421+ } ,
422+ WeightedNode {
423+ node: node_5. clone( ) ,
424+ latency_mov_avg: latency_mov_avg. clone( ) ,
425+ weight: 1.0 ,
426+ } ,
427+ WeightedNode {
428+ node: node_4. clone( ) ,
429+ latency_mov_avg: latency_mov_avg. clone( ) ,
430+ weight: 4.1 ,
431+ } ,
432+ ] ;
433+
434+ let mut stats = HashMap :: new ( ) ;
435+ let experiments = 30 ;
436+ let select_nodes_count = 10 ;
437+ for i in 0 ..experiments {
438+ let nodes = snapshot. next_n_nodes ( select_nodes_count) . unwrap ( ) ;
439+ println ! ( "Experiment {i}: selected nodes {nodes:?}" ) ;
440+ for item in nodes. into_iter ( ) {
441+ * stats. entry ( item) . or_insert ( 1 ) += 1 ;
442+ }
443+ }
444+ for ( node, count) in stats {
445+ println ! (
446+ "Node {:?} is selected with probability {}" ,
447+ node. domain( ) ,
448+ count as f64 / experiments as f64
449+ ) ;
450+ }
451+ }
362452}
0 commit comments