Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion scylla/src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::policies::timestamp_generator::TimestampGenerator;
use crate::response::query_result::{MaybeFirstRowError, QueryResult, RowsError};
use crate::response::{NonErrorQueryResponse, PagingState, PagingStateResponse, QueryResponse};
use crate::routing::partitioner::PartitionerName;
use crate::routing::Shard;
use crate::routing::{Shard, ShardAwarePortRange};
use crate::statement::batch::batch_values;
use crate::statement::batch::{Batch, BatchStatement};
use crate::statement::prepared::{PartitionKeyError, PreparedStatement};
Expand Down Expand Up @@ -162,6 +162,11 @@ pub struct SessionConfig {
/// - `in6addr_any` for IPv6 ([`Ipv6Addr::UNSPECIFIED`][std::net::Ipv6Addr::UNSPECIFIED])
pub local_ip_address: Option<IpAddr>,

/// Specifies the local port range used for shard-aware connections.
///
/// By default set to [`ShardAwarePortRange::EPHEMERAL_PORT_RANGE`].
pub shard_aware_local_port_range: ShardAwarePortRange,

/// Preferred compression algorithm to use on connections.
/// If it's not supported by database server Session will fall back to no compression.
pub compression: Option<Compression>,
Expand Down Expand Up @@ -301,6 +306,7 @@ impl SessionConfig {
SessionConfig {
known_nodes: Vec::new(),
local_ip_address: None,
shard_aware_local_port_range: ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
compression: None,
tcp_nodelay: true,
tcp_keepalive_interval: None,
Expand Down Expand Up @@ -906,6 +912,7 @@ impl Session {

let connection_config = ConnectionConfig {
local_ip_address: config.local_ip_address,
shard_aware_local_port_range: config.shard_aware_local_port_range,
compression: config.compression,
tcp_nodelay: config.tcp_nodelay,
tcp_keepalive_interval: config.tcp_keepalive_interval,
Expand Down
30 changes: 30 additions & 0 deletions scylla/src/client/session_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::errors::NewSessionError;
use crate::policies::address_translator::AddressTranslator;
use crate::policies::host_filter::HostFilter;
use crate::policies::timestamp_generator::TimestampGenerator;
use crate::routing::ShardAwarePortRange;
use crate::statement::Consistency;
use std::borrow::Borrow;
use std::marker::PhantomData;
Expand Down Expand Up @@ -420,6 +421,35 @@ impl<K: SessionBuilderKind> GenericSessionBuilder<K> {
self
}

/// Specifies the local port range used for shard-aware connections.
///
/// A possible use case is when you want to have multiple [`Session`] objects and do not want
/// them to compete for the ports within the same range. It is then advised to assign
/// mutually non-overlapping port ranges to each session object.
///
/// By default this option is set to [`ShardAwarePortRange::EPHEMERAL_PORT_RANGE`].
///
/// For details, see [`ShardAwarePortRange`] documentation.
///
/// # Example
/// ```
/// # use scylla::client::session::Session;
/// # use scylla::client::session_builder::SessionBuilder;
/// # use scylla::routing::ShardAwarePortRange;
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let session: Session = SessionBuilder::new()
/// .known_node("127.0.0.1:9042")
/// .shard_aware_local_port_range(ShardAwarePortRange::new(49200..=50000)?)
/// .build()
/// .await?;
/// # Ok(())
/// # }
/// ```
pub fn shard_aware_local_port_range(mut self, port_range: ShardAwarePortRange) -> Self {
self.config.shard_aware_local_port_range = port_range;
self
}

/// Set preferred Compression algorithm.
/// The default is no compression.
/// If it is not supported by database server Session will fall back to no encryption.
Expand Down
10 changes: 8 additions & 2 deletions scylla/src/network/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
NonErrorAuthResponse, NonErrorStartupResponse, PagingState, PagingStateResponse, QueryResponse,
};
use crate::routing::locator::tablets::{RawTablet, TabletParsingError};
use crate::routing::{Shard, ShardInfo, Sharder, ShardingError};
use crate::routing::{Shard, ShardAwarePortRange, ShardInfo, Sharder, ShardingError};
use crate::statement::batch::{Batch, BatchStatement};
use crate::statement::prepared::PreparedStatement;
use crate::statement::unprepared::Statement;
Expand Down Expand Up @@ -274,6 +274,7 @@
#[derive(Clone)]
pub(crate) struct ConnectionConfig {
pub(crate) local_ip_address: Option<IpAddr>,
pub(crate) shard_aware_local_port_range: ShardAwarePortRange,
pub(crate) compression: Option<Compression>,
pub(crate) tcp_nodelay: bool,
pub(crate) tcp_keepalive_interval: Option<Duration>,
Expand Down Expand Up @@ -309,6 +310,7 @@

HostConnectionConfig {
local_ip_address: self.local_ip_address,
shard_aware_local_port_range: self.shard_aware_local_port_range.clone(),
compression: self.compression,
tcp_nodelay: self.tcp_nodelay,
tcp_keepalive_interval: self.tcp_keepalive_interval,
Expand All @@ -334,6 +336,7 @@
#[derive(Clone)]
pub(crate) struct HostConnectionConfig {
pub(crate) local_ip_address: Option<IpAddr>,
pub(crate) shard_aware_local_port_range: ShardAwarePortRange,
pub(crate) compression: Option<Compression>,
pub(crate) tcp_nodelay: bool,
pub(crate) tcp_keepalive_interval: Option<Duration>,
Expand All @@ -359,6 +362,7 @@
fn default() -> Self {
Self {
local_ip_address: None,
shard_aware_local_port_range: ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
compression: None,
tcp_nodelay: true,
tcp_keepalive_interval: None,
Expand Down Expand Up @@ -387,6 +391,7 @@
fn default() -> Self {
Self {
local_ip_address: None,
shard_aware_local_port_range: ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
compression: None,
tcp_nodelay: true,
tcp_keepalive_interval: None,
Expand Down Expand Up @@ -2022,7 +2027,8 @@
config: &HostConnectionConfig,
) -> Result<(Connection, ErrorReceiver), ConnectionError> {
// Create iterator over all possible source ports for this shard
let source_port_iter = sharder.iter_source_ports_for_shard(shard);
let source_port_iter =
sharder.iter_source_ports_for_shard_from_range(shard, &config.shard_aware_local_port_range);

for port in source_port_iter {
let connect_result = open_connection(endpoint, Some(port), config).await;
Expand Down Expand Up @@ -2282,14 +2288,14 @@

#[cfg(test)]
mod tests {
use assert_matches::assert_matches;

Check warning on line 2291 in scylla/src/network/connection.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `assert_matches::assert_matches`
use scylla_cql::frame::protocol_features::{
LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION,
};
use scylla_cql::frame::types;
use scylla_proxy::{
Condition, Node, Proxy, Reaction, RequestFrame, RequestOpcode, RequestReaction,
RequestRule, ResponseFrame, ShardAwareness,

Check warning on line 2298 in scylla/src/network/connection.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `ShardAwareness`
};

use tokio::select;
Expand All @@ -2298,14 +2304,14 @@
use super::{open_connection, HostConnectionConfig};
use crate::cluster::metadata::UntranslatedEndpoint;
use crate::cluster::node::ResolvedContactPoint;
use crate::statement::unprepared::Statement;

Check warning on line 2307 in scylla/src/network/connection.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `crate::statement::unprepared::Statement`
use crate::test_utils::setup_tracing;
use crate::utils::test_utils::{resolve_hostname, unique_keyspace_name, PerformDDL};

Check warning on line 2309 in scylla/src/network/connection.rs

View workflow job for this annotation

GitHub Actions / build

unused imports: `PerformDDL`, `resolve_hostname`, and `unique_keyspace_name`
use futures::{StreamExt, TryStreamExt};

Check warning on line 2310 in scylla/src/network/connection.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `TryStreamExt`
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

Check warning on line 2314 in scylla/src/network/connection.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `std::time::Duration`

/// Tests for Connection::query_iter
/// 1. SELECT from an empty table.
Expand Down
2 changes: 1 addition & 1 deletion scylla/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub mod locator;
pub mod partitioner;
mod sharding;

pub use sharding::{Shard, ShardCount, Sharder};
pub use sharding::{InvalidShardAwarePortRange, Shard, ShardAwarePortRange, ShardCount, Sharder};
pub(crate) use sharding::{ShardInfo, ShardingError};

#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)]
Expand Down
177 changes: 147 additions & 30 deletions scylla/src/routing/sharding.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,50 @@
use std::collections::HashMap;
use std::num::NonZeroU16;
use std::ops::RangeInclusive;

use rand::Rng as _;
use thiserror::Error;

use super::Token;

/// A range of ports that can be used for shard-aware connections.
///
/// The range is inclusive and has to be a sub-range of [1024, 65535].
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct ShardAwarePortRange(RangeInclusive<u16>);

impl ShardAwarePortRange {
/// The default shard-aware local port range - [49152, 65535].
pub const EPHEMERAL_PORT_RANGE: Self = Self(49152..=65535);

/// Creates a new `ShardAwarePortRange` with the given range.
///
/// The error is returned in two cases:
/// 1. Provided range is empty (`end` < `start`).
/// 2. Provided range starts at a port lower than 1024. Ports 0-1023 are reserved and
/// should not be used by application.
#[inline]
pub fn new(range: impl Into<RangeInclusive<u16>>) -> Result<Self, InvalidShardAwarePortRange> {
let range = range.into();
if range.is_empty() || range.start() < &1024 {
return Err(InvalidShardAwarePortRange);
}
Ok(Self(range))
}
}

impl Default for ShardAwarePortRange {
fn default() -> Self {
Self::EPHEMERAL_PORT_RANGE
}
}

/// An error returned by [`ShardAwarePortRange::new()`].
#[derive(Debug, Error)]
#[error("Invalid shard-aware local port range")]
pub struct InvalidShardAwarePortRange;

pub type Shard = u32;
pub type ShardCount = NonZeroU16;

Expand Down Expand Up @@ -64,31 +103,68 @@ impl Sharder {
}

/// Randomly choose a source port `p` such that `shard == shard_of_source_port(p)`.
///
/// The port is chosen from ephemeral port range [49152, 65535].
pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 {
self.draw_source_port_for_shard_from_range(
shard,
&ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
)
}

/// Randomly choose a source port `p` such that `shard == shard_of_source_port(p)`.
///
/// The port is chosen from the provided port range.
pub(crate) fn draw_source_port_for_shard_from_range(
&self,
shard: Shard,
port_range: &ShardAwarePortRange,
) -> u16 {
assert!(shard < self.nr_shards.get() as u32);
rand::rng()
.random_range((49152 + self.nr_shards.get() - 1)..(65535 - self.nr_shards.get() + 1))
/ self.nr_shards.get()
let (range_start, range_end) = (port_range.0.start(), port_range.0.end());
rand::rng().random_range(
(range_start + self.nr_shards.get() - 1)..(range_end - self.nr_shards.get() + 1),
) / self.nr_shards.get()
* self.nr_shards.get()
+ shard as u16
}

/// Returns iterator over source ports `p` such that `shard == shard_of_source_port(p)`.
/// Starts at a random port and goes forward by `nr_shards`. After reaching maximum wraps back around.
/// Stops once all possible ports have been returned
///
/// The ports are chosen from ephemeral port range [49152, 65535].
pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator<Item = u16> {
self.iter_source_ports_for_shard_from_range(
shard,
&ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
)
}

/// Returns iterator over source ports `p` such that `shard == shard_of_source_port(p)`.
/// Starts at a random port and goes forward by `nr_shards`. After reaching maximum wraps back around.
/// Stops once all possible ports have been returned
///
/// The ports are chosen from the provided port range.
pub(crate) fn iter_source_ports_for_shard_from_range(
&self,
shard: Shard,
port_range: &ShardAwarePortRange,
) -> impl Iterator<Item = u16> {
assert!(shard < self.nr_shards.get() as u32);

let (range_start, range_end) = (port_range.0.start(), port_range.0.end());

// Randomly choose a port to start at
let starting_port = self.draw_source_port_for_shard(shard);
let starting_port = self.draw_source_port_for_shard_from_range(shard, port_range);

// Choose smallest available port number to begin at after wrapping
// apply the formula from draw_source_port_for_shard for lowest possible gen_range result
let first_valid_port = (49152 + self.nr_shards.get() - 1) / self.nr_shards.get()
let first_valid_port = (range_start + self.nr_shards.get() - 1) / self.nr_shards.get()
* self.nr_shards.get()
+ shard as u16;

let before_wrap = (starting_port..=65535).step_by(self.nr_shards.get().into());
let before_wrap = (starting_port..=*range_end).step_by(self.nr_shards.get().into());
let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into());

before_wrap.chain(after_wrap)
Expand Down Expand Up @@ -174,12 +250,30 @@ impl ShardInfo {

#[cfg(test)]
mod tests {
use crate::routing::{Shard, ShardAwarePortRange};
use crate::test_utils::setup_tracing;

use super::Token;
use super::{ShardCount, Sharder};
use std::collections::HashSet;

#[test]
fn test_shard_aware_port_range_constructor() {
setup_tracing();

// Test valid range
let range = ShardAwarePortRange::new(49152..=65535).unwrap();
assert_eq!(range, ShardAwarePortRange::EPHEMERAL_PORT_RANGE);

// Test invalid range (empty)
#[allow(clippy::reversed_empty_ranges)]
{
assert!(ShardAwarePortRange::new(49152..=49151).is_err());
}
// Test invalid range (too low)
assert!(ShardAwarePortRange::new(0..=65535).is_err());
}

#[test]
fn test_shard_of() {
setup_tracing();
Expand All @@ -202,36 +296,59 @@ mod tests {
#[test]
fn test_iter_source_ports_for_shard() {
setup_tracing();
let nr_shards = 4;
let max_port_num = 65535;
let min_port_num = (49152 + nr_shards - 1) / nr_shards * nr_shards;

let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12);

// Test for each shard
for shard in 0..nr_shards {
// Find lowest port for this shard
let mut lowest_port = min_port_num;
while lowest_port % nr_shards != shard {
lowest_port += 1;
}

// Find total number of ports the iterator should return
let possible_ports_number: usize =
((max_port_num - lowest_port) / nr_shards + 1).into();
fn test_helper<F, I>(nr_shards: u16, port_range: ShardAwarePortRange, get_iter: F)
where
F: Fn(&Sharder, Shard) -> I,
I: Iterator<Item = u16>,
{
let max_port_num = port_range.0.end();
let min_port_num = (port_range.0.start() + nr_shards - 1) / nr_shards * nr_shards;

let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12);

// Test for each shard
for shard in 0..nr_shards {
// Find lowest port for this shard
let mut lowest_port = min_port_num;
while lowest_port % nr_shards != shard {
lowest_port += 1;
}

// Find total number of ports the iterator should return
let possible_ports_number: usize =
((max_port_num - lowest_port) / nr_shards + 1).into();

let port_iter = get_iter(&sharder, shard.into());

let port_iter = sharder.iter_source_ports_for_shard(shard.into());
let mut returned_ports: HashSet<u16> = HashSet::new();
for port in port_iter {
assert!(!returned_ports.contains(&port)); // No port occurs two times
assert!(port % nr_shards == shard); // Each port maps to this shard

let mut returned_ports: HashSet<u16> = HashSet::new();
for port in port_iter {
assert!(!returned_ports.contains(&port)); // No port occurs two times
assert!(port % nr_shards == shard); // Each port maps to this shard
returned_ports.insert(port);
}

returned_ports.insert(port);
// Numbers of ports returned matches the expected value
assert_eq!(returned_ports.len(), possible_ports_number);
}
}

// Test of public method (with default range)
{
test_helper(
4,
ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
|sharder, shard| sharder.iter_source_ports_for_shard(shard),
);
}

// Numbers of ports returned matches the expected value
assert_eq!(returned_ports.len(), possible_ports_number);
// Test of private method with some custom port range.
{
let port_range = ShardAwarePortRange::new(21371..=42424).unwrap();
test_helper(4, port_range.clone(), |sharder, shard| {
sharder.iter_source_ports_for_shard_from_range(shard, &port_range)
});
}
}
}
Loading