diff --git a/config/topology.yaml b/config/topology.yaml index c0dce7ebf..04c45c60f 100644 --- a/config/topology.yaml +++ b/config/topology.yaml @@ -7,8 +7,12 @@ sources: connection_limit: 1000 chain_config: redis_chain: - - RedisCluster: - first_contact_points: ["redis://127.0.0.1:2220/", "redis://127.0.0.1:2221/", "redis://127.0.0.1:2222/", "redis://127.0.0.1:2223/", "redis://127.0.0.1:2224/", "redis://127.0.0.1:2225/"] + - PoolConnections: + name: "RedisCluster-subchain" + parallelism: 128 + chain: + - RedisCluster: + first_contact_points: [ "redis://127.0.0.1:2220/", "redis://127.0.0.1:2221/", "redis://127.0.0.1:2222/", "redis://127.0.0.1:2223/", "redis://127.0.0.1:2224/", "redis://127.0.0.1:2225/" ] named_topics: - testtopic source_to_chain_mapping: diff --git a/examples/redis-cluster/config.yaml b/examples/redis-cluster/config.yaml index 8ea87542d..bcf70835b 100644 --- a/examples/redis-cluster/config.yaml +++ b/examples/redis-cluster/config.yaml @@ -6,8 +6,12 @@ sources: listen_addr: "127.0.0.1:6379" chain_config: redis_chain: - - RedisCluster: - first_contact_points: ["redis://127.0.0.1:2220/", "redis://127.0.0.1:2221/", "redis://127.0.0.1:2222/", "redis://127.0.0.1:2223/", "redis://127.0.0.1:2224/", "redis://127.0.0.1:2225/"] + - PoolConnections: + name: "RedisCluster-subchain" + parallelism: 3 + chain: + - RedisCluster: + first_contact_points: ["redis://127.0.0.1:2220/", "redis://127.0.0.1:2221/", "redis://127.0.0.1:2222/", "redis://127.0.0.1:2223/", "redis://127.0.0.1:2224/", "redis://127.0.0.1:2225/"] named_topics: - testtopic source_to_chain_mapping: diff --git a/src/transforms/chain.rs b/src/transforms/chain.rs index 40ac3dd5b..c44005b47 100644 --- a/src/transforms/chain.rs +++ b/src/transforms/chain.rs @@ -1,18 +1,21 @@ use crate::config::topology::ChannelMessage; use crate::error::ChainResponse; use crate::transforms::{Transforms, Wrapper}; -use anyhow::{Result}; +use anyhow::Result; use bytes::Bytes; use evmap::ReadHandleFactory; use futures::FutureExt; use itertools::Itertools; use metrics::{counter, timing}; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot::Receiver as OneReceiver; +use tokio::sync::Mutex; +use tokio::time::timeout; use tokio::time::Duration; use tokio::time::Instant; -use tokio::time::{timeout}; use tracing::{info, trace}; type InnerChain = Vec; @@ -44,6 +47,8 @@ impl Clone for TransformChain { #[derive(Debug, Clone)] pub struct BufferedChain { send_handle: Sender, + #[cfg(test)] + pub count: Arc>, } impl BufferedChain { @@ -78,15 +83,24 @@ impl TransformChain { ) -> BufferedChain { let (tx, mut rx) = tokio::sync::mpsc::channel::(buffer_size); + // If this is not a test, this should get removed by the compiler + let mut count_outer: Arc> = Arc::new(Mutex::new(0 as usize)); + let mut count = count_outer.clone(); + // Even though we don't keep the join handle, this thread will wrap up once all corresponding senders have been dropped. let _jh = tokio::spawn(async move { let mut chain = self; + while let Some(ChannelMessage { return_chan, messages, }) = rx.recv().await { let name = chain.name.clone(); + if cfg!(test) { + let mut count = count.lock().await; + *count += 1; + } let future = async { match timeout_millis { None => Ok(chain.process_request(Wrapper::new(messages), name).await), @@ -121,10 +135,13 @@ impl TransformChain { } } } - }); - BufferedChain { send_handle: tx } + BufferedChain { + send_handle: tx, + #[cfg(test)] + count: count_outer, + } } pub fn new_no_shared_state(transform_list: Vec, name: String) -> Self { diff --git a/src/transforms/load_balance.rs b/src/transforms/load_balance.rs new file mode 100644 index 000000000..e747860c8 --- /dev/null +++ b/src/transforms/load_balance.rs @@ -0,0 +1,146 @@ +use crate::config::topology::TopicHolder; +use crate::error::ChainResponse; +use crate::transforms::chain::{BufferedChain, TransformChain}; +use crate::transforms::{ + build_chain_from_config, Transform, Transforms, TransformsConfig, TransformsFromConfig, Wrapper, +}; +use anyhow::Result; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] +pub struct ConnectionBalanceAndPoolConfig { + pub name: String, + pub parallelism: usize, + pub chain: Vec, +} + +#[async_trait] +impl TransformsFromConfig for ConnectionBalanceAndPoolConfig { + async fn get_source(&self, topics: &TopicHolder) -> Result { + let chain = build_chain_from_config(self.name.clone(), &self.chain, &topics).await?; + + Ok(Transforms::PoolConnections(ConnectionBalanceAndPool { + name: "PoolConnections", + active_connection: None, + parallelism: self.parallelism, + other_connections: Arc::new(Mutex::new(Vec::with_capacity(self.parallelism))), + chain_to_clone: chain, + })) + } +} + +#[derive(Debug)] +pub struct ConnectionBalanceAndPool { + pub name: &'static str, + pub active_connection: Option, + pub parallelism: usize, + pub other_connections: Arc>>, + pub chain_to_clone: TransformChain, +} + +impl Clone for ConnectionBalanceAndPool { + fn clone(&self) -> Self { + return ConnectionBalanceAndPool { + name: self.name.clone(), + active_connection: None, + parallelism: self.parallelism.clone(), + other_connections: self.other_connections.clone(), + chain_to_clone: self.chain_to_clone.clone(), + }; + } +} + +#[async_trait] +impl Transform for ConnectionBalanceAndPool { + async fn transform<'a>(&'a mut self, qd: Wrapper<'a>) -> ChainResponse { + if self.active_connection.is_none() { + let mut guard = self.other_connections.lock().await; + if guard.len() < self.parallelism { + let chain = self.chain_to_clone.clone().build_buffered_chain(5, None); + self.active_connection.replace(chain.clone()); + guard.push(chain); + } else { + //take the first available existing change and grab its reference + let top = guard.remove(0); + self.active_connection.replace(top.clone()); + // put the chain at the back of the list + guard.push(top); + } + } + if let Some(chain) = &mut self.active_connection { + return chain + .process_request(qd, "Connection Balance and Pooler".to_string()) + .await; + } + unreachable!() + } + + fn get_name(&self) -> &'static str { + self.name + } +} + +#[cfg(test)] +mod test { + use crate::config::topology::TopicHolder; + use crate::message::Messages; + use crate::transforms::chain::TransformChain; + use crate::transforms::load_balance::ConnectionBalanceAndPool; + use crate::transforms::test_transforms::ReturnerTransform; + use crate::transforms::{Transforms, Wrapper}; + use anyhow::Result; + use std::sync::Arc; + + #[tokio::test(threaded_scheduler)] + pub async fn test_balance() -> Result<()> { + let topic_holder = TopicHolder::get_test_holder(); + + let transform = Transforms::PoolConnections(ConnectionBalanceAndPool { + name: "", + active_connection: None, + parallelism: 3, + other_connections: Arc::new(Default::default()), + chain_to_clone: TransformChain::new( + vec![Transforms::RepeatMessage(Box::new(ReturnerTransform { + message: Messages::new(), + ok: true, + }))], + "child_test".to_string(), + topic_holder.global_map_handle.clone(), + topic_holder.global_tx.clone(), + ), + }); + + let mut chain = TransformChain::new( + vec![transform], + "test".to_string(), + topic_holder.global_map_handle.clone(), + topic_holder.global_tx.clone(), + ); + + for _ in 0..90 { + let r = chain + .clone() + .process_request(Wrapper::new(Messages::new()), "test_client".to_string()) + .await; + assert_eq!(r.is_ok(), true); + } + + match chain.chain.remove(0) { + Transforms::PoolConnections(p) => { + let guard = p.other_connections.lock().await; + assert_eq!(guard.len(), 3); + for bc in guard.iter() { + let guard = bc.count.lock().await; + assert_eq!(*guard, 30); + } + } + _ => panic!("whoops"), + } + + Ok(()) + } +} diff --git a/src/transforms/mod.rs b/src/transforms/mod.rs index 8f656e77c..7cd7fe8be 100644 --- a/src/transforms/mod.rs +++ b/src/transforms/mod.rs @@ -17,6 +17,7 @@ use crate::transforms::distributed::tunable_consistency_scatter::{ TunableConsistency, TunableConsistencyConfig, }; use crate::transforms::kafka_destination::{KafkaConfig, KafkaDestination}; +use crate::transforms::load_balance::{ConnectionBalanceAndPool, ConnectionBalanceAndPoolConfig}; use crate::transforms::lua::LuaFilterTransform; use crate::transforms::mpsc::{Buffer, BufferConfig, Tee, TeeConfig}; use crate::transforms::null::Null; @@ -41,6 +42,7 @@ pub mod cassandra_codec_destination; pub mod chain; pub mod distributed; pub mod kafka_destination; +pub mod load_balance; pub mod lua; pub mod mpsc; pub mod noop; @@ -77,6 +79,7 @@ pub enum Transforms { Printer(Printer), SequentialMap(SequentialMap), ParallelMap(ParallelMap), + PoolConnections(ConnectionBalanceAndPool), } impl Debug for Transforms { @@ -108,6 +111,7 @@ impl Transform for Transforms { Transforms::RedisCluster(r) => r.transform(qd).await, Transforms::SequentialMap(s) => s.transform(qd).await, Transforms::ParallelMap(s) => s.transform(qd).await, + Transforms::PoolConnections(s) => s.transform(qd).await, } } @@ -132,6 +136,7 @@ impl Transform for Transforms { Transforms::RedisCluster(r) => r.get_name(), Transforms::SequentialMap(s) => s.get_name(), Transforms::ParallelMap(s) => s.get_name(), + Transforms::PoolConnections(s) => s.get_name(), } } @@ -156,6 +161,7 @@ impl Transform for Transforms { Transforms::RedisCluster(r) => r.prep_transform_chain(t).await, Transforms::SequentialMap(s) => s.prep_transform_chain(t).await, Transforms::ParallelMap(s) => s.prep_transform_chain(t).await, + Transforms::PoolConnections(s) => s.prep_transform_chain(t).await, } } } @@ -176,6 +182,7 @@ pub enum TransformsConfig { Printer, SequentialMap(SequentialMapConfig), ParallelMap(ParallelMapConfig), + PoolConnections(ConnectionBalanceAndPoolConfig), } impl TransformsConfig { @@ -197,6 +204,7 @@ impl TransformsConfig { TransformsConfig::RedisCluster(r) => r.get_source(topics).await, TransformsConfig::SequentialMap(s) => s.get_source(topics).await, TransformsConfig::ParallelMap(s) => s.get_source(topics).await, + TransformsConfig::PoolConnections(s) => s.get_source(topics).await, } } }