diff --git a/examples/custom_load_balancing_policy.rs b/examples/custom_load_balancing_policy.rs index fb1ae0cb7c..5c279f2331 100644 --- a/examples/custom_load_balancing_policy.rs +++ b/examples/custom_load_balancing_policy.rs @@ -18,12 +18,12 @@ struct CustomLoadBalancingPolicy { fav_datacenter_name: String, } -fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) { +fn with_random_shard(node: NodeRef) -> (NodeRef, Option) { let nr_shards = node .sharder() .map(|sharder| sharder.nr_shards.get()) .unwrap_or(1); - (node, thread_rng().gen_range(0..nr_shards) as Shard) + (node, Some(thread_rng().gen_range(0..nr_shards) as Shard)) } impl LoadBalancingPolicy for CustomLoadBalancingPolicy { @@ -31,7 +31,7 @@ impl LoadBalancingPolicy for CustomLoadBalancingPolicy { &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.fallback(_info, cluster).next() } diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index e3c1f97377..625232cfae 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -10,8 +10,10 @@ use itertools::{Either, Itertools}; use rand::{prelude::SliceRandom, thread_rng, Rng}; use rand_pcg::Pcg32; use scylla_cql::{errors::QueryError, frame::types::SerialConsistency, Consistency}; +use std::hash::{Hash, Hasher}; use std::{fmt, sync::Arc, time::Duration}; use tracing::{debug, warn}; +use uuid::Uuid; #[derive(Clone, Copy)] enum NodeLocationCriteria<'a> { @@ -75,7 +77,7 @@ pub struct DefaultPolicy { preferences: NodeLocationPreference, is_token_aware: bool, permit_dc_failover: bool, - pick_predicate: Box, Shard)) -> bool + Send + Sync>, + pick_predicate: Box, Option) -> bool + Send + Sync>, latency_awareness: Option, fixed_seed: Option, } @@ -97,7 +99,7 @@ impl LoadBalancingPolicy for DefaultPolicy { &'a self, query: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { let routing_info = self.routing_info(query, cluster); if let Some(ref token_with_strategy) = routing_info.token_with_strategy { if self.preferences.datacenter().is_some() @@ -126,13 +128,13 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_rack_picked = self.pick_replica( ts, NodeLocationCriteria::DatacenterAndRack(dc, rack), - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_local_rack_replica) = local_rack_picked { - return Some(alive_local_rack_replica); + if let Some((alive_local_rack_replica, shard)) = local_rack_picked { + return Some((alive_local_rack_replica, Some(shard))); } } @@ -143,13 +145,13 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let picked = self.pick_replica( ts, NodeLocationCriteria::Datacenter(dc), - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_local_replica) = picked { - return Some(alive_local_replica); + if let Some((alive_local_replica, shard)) = picked { + return Some((alive_local_replica, Some(shard))); } } @@ -161,12 +163,12 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let picked = self.pick_replica( ts, NodeLocationCriteria::Any, - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_remote_replica) = picked { - return Some(alive_remote_replica); + if let Some((alive_remote_replica, shard)) = picked { + return Some((alive_remote_replica, Some(shard))); } } }; @@ -179,47 +181,47 @@ or refrain from preferring datacenters (which may ban all other datacenters, if if let NodeLocationPreference::DatacenterAndRack(dc, rack) = &self.preferences { // Try to pick some alive local rack random node. let rack_predicate = Self::make_rack_predicate( - &self.pick_predicate, + |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); let local_rack_picked = self.pick_node(nodes, rack_predicate); if let Some(alive_local_rack) = local_rack_picked { - return Some(alive_local_rack); + return Some((alive_local_rack, None)); } } // Try to pick some alive local random node. - if let Some(alive_local) = self.pick_node(nodes, &self.pick_predicate) { - return Some(alive_local); + if let Some(alive_local) = self.pick_node(nodes, |node| (self.pick_predicate)(node, None)) { + return Some((alive_local, None)); } let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = self.pick_node(all_nodes, &self.pick_predicate); + let picked = self.pick_node(all_nodes, |node| (self.pick_predicate)(node, None)); if let Some(alive_maybe_remote) = picked { - return Some(alive_maybe_remote); + return Some((alive_maybe_remote, None)); } } // Previous checks imply that every node we could have selected is down. // Let's try to return a down node that wasn't disabled. - let picked = self.pick_node(nodes, |(node, _shard)| node.is_enabled()); + let picked = self.pick_node(nodes, |node| node.is_enabled()); if let Some(down_but_enabled_local_node) = picked { - return Some(down_but_enabled_local_node); + return Some((down_but_enabled_local_node, None)); } // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = self.pick_node(all_nodes, |(node, _shard)| node.is_enabled()); + let picked = self.pick_node(all_nodes, |node| node.is_enabled()); if let Some(down_but_enabled_maybe_remote_node) = picked { - return Some(down_but_enabled_maybe_remote_node); + return Some((down_but_enabled_maybe_remote_node, None)); } } // Every node is disabled. This could be due to a bad host filter - configuration error. - nodes.first().map(|node| self.with_random_shard(node)) + nodes.first().map(|node| (node, None)) } fn fallback<'a>( @@ -241,7 +243,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_rack_replicas = self.fallback_replicas( ts, NodeLocationCriteria::DatacenterAndRack(dc, rack), - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -257,7 +259,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_replicas = self.fallback_replicas( ts, NodeLocationCriteria::Datacenter(dc), - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -273,7 +275,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let remote_replicas = self.fallback_replicas( ts, NodeLocationCriteria::Any, - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -287,10 +289,11 @@ or refrain from preferring datacenters (which may ban all other datacenters, if Either::Left( maybe_local_rack_replicas .chain(maybe_local_replicas) - .chain(maybe_remote_replicas), + .chain(maybe_remote_replicas) + .map(|(node, shard)| (node, Some(shard))), ) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; // Get a list of all local alive nodes, and apply a round robin to it @@ -299,31 +302,37 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let maybe_local_rack_nodes = if let NodeLocationPreference::DatacenterAndRack(dc, rack) = &self.preferences { let rack_predicate = Self::make_rack_predicate( - &self.pick_predicate, + |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - Either::Left(self.round_robin_nodes_with_shards(local_nodes, rack_predicate)) + Either::Left( + self.round_robin_nodes(local_nodes, rack_predicate) + .map(|node| (node, None)), + ) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; - let robinned_local_nodes = self.round_robin_nodes_with_shards(local_nodes, Self::is_alive); + let robinned_local_nodes = self + .round_robin_nodes(local_nodes, |node| Self::is_alive(node, None)) + .map(|node| (node, None)); let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. let maybe_remote_nodes = if self.is_datacenter_failover_possible(&routing_info) { - let robinned_all_nodes = self.round_robin_nodes_with_shards(all_nodes, Self::is_alive); + let robinned_all_nodes = + self.round_robin_nodes(all_nodes, |node| Self::is_alive(node, None)); - Either::Left(robinned_all_nodes) + Either::Left(robinned_all_nodes.map(|node| (node, None))) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; // Even if we consider some enabled nodes to be down, we should try contacting them in the last resort. let maybe_down_local_nodes = local_nodes .iter() .filter(|node| node.is_enabled()) - .map(|node| self.with_random_shard(node)); + .map(|node| (node, None)); // If a datacenter failover is possible, loosen restriction about locality. let maybe_down_nodes = if self.is_datacenter_failover_possible(&routing_info) { @@ -331,12 +340,36 @@ or refrain from preferring datacenters (which may ban all other datacenters, if all_nodes .iter() .filter(|node| node.is_enabled()) - .map(|node| self.with_random_shard(node)), + .map(|node| (node, None)), ) } else { Either::Right(std::iter::empty()) }; + struct DefaultPolicyTargetComparator { + host_id: Uuid, + shard: Option, + } + + impl PartialEq for DefaultPolicyTargetComparator { + fn eq(&self, other: &Self) -> bool { + match (self.shard, other.shard) { + (_, None) | (None, _) => self.host_id.eq(&other.host_id), + (Some(shard_left), Some(shard_right)) => { + self.host_id.eq(&other.host_id) && shard_left.eq(&shard_right) + } + } + } + } + + impl Eq for DefaultPolicyTargetComparator {} + + impl Hash for DefaultPolicyTargetComparator { + fn hash(&self, state: &mut H) { + self.host_id.hash(state); + } + } + // Construct a fallback plan as a composition of replicas, local nodes and remote nodes. let plan = maybe_replicas .chain(maybe_local_rack_nodes) @@ -344,7 +377,10 @@ or refrain from preferring datacenters (which may ban all other datacenters, if .chain(maybe_remote_nodes) .chain(maybe_down_local_nodes) .chain(maybe_down_nodes) - .unique(); + .unique_by(|(node, shard)| DefaultPolicyTargetComparator { + host_id: node.host_id, + shard: *shard, + }); if let Some(latency_awareness) = self.latency_awareness.as_ref() { Box::new(latency_awareness.wrap(plan)) @@ -433,15 +469,28 @@ impl DefaultPolicy { /// Wraps the provided predicate, adding the requirement for rack to match. fn make_rack_predicate<'a>( - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'a>) -> bool + 'a, replica_location: NodeLocationCriteria<'a>, - ) -> impl Fn(&(NodeRef<'a>, Shard)) -> bool { - move |node_and_shard @ (node, _shard)| match replica_location { + ) -> impl Fn(NodeRef<'a>) -> bool { + move |node| match replica_location { + NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => predicate(node), + NodeLocationCriteria::DatacenterAndRack(_, rack) => { + predicate(node) && node.rack.as_deref() == Some(rack) + } + } + } + + /// Wraps the provided predicate, adding the requirement for rack to match. + fn make_sharded_rack_predicate<'a>( + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, + replica_location: NodeLocationCriteria<'a>, + ) -> impl Fn(NodeRef<'a>, Shard) -> bool { + move |node, shard| match replica_location { NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => { - predicate(node_and_shard) + predicate(node, shard) } NodeLocationCriteria::DatacenterAndRack(_, rack) => { - predicate(node_and_shard) && node.rack.as_deref() == Some(rack) + predicate(node, shard) && node.rack.as_deref() == Some(rack) } } } @@ -450,11 +499,11 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, order: ReplicaOrder, ) -> impl Iterator, Shard)> { - let predicate = Self::make_rack_predicate(predicate, replica_location); + let predicate = Self::make_sharded_rack_predicate(predicate, replica_location); let replica_iter = match order { ReplicaOrder::Arbitrary => Either::Left( @@ -467,14 +516,14 @@ impl DefaultPolicy { .into_iter(), ), }; - replica_iter.filter(move |node_and_shard: &(NodeRef<'a>, Shard)| predicate(node_and_shard)) + replica_iter.filter(move |(node, shard): &(NodeRef<'a>, Shard)| predicate(node, *shard)) } fn pick_replica<'a>( &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, ) -> Option<(NodeRef<'a>, Shard)> { @@ -502,7 +551,7 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, ) -> Option<(NodeRef<'a>, Shard)> { match replica_location { @@ -521,8 +570,8 @@ impl DefaultPolicy { .into_replicas_ordered() .into_iter() .next() - .and_then(|primary_replica| { - predicate(&primary_replica).then_some(primary_replica) + .and_then(|(primary_replica, shard)| { + predicate(primary_replica, shard).then_some((primary_replica, shard)) }) } NodeLocationCriteria::Datacenter(_) | NodeLocationCriteria::DatacenterAndRack(_, _) => { @@ -534,7 +583,7 @@ impl DefaultPolicy { self.replicas( ts, replica_location, - move |node_and_shard| predicate(node_and_shard), + predicate, cluster, ReplicaOrder::RingOrder, ) @@ -547,18 +596,18 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, ) -> Option<(NodeRef<'a>, Shard)> { - let predicate = Self::make_rack_predicate(predicate, replica_location); + let predicate = Self::make_sharded_rack_predicate(predicate, replica_location); let replica_set = self.nonfiltered_replica_set(ts, replica_location, cluster); if let Some(fixed) = self.fixed_seed { let mut gen = Pcg32::new(fixed, 0); - replica_set.choose_filtered(&mut gen, predicate) + replica_set.choose_filtered(&mut gen, |(node, shard)| predicate(node, *shard)) } else { - replica_set.choose_filtered(&mut thread_rng(), predicate) + replica_set.choose_filtered(&mut thread_rng(), |(node, shard)| predicate(node, *shard)) } } @@ -566,7 +615,7 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&(NodeRef<'_>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'_>, Shard) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, ) -> impl Iterator, Shard)> { @@ -604,22 +653,18 @@ impl DefaultPolicy { fn pick_node<'a>( &'a self, nodes: &'a [Arc], - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, - ) -> Option<(NodeRef<'_>, Shard)> { + predicate: impl Fn(NodeRef<'a>) -> bool, + ) -> Option> { // Select the first node that matches the predicate - Self::randomly_rotated_nodes(nodes) - .map(|node| self.with_random_shard(node)) - .find(predicate) + Self::randomly_rotated_nodes(nodes).find(|&node| predicate(node)) } - fn round_robin_nodes_with_shards<'a>( + fn round_robin_nodes<'a>( &'a self, nodes: &'a [Arc], - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, - ) -> impl Iterator, Shard)> { - Self::randomly_rotated_nodes(nodes) - .map(|node| self.with_random_shard(node)) - .filter(predicate) + predicate: impl Fn(NodeRef<'a>) -> bool, + ) -> impl Iterator> { + Self::randomly_rotated_nodes(nodes).filter(move |node| predicate(node)) } fn shuffle<'a>( @@ -638,23 +683,7 @@ impl DefaultPolicy { vec.into_iter() } - fn with_random_shard<'a>(&self, node: NodeRef<'a>) -> (NodeRef<'a>, Shard) { - let nr_shards = node - .sharder() - .map(|sharder| sharder.nr_shards.get()) - .unwrap_or(1); - ( - node, - (if let Some(fixed) = self.fixed_seed { - let mut gen = Pcg32::new(fixed, 0); - gen.gen_range(0..nr_shards) - } else { - thread_rng().gen_range(0..nr_shards) - }) as Shard, - ) - } - - fn is_alive(&(node, _shard): &(NodeRef<'_>, Shard)) -> bool { + fn is_alive(node: NodeRef, _shard: Option) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() @@ -720,11 +749,10 @@ impl DefaultPolicyBuilder { let latency_awareness = self.latency_awareness.map(|builder| builder.build()); let pick_predicate = if let Some(ref latency_awareness) = latency_awareness { let latency_predicate = latency_awareness.generate_predicate(); - Box::new( - move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { - DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) - }, - ) as Box, Shard)) -> bool + Send + Sync + 'static> + Box::new(move |node: NodeRef<'_>, shard| { + DefaultPolicy::is_alive(node, shard) && latency_predicate(node) + }) + as Box, Option) -> bool + Send + Sync + 'static> } else { Box::new(DefaultPolicy::is_alive) }; @@ -912,6 +940,7 @@ impl<'a> TokenWithStrategy<'a> { #[cfg(test)] mod tests { use scylla_cql::{frame::types::SerialConsistency, Consistency}; + use tracing::info; use self::framework::{ get_plan_and_collect_node_identifiers, mock_cluster_data_for_token_unaware_tests, @@ -919,7 +948,7 @@ mod tests { }; use crate::{ load_balancing::{ - default::tests::framework::mock_cluster_data_for_token_aware_tests, RoutingInfo, + default::tests::framework::mock_cluster_data_for_token_aware_tests, Plan, RoutingInfo, }, routing::Token, test_utils::setup_tracing, @@ -947,6 +976,7 @@ mod tests { }, }; + #[derive(Debug)] enum ExpectedGroup { NonDeterministic(HashSet), Deterministic(HashSet), @@ -1002,6 +1032,7 @@ mod tests { } } + #[derive(Debug)] pub(crate) struct ExpectedGroups { groups: Vec, } @@ -1197,6 +1228,19 @@ mod tests { let plan = get_plan_and_collect_node_identifiers(policy, routing_info, cluster); plans.push(plan); } + let example_plan = Plan::new(policy, routing_info, cluster); + info!("Example plan from policy:",); + for (node, shard) in example_plan { + info!( + "Node port: {}, shard: {}, dc: {:?}, rack: {:?}, down: {:?}", + node.address.port(), + shard, + node.datacenter, + node.rack, + node.is_down() + ); + } + expected_groups.assert_proper_grouping_in_plans(&plans); } @@ -1252,9 +1296,11 @@ mod tests { #[tokio::test] async fn test_default_policy_with_token_aware_statements() { setup_tracing(); - use crate::transport::locator::test::{A, B, C, D, E, F, G}; + use crate::transport::locator::test::{A, B, C, D, E, F, G}; let cluster = mock_cluster_data_for_token_aware_tests().await; + + #[derive(Debug)] struct Test<'a> { policy: DefaultPolicy, routing_info: RoutingInfo<'a>, @@ -1704,12 +1750,13 @@ mod tests { }, ]; - for Test { - policy, - routing_info, - expected_groups, - } in tests - { + for test in tests { + info!("Test: {:?}", test); + let Test { + policy, + routing_info, + expected_groups, + } = test; test_default_policy_with_given_cluster_and_routing_info( &policy, &cluster, @@ -2401,21 +2448,45 @@ mod latency_awareness { pub(super) fn wrap<'a>( &self, - fallback: impl Iterator, Shard)>, - ) -> impl Iterator, Shard)> { + fallback: impl Iterator, Option)>, + ) -> impl Iterator, Option)> { let min_avg_latency = match self.last_min_latency.load() { Some(min_avg) => min_avg, None => return Either::Left(fallback), // noop, as no latency data has been collected yet }; - Either::Right(IteratorWithSkippedNodes::new( - self.node_avgs.read().unwrap().deref(), - fallback, - self.exclusion_threshold, - self.retry_period, - self.minimum_measurements, - min_avg_latency, - )) + let average_latencies = self.node_avgs.read().unwrap(); + let targets = fallback; + + let mut fast_targets = vec![]; + let mut penalised_targets = vec![]; + + for node_and_shard @ (node, _shard) in targets { + match fast_enough( + average_latencies.deref(), + node.host_id, + self.exclusion_threshold, + self.retry_period, + self.minimum_measurements, + min_avg_latency, + ) { + FastEnough::Yes => fast_targets.push(node_and_shard), + FastEnough::No { average } => { + trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).", + node.address, node.datacenter, node.rack, self.exclusion_threshold, average.as_millis(), min_avg_latency.as_millis()); + penalised_targets.push(node_and_shard); + } + } + } + + let mut fast_targets = fast_targets.into_iter(); + let mut penalised_targets = penalised_targets.into_iter(); + + let skipping_penalised_targets_iterator = std::iter::from_fn(move || { + fast_targets.next().or_else(|| penalised_targets.next()) + }); + + Either::Right(skipping_penalised_targets_iterator) } pub(super) fn report_query(&self, node: &Node, latency: Duration) { @@ -2721,71 +2792,6 @@ mod latency_awareness { } } - struct IteratorWithSkippedNodes<'a, Fast, Penalised> - where - Fast: Iterator, Shard)>, - Penalised: Iterator, Shard)>, - { - fast_nodes: Fast, - penalised_nodes: Penalised, - } - - impl<'a> - IteratorWithSkippedNodes< - 'a, - std::vec::IntoIter<(NodeRef<'a>, Shard)>, - std::vec::IntoIter<(NodeRef<'a>, Shard)>, - > - { - fn new( - average_latencies: &HashMap>>, - nodes: impl Iterator, Shard)>, - exclusion_threshold: f64, - retry_period: Duration, - minimum_measurements: usize, - min_avg: Duration, - ) -> Self { - let mut fast_nodes = vec![]; - let mut penalised_nodes = vec![]; - - for node_and_shard @ (node, _shard) in nodes { - match fast_enough( - average_latencies, - node.host_id, - exclusion_threshold, - retry_period, - minimum_measurements, - min_avg, - ) { - FastEnough::Yes => fast_nodes.push(node_and_shard), - FastEnough::No { average } => { - trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).", - node.address, node.datacenter, node.rack, exclusion_threshold, average.as_millis(), min_avg.as_millis()); - penalised_nodes.push(node_and_shard); - } - } - } - - Self { - fast_nodes: fast_nodes.into_iter(), - penalised_nodes: penalised_nodes.into_iter(), - } - } - } - - impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised> - where - Fast: Iterator, Shard)>, - Penalised: Iterator, Shard)>, - { - type Item = (NodeRef<'a>, Shard); - - fn next(&mut self) -> Option { - self.fast_nodes - .next() - .or_else(|| self.penalised_nodes.next()) - } - } #[cfg(test)] mod tests { use scylla_cql::Consistency; @@ -2860,12 +2866,10 @@ mod latency_awareness { ) -> DefaultPolicy { let pick_predicate = { let latency_predicate = latency_awareness.generate_predicate(); - Box::new( - move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { - DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) - }, - ) - as Box, Shard)) -> bool + Send + Sync + 'static> + Box::new(move |node: NodeRef<'_>, shard| { + DefaultPolicy::is_alive(node, shard) && latency_predicate(node) + }) + as Box, Option) -> bool + Send + Sync + 'static> }; DefaultPolicy { diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index 977e3d508f..f1cd5bdf27 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -39,7 +39,8 @@ pub struct RoutingInfo<'a> { /// /// It is computed on-demand, only if querying the most preferred node fails /// (or when speculative execution is triggered). -pub type FallbackPlan<'a> = Box, Shard)> + Send + Sync + 'a>; +pub type FallbackPlan<'a> = + Box, Option)> + Send + Sync + 'a>; /// Policy that decides which nodes and shards to contact for each query. /// @@ -67,7 +68,7 @@ pub trait LoadBalancingPolicy: Send + Sync + std::fmt::Debug { &'a self, query: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)>; + ) -> Option<(NodeRef<'a>, Option)>; /// Returns all contact-appropriate nodes for a given query. fn fallback<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) diff --git a/scylla/src/transport/load_balancing/plan.rs b/scylla/src/transport/load_balancing/plan.rs index 5fc6294467..3dc946c58b 100644 --- a/scylla/src/transport/load_balancing/plan.rs +++ b/scylla/src/transport/load_balancing/plan.rs @@ -1,3 +1,4 @@ +use rand::{thread_rng, Rng}; use tracing::error; use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; @@ -6,20 +7,65 @@ use crate::{routing::Shard, transport::ClusterData}; enum PlanState<'a> { Created, PickedNone, // This always means an abnormal situation: it means that no nodes satisfied locality/node filter requirements. - Picked((NodeRef<'a>, Shard)), + Picked((NodeRef<'a>, Option)), Fallback { iter: FallbackPlan<'a>, - node_to_filter_out: (NodeRef<'a>, Shard), + target_to_filter_out: (NodeRef<'a>, Option), }, } -/// The list of nodes constituting the query plan. +/// The list of targets constituting the query plan. Target here is a pair `(NodeRef<'a>, Shard)`. /// -/// The plan is partly lazily computed, with the first node computed -/// eagerly in the first place and the remaining nodes computed on-demand +/// The plan is partly lazily computed, with the first target computed +/// eagerly in the first place and the remaining targets computed on-demand /// (all at once). /// This significantly reduces the allocation overhead on "the happy path" -/// (when the first node successfully handles the request), +/// (when the first target successfully handles the request). +/// +/// `Plan` implements `Iterator, Shard)>` but LoadBalancingPolicy +/// returns `Option` instead of `Shard` both in `pick` and in `fallback`. +/// `Plan` handles the `None` case by using random shard for a given node. +/// There is currently no way to configure RNG used by `Plan`. +/// If you don't want `Plan` to do randomize shards or you want to control the RNG, +/// use custom LBP that will always return non-`None` shards. +/// Example of LBP that always uses shard 0, preventing `Plan` from using random numbers: +/// +/// ``` +/// # use std::sync::Arc; +/// # use scylla::load_balancing::LoadBalancingPolicy; +/// # use scylla::load_balancing::RoutingInfo; +/// # use scylla::transport::ClusterData; +/// # use scylla::transport::NodeRef; +/// # use scylla::routing::Shard; +/// # use scylla::load_balancing::FallbackPlan; +/// +/// #[derive(Debug)] +/// struct NonRandomLBP { +/// inner: Arc, +/// } +/// impl LoadBalancingPolicy for NonRandomLBP { +/// fn pick<'a>( +/// &'a self, +/// info: &'a RoutingInfo, +/// cluster: &'a ClusterData, +/// ) -> Option<(NodeRef<'a>, Option)> { +/// self.inner +/// .pick(info, cluster) +/// .map(|(node, shard)| (node, shard.or(Some(0)))) +/// } +/// +/// fn fallback<'a>(&'a self, info: &'a RoutingInfo, cluster: &'a ClusterData) -> FallbackPlan<'a> { +/// Box::new(self.inner +/// .fallback(info, cluster) +/// .map(|(node, shard)| (node, shard.or(Some(0))))) +/// } +/// +/// fn name(&self) -> String { +/// "NonRandomLBP".to_string() +/// } +/// } +/// ``` + pub struct Plan<'a> { policy: &'a dyn LoadBalancingPolicy, routing_info: &'a RoutingInfo<'a>, @@ -41,6 +87,21 @@ impl<'a> Plan<'a> { state: PlanState::Created, } } + + fn with_random_shard_if_unknown( + (node, shard): (NodeRef<'_>, Option), + ) -> (NodeRef<'_>, Shard) { + ( + node, + shard.unwrap_or_else(|| { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + thread_rng().gen_range(0..nr_shards).into() + }), + ) + } } impl<'a> Iterator for Plan<'a> { @@ -52,7 +113,7 @@ impl<'a> Iterator for Plan<'a> { let picked = self.policy.pick(self.routing_info, self.cluster); if let Some(picked) = picked { self.state = PlanState::Picked(picked); - Some(picked) + Some(Self::with_random_shard_if_unknown(picked)) } else { // `pick()` returned None, which semantically means that a first node cannot be computed _cheaply_. // This, however, does not imply that fallback would return an empty plan, too. @@ -64,9 +125,9 @@ impl<'a> Iterator for Plan<'a> { if let Some(node) = first_fallback_node { self.state = PlanState::Fallback { iter, - node_to_filter_out: node, + target_to_filter_out: node, }; - Some(node) + Some(Self::with_random_shard_if_unknown(node)) } else { error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info); self.state = PlanState::PickedNone; @@ -77,20 +138,20 @@ impl<'a> Iterator for Plan<'a> { PlanState::Picked(node) => { self.state = PlanState::Fallback { iter: self.policy.fallback(self.routing_info, self.cluster), - node_to_filter_out: *node, + target_to_filter_out: *node, }; self.next() } PlanState::Fallback { iter, - node_to_filter_out, + target_to_filter_out: node_to_filter_out, } => { for node in iter { if node == *node_to_filter_out { continue; } else { - return Some(node); + return Some(Self::with_random_shard_if_unknown(node)); } } @@ -135,7 +196,7 @@ mod tests { &'a self, _query: &'a RoutingInfo, _cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { None } @@ -147,7 +208,7 @@ mod tests { Box::new( self.expected_nodes .iter() - .map(|(node_ref, shard)| (node_ref, *shard)), + .map(|(node_ref, shard)| (node_ref, Some(*shard))), ) } diff --git a/scylla/tests/integration/consistency.rs b/scylla/tests/integration/consistency.rs index 5f178a3bea..a96e4450bb 100644 --- a/scylla/tests/integration/consistency.rs +++ b/scylla/tests/integration/consistency.rs @@ -379,7 +379,7 @@ impl LoadBalancingPolicy for RoutingInfoReportingWrapper { &'a self, query: &'a RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.routing_info_tx .send(OwnedRoutingInfo::from(query.clone())) .unwrap(); diff --git a/scylla/tests/integration/execution_profiles.rs b/scylla/tests/integration/execution_profiles.rs index c0d1964f0a..59f95dfa88 100644 --- a/scylla/tests/integration/execution_profiles.rs +++ b/scylla/tests/integration/execution_profiles.rs @@ -51,9 +51,13 @@ impl LoadBalancingPolicy for BoundToPredefinedNodePolicy { &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.report_node(Report::LoadBalancing); - cluster.get_nodes_info().iter().next().map(|node| (node, 0)) + cluster + .get_nodes_info() + .iter() + .next() + .map(|node| (node, None)) } fn fallback<'a>( diff --git a/scylla/tests/integration/utils.rs b/scylla/tests/integration/utils.rs index b32be090af..7839d772f3 100644 --- a/scylla/tests/integration/utils.rs +++ b/scylla/tests/integration/utils.rs @@ -1,8 +1,4 @@ use futures::Future; -use itertools::Itertools; -use scylla::load_balancing::LoadBalancingPolicy; -use scylla::routing::Shard; -use scylla::transport::NodeRef; use std::collections::HashMap; use std::env; use std::net::SocketAddr; @@ -19,66 +15,6 @@ pub(crate) fn setup_tracing() { .try_init(); } -fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Shard) { - let nr_shards = node - .sharder() - .map(|sharder| sharder.nr_shards.get()) - .unwrap_or(1); - (node, ((nr_shards - 1) % 42) as Shard) -} - -#[derive(Debug)] -pub(crate) struct FixedOrderLoadBalancer; -impl LoadBalancingPolicy for FixedOrderLoadBalancer { - fn pick<'a>( - &'a self, - _info: &'a scylla::load_balancing::RoutingInfo, - cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { - cluster - .get_nodes_info() - .iter() - .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)) - .next() - .map(with_pseudorandom_shard) - } - - fn fallback<'a>( - &'a self, - _info: &'a scylla::load_balancing::RoutingInfo, - cluster: &'a scylla::transport::ClusterData, - ) -> scylla::load_balancing::FallbackPlan<'a> { - Box::new( - cluster - .get_nodes_info() - .iter() - .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)) - .map(with_pseudorandom_shard), - ) - } - - fn on_query_success( - &self, - _: &scylla::load_balancing::RoutingInfo, - _: std::time::Duration, - _: NodeRef<'_>, - ) { - } - - fn on_query_failure( - &self, - _: &scylla::load_balancing::RoutingInfo, - _: std::time::Duration, - _: NodeRef<'_>, - _: &scylla_cql::errors::QueryError, - ) { - } - - fn name(&self) -> String { - "FixedOrderLoadBalancer".to_string() - } -} - pub(crate) async fn test_with_3_node_cluster( shard_awareness: ShardAwareness, test: F,