diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index 2211e85d3..f09ee1a7f 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -17,7 +17,7 @@ use kafka_protocol::messages::{ TopicName, }; use kafka_protocol::protocol::{Builder, StrBytes}; -use node::{KafkaAddress, KafkaNode}; +use node::{ConnectionFactory, KafkaAddress, KafkaNode}; use rand::rngs::SmallRng; use rand::seq::{IteratorRandom, SliceRandom}; use rand::SeedableRng; @@ -118,7 +118,6 @@ impl TransformBuilder for KafkaSinkClusterBuilder { first_contact_points: self.first_contact_points.clone(), shotover_nodes: self.shotover_nodes.clone(), pushed_messages_tx: None, - connect_timeout: self.connect_timeout, read_timeout: self.read_timeout, nodes: vec![], nodes_shared: self.nodes_shared.clone(), @@ -126,7 +125,7 @@ impl TransformBuilder for KafkaSinkClusterBuilder { group_to_coordinator_broker: self.group_to_coordinator_broker.clone(), topics: self.topics.clone(), rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), - tls: self.tls.clone(), + connection_factory: ConnectionFactory::new(self.tls.clone(), self.connect_timeout), }) } @@ -165,7 +164,6 @@ pub struct KafkaSinkCluster { first_contact_points: Vec, shotover_nodes: Vec, pushed_messages_tx: Option>, - connect_timeout: Duration, read_timeout: Option, nodes: Vec, nodes_shared: Arc>>, @@ -173,7 +171,7 @@ pub struct KafkaSinkCluster { group_to_coordinator_broker: Arc>, topics: Arc>, rng: SmallRng, - tls: Option, + connection_factory: ConnectionFactory, } #[async_trait] @@ -345,9 +343,7 @@ impl KafkaSinkCluster { for node in &mut self.nodes { if node.broker_id == partition.leader_id { connection = Some( - node.get_connection(self.connect_timeout, &self.tls) - .await? - .clone(), + node.get_connection(&self.connection_factory).await?.clone(), ); } } @@ -359,7 +355,7 @@ impl KafkaSinkCluster { self.nodes .choose_mut(&mut self.rng) .unwrap() - .get_connection(self.connect_timeout, &self.tls) + .get_connection(&self.connection_factory) .await? .clone() } @@ -399,7 +395,7 @@ impl KafkaSinkCluster { .filter(|node| partition.replica_nodes.contains(&node.broker_id)) .choose(&mut self.rng) .unwrap() - .get_connection(self.connect_timeout, &self.tls) + .get_connection(&self.connection_factory) .await? .clone() } else { @@ -408,7 +404,7 @@ impl KafkaSinkCluster { self.nodes .choose_mut(&mut self.rng) .unwrap() - .get_connection(self.connect_timeout, &self.tls) + .get_connection(&self.connection_factory) .await? .clone() }; @@ -472,7 +468,7 @@ impl KafkaSinkCluster { .nodes .choose_mut(&mut self.rng) .unwrap() - .get_connection(self.connect_timeout, &self.tls) + .get_connection(&self.connection_factory) .await?; let (tx, rx) = oneshot::channel(); connection @@ -509,7 +505,7 @@ impl KafkaSinkCluster { .nodes .choose_mut(&mut self.rng) .unwrap() - .get_connection(self.connect_timeout, &self.tls) + .get_connection(&self.connection_factory) .await?; let (tx, rx) = oneshot::channel(); connection @@ -563,7 +559,7 @@ impl KafkaSinkCluster { .nodes .choose_mut(&mut self.rng) .unwrap() - .get_connection(self.connect_timeout, &self.tls) + .get_connection(&self.connection_factory) .await?; let (tx, rx) = oneshot::channel(); connection @@ -672,15 +668,13 @@ impl KafkaSinkCluster { let connection = if let Some(node) = self.nodes.iter_mut().find(|x| x.broker_id == *broker_id) { - node.get_connection(self.connect_timeout, &self.tls) - .await? - .clone() + node.get_connection(&self.connection_factory).await?.clone() } else { tracing::warn!("no known broker with id {broker_id:?}, routing message to a random node so that a NOT_CONTROLLER or similar error is returned to the client"); self.nodes .choose_mut(&mut self.rng) .unwrap() - .get_connection(self.connect_timeout, &self.tls) + .get_connection(&self.connection_factory) .await? .clone() }; @@ -704,11 +698,7 @@ impl KafkaSinkCluster { for node in &mut self.nodes { if let Some(broker_id) = self.group_to_coordinator_broker.get(&group_id) { if node.broker_id == *broker_id { - connection = Some( - node.get_connection(self.connect_timeout, &self.tls) - .await? - .clone(), - ); + connection = Some(node.get_connection(&self.connection_factory).await?.clone()); } } } @@ -719,7 +709,7 @@ impl KafkaSinkCluster { self.nodes .choose_mut(&mut self.rng) .unwrap() - .get_connection(self.connect_timeout, &self.tls) + .get_connection(&self.connection_factory) .await? .clone() } diff --git a/shotover/src/transforms/kafka/sink_cluster/node.rs b/shotover/src/transforms/kafka/sink_cluster/node.rs index 1cfb3c46f..afe059525 100644 --- a/shotover/src/transforms/kafka/sink_cluster/node.rs +++ b/shotover/src/transforms/kafka/sink_cluster/node.rs @@ -8,6 +8,36 @@ use kafka_protocol::protocol::StrBytes; use std::time::Duration; use tokio::io::split; +pub struct ConnectionFactory { + tls: Option, + connect_timeout: Duration, +} + +impl ConnectionFactory { + pub fn new(tls: Option, connect_timeout: Duration) -> Self { + ConnectionFactory { + tls, + connect_timeout, + } + } + + pub async fn create_connection(&self, kafka_address: &KafkaAddress) -> Result { + let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkCluster".to_owned()); + let address = (kafka_address.host.to_string(), kafka_address.port as u16); + if let Some(tls) = self.tls.as_ref() { + let tls_stream = tls.connect(self.connect_timeout, address).await?; + let (rx, tx) = split(tls_stream); + let connection = spawn_read_write_tasks(&codec, rx, tx); + Ok(connection) + } else { + let tcp_stream = tcp::tcp_stream(self.connect_timeout, address).await?; + let (rx, tx) = tcp_stream.into_split(); + let connection = spawn_read_write_tasks(&codec, rx, tx); + Ok(connection) + } + } +} + #[derive(Clone, PartialEq)] pub struct KafkaAddress { pub host: StrBytes, @@ -55,24 +85,14 @@ impl KafkaNode { pub async fn get_connection( &mut self, - connect_timeout: Duration, - tls: &Option, + connection_factory: &ConnectionFactory, ) -> Result<&Connection> { if self.connection.is_none() { - let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkCluster".to_owned()); - let address = ( - self.kafka_address.host.to_string(), - self.kafka_address.port as u16, + self.connection = Some( + connection_factory + .create_connection(&self.kafka_address) + .await?, ); - if let Some(tls) = tls.as_ref() { - let tls_stream = tls.connect(connect_timeout, address).await?; - let (rx, tx) = split(tls_stream); - self.connection = Some(spawn_read_write_tasks(&codec, rx, tx)); - } else { - let tcp_stream = tcp::tcp_stream(connect_timeout, address).await?; - let (rx, tx) = tcp_stream.into_split(); - self.connection = Some(spawn_read_write_tasks(&codec, rx, tx)); - } } Ok(self.connection.as_ref().unwrap()) }