From 6ecc4449d75908cb864e0ca9960c5edf04f0bd0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Wed, 27 Mar 2024 11:34:49 +0100 Subject: [PATCH 1/4] LBP: Return Option instead of Shard This was already documented as such, but due to an oversight the code was in disagreement with documentation. Approach from the documentation is better, because the currently implemented approach prevented deduplication in Plan from working correctly. --- examples/custom_load_balancing_policy.rs | 6 +- .../src/transport/load_balancing/default.rs | 274 ++++++++++-------- scylla/src/transport/load_balancing/mod.rs | 5 +- scylla/src/transport/load_balancing/plan.rs | 89 +++++- scylla/tests/integration/consistency.rs | 2 +- .../tests/integration/execution_profiles.rs | 8 +- scylla/tests/integration/utils.rs | 6 +- 7 files changed, 252 insertions(+), 138 deletions(-) diff --git a/examples/custom_load_balancing_policy.rs b/examples/custom_load_balancing_policy.rs index fb1ae0cb7c..5c279f2331 100644 --- a/examples/custom_load_balancing_policy.rs +++ b/examples/custom_load_balancing_policy.rs @@ -18,12 +18,12 @@ struct CustomLoadBalancingPolicy { fav_datacenter_name: String, } -fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) { +fn with_random_shard(node: NodeRef) -> (NodeRef, Option) { let nr_shards = node .sharder() .map(|sharder| sharder.nr_shards.get()) .unwrap_or(1); - (node, thread_rng().gen_range(0..nr_shards) as Shard) + (node, Some(thread_rng().gen_range(0..nr_shards) as Shard)) } impl LoadBalancingPolicy for CustomLoadBalancingPolicy { @@ -31,7 +31,7 @@ impl LoadBalancingPolicy for CustomLoadBalancingPolicy { &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.fallback(_info, cluster).next() } diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index e3c1f97377..94525400d7 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -10,8 +10,10 @@ use itertools::{Either, Itertools}; use rand::{prelude::SliceRandom, thread_rng, Rng}; use rand_pcg::Pcg32; use scylla_cql::{errors::QueryError, frame::types::SerialConsistency, Consistency}; +use std::hash::{Hash, Hasher}; use std::{fmt, sync::Arc, time::Duration}; use tracing::{debug, warn}; +use uuid::Uuid; #[derive(Clone, Copy)] enum NodeLocationCriteria<'a> { @@ -75,7 +77,7 @@ pub struct DefaultPolicy { preferences: NodeLocationPreference, is_token_aware: bool, permit_dc_failover: bool, - pick_predicate: Box, Shard)) -> bool + Send + Sync>, + pick_predicate: Box, Option) -> bool + Send + Sync>, latency_awareness: Option, fixed_seed: Option, } @@ -97,7 +99,7 @@ impl LoadBalancingPolicy for DefaultPolicy { &'a self, query: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { let routing_info = self.routing_info(query, cluster); if let Some(ref token_with_strategy) = routing_info.token_with_strategy { if self.preferences.datacenter().is_some() @@ -126,13 +128,13 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_rack_picked = self.pick_replica( ts, NodeLocationCriteria::DatacenterAndRack(dc, rack), - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_local_rack_replica) = local_rack_picked { - return Some(alive_local_rack_replica); + if let Some((alive_local_rack_replica, shard)) = local_rack_picked { + return Some((alive_local_rack_replica, Some(shard))); } } @@ -143,13 +145,13 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let picked = self.pick_replica( ts, NodeLocationCriteria::Datacenter(dc), - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_local_replica) = picked { - return Some(alive_local_replica); + if let Some((alive_local_replica, shard)) = picked { + return Some((alive_local_replica, Some(shard))); } } @@ -161,12 +163,12 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let picked = self.pick_replica( ts, NodeLocationCriteria::Any, - &self.pick_predicate, + |node, shard| (self.pick_predicate)(node, Some(shard)), cluster, statement_type, ); - if let Some(alive_remote_replica) = picked { - return Some(alive_remote_replica); + if let Some((alive_remote_replica, shard)) = picked { + return Some((alive_remote_replica, Some(shard))); } } }; @@ -179,47 +181,47 @@ or refrain from preferring datacenters (which may ban all other datacenters, if if let NodeLocationPreference::DatacenterAndRack(dc, rack) = &self.preferences { // Try to pick some alive local rack random node. let rack_predicate = Self::make_rack_predicate( - &self.pick_predicate, + |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - let local_rack_picked = self.pick_node(nodes, rack_predicate); + let local_rack_picked = self.pick_node(nodes, |node| rack_predicate(&node)); if let Some(alive_local_rack) = local_rack_picked { - return Some(alive_local_rack); + return Some((alive_local_rack, None)); } } // Try to pick some alive local random node. - if let Some(alive_local) = self.pick_node(nodes, &self.pick_predicate) { - return Some(alive_local); + if let Some(alive_local) = self.pick_node(nodes, |node| (self.pick_predicate)(node, None)) { + return Some((alive_local, None)); } let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = self.pick_node(all_nodes, &self.pick_predicate); + let picked = self.pick_node(all_nodes, |node| (self.pick_predicate)(node, None)); if let Some(alive_maybe_remote) = picked { - return Some(alive_maybe_remote); + return Some((alive_maybe_remote, None)); } } // Previous checks imply that every node we could have selected is down. // Let's try to return a down node that wasn't disabled. - let picked = self.pick_node(nodes, |(node, _shard)| node.is_enabled()); + let picked = self.pick_node(nodes, |node| node.is_enabled()); if let Some(down_but_enabled_local_node) = picked { - return Some(down_but_enabled_local_node); + return Some((down_but_enabled_local_node, None)); } // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = self.pick_node(all_nodes, |(node, _shard)| node.is_enabled()); + let picked = self.pick_node(all_nodes, |node| node.is_enabled()); if let Some(down_but_enabled_maybe_remote_node) = picked { - return Some(down_but_enabled_maybe_remote_node); + return Some((down_but_enabled_maybe_remote_node, None)); } } // Every node is disabled. This could be due to a bad host filter - configuration error. - nodes.first().map(|node| self.with_random_shard(node)) + nodes.first().map(|node| (node, None)) } fn fallback<'a>( @@ -241,7 +243,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_rack_replicas = self.fallback_replicas( ts, NodeLocationCriteria::DatacenterAndRack(dc, rack), - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -257,7 +259,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let local_replicas = self.fallback_replicas( ts, NodeLocationCriteria::Datacenter(dc), - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -273,7 +275,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let remote_replicas = self.fallback_replicas( ts, NodeLocationCriteria::Any, - Self::is_alive, + |node, shard| Self::is_alive(node, Some(shard)), cluster, statement_type, ); @@ -287,10 +289,11 @@ or refrain from preferring datacenters (which may ban all other datacenters, if Either::Left( maybe_local_rack_replicas .chain(maybe_local_replicas) - .chain(maybe_remote_replicas), + .chain(maybe_remote_replicas) + .map(|(node, shard)| (node, Some(shard))), ) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; // Get a list of all local alive nodes, and apply a round robin to it @@ -299,31 +302,37 @@ or refrain from preferring datacenters (which may ban all other datacenters, if let maybe_local_rack_nodes = if let NodeLocationPreference::DatacenterAndRack(dc, rack) = &self.preferences { let rack_predicate = Self::make_rack_predicate( - &self.pick_predicate, + |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - Either::Left(self.round_robin_nodes_with_shards(local_nodes, rack_predicate)) + Either::Left( + self.round_robin_nodes(local_nodes, rack_predicate) + .map(|node| (node, None)), + ) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; - let robinned_local_nodes = self.round_robin_nodes_with_shards(local_nodes, Self::is_alive); + let robinned_local_nodes = self + .round_robin_nodes(local_nodes, |node| Self::is_alive(node, None)) + .map(|node| (node, None)); let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. let maybe_remote_nodes = if self.is_datacenter_failover_possible(&routing_info) { - let robinned_all_nodes = self.round_robin_nodes_with_shards(all_nodes, Self::is_alive); + let robinned_all_nodes = + self.round_robin_nodes(all_nodes, |node| Self::is_alive(node, None)); - Either::Left(robinned_all_nodes) + Either::Left(robinned_all_nodes.map(|node| (node, None))) } else { - Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Option)>()) }; // Even if we consider some enabled nodes to be down, we should try contacting them in the last resort. let maybe_down_local_nodes = local_nodes .iter() .filter(|node| node.is_enabled()) - .map(|node| self.with_random_shard(node)); + .map(|node| (node, None)); // If a datacenter failover is possible, loosen restriction about locality. let maybe_down_nodes = if self.is_datacenter_failover_possible(&routing_info) { @@ -331,12 +340,36 @@ or refrain from preferring datacenters (which may ban all other datacenters, if all_nodes .iter() .filter(|node| node.is_enabled()) - .map(|node| self.with_random_shard(node)), + .map(|node| (node, None)), ) } else { Either::Right(std::iter::empty()) }; + struct DefaultPolicyTargetComparator { + host_id: Uuid, + shard: Option, + } + + impl PartialEq for DefaultPolicyTargetComparator { + fn eq(&self, other: &Self) -> bool { + match (self.shard, other.shard) { + (_, None) | (None, _) => self.host_id.eq(&other.host_id), + (Some(shard_left), Some(shard_right)) => { + self.host_id.eq(&other.host_id) && shard_left.eq(&shard_right) + } + } + } + } + + impl Eq for DefaultPolicyTargetComparator {} + + impl Hash for DefaultPolicyTargetComparator { + fn hash(&self, state: &mut H) { + self.host_id.hash(state); + } + } + // Construct a fallback plan as a composition of replicas, local nodes and remote nodes. let plan = maybe_replicas .chain(maybe_local_rack_nodes) @@ -344,7 +377,10 @@ or refrain from preferring datacenters (which may ban all other datacenters, if .chain(maybe_remote_nodes) .chain(maybe_down_local_nodes) .chain(maybe_down_nodes) - .unique(); + .unique_by(|(node, shard)| DefaultPolicyTargetComparator { + host_id: node.host_id, + shard: *shard, + }); if let Some(latency_awareness) = self.latency_awareness.as_ref() { Box::new(latency_awareness.wrap(plan)) @@ -433,15 +469,28 @@ impl DefaultPolicy { /// Wraps the provided predicate, adding the requirement for rack to match. fn make_rack_predicate<'a>( - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'a>) -> bool + 'a, + replica_location: NodeLocationCriteria<'a>, + ) -> impl Fn(&NodeRef<'a>) -> bool { + move |node| match replica_location { + NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => predicate(node), + NodeLocationCriteria::DatacenterAndRack(_, rack) => { + predicate(node) && node.rack.as_deref() == Some(rack) + } + } + } + + /// Wraps the provided predicate, adding the requirement for rack to match. + fn make_sharded_rack_predicate<'a>( + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, replica_location: NodeLocationCriteria<'a>, - ) -> impl Fn(&(NodeRef<'a>, Shard)) -> bool { - move |node_and_shard @ (node, _shard)| match replica_location { + ) -> impl Fn(NodeRef<'a>, Shard) -> bool { + move |node, shard| match replica_location { NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => { - predicate(node_and_shard) + predicate(node, shard) } NodeLocationCriteria::DatacenterAndRack(_, rack) => { - predicate(node_and_shard) && node.rack.as_deref() == Some(rack) + predicate(node, shard) && node.rack.as_deref() == Some(rack) } } } @@ -450,11 +499,11 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, order: ReplicaOrder, ) -> impl Iterator, Shard)> { - let predicate = Self::make_rack_predicate(predicate, replica_location); + let predicate = Self::make_sharded_rack_predicate(predicate, replica_location); let replica_iter = match order { ReplicaOrder::Arbitrary => Either::Left( @@ -467,14 +516,14 @@ impl DefaultPolicy { .into_iter(), ), }; - replica_iter.filter(move |node_and_shard: &(NodeRef<'a>, Shard)| predicate(node_and_shard)) + replica_iter.filter(move |(node, shard): &(NodeRef<'a>, Shard)| predicate(node, *shard)) } fn pick_replica<'a>( &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, ) -> Option<(NodeRef<'a>, Shard)> { @@ -502,7 +551,7 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, ) -> Option<(NodeRef<'a>, Shard)> { match replica_location { @@ -521,8 +570,8 @@ impl DefaultPolicy { .into_replicas_ordered() .into_iter() .next() - .and_then(|primary_replica| { - predicate(&primary_replica).then_some(primary_replica) + .and_then(|(primary_replica, shard)| { + predicate(primary_replica, shard).then_some((primary_replica, shard)) }) } NodeLocationCriteria::Datacenter(_) | NodeLocationCriteria::DatacenterAndRack(_, _) => { @@ -534,7 +583,7 @@ impl DefaultPolicy { self.replicas( ts, replica_location, - move |node_and_shard| predicate(node_and_shard), + predicate, cluster, ReplicaOrder::RingOrder, ) @@ -547,18 +596,18 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, + predicate: impl Fn(NodeRef<'a>, Shard) -> bool + 'a, cluster: &'a ClusterData, ) -> Option<(NodeRef<'a>, Shard)> { - let predicate = Self::make_rack_predicate(predicate, replica_location); + let predicate = Self::make_sharded_rack_predicate(predicate, replica_location); let replica_set = self.nonfiltered_replica_set(ts, replica_location, cluster); if let Some(fixed) = self.fixed_seed { let mut gen = Pcg32::new(fixed, 0); - replica_set.choose_filtered(&mut gen, predicate) + replica_set.choose_filtered(&mut gen, |(node, shard)| predicate(node, *shard)) } else { - replica_set.choose_filtered(&mut thread_rng(), predicate) + replica_set.choose_filtered(&mut thread_rng(), |(node, shard)| predicate(node, *shard)) } } @@ -566,7 +615,7 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&(NodeRef<'_>, Shard)) -> bool + 'a, + predicate: impl Fn(NodeRef<'_>, Shard) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, ) -> impl Iterator, Shard)> { @@ -604,22 +653,21 @@ impl DefaultPolicy { fn pick_node<'a>( &'a self, nodes: &'a [Arc], - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, - ) -> Option<(NodeRef<'_>, Shard)> { + predicate: impl Fn(NodeRef<'a>) -> bool, + ) -> Option> { // Select the first node that matches the predicate - Self::randomly_rotated_nodes(nodes) - .map(|node| self.with_random_shard(node)) - .find(predicate) + Self::randomly_rotated_nodes(nodes).find(|&node| predicate(node)) } - fn round_robin_nodes_with_shards<'a>( + fn round_robin_nodes<'a>( &'a self, nodes: &'a [Arc], - predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, - ) -> impl Iterator, Shard)> { - Self::randomly_rotated_nodes(nodes) - .map(|node| self.with_random_shard(node)) - .filter(predicate) + // I wanted this to be + // impl Fn(&NodeRef<'a>) -> bool + // but I have no idea how to make this work with borrow checker + predicate: impl Fn(&NodeRef<'a>) -> bool, + ) -> impl Iterator> { + Self::randomly_rotated_nodes(nodes).filter(predicate) } fn shuffle<'a>( @@ -638,23 +686,7 @@ impl DefaultPolicy { vec.into_iter() } - fn with_random_shard<'a>(&self, node: NodeRef<'a>) -> (NodeRef<'a>, Shard) { - let nr_shards = node - .sharder() - .map(|sharder| sharder.nr_shards.get()) - .unwrap_or(1); - ( - node, - (if let Some(fixed) = self.fixed_seed { - let mut gen = Pcg32::new(fixed, 0); - gen.gen_range(0..nr_shards) - } else { - thread_rng().gen_range(0..nr_shards) - }) as Shard, - ) - } - - fn is_alive(&(node, _shard): &(NodeRef<'_>, Shard)) -> bool { + fn is_alive(node: NodeRef, _shard: Option) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() @@ -720,11 +752,10 @@ impl DefaultPolicyBuilder { let latency_awareness = self.latency_awareness.map(|builder| builder.build()); let pick_predicate = if let Some(ref latency_awareness) = latency_awareness { let latency_predicate = latency_awareness.generate_predicate(); - Box::new( - move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { - DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) - }, - ) as Box, Shard)) -> bool + Send + Sync + 'static> + Box::new(move |node: NodeRef<'_>, shard| { + DefaultPolicy::is_alive(node, shard) && latency_predicate(node) + }) + as Box, Option) -> bool + Send + Sync + 'static> } else { Box::new(DefaultPolicy::is_alive) }; @@ -912,6 +943,7 @@ impl<'a> TokenWithStrategy<'a> { #[cfg(test)] mod tests { use scylla_cql::{frame::types::SerialConsistency, Consistency}; + use tracing::info; use self::framework::{ get_plan_and_collect_node_identifiers, mock_cluster_data_for_token_unaware_tests, @@ -919,7 +951,7 @@ mod tests { }; use crate::{ load_balancing::{ - default::tests::framework::mock_cluster_data_for_token_aware_tests, RoutingInfo, + default::tests::framework::mock_cluster_data_for_token_aware_tests, Plan, RoutingInfo, }, routing::Token, test_utils::setup_tracing, @@ -947,6 +979,7 @@ mod tests { }, }; + #[derive(Debug)] enum ExpectedGroup { NonDeterministic(HashSet), Deterministic(HashSet), @@ -1002,6 +1035,7 @@ mod tests { } } + #[derive(Debug)] pub(crate) struct ExpectedGroups { groups: Vec, } @@ -1197,6 +1231,19 @@ mod tests { let plan = get_plan_and_collect_node_identifiers(policy, routing_info, cluster); plans.push(plan); } + let example_plan = Plan::new(policy, routing_info, cluster); + info!("Example plan from policy:",); + for (node, shard) in example_plan { + info!( + "Node port: {}, shard: {}, dc: {:?}, rack: {:?}, down: {:?}", + node.address.port(), + shard, + node.datacenter, + node.rack, + node.is_down() + ); + } + expected_groups.assert_proper_grouping_in_plans(&plans); } @@ -1252,9 +1299,11 @@ mod tests { #[tokio::test] async fn test_default_policy_with_token_aware_statements() { setup_tracing(); - use crate::transport::locator::test::{A, B, C, D, E, F, G}; + use crate::transport::locator::test::{A, B, C, D, E, F, G}; let cluster = mock_cluster_data_for_token_aware_tests().await; + + #[derive(Debug)] struct Test<'a> { policy: DefaultPolicy, routing_info: RoutingInfo<'a>, @@ -1704,12 +1753,13 @@ mod tests { }, ]; - for Test { - policy, - routing_info, - expected_groups, - } in tests - { + for test in tests { + info!("Test: {:?}", test); + let Test { + policy, + routing_info, + expected_groups, + } = test; test_default_policy_with_given_cluster_and_routing_info( &policy, &cluster, @@ -2401,8 +2451,8 @@ mod latency_awareness { pub(super) fn wrap<'a>( &self, - fallback: impl Iterator, Shard)>, - ) -> impl Iterator, Shard)> { + fallback: impl Iterator, Option)>, + ) -> impl Iterator, Option)> { let min_avg_latency = match self.last_min_latency.load() { Some(min_avg) => min_avg, None => return Either::Left(fallback), // noop, as no latency data has been collected yet @@ -2723,8 +2773,8 @@ mod latency_awareness { struct IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator, Shard)>, - Penalised: Iterator, Shard)>, + Fast: Iterator, Option)>, + Penalised: Iterator, Option)>, { fast_nodes: Fast, penalised_nodes: Penalised, @@ -2733,13 +2783,13 @@ mod latency_awareness { impl<'a> IteratorWithSkippedNodes< 'a, - std::vec::IntoIter<(NodeRef<'a>, Shard)>, - std::vec::IntoIter<(NodeRef<'a>, Shard)>, + std::vec::IntoIter<(NodeRef<'a>, Option)>, + std::vec::IntoIter<(NodeRef<'a>, Option)>, > { fn new( average_latencies: &HashMap>>, - nodes: impl Iterator, Shard)>, + nodes: impl Iterator, Option)>, exclusion_threshold: f64, retry_period: Duration, minimum_measurements: usize, @@ -2775,10 +2825,10 @@ mod latency_awareness { impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator, Shard)>, - Penalised: Iterator, Shard)>, + Fast: Iterator, Option)>, + Penalised: Iterator, Option)>, { - type Item = (NodeRef<'a>, Shard); + type Item = (NodeRef<'a>, Option); fn next(&mut self) -> Option { self.fast_nodes @@ -2860,12 +2910,10 @@ mod latency_awareness { ) -> DefaultPolicy { let pick_predicate = { let latency_predicate = latency_awareness.generate_predicate(); - Box::new( - move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { - DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) - }, - ) - as Box, Shard)) -> bool + Send + Sync + 'static> + Box::new(move |node: NodeRef<'_>, shard| { + DefaultPolicy::is_alive(node, shard) && latency_predicate(node) + }) + as Box, Option) -> bool + Send + Sync + 'static> }; DefaultPolicy { diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index 977e3d508f..f1cd5bdf27 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -39,7 +39,8 @@ pub struct RoutingInfo<'a> { /// /// It is computed on-demand, only if querying the most preferred node fails /// (or when speculative execution is triggered). -pub type FallbackPlan<'a> = Box, Shard)> + Send + Sync + 'a>; +pub type FallbackPlan<'a> = + Box, Option)> + Send + Sync + 'a>; /// Policy that decides which nodes and shards to contact for each query. /// @@ -67,7 +68,7 @@ pub trait LoadBalancingPolicy: Send + Sync + std::fmt::Debug { &'a self, query: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)>; + ) -> Option<(NodeRef<'a>, Option)>; /// Returns all contact-appropriate nodes for a given query. fn fallback<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) diff --git a/scylla/src/transport/load_balancing/plan.rs b/scylla/src/transport/load_balancing/plan.rs index 5fc6294467..3dc946c58b 100644 --- a/scylla/src/transport/load_balancing/plan.rs +++ b/scylla/src/transport/load_balancing/plan.rs @@ -1,3 +1,4 @@ +use rand::{thread_rng, Rng}; use tracing::error; use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; @@ -6,20 +7,65 @@ use crate::{routing::Shard, transport::ClusterData}; enum PlanState<'a> { Created, PickedNone, // This always means an abnormal situation: it means that no nodes satisfied locality/node filter requirements. - Picked((NodeRef<'a>, Shard)), + Picked((NodeRef<'a>, Option)), Fallback { iter: FallbackPlan<'a>, - node_to_filter_out: (NodeRef<'a>, Shard), + target_to_filter_out: (NodeRef<'a>, Option), }, } -/// The list of nodes constituting the query plan. +/// The list of targets constituting the query plan. Target here is a pair `(NodeRef<'a>, Shard)`. /// -/// The plan is partly lazily computed, with the first node computed -/// eagerly in the first place and the remaining nodes computed on-demand +/// The plan is partly lazily computed, with the first target computed +/// eagerly in the first place and the remaining targets computed on-demand /// (all at once). /// This significantly reduces the allocation overhead on "the happy path" -/// (when the first node successfully handles the request), +/// (when the first target successfully handles the request). +/// +/// `Plan` implements `Iterator, Shard)>` but LoadBalancingPolicy +/// returns `Option` instead of `Shard` both in `pick` and in `fallback`. +/// `Plan` handles the `None` case by using random shard for a given node. +/// There is currently no way to configure RNG used by `Plan`. +/// If you don't want `Plan` to do randomize shards or you want to control the RNG, +/// use custom LBP that will always return non-`None` shards. +/// Example of LBP that always uses shard 0, preventing `Plan` from using random numbers: +/// +/// ``` +/// # use std::sync::Arc; +/// # use scylla::load_balancing::LoadBalancingPolicy; +/// # use scylla::load_balancing::RoutingInfo; +/// # use scylla::transport::ClusterData; +/// # use scylla::transport::NodeRef; +/// # use scylla::routing::Shard; +/// # use scylla::load_balancing::FallbackPlan; +/// +/// #[derive(Debug)] +/// struct NonRandomLBP { +/// inner: Arc, +/// } +/// impl LoadBalancingPolicy for NonRandomLBP { +/// fn pick<'a>( +/// &'a self, +/// info: &'a RoutingInfo, +/// cluster: &'a ClusterData, +/// ) -> Option<(NodeRef<'a>, Option)> { +/// self.inner +/// .pick(info, cluster) +/// .map(|(node, shard)| (node, shard.or(Some(0)))) +/// } +/// +/// fn fallback<'a>(&'a self, info: &'a RoutingInfo, cluster: &'a ClusterData) -> FallbackPlan<'a> { +/// Box::new(self.inner +/// .fallback(info, cluster) +/// .map(|(node, shard)| (node, shard.or(Some(0))))) +/// } +/// +/// fn name(&self) -> String { +/// "NonRandomLBP".to_string() +/// } +/// } +/// ``` + pub struct Plan<'a> { policy: &'a dyn LoadBalancingPolicy, routing_info: &'a RoutingInfo<'a>, @@ -41,6 +87,21 @@ impl<'a> Plan<'a> { state: PlanState::Created, } } + + fn with_random_shard_if_unknown( + (node, shard): (NodeRef<'_>, Option), + ) -> (NodeRef<'_>, Shard) { + ( + node, + shard.unwrap_or_else(|| { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + thread_rng().gen_range(0..nr_shards).into() + }), + ) + } } impl<'a> Iterator for Plan<'a> { @@ -52,7 +113,7 @@ impl<'a> Iterator for Plan<'a> { let picked = self.policy.pick(self.routing_info, self.cluster); if let Some(picked) = picked { self.state = PlanState::Picked(picked); - Some(picked) + Some(Self::with_random_shard_if_unknown(picked)) } else { // `pick()` returned None, which semantically means that a first node cannot be computed _cheaply_. // This, however, does not imply that fallback would return an empty plan, too. @@ -64,9 +125,9 @@ impl<'a> Iterator for Plan<'a> { if let Some(node) = first_fallback_node { self.state = PlanState::Fallback { iter, - node_to_filter_out: node, + target_to_filter_out: node, }; - Some(node) + Some(Self::with_random_shard_if_unknown(node)) } else { error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info); self.state = PlanState::PickedNone; @@ -77,20 +138,20 @@ impl<'a> Iterator for Plan<'a> { PlanState::Picked(node) => { self.state = PlanState::Fallback { iter: self.policy.fallback(self.routing_info, self.cluster), - node_to_filter_out: *node, + target_to_filter_out: *node, }; self.next() } PlanState::Fallback { iter, - node_to_filter_out, + target_to_filter_out: node_to_filter_out, } => { for node in iter { if node == *node_to_filter_out { continue; } else { - return Some(node); + return Some(Self::with_random_shard_if_unknown(node)); } } @@ -135,7 +196,7 @@ mod tests { &'a self, _query: &'a RoutingInfo, _cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { None } @@ -147,7 +208,7 @@ mod tests { Box::new( self.expected_nodes .iter() - .map(|(node_ref, shard)| (node_ref, *shard)), + .map(|(node_ref, shard)| (node_ref, Some(*shard))), ) } diff --git a/scylla/tests/integration/consistency.rs b/scylla/tests/integration/consistency.rs index 5f178a3bea..a96e4450bb 100644 --- a/scylla/tests/integration/consistency.rs +++ b/scylla/tests/integration/consistency.rs @@ -379,7 +379,7 @@ impl LoadBalancingPolicy for RoutingInfoReportingWrapper { &'a self, query: &'a RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.routing_info_tx .send(OwnedRoutingInfo::from(query.clone())) .unwrap(); diff --git a/scylla/tests/integration/execution_profiles.rs b/scylla/tests/integration/execution_profiles.rs index c0d1964f0a..59f95dfa88 100644 --- a/scylla/tests/integration/execution_profiles.rs +++ b/scylla/tests/integration/execution_profiles.rs @@ -51,9 +51,13 @@ impl LoadBalancingPolicy for BoundToPredefinedNodePolicy { &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { self.report_node(Report::LoadBalancing); - cluster.get_nodes_info().iter().next().map(|node| (node, 0)) + cluster + .get_nodes_info() + .iter() + .next() + .map(|node| (node, None)) } fn fallback<'a>( diff --git a/scylla/tests/integration/utils.rs b/scylla/tests/integration/utils.rs index b32be090af..f30021db33 100644 --- a/scylla/tests/integration/utils.rs +++ b/scylla/tests/integration/utils.rs @@ -19,12 +19,12 @@ pub(crate) fn setup_tracing() { .try_init(); } -fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Shard) { +fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Option) { let nr_shards = node .sharder() .map(|sharder| sharder.nr_shards.get()) .unwrap_or(1); - (node, ((nr_shards - 1) % 42) as Shard) + (node, Some(((nr_shards - 1) % 42) as Shard)) } #[derive(Debug)] @@ -34,7 +34,7 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { &'a self, _info: &'a scylla::load_balancing::RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Shard)> { + ) -> Option<(NodeRef<'a>, Option)> { cluster .get_nodes_info() .iter() From 91561919eea65e3241168f486cbfdccb04b473ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Tue, 26 Mar 2024 22:41:44 +0100 Subject: [PATCH 2/4] default_lbp: get rid of redundant indirection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Wojciech Przytuła --- scylla/src/transport/load_balancing/default.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 94525400d7..30e9d19f99 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -184,7 +184,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if |node| (self.pick_predicate)(node, None), NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - let local_rack_picked = self.pick_node(nodes, |node| rack_predicate(&node)); + let local_rack_picked = self.pick_node(nodes, rack_predicate); if let Some(alive_local_rack) = local_rack_picked { return Some((alive_local_rack, None)); @@ -471,7 +471,7 @@ impl DefaultPolicy { fn make_rack_predicate<'a>( predicate: impl Fn(NodeRef<'a>) -> bool + 'a, replica_location: NodeLocationCriteria<'a>, - ) -> impl Fn(&NodeRef<'a>) -> bool { + ) -> impl Fn(NodeRef<'a>) -> bool { move |node| match replica_location { NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => predicate(node), NodeLocationCriteria::DatacenterAndRack(_, rack) => { @@ -662,12 +662,9 @@ impl DefaultPolicy { fn round_robin_nodes<'a>( &'a self, nodes: &'a [Arc], - // I wanted this to be - // impl Fn(&NodeRef<'a>) -> bool - // but I have no idea how to make this work with borrow checker - predicate: impl Fn(&NodeRef<'a>) -> bool, + predicate: impl Fn(NodeRef<'a>) -> bool, ) -> impl Iterator> { - Self::randomly_rotated_nodes(nodes).filter(predicate) + Self::randomly_rotated_nodes(nodes).filter(move |node| predicate(node)) } fn shuffle<'a>( From 53801ae10f36c6340d252b06925a4958132054f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 26 Mar 2024 16:20:55 +0100 Subject: [PATCH 3/4] default_lbp: replace boilerplate with from_iter() As there is a brilliant `std::iter::from_iter()` function that creates a new iterator based on a closure, it can be used instead of verbose boilerplate incurred by introducing IteratorWithSkippedNodes. --- .../src/transport/load_balancing/default.rs | 105 ++++++------------ 1 file changed, 32 insertions(+), 73 deletions(-) diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 30e9d19f99..625232cfae 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -2455,14 +2455,38 @@ mod latency_awareness { None => return Either::Left(fallback), // noop, as no latency data has been collected yet }; - Either::Right(IteratorWithSkippedNodes::new( - self.node_avgs.read().unwrap().deref(), - fallback, - self.exclusion_threshold, - self.retry_period, - self.minimum_measurements, - min_avg_latency, - )) + let average_latencies = self.node_avgs.read().unwrap(); + let targets = fallback; + + let mut fast_targets = vec![]; + let mut penalised_targets = vec![]; + + for node_and_shard @ (node, _shard) in targets { + match fast_enough( + average_latencies.deref(), + node.host_id, + self.exclusion_threshold, + self.retry_period, + self.minimum_measurements, + min_avg_latency, + ) { + FastEnough::Yes => fast_targets.push(node_and_shard), + FastEnough::No { average } => { + trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).", + node.address, node.datacenter, node.rack, self.exclusion_threshold, average.as_millis(), min_avg_latency.as_millis()); + penalised_targets.push(node_and_shard); + } + } + } + + let mut fast_targets = fast_targets.into_iter(); + let mut penalised_targets = penalised_targets.into_iter(); + + let skipping_penalised_targets_iterator = std::iter::from_fn(move || { + fast_targets.next().or_else(|| penalised_targets.next()) + }); + + Either::Right(skipping_penalised_targets_iterator) } pub(super) fn report_query(&self, node: &Node, latency: Duration) { @@ -2768,71 +2792,6 @@ mod latency_awareness { } } - struct IteratorWithSkippedNodes<'a, Fast, Penalised> - where - Fast: Iterator, Option)>, - Penalised: Iterator, Option)>, - { - fast_nodes: Fast, - penalised_nodes: Penalised, - } - - impl<'a> - IteratorWithSkippedNodes< - 'a, - std::vec::IntoIter<(NodeRef<'a>, Option)>, - std::vec::IntoIter<(NodeRef<'a>, Option)>, - > - { - fn new( - average_latencies: &HashMap>>, - nodes: impl Iterator, Option)>, - exclusion_threshold: f64, - retry_period: Duration, - minimum_measurements: usize, - min_avg: Duration, - ) -> Self { - let mut fast_nodes = vec![]; - let mut penalised_nodes = vec![]; - - for node_and_shard @ (node, _shard) in nodes { - match fast_enough( - average_latencies, - node.host_id, - exclusion_threshold, - retry_period, - minimum_measurements, - min_avg, - ) { - FastEnough::Yes => fast_nodes.push(node_and_shard), - FastEnough::No { average } => { - trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).", - node.address, node.datacenter, node.rack, exclusion_threshold, average.as_millis(), min_avg.as_millis()); - penalised_nodes.push(node_and_shard); - } - } - } - - Self { - fast_nodes: fast_nodes.into_iter(), - penalised_nodes: penalised_nodes.into_iter(), - } - } - } - - impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised> - where - Fast: Iterator, Option)>, - Penalised: Iterator, Option)>, - { - type Item = (NodeRef<'a>, Option); - - fn next(&mut self) -> Option { - self.fast_nodes - .next() - .or_else(|| self.penalised_nodes.next()) - } - } #[cfg(test)] mod tests { use scylla_cql::Consistency; From e90f102dd1b1063d36c49799ec2fdee4d40eaccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Bary=C5=82a?= Date: Wed, 27 Mar 2024 11:39:10 +0100 Subject: [PATCH 4/4] Integration tests: Remove unused utils `with_pseudorandom_shard` and `FixedOrderLoadBalancer` are no longer used. No need to keep them around. --- scylla/tests/integration/utils.rs | 64 ------------------------------- 1 file changed, 64 deletions(-) diff --git a/scylla/tests/integration/utils.rs b/scylla/tests/integration/utils.rs index f30021db33..7839d772f3 100644 --- a/scylla/tests/integration/utils.rs +++ b/scylla/tests/integration/utils.rs @@ -1,8 +1,4 @@ use futures::Future; -use itertools::Itertools; -use scylla::load_balancing::LoadBalancingPolicy; -use scylla::routing::Shard; -use scylla::transport::NodeRef; use std::collections::HashMap; use std::env; use std::net::SocketAddr; @@ -19,66 +15,6 @@ pub(crate) fn setup_tracing() { .try_init(); } -fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Option) { - let nr_shards = node - .sharder() - .map(|sharder| sharder.nr_shards.get()) - .unwrap_or(1); - (node, Some(((nr_shards - 1) % 42) as Shard)) -} - -#[derive(Debug)] -pub(crate) struct FixedOrderLoadBalancer; -impl LoadBalancingPolicy for FixedOrderLoadBalancer { - fn pick<'a>( - &'a self, - _info: &'a scylla::load_balancing::RoutingInfo, - cluster: &'a scylla::transport::ClusterData, - ) -> Option<(NodeRef<'a>, Option)> { - cluster - .get_nodes_info() - .iter() - .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)) - .next() - .map(with_pseudorandom_shard) - } - - fn fallback<'a>( - &'a self, - _info: &'a scylla::load_balancing::RoutingInfo, - cluster: &'a scylla::transport::ClusterData, - ) -> scylla::load_balancing::FallbackPlan<'a> { - Box::new( - cluster - .get_nodes_info() - .iter() - .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)) - .map(with_pseudorandom_shard), - ) - } - - fn on_query_success( - &self, - _: &scylla::load_balancing::RoutingInfo, - _: std::time::Duration, - _: NodeRef<'_>, - ) { - } - - fn on_query_failure( - &self, - _: &scylla::load_balancing::RoutingInfo, - _: std::time::Duration, - _: NodeRef<'_>, - _: &scylla_cql::errors::QueryError, - ) { - } - - fn name(&self) -> String { - "FixedOrderLoadBalancer".to_string() - } -} - pub(crate) async fn test_with_3_node_cluster( shard_awareness: ShardAwareness, test: F,