diff --git a/elfo-network/src/config.rs b/elfo-network/src/config.rs index 6bc50e3..884e8a5 100644 --- a/elfo-network/src/config.rs +++ b/elfo-network/src/config.rs @@ -58,22 +58,46 @@ pub struct Config { pub idle_timeout: Duration, } +/// Preference. +#[derive(Debug, Clone, Copy, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Preference { + /// This is preferred. + Preferred, + + /// This is just supported. + Supported, +} + +/// Preferences for the compression algorithms +/// selection. +#[derive(Debug, Deserialize, Default, Clone)] +pub struct CompressionPreference { + /// LZ4 compression algorithm. + pub lz4: Option, +} + /// Compression settings. #[derive(Debug, Default, Deserialize, Clone)] pub struct CompressionConfig { - /// Compression algorithm. + /// Compression algorithms. + /// + /// Example: + /// ```toml + /// algorithms = { lz4 = "preferred" } + /// ``` + /// + /// Preferred implies supported. #[serde(default)] - pub algorithm: CompressionAlgorithm, + pub algorithms: CompressionPreference, } /// Compression algorithms. #[derive(Debug, Default, PartialEq, Eq, Deserialize, Clone)] pub enum CompressionAlgorithm { /// LZ4 with default compression level. - Lz4, - /// Compression disabled. #[default] - None, + Lz4, } fn default_ping_interval() -> Duration { diff --git a/elfo-network/src/discovery/mod.rs b/elfo-network/src/discovery/mod.rs index 05ba34e..d8bc106 100644 --- a/elfo-network/src/discovery/mod.rs +++ b/elfo-network/src/discovery/mod.rs @@ -12,7 +12,7 @@ use elfo_core::{ use crate::{ codec::format::{NetworkAddr, NetworkEnvelope, NetworkEnvelopePayload}, - config::{self, CompressionAlgorithm, Transport}, + config::{self, Transport}, node_map::{NodeInfo, NodeMap}, protocol::{internode, DataConnectionFailed, GroupInfo, HandleConnection}, socket::{self, ReadError, Socket}, @@ -135,12 +135,21 @@ impl Discovery { Ok(()) } + fn get_compression(&self) -> socket::Compression { + use socket::Algorithms as Algos; + + let mut compression = socket::Compression::empty(); + let cfg = &self.cfg.compression.algorithms; + + compression.toggle(Algos::LZ4, cfg.lz4); + + compression + } + fn get_capabilities(&self) -> socket::Capabilities { - let mut capabilities = socket::Capabilities::empty(); - if self.cfg.compression.algorithm == CompressionAlgorithm::Lz4 { - capabilities |= socket::Capabilities::LZ4; - } - capabilities + let compression = self.get_compression(); + + socket::Capabilities::new(compression) } fn on_update_config(&mut self) { diff --git a/elfo-network/src/socket/handshake.rs b/elfo-network/src/socket/handshake.rs index 6baf0e3..0f7d278 100644 --- a/elfo-network/src/socket/handshake.rs +++ b/elfo-network/src/socket/handshake.rs @@ -6,7 +6,9 @@ use tokio::io; use elfo_core::addr::{NodeLaunchId, NodeNo}; -use super::{raw, Capabilities}; +use crate::config::CompressionAlgorithm; + +use super::{raw, Algorithms, Capabilities}; const THIS_NODE_VERSION: u8 = 0; @@ -39,6 +41,18 @@ impl Handshake { } } + pub(super) fn choose_compression(&self) -> Option { + let compr = self.capabilities.compression().preferred(); + + // Actual selection logic is done in the [`Compression::intersection`], + // let's just check our preferences. + if compr.contains(Algorithms::LZ4) { + Some(CompressionAlgorithm::Lz4) + } else { + None + } + } + pub(super) fn as_bytes(&self) -> Result> { let mut buf = Cursor::new(Self::make_containing_buf()); diff --git a/elfo-network/src/socket/mod.rs b/elfo-network/src/socket/mod.rs index 51e4e41..57a2f28 100644 --- a/elfo-network/src/socket/mod.rs +++ b/elfo-network/src/socket/mod.rs @@ -13,7 +13,7 @@ use elfo_utils::likely; use self::idleness::{IdleTrack, IdleTracker}; use crate::{ codec::{decode::EnvelopeDetails, encode::EncodeError, format::NetworkEnvelope}, - config::Transport, + config::{CompressionAlgorithm, Preference, Transport}, frame::{ read::{FramedRead, FramedReadState, FramedReadStrategy}, write::{FrameState, FramedWrite, FramedWriteStrategy}, @@ -25,9 +25,192 @@ mod idleness; mod raw; bitflags::bitflags! { - #[derive(Clone, Copy)] - pub(crate) struct Capabilities: u32 { - const LZ4 = 1 << 8; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub(crate) struct Algorithms: u8 { + const LZ4 = 1; + // NB: Shift by 2: `const ZSTD = 1 << 2;`. + } +} + +impl Algorithms { + // Calculated to not require keeping in sync by hand. + // TODO: This is plain silly, `bitflags` should support that + // by default, given that they support iteration by names. + const NO_ELEMS: u32 = { + let all = Algorithms::all().bits(); + + // Number of unfilled zeros are odd, since: + // 1 (odd, LZ4) + 1 (odd, zero bit) + 1 (odd, used bit) = odd. + // This is because (ASSUMED) we have one at the lowest bit, so we just subtract + // it. + let unfilled_bits = all.leading_zeros() - 1; + + // Get number of unfilled cells. + // Cell = two flags (supported + preferred). + let unfilled_cells = unfilled_bits / 2; + // Total number of available cells. + let cells = 4; + + cells - unfilled_cells + }; +} + +// Layouts are specified from highest to lowest bits. + +/// Layout: +/// ```text +/// Bits +/// 6 2 +/// +---+-----+ +/// | R | Lz4 | +/// +---+-----+ +/// ``` +/// +/// `R` - reserved, any other mean specific compression algorithm. Layout +/// for specific compression algorithm: +/// ```text +/// Bits +/// 1 1 +/// +---+---+ +/// | S | P | +/// +---+---+ +/// ``` +/// +/// 1. `S` - the compression algorithm is supported. +/// 2. `P` - the compression algorithm is preferred, implies `S`. +#[derive(Debug, Clone, Copy)] +pub(crate) struct Compression(u8); + +impl Compression { + pub(crate) const fn empty() -> Self { + Self::new(Algorithms::empty(), Algorithms::empty()) + } + + pub(crate) const fn from_bits_truncate(v: u8) -> Self { + let supported = Algorithms::from_bits_truncate(v >> 1); + let preferred = Algorithms::from_bits_truncate(v); + + Self::new(supported, preferred) + } + + pub(crate) const fn new(supported: Algorithms, preferred: Algorithms) -> Self { + let preferred = preferred.bits(); + // Preferred implies supported. + let supported = supported.bits() | preferred; + + // 0 1 0 1 | Preferred + // 1 0 1 0 | Supported + // ------- + // 1 1 1 1 + let joined = (supported << 1) | preferred; + + Self(joined) + } + + pub(crate) fn toggle(&mut self, algos: Algorithms, pref: Option) { + let Some(pref) = pref else { + return; + }; + + let preferred = self.preferred(); + let supported = self.supported(); + + *self = match pref { + Preference::Preferred => Self::new(supported, preferred | algos), + Preference::Supported => Self::new(supported | algos, preferred), + }; + } + + pub(crate) const fn intersection(self, rhs: Self) -> Self { + let we_prefer = self.preferred(); + let we_support = self.supported(); + + let they_prefer = rhs.preferred(); + let they_support = rhs.supported(); + + // Let's see what we both support. + let both_support = we_support.intersection(they_support); + // And if we both prefer something. + let both_prefer = we_prefer.intersection(they_prefer); + + let preferred = if both_prefer.is_empty() { + // if we prefer something that is supported by us and + // the remote node, then it's a deal. + we_prefer.intersection(both_support) + } else { + // We both prefer something! + both_prefer + }; + + Self::new(both_support, preferred) + } +} + +impl Compression { + pub(crate) const fn bits(self) -> u8 { + self.0 + } + + pub(crate) const fn supported(self) -> Algorithms { + // `preferred` bits would be discarded. + Algorithms::from_bits_truncate(self.0 >> 1) + } + + pub(crate) const fn preferred(self) -> Algorithms { + // `supported` bits would be discarded. + Algorithms::from_bits_truncate(self.0) + } +} + +/// Layout: +/// ```text +/// Bits +/// 23 2 7 +/// +-----+-----+-----+ +/// | R | C | R | +/// +-----+-----+-----+ +/// ``` +/// +/// 1. C - compression +/// 2. R - reserved +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct Capabilities(u32); + +const fn compression_bits(whole: u32) -> u8 { + // 9 ones. + const MASK: u32 = (1 << 9) - 1; + let part = whole & MASK; + let compression = part >> (9 - Algorithms::NO_ELEMS * 2); + + compression as u8 +} + +impl Capabilities { + pub(crate) const fn new(compression: Compression) -> Self { + let compression = compression.bits() as u32; + let joined = compression << 7; + + Self(joined) + } + + pub(crate) const fn from_bits_truncate(bits: u32) -> Self { + let compression = Compression::from_bits_truncate(compression_bits(bits)); + Self::new(compression) + } + + pub(crate) const fn intersection(self, rhs: Self) -> Self { + let compr = self.compression().intersection(rhs.compression()); + Self::new(compr) + } +} + +impl Capabilities { + pub(crate) const fn compression(self) -> Compression { + Compression(compression_bits(self.0)) + } + + pub(crate) const fn bits(self) -> u32 { + self.0 } } @@ -49,13 +232,17 @@ pub(crate) struct Peer { impl Socket { fn new(raw: raw::Socket, handshake: handshake::Handshake) -> Self { // TODO: maybe do something with the version. + let algorithm = handshake.choose_compression(); - let (framed_read, framed_write) = if handshake.capabilities.contains(Capabilities::LZ4) { - (FramedRead::lz4(), FramedWrite::lz4(None)) + let (framed_read, framed_write) = if let Some(algo) = algorithm { + use CompressionAlgorithm as Algo; + + match algo { + Algo::Lz4 => (FramedRead::lz4(), FramedWrite::lz4(None)), + } } else { (FramedRead::none(), FramedWrite::none(None)) }; - let (idle_tracker, idle_track) = IdleTracker::new(); Self { @@ -301,10 +488,48 @@ mod tests { use super::*; + const EMPTY_CAPS: Capabilities = Capabilities::new(Compression::empty()); + #[message] #[derive(PartialEq)] struct TestSocketMessage(String); + #[test] + fn capabilities_format_is_compatible_with_020alpha17() { + let caps = Capabilities::new(Compression::new(Algorithms::LZ4, Algorithms::empty())); + let lz4_bit = caps.bits() & (1 << 8); + + assert_eq!(lz4_bit, 1 << 8); + } + + #[test] + fn compression_capabilities_encoded_right_way() { + #[track_caller] + fn case(create: (Algorithms, Algorithms), expect: (Algorithms, Algorithms)) { + let caps = Capabilities::new(Compression::new(create.0, create.1)); + let compr = caps.compression(); + + assert_eq!(compr.supported(), expect.0); + assert_eq!(compr.preferred(), expect.1); + + // Just in case we should decode same caps. + + let bits = caps.bits(); + let same_caps = Capabilities::from_bits_truncate(bits); + + assert_eq!(caps, same_caps); + } + + // Supported does not implies preferred. + case((Algorithms::LZ4, Algorithms::empty()), (Algorithms::LZ4, Algorithms::empty())); + + // Preferred implies supported. + case((Algorithms::empty(), Algorithms::LZ4), (Algorithms::LZ4, Algorithms::LZ4)); + + // Nothing ever happens. + case((Algorithms::empty(), Algorithms::empty()), (Algorithms::empty(), Algorithms::empty())); + } + fn feed_frame(client_socket: &mut Socket, envelope: &NetworkEnvelope) { for _ in 0..100 { client_socket @@ -394,23 +619,23 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tracing_test::traced_test] async fn tcp_read_write_no_framing() { - ensure_read_write("tcp://127.0.0.1:9200", Capabilities::empty()).await; + ensure_read_write("tcp://127.0.0.1:9200", EMPTY_CAPS).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tracing_test::traced_test] async fn tcp_read_write_lz4() { - ensure_read_write("tcp://127.0.0.1:9201", Capabilities::LZ4).await; + ensure_read_write( + "tcp://127.0.0.1:9201", + Capabilities::new(Compression::new(Algorithms::empty(), Algorithms::LZ4)), + ) + .await; } #[cfg(unix)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tracing_test::traced_test] async fn uds_read_write_no_framing() { - ensure_read_write( - "uds://test_uds_read_write_no_framing.socket", - Capabilities::empty(), - ) - .await; + ensure_read_write("uds://test_uds_read_write_no_framing.socket", EMPTY_CAPS).await; } }