Skip to content

Commit 6099348

Browse files
feat: add fn n_routes() to RouteProvider trait (#584)
1 parent 3010732 commit 6099348

File tree

6 files changed

+325
-33
lines changed

6 files changed

+325
-33
lines changed

ic-agent/src/agent/http_transport/dynamic_routing/dynamic_route_provider.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,20 @@ where
170170
{
171171
fn route(&self) -> Result<Url, AgentError> {
172172
let snapshot = self.routing_snapshot.load();
173-
let node = snapshot.next().ok_or_else(|| {
173+
let node = snapshot.next_node().ok_or_else(|| {
174174
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
175175
})?;
176176
Ok(node.to_routing_url())
177177
}
178+
179+
fn n_ordered_routes(&self, n: usize) -> Result<Vec<Url>, AgentError> {
180+
let snapshot = self.routing_snapshot.load();
181+
let nodes = snapshot.next_n_nodes(n).ok_or_else(|| {
182+
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
183+
})?;
184+
let urls = nodes.iter().map(|n| n.to_routing_url()).collect();
185+
Ok(urls)
186+
}
178187
}
179188

180189
impl<S> DynamicRouteProvider<S>

ic-agent/src/agent/http_transport/dynamic_routing/nodes_fetch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ where
143143
// - failure should never happen, but we trace it if it does
144144
loop {
145145
let snapshot = self.routing_snapshot.load();
146-
if let Some(node) = snapshot.next() {
146+
if let Some(node) = snapshot.next_node() {
147147
match self.fetcher.fetch((&node).into()).await {
148148
Ok(nodes) => {
149149
let msg = Some(

ic-agent/src/agent/http_transport/dynamic_routing/snapshot/latency_based_routing.rs

Lines changed: 112 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ impl LatencyRoutingSnapshot {
4949
/// Helper function to sample nodes based on their weights.
5050
/// Here weight index is selected based on the input number in range [0, 1]
5151
#[inline(always)]
52-
fn weighted_sample(weights: &[f64], number: f64) -> Option<usize> {
52+
fn weighted_sample(weighted_nodes: &[(f64, &Node)], number: f64) -> Option<usize> {
5353
if !(0.0..=1.0).contains(&number) {
5454
return None;
5555
}
56-
let sum: f64 = weights.iter().sum();
56+
let sum: f64 = weighted_nodes.iter().map(|n| n.0).sum();
5757
let mut weighted_number = number * sum;
58-
for (idx, weight) in weights.iter().enumerate() {
58+
for (idx, &(weight, _)) in weighted_nodes.iter().enumerate() {
5959
weighted_number -= weight;
6060
if weighted_number <= 0.0 {
6161
return Some(idx);
@@ -69,19 +69,40 @@ impl RoutingSnapshot for LatencyRoutingSnapshot {
6969
!self.weighted_nodes.is_empty()
7070
}
7171

72-
fn next(&self) -> Option<Node> {
73-
// We select a node based on it's weight, using a stochastic weighted random sampling approach.
74-
let weights = self
72+
fn next_node(&self) -> Option<Node> {
73+
self.next_n_nodes(1).unwrap_or_default().into_iter().next()
74+
}
75+
76+
// Uses weighted random sampling algorithm without item replacement n times.
77+
fn next_n_nodes(&self, n: usize) -> Option<Vec<Node>> {
78+
if n == 0 {
79+
return Some(Vec::new());
80+
}
81+
82+
let n = std::cmp::min(n, self.weighted_nodes.len());
83+
84+
let mut nodes = Vec::with_capacity(n);
85+
86+
let mut weighted_nodes: Vec<_> = self
7587
.weighted_nodes
7688
.iter()
77-
.map(|n| n.weight)
78-
.collect::<Vec<_>>();
79-
// Generate a random float in the range [0, 1)
89+
.map(|n| (n.weight, &n.node))
90+
.collect();
91+
8092
let mut rng = rand::thread_rng();
81-
let rand_num = rng.gen::<f64>();
82-
// Using this random float and an array of weights we get an index of the node.
83-
let idx = weighted_sample(weights.as_slice(), rand_num);
84-
idx.map(|idx| self.weighted_nodes[idx].node.clone())
93+
94+
for _ in 0..n {
95+
// Generate a random float in the range [0, 1)
96+
let rand_num = rng.gen::<f64>();
97+
if let Some(idx) = weighted_sample(weighted_nodes.as_slice(), rand_num) {
98+
let node = weighted_nodes[idx].1;
99+
nodes.push(node.clone());
100+
// Remove the item, so that it can't be selected anymore.
101+
weighted_nodes.swap_remove(idx);
102+
}
103+
}
104+
105+
Some(nodes)
85106
}
86107

87108
fn sync_nodes(&mut self, nodes: &[Node]) -> bool {
@@ -143,7 +164,10 @@ impl RoutingSnapshot for LatencyRoutingSnapshot {
143164

144165
#[cfg(test)]
145166
mod tests {
146-
use std::{collections::HashSet, time::Duration};
167+
use std::{
168+
collections::{HashMap, HashSet},
169+
time::Duration,
170+
};
147171

148172
use simple_moving_average::SMA;
149173

@@ -166,7 +190,7 @@ mod tests {
166190
assert!(snapshot.weighted_nodes.is_empty());
167191
assert!(snapshot.existing_nodes.is_empty());
168192
assert!(!snapshot.has_nodes());
169-
assert!(snapshot.next().is_none());
193+
assert!(snapshot.next_node().is_none());
170194
}
171195

172196
#[test]
@@ -181,7 +205,7 @@ mod tests {
181205
assert!(!is_updated);
182206
assert!(snapshot.weighted_nodes.is_empty());
183207
assert!(!snapshot.has_nodes());
184-
assert!(snapshot.next().is_none());
208+
assert!(snapshot.next_node().is_none());
185209
}
186210

187211
#[test]
@@ -201,7 +225,7 @@ mod tests {
201225
Duration::from_secs(1)
202226
);
203227
assert_eq!(weighted_node.weight, 1.0);
204-
assert_eq!(snapshot.next().unwrap(), node);
228+
assert_eq!(snapshot.next_node().unwrap(), node);
205229
// Check second update
206230
let health = HealthCheckStatus::new(Some(Duration::from_secs(2)));
207231
let is_updated = snapshot.update_node(&node, health);
@@ -232,7 +256,7 @@ mod tests {
232256
assert_eq!(weighted_node.weight, 1.0 / avg_latency.as_secs_f64());
233257
assert_eq!(snapshot.weighted_nodes.len(), 1);
234258
assert_eq!(snapshot.existing_nodes.len(), 1);
235-
assert_eq!(snapshot.next().unwrap(), node);
259+
assert_eq!(snapshot.next_node().unwrap(), node);
236260
}
237261

238262
#[test]
@@ -307,12 +331,13 @@ mod tests {
307331

308332
#[test]
309333
fn test_weighted_sample() {
334+
let node = &Node::new("api1.com").unwrap();
310335
// Case 1: empty array
311-
let arr: &[f64] = &[];
336+
let arr = &[];
312337
let idx = weighted_sample(arr, 0.5);
313338
assert_eq!(idx, None);
314339
// Case 2: single element in array
315-
let arr: &[f64] = &[1.0];
340+
let arr = &[(1.0, node)];
316341
let idx = weighted_sample(arr, 0.0);
317342
assert_eq!(idx, Some(0));
318343
let idx = weighted_sample(arr, 1.0);
@@ -323,7 +348,7 @@ mod tests {
323348
let idx = weighted_sample(arr, 1.1);
324349
assert_eq!(idx, None);
325350
// Case 3: two elements in array (second element has twice the weight of the first)
326-
let arr: &[f64] = &[1.0, 2.0]; // prefixed_sum = [1.0, 3.0]
351+
let arr = &[(1.0, node), (2.0, node)]; // // prefixed_sum = [1.0, 3.0]
327352
let idx = weighted_sample(arr, 0.0); // 0.0 * 3.0 < 1.0
328353
assert_eq!(idx, Some(0));
329354
let idx = weighted_sample(arr, 0.33); // 0.33 * 3.0 < 1.0
@@ -338,7 +363,7 @@ mod tests {
338363
let idx = weighted_sample(arr, 1.1);
339364
assert_eq!(idx, None);
340365
// Case 4: four elements in array
341-
let arr: &[f64] = &[1.0, 2.0, 1.5, 2.5]; // prefixed_sum = [1.0, 3.0, 4.5, 7.0]
366+
let arr = &[(1.0, node), (2.0, node), (1.5, node), (2.5, node)]; // prefixed_sum = [1.0, 3.0, 4.5, 7.0]
342367
let idx = weighted_sample(arr, 0.14); // 0.14 * 7 < 1.0
343368
assert_eq!(idx, Some(0)); // probability ~0.14
344369
let idx = weighted_sample(arr, 0.15); // 0.15 * 7 > 1.0
@@ -359,4 +384,69 @@ mod tests {
359384
let idx = weighted_sample(arr, 1.1);
360385
assert_eq!(idx, None);
361386
}
387+
388+
#[test]
389+
// #[ignore]
390+
// This test is for manual runs to see the statistics for nodes selection probability.
391+
fn test_stats_for_next_n_nodes() {
392+
// Arrange
393+
let mut snapshot = LatencyRoutingSnapshot::new();
394+
let node_1 = Node::new("api1.com").unwrap();
395+
let node_2 = Node::new("api2.com").unwrap();
396+
let node_3 = Node::new("api3.com").unwrap();
397+
let node_4 = Node::new("api4.com").unwrap();
398+
let node_5 = Node::new("api5.com").unwrap();
399+
let node_6 = Node::new("api6.com").unwrap();
400+
let latency_mov_avg = LatencyMovAvg::from_zero(Duration::ZERO);
401+
snapshot.weighted_nodes = vec![
402+
WeightedNode {
403+
node: node_2.clone(),
404+
latency_mov_avg: latency_mov_avg.clone(),
405+
weight: 8.0,
406+
},
407+
WeightedNode {
408+
node: node_3.clone(),
409+
latency_mov_avg: latency_mov_avg.clone(),
410+
weight: 4.0,
411+
},
412+
WeightedNode {
413+
node: node_1.clone(),
414+
latency_mov_avg: latency_mov_avg.clone(),
415+
weight: 16.0,
416+
},
417+
WeightedNode {
418+
node: node_6.clone(),
419+
latency_mov_avg: latency_mov_avg.clone(),
420+
weight: 2.0,
421+
},
422+
WeightedNode {
423+
node: node_5.clone(),
424+
latency_mov_avg: latency_mov_avg.clone(),
425+
weight: 1.0,
426+
},
427+
WeightedNode {
428+
node: node_4.clone(),
429+
latency_mov_avg: latency_mov_avg.clone(),
430+
weight: 4.1,
431+
},
432+
];
433+
434+
let mut stats = HashMap::new();
435+
let experiments = 30;
436+
let select_nodes_count = 10;
437+
for i in 0..experiments {
438+
let nodes = snapshot.next_n_nodes(select_nodes_count).unwrap();
439+
println!("Experiment {i}: selected nodes {nodes:?}");
440+
for item in nodes.into_iter() {
441+
*stats.entry(item).or_insert(1) += 1;
442+
}
443+
}
444+
for (node, count) in stats {
445+
println!(
446+
"Node {:?} is selected with probability {}",
447+
node.domain(),
448+
count as f64 / experiments as f64
449+
);
450+
}
451+
}
362452
}

0 commit comments

Comments
 (0)