diff --git a/shotover/src/connection.rs b/shotover/src/connection.rs index c27fc0667..d43e467c5 100644 --- a/shotover/src/connection.rs +++ b/shotover/src/connection.rs @@ -1,6 +1,6 @@ -//! This is the one true connection implementation that all other transforms + server.rs should be ported to. +//! All Sink transforms use SinkConnection for their outgoing connections. -use crate::codec::{CodecBuilder, CodecReadError, CodecWriteError, Direction}; +use crate::codec::{CodecBuilder, CodecReadError, CodecWriteError}; use crate::frame::Frame; use crate::message::{Message, MessageId, Messages}; use crate::tcp; @@ -18,22 +18,21 @@ use tokio_util::codec::{FramedRead, FramedWrite}; use tracing::error; use tracing::Instrument; -pub struct Connection { +pub struct SinkConnection { in_rx: mpsc::Receiver>, out_tx: mpsc::UnboundedSender>, connection_closed_rx: mpsc::Receiver, error: Option, - dummy_response_inserter: Option, + dummy_response_inserter: DummyResponseInserter, } -impl Connection { +impl SinkConnection { pub async fn new( host: A, codec_builder: C, tls: &Option, connect_timeout: Duration, - force_run_chain: Option>, - direction: Direction, + force_run_chain: Arc, ) -> anyhow::Result { let destination = tokio::net::lookup_host(&host).await?.next().unwrap(); let (in_tx, in_rx) = mpsc::channel::(10_000); @@ -68,15 +67,12 @@ impl Connection { ); } - let dummy_response_inserter = match direction { - Direction::Source => None, - Direction::Sink => Some(DummyResponseInserter { - dummy_requests: vec![], - pending_requests_count: 0, - }), + let dummy_response_inserter = DummyResponseInserter { + dummy_requests: vec![], + pending_requests_count: 0, }; - Ok(Connection { + Ok(SinkConnection { in_rx, out_tx, connection_closed_rx, @@ -93,9 +89,8 @@ impl Connection { /// Send messages. /// If there is a problem with the connection an error is returned. pub fn send(&mut self, mut messages: Vec) -> Result<(), ConnectionError> { - if let Some(dummy_response_inserter) = &mut self.dummy_response_inserter { - dummy_response_inserter.process_requests(&mut messages); - } + self.dummy_response_inserter.process_requests(&mut messages); + if let Some(error) = &self.error { Err(error.clone()) } else { @@ -110,20 +105,19 @@ impl Connection { Err(error.clone()) } else { // first process any immediately pending dummy responses - if let Some(dummy_response_inserter) = &mut self.dummy_response_inserter { - // ensure we include any received messages so we dont leave them hanging after using up a force_run_chain. - let mut messages = self.in_rx.try_recv().unwrap_or_default(); - dummy_response_inserter.process_responses(&mut messages); - if !messages.is_empty() { - return Ok(messages); - } + + // ensure we include any received messages so we dont leave them hanging after using up a force_run_chain. + let mut messages = self.in_rx.try_recv().unwrap_or_default(); + self.dummy_response_inserter + .process_responses(&mut messages); + if !messages.is_empty() { + return Ok(messages); } match self.in_rx.recv().await { Some(mut messages) => { - if let Some(dummy_response_inserter) = &mut self.dummy_response_inserter { - dummy_response_inserter.process_responses(&mut messages); - } + self.dummy_response_inserter + .process_responses(&mut messages); Ok(messages) } None => Err(self.set_get_error()), @@ -139,9 +133,8 @@ impl Connection { } else { match self.in_rx.try_recv() { Ok(mut messages) => { - if let Some(dummy_response_inserter) = &mut self.dummy_response_inserter { - dummy_response_inserter.process_responses(&mut messages); - } + self.dummy_response_inserter + .process_responses(&mut messages); Ok(messages) } Err(TryRecvError::Disconnected) => Err(self.set_get_error()), @@ -179,7 +172,7 @@ fn spawn_read_write_tasks< in_tx: mpsc::Sender, out_rx: UnboundedReceiver, out_tx: UnboundedSender, - force_run_chain: Option>, + force_run_chain: Arc, connection_closed_tx: mpsc::Sender, ) { let (decoder, encoder) = codec.build(); @@ -243,7 +236,7 @@ async fn reader_task::Decoder>, in_tx: mpsc::Sender, out_tx: UnboundedSender, - force_run_chain: Option>, + force_run_chain: Arc, ) -> Result<(), ConnectionError> { loop { tokio::select! { @@ -260,9 +253,8 @@ async fn reader_task { if let Err(err) = out_tx.send(messages) { diff --git a/shotover/src/server.rs b/shotover/src/server.rs index 1b3984b65..455f24c13 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -435,7 +435,6 @@ pub fn spawn_read_write_tasks< in_tx: mpsc::Sender, mut out_rx: UnboundedReceiver, out_tx: UnboundedSender, - force_run_chain: Option>, ) { let (decoder, encoder) = codec.build(); let mut reader = FramedRead::new(rx, decoder); @@ -468,9 +467,6 @@ pub fn spawn_read_write_tasks< // main task has shutdown, this task is no longer needed return; } - if let Some(force_run_chain) = force_run_chain.as_ref() { - force_run_chain.notify_one(); - } } Err(CodecReadError::RespondAndThenCloseConnection(messages)) => { if let Err(err) = out_tx.send(messages) { @@ -618,7 +614,6 @@ impl Handler { in_tx, out_rx, out_tx.clone(), - None, ); } else { let (rx, tx) = stream.into_split(); @@ -629,7 +624,6 @@ impl Handler { in_tx, out_rx, out_tx.clone(), - None, ); }; } diff --git a/shotover/src/transforms/cassandra/sink_single.rs b/shotover/src/transforms/cassandra/sink_single.rs index 997cc0309..c4eba329d 100644 --- a/shotover/src/transforms/cassandra/sink_single.rs +++ b/shotover/src/transforms/cassandra/sink_single.rs @@ -1,5 +1,5 @@ use crate::codec::{cassandra::CassandraCodecBuilder, CodecBuilder, Direction}; -use crate::connection::Connection; +use crate::connection::SinkConnection; use crate::frame::cassandra::CassandraMetadata; use crate::message::{Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; @@ -111,7 +111,7 @@ impl TransformBuilder for CassandraSinkSingleBuilder { pub struct CassandraSinkSingle { version: Option, address: String, - connection: Option, + connection: Option, failed_requests: Counter, tls: Option, connect_timeout: Duration, @@ -144,13 +144,12 @@ impl CassandraSinkSingle { if self.connection.is_none() { trace!("creating outbound connection {:?}", self.address); self.connection = Some( - Connection::new( + SinkConnection::new( self.address.clone(), self.codec_builder.clone(), &self.tls, self.connect_timeout, - Some(self.force_run_chain.clone()), - Direction::Sink, + self.force_run_chain.clone(), ) .await?, ); diff --git a/shotover/src/transforms/kafka/common.rs b/shotover/src/transforms/kafka/common.rs deleted file mode 100644 index 34b9e1e3d..000000000 --- a/shotover/src/transforms/kafka/common.rs +++ /dev/null @@ -1,24 +0,0 @@ -use crate::{frame::Frame, message::Message, transforms::util::Response}; -use kafka_protocol::messages::ProduceRequest; -use tokio::sync::oneshot; - -/// when a produce request has acks set to 0, the kafka instance will return no response. -/// In order to maintain shotover transform invariants we need to return a dummy response instead. -pub fn produce_channel( - produce: &ProduceRequest, -) -> ( - Option>, - oneshot::Receiver, -) { - let (tx, rx) = oneshot::channel(); - let return_chan = if produce.acks == 0 { - tx.send(Response { - response: Ok(Message::from_frame(Frame::Dummy)), - }) - .unwrap(); - None - } else { - Some(tx) - }; - (return_chan, rx) -} diff --git a/shotover/src/transforms/kafka/mod.rs b/shotover/src/transforms/kafka/mod.rs index e3c34d5ad..75b708060 100644 --- a/shotover/src/transforms/kafka/mod.rs +++ b/shotover/src/transforms/kafka/mod.rs @@ -1,3 +1,2 @@ -mod common; pub mod sink_cluster; pub mod sink_single; diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index a8a76b2ab..f31d8cd0a 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -1,9 +1,8 @@ -use super::common::produce_channel; +use crate::connection::SinkConnection; use crate::frame::kafka::{KafkaFrame, RequestBody, ResponseBody}; use crate::frame::Frame; use crate::message::{Message, MessageIdMap, Messages}; use crate::tls::{TlsConnector, TlsConnectorConfig}; -use crate::transforms::util::{Request, Response}; use crate::transforms::{Transform, TransformBuilder, TransformContextBuilder, Wrapper}; use crate::transforms::{TransformConfig, TransformContextConfig}; use anyhow::{anyhow, Result}; @@ -23,16 +22,18 @@ use rand::rngs::SmallRng; use rand::seq::{IteratorRandom, SliceRandom}; use rand::SeedableRng; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::hash::Hasher; use std::net::SocketAddr; use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::time::Duration; -use tokio::sync::{oneshot, RwLock}; +use tokio::sync::RwLock; use tokio::time::timeout; use uuid::Uuid; +mod node; + #[derive(thiserror::Error, Debug)] enum FindCoordinatorError { #[error("Coordinator not available")] @@ -41,8 +42,6 @@ enum FindCoordinatorError { Unrecoverable(#[from] anyhow::Error), } -mod node; - #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct KafkaSinkClusterConfig { @@ -149,7 +148,7 @@ impl KafkaSinkClusterBuilder { } impl TransformBuilder for KafkaSinkClusterBuilder { - fn build(&self, _transform_context: TransformContextBuilder) -> Box { + fn build(&self, transform_context: TransformContextBuilder) -> Box { Box::new(KafkaSinkCluster { first_contact_points: self.first_contact_points.clone(), shotover_nodes: self.shotover_nodes.clone(), @@ -162,10 +161,16 @@ impl TransformBuilder for KafkaSinkClusterBuilder { topic_by_id: self.topic_by_id.clone(), rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), auth_complete: false, - connection_factory: ConnectionFactory::new(self.tls.clone(), self.connect_timeout), + connection_factory: ConnectionFactory::new( + self.tls.clone(), + self.connect_timeout, + transform_context.force_run_chain, + ), first_contact_node: None, + control_connection: None, fetch_session_id_to_broker: HashMap::new(), fetch_request_destinations: Default::default(), + pending_requests: Default::default(), }) } @@ -214,11 +219,37 @@ pub struct KafkaSinkCluster { auth_complete: bool, connection_factory: ConnectionFactory, first_contact_node: Option, + control_connection: Option, // its not clear from the docs if this cache needs to be accessed cross connection: // https://cwiki.apache.org/confluence/display/KAFKA/KIP-227%3A+Introduce+Incremental+FetchRequests+to+Increase+Partition+Scalability fetch_session_id_to_broker: HashMap, // for use with fetch_session_id_to_broker fetch_request_destinations: MessageIdMap, + /// Maintains the state of each request/response pair. + /// Ordering must be maintained to ensure responses match up with their request. + pending_requests: VecDeque, +} + +/// State of a Request/Response is maintained by this enum. +/// The state progresses from Routed -> Sent -> Received +#[derive(Debug)] +enum PendingRequest { + /// A route has been determined for this request but it has not yet been sent. + Routed { + destination: BrokerId, + request: Message, + }, + /// The request has been sent to the specified broker and we are now awaiting a response from that broker. + Sent { + destination: BrokerId, + /// How many responses must be received before this respose is received. + /// When this is 0 the next response from the broker will be for this request. + /// This field must be manually decremented when another response for this broker comes through. + index: usize, + }, + /// The broker has returned a Response to this request. + /// Returning this response may be delayed until a response to an earlier request comes back from another broker. + Received { response: Message }, } #[async_trait] @@ -247,30 +278,54 @@ impl Transform for KafkaSinkCluster { self.nodes = nodes?; } - self.update_local_nodes().await; - let mut find_coordinator_requests = vec![]; - for (index, request) in requests_wrapper.requests.iter_mut().enumerate() { - if let Some(Frame::Kafka(KafkaFrame::Request { - body: RequestBody::FindCoordinator(find_coordinator), - .. - })) = request.frame() - { - find_coordinator_requests.push(FindCoordinator { - index, - key: find_coordinator.key.clone(), - key_type: find_coordinator.key_type, - }); + + let mut responses = if requests_wrapper.requests.is_empty() { + // there are no requests, so no point sending any, but we should check for any responses without awaiting + self.recv_responses_no_await()? + } else { + self.update_local_nodes().await; + + for (index, request) in requests_wrapper.requests.iter_mut().enumerate() { + if let Some(Frame::Kafka(KafkaFrame::Request { + body: RequestBody::FindCoordinator(find_coordinator), + .. + })) = request.frame() + { + find_coordinator_requests.push(FindCoordinator { + index, + key: find_coordinator.key.clone(), + key_type: find_coordinator.key_type, + }); + } } - } - let responses = self.send_requests(requests_wrapper.requests).await?; - self.receive_responses(&find_coordinator_requests, responses) - .await + let request_count = requests_wrapper.requests.len(); + self.route_requests(requests_wrapper.requests).await?; + self.send_requests().await?; + self.recv_responses(request_count).await? + }; + + self.process_responses(&find_coordinator_requests, &mut responses) + .await?; + Ok(responses) } } impl KafkaSinkCluster { + /// Send a request over the control connection and immediately receive the response. + /// Since we always await the response we know for sure that the response will not get mixed up with any other incoming responses. + async fn control_send_receive(&mut self, requests: Message) -> Result { + if self.control_connection.is_none() { + let address = &self.nodes.choose(&mut self.rng).unwrap().kafka_address; + self.control_connection = + Some(self.connection_factory.create_connection(address).await?); + } + let connection = self.control_connection.as_mut().unwrap(); + connection.send(vec![requests])?; + Ok(connection.recv().await?.remove(0)) + } + fn store_topic(&self, topics: &mut Vec, topic: TopicName) { if self.topic_by_name.get(&topic).is_none() && !topics.contains(&topic) { topics.push(topic); @@ -356,6 +411,7 @@ impl KafkaSinkCluster { } } + async fn route_requests(&mut self, mut requests: Vec) -> Result<()> { let mut topics = vec![]; let mut groups = vec![]; for request in &mut requests { @@ -449,34 +505,22 @@ impl KafkaSinkCluster { as usize]; for node in &mut self.nodes { if node.broker_id == partition.leader_id { - connection = Some( - node.get_connection(&self.connection_factory).await?.clone(), - ); + connection = Some(node.broker_id); } } } - let connection = match connection { + let destination = match connection { Some(connection) => connection, None => { tracing::warn!("no known partition leader for {topic_name:?}, routing message to a random node so that a NOT_LEADER_OR_FOLLOWER or similar error is returned to the client"); - self.nodes - .choose_mut(&mut self.rng) - .unwrap() - .get_connection(&self.connection_factory) - .await? - .clone() + self.nodes.choose(&mut self.rng).unwrap().broker_id } }; - let (return_chan, rx) = produce_channel(produce); - - connection - .send(Request { - message, - return_chan, - }) - .map_err(|_| anyhow!("Failed to send"))?; - results.push(rx); + self.pending_requests.push_back(PendingRequest::Routed { + destination, + request: message, + }) } // route to random partition replica @@ -484,7 +528,7 @@ impl KafkaSinkCluster { body: RequestBody::Fetch(fetch), .. })) => { - let node = if fetch.session_id == 0 { + let destination = if fetch.session_id == 0 { // assume that all topics in this message have the same routing requirements let topic = fetch .topics @@ -503,7 +547,7 @@ impl KafkaSinkCluster { topic_meta = topic_by_name.as_deref(); } - let node = if let Some(topic_meta) = topic_meta { + let destination = if let Some(topic_meta) = topic_meta { let partition_index = topic .partitions .first() @@ -519,41 +563,33 @@ impl KafkaSinkCluster { }) .choose(&mut self.rng) .unwrap() + .broker_id } else { let partition_len = topic_meta.partitions.len(); tracing::warn!("no known partition replica for {topic_name:?} at partition index {partition_index} out of {partition_len} partitions, routing message to a random node so that a NOT_LEADER_OR_FOLLOWER or similar error is returned to the client"); - self.nodes.choose_mut(&mut self.rng).unwrap() + self.nodes.choose(&mut self.rng).unwrap().broker_id } } else { tracing::warn!("no known partition replica for {topic_name:?}, routing message to a random node so that a NOT_LEADER_OR_FOLLOWER or similar error is returned to the client"); - self.nodes.choose_mut(&mut self.rng).unwrap() + self.nodes.choose(&mut self.rng).unwrap().broker_id }; self.fetch_request_destinations - .insert(message.id(), node.broker_id); - node + .insert(message.id(), destination); + destination } else { // route via session id if let Some(destination) = self.fetch_session_id_to_broker.get(&fetch.session_id) { - self.nodes - .iter_mut() - .find(|x| &x.broker_id == destination) - .unwrap() + *destination } else { todo!() } }; - let connection = node.get_connection(&self.connection_factory).await?.clone(); - - let (tx, rx) = oneshot::channel(); - connection - .send(Request { - message, - return_chan: Some(tx), - }) - .map_err(|_| anyhow!("Failed to send"))?; - results.push(rx); + self.pending_requests.push_back(PendingRequest::Routed { + destination, + request: message, + }) } // route to group coordinator @@ -562,14 +598,14 @@ impl KafkaSinkCluster { .. })) => { let group_id = heartbeat.group_id.clone(); - results.push(self.route_to_coordinator(message, group_id).await?); + self.route_to_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::SyncGroup(sync_group), .. })) => { let group_id = sync_group.group_id.clone(); - results.push(self.route_to_coordinator(message, group_id).await?); + self.route_to_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::OffsetFetch(offset_fetch), @@ -586,76 +622,58 @@ impl KafkaSinkCluster { // For now just pick the first group as that is sufficient for the simple cases. offset_fetch.groups.first().unwrap().group_id.clone() }; - results.push(self.route_to_coordinator(message, group_id).await?); + self.route_to_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::JoinGroup(join_group), .. })) => { let group_id = join_group.group_id.clone(); - results.push(self.route_to_coordinator(message, group_id).await?); + self.route_to_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::DeleteGroups(groups), .. })) => { let group_id = groups.groups_names.first().unwrap().clone(); - results.push(self.route_to_coordinator(message, group_id).await?); + self.route_to_coordinator(message, group_id); } // route to controller broker Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::CreateTopics(_), .. - })) => results.push(self.route_to_controller(message).await?), - + })) => self.route_to_controller(message), // route to random node _ => { - let connection = self - .nodes - .choose_mut(&mut self.rng) - .unwrap() - .get_connection(&self.connection_factory) - .await?; - let (tx, rx) = oneshot::channel(); - connection - .send(Request { - message, - return_chan: Some(tx), - }) - .map_err(|_| anyhow!("Failed to send"))?; - results.push(rx); + let destination = self.nodes.choose(&mut self.rng).unwrap().broker_id; + self.pending_requests.push_back(PendingRequest::Routed { + destination, + request: message, + }); } } } - Ok(results) + Ok(()) } - async fn route_to_first_contact_node( - &mut self, - message: Message, - return_chan: Option>, - ) -> Result<()> { - let node = if let Some(first_contact_node) = &self.first_contact_node { + fn route_to_first_contact_node(&mut self, message: Message) { + let destination = if let Some(first_contact_node) = &self.first_contact_node { self.nodes .iter_mut() .find(|node| node.kafka_address == *first_contact_node) .unwrap() + .broker_id } else { let node = self.nodes.get_mut(0).unwrap(); self.first_contact_node = Some(node.kafka_address.clone()); - node + node.broker_id }; - node.get_connection(&self.connection_factory) - .await? - .send(Request { - message, - return_chan, - }) - .map_err(|_| anyhow!("Failed to send"))?; - - Ok(()) + self.pending_requests.push_back(PendingRequest::Routed { + destination, + request: message, + }); } async fn find_coordinator_of_group( @@ -678,20 +696,7 @@ impl KafkaSinkCluster { ), })); - let connection = self - .nodes - .choose_mut(&mut self.rng) - .unwrap() - .get_connection(&self.connection_factory) - .await?; - let (tx, rx) = oneshot::channel(); - connection - .send(Request { - message: request, - return_chan: Some(tx), - }) - .map_err(|_| anyhow!("Failed to send"))?; - let mut response = rx.await.unwrap().response.unwrap(); + let mut response = self.control_send_receive(request).await?; match response.frame() { Some(Frame::Kafka(KafkaFrame::Response { body: ResponseBody::FindCoordinator(coordinator), @@ -739,36 +744,172 @@ impl KafkaSinkCluster { ), })); - let connection = self - .nodes - .choose_mut(&mut self.rng) - .unwrap() - .get_connection(&self.connection_factory) - .await?; - let (tx, rx) = oneshot::channel(); - connection - .send(Request { - message: request, - return_chan: Some(tx), - }) - .map_err(|_| anyhow!("Failed to send"))?; - Ok(rx.await.unwrap().response.unwrap()) + self.control_send_receive(request).await } - async fn receive_responses( + /// Convert all PendingRequest::Routed into PendingRequest::Sent + async fn send_requests(&mut self) -> Result<()> { + struct RoutedRequests { + requests: Vec, + already_pending: usize, + } + + let mut broker_to_routed_requests: HashMap = HashMap::new(); + for i in 0..self.pending_requests.len() { + if let PendingRequest::Routed { destination, .. } = &self.pending_requests[i] { + let routed_requests = broker_to_routed_requests + .entry(*destination) + .or_insert_with(|| RoutedRequests { + requests: vec![], + already_pending: self + .pending_requests + .iter() + .filter(|pending_request| { + if let PendingRequest::Sent { + destination: check_destination, + .. + } = pending_request + { + check_destination == destination + } else { + false + } + }) + .count(), + }); + let mut value = PendingRequest::Sent { + destination: *destination, + index: routed_requests.requests.len() + routed_requests.already_pending, + }; + std::mem::swap(&mut self.pending_requests[i], &mut value); + if let PendingRequest::Routed { request, .. } = value { + routed_requests.requests.push(request); + } + } + } + + for (destination, requests) in broker_to_routed_requests { + self.nodes + .iter_mut() + .find(|x| x.broker_id == destination) + .unwrap() + .get_connection(&self.connection_factory) + .await? + .send(requests.requests)?; + } + + Ok(()) + } + + /// Convert some PendingRequest::Sent into PendingRequest::Received + fn recv_responses_no_await(&mut self) -> Result> { + for node in &mut self.nodes { + if let Some(connection) = node.get_connection_if_open() { + if let Ok(responses) = connection.try_recv() { + for response in responses { + let mut response = Some(response); + for pending_request in &mut self.pending_requests { + if let PendingRequest::Sent { destination, index } = pending_request { + if *destination == node.broker_id { + if *index == 0 { + *pending_request = PendingRequest::Received { + response: response.take().unwrap(), + }; + } else { + *index -= 1; + } + } + } + } + } + } + } + } + + let mut responses = vec![]; + while let Some(pending_request) = self.pending_requests.front() { + if let PendingRequest::Received { .. } = pending_request { + // The next response we are waiting on has been received, add it to responses + if let Some(PendingRequest::Received { response }) = + self.pending_requests.pop_front() + { + responses.push(response); + } + } else { + // The pending_request is not received, we need to break to maintain response ordering. + break; + } + } + Ok(responses) + } + + /// Convert some PendingRequest::Sent into PendingRequest::Received + // TODO: This function duplicates a lot of logic from recv_responses_no_await, + // but I plan to delete this function in the near future, + // so there is no point factoring out the common bits. + async fn recv_responses(&mut self, request_count: usize) -> Result> { + let mut responses = vec![]; + while responses.len() < request_count { + for node in &mut self.nodes { + let broker_id = node.broker_id; + if let Some(connection) = node.get_connection_if_open() { + let connection_is_pending = self.pending_requests.iter().any(|x| { + if let PendingRequest::Sent { destination, .. } = x { + *destination == broker_id + } else { + false + } + }); + if connection_is_pending { + let responses = if let Some(read_timeout) = self.read_timeout { + timeout(read_timeout, connection.recv()).await? + } else { + connection.recv().await + }?; + for response in responses { + let mut response = Some(response); + for pending_request in &mut self.pending_requests { + if let PendingRequest::Sent { destination, index } = pending_request + { + if *destination == node.broker_id { + if *index == 0 { + *pending_request = PendingRequest::Received { + response: response.take().unwrap(), + }; + } else { + *index -= 1; + } + } + } + } + } + } + } + } + + while let Some(pending_request) = self.pending_requests.front() { + if let PendingRequest::Received { .. } = pending_request { + // The next response we are waiting on has been received, add it to responses + if let Some(PendingRequest::Received { response }) = + self.pending_requests.pop_front() + { + responses.push(response); + } + } else { + // The pending_request is not received, we need to break to maintain response ordering. + break; + } + } + } + Ok(responses) + } + + async fn process_responses( &mut self, find_coordinator_requests: &[FindCoordinator], - responses: Vec>, - ) -> Result> { - // TODO: since kafka will never send requests out of order I wonder if it would be faster to use an mpsc instead of a oneshot or maybe just directly run the sending/receiving here? - let mut responses = if let Some(read_timeout) = self.read_timeout { - timeout(read_timeout, read_responses(responses)).await? - } else { - read_responses(responses).await - }?; - + responses: &mut [Message], + ) -> Result<()> { // TODO: Handle errors like NOT_COORDINATOR by removing element from self.topics and self.coordinator_broker_id - for (i, response) in responses.iter_mut().enumerate() { let request_id = response.request_id(); match response.frame() { @@ -819,73 +960,40 @@ impl KafkaSinkCluster { } } - Ok(responses) + Ok(()) } - async fn route_to_controller( - &mut self, - message: Message, - ) -> Result> { + fn route_to_controller(&mut self, message: Message) { let broker_id = self.controller_broker.get().unwrap(); - let connection = if let Some(node) = + let destination = if let Some(node) = self.nodes.iter_mut().find(|x| x.broker_id == *broker_id) { - node.get_connection(&self.connection_factory).await?.clone() + node.broker_id } else { tracing::warn!("no known broker with id {broker_id:?}, routing message to a random node so that a NOT_CONTROLLER or similar error is returned to the client"); - self.nodes - .choose_mut(&mut self.rng) - .unwrap() - .get_connection(&self.connection_factory) - .await? - .clone() + self.nodes.choose(&mut self.rng).unwrap().broker_id }; - let (tx, rx) = oneshot::channel(); - connection - .send(Request { - message, - return_chan: Some(tx), - }) - .map_err(|_| anyhow!("Failed to send"))?; - Ok(rx) + self.pending_requests.push_back(PendingRequest::Routed { + destination, + request: message, + }); } - async fn route_to_coordinator( - &mut self, - message: Message, - group_id: GroupId, - ) -> Result> { - let mut connection = None; - for node in &mut self.nodes { - if let Some(broker_id) = self.group_to_coordinator_broker.get(&group_id) { - if node.broker_id == *broker_id { - connection = Some(node.get_connection(&self.connection_factory).await?.clone()); - break; - } - } - } - let connection = match connection { - Some(connection) => connection, + fn route_to_coordinator(&mut self, message: Message, group_id: GroupId) { + let destination = self.group_to_coordinator_broker.get(&group_id); + let destination = match destination { + Some(destination) => *destination, None => { tracing::warn!("no known coordinator for {group_id:?}, routing message to a random node so that a NOT_COORDINATOR or similar error is returned to the client"); - self.nodes - .choose_mut(&mut self.rng) - .unwrap() - .get_connection(&self.connection_factory) - .await? - .clone() + self.nodes.choose(&mut self.rng).unwrap().broker_id } }; - let (tx, rx) = oneshot::channel(); - connection - .send(Request { - message, - return_chan: Some(tx), - }) - .map_err(|_| anyhow!("Failed to send"))?; - Ok(rx) + self.pending_requests.push_back(PendingRequest::Routed { + destination, + request: message, + }); } async fn process_metadata_response(&mut self, metadata: &MetadataResponse) { @@ -1142,14 +1250,6 @@ impl KafkaSinkCluster { } } -async fn read_responses(responses: Vec>) -> Result { - let mut result = Vec::with_capacity(responses.len()); - for response in responses { - result.push(response.await.unwrap().response?); - } - Ok(result) -} - fn hash_partition(topic_id: Uuid, partition_index: i32) -> usize { let mut hasher = xxhash_rust::xxh3::Xxh3::new(); hasher.write(topic_id.as_bytes()); diff --git a/shotover/src/transforms/kafka/sink_cluster/node.rs b/shotover/src/transforms/kafka/sink_cluster/node.rs index d3ad0b1b1..83ef75cbc 100644 --- a/shotover/src/transforms/kafka/sink_cluster/node.rs +++ b/shotover/src/transforms/kafka/sink_cluster/node.rs @@ -1,30 +1,34 @@ use crate::codec::{kafka::KafkaCodecBuilder, CodecBuilder, Direction}; +use crate::connection::SinkConnection; use crate::message::Message; -use crate::tcp; use crate::tls::TlsConnector; -use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection}; -use crate::transforms::util::Request; use anyhow::{anyhow, Result}; use kafka_protocol::messages::BrokerId; use kafka_protocol::protocol::StrBytes; +use std::sync::Arc; use std::time::Duration; -use tokio::io::split; -use tokio::sync::oneshot; +use tokio::sync::Notify; pub struct ConnectionFactory { tls: Option, connect_timeout: Duration, handshake_message: Option, auth_message: Option, + force_run_chain: Arc, } impl ConnectionFactory { - pub fn new(tls: Option, connect_timeout: Duration) -> Self { + pub fn new( + tls: Option, + connect_timeout: Duration, + force_run_chain: Arc, + ) -> Self { ConnectionFactory { tls, connect_timeout, handshake_message: None, auth_message: None, + force_run_chain, } } @@ -36,46 +40,29 @@ impl ConnectionFactory { self.auth_message = Some(message); } - pub async fn create_connection(&self, kafka_address: &KafkaAddress) -> Result { + pub async fn create_connection(&self, kafka_address: &KafkaAddress) -> Result { let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkCluster".to_owned()); - let address = (kafka_address.host.to_string(), kafka_address.port as u16); - if let Some(tls) = self.tls.as_ref() { - let tls_stream = tls.connect(self.connect_timeout, address).await?; - let (rx, tx) = split(tls_stream); - let connection = spawn_read_write_tasks(&codec, rx, tx); - Ok(connection) - } else { - let tcp_stream = tcp::tcp_stream(self.connect_timeout, address).await?; - let (rx, tx) = tcp_stream.into_split(); - let connection = spawn_read_write_tasks(&codec, rx, tx); - - if let Some(message) = self.auth_message.as_ref() { - let handshake_msg = self.handshake_message.as_ref().unwrap(); - - let (tx, rx) = oneshot::channel(); - connection - .send(Request { - message: handshake_msg.clone(), - return_chan: Some(tx), - }) - .map_err(|_| anyhow!("Failed to send"))?; - - let _response = rx.await.map_err(|_| anyhow!("Failed to receive"))?; - - let (tx, rx) = oneshot::channel(); - connection - .send(Request { - message: message.clone(), - return_chan: Some(tx), - }) - .map_err(|_| anyhow!("Failed to send"))?; - - let _response = rx.await.map_err(|_| anyhow!("Failed to receive"))?; + let mut connection = SinkConnection::new( + address, + codec, + &self.tls, + self.connect_timeout, + self.force_run_chain.clone(), + ) + .await?; + + if let Some(auth_message) = self.auth_message.as_ref() { + let handshake_msg = self.handshake_message.as_ref().unwrap(); + + connection.send(vec![handshake_msg.clone(), auth_message.clone()])?; + let mut received_count = 0; + while received_count < 2 { + received_count += connection.recv().await?.len(); } - - Ok(connection) } + + Ok(connection) } } @@ -108,12 +95,22 @@ impl KafkaAddress { } } -#[derive(Clone, Debug)] pub struct KafkaNode { pub broker_id: BrokerId, pub rack: Option, pub kafka_address: KafkaAddress, - connection: Option, + connection: Option, +} + +impl Clone for KafkaNode { + fn clone(&self) -> Self { + Self { + broker_id: self.broker_id, + rack: self.rack.clone(), + kafka_address: self.kafka_address.clone(), + connection: None, + } + } } impl KafkaNode { @@ -129,7 +126,7 @@ impl KafkaNode { pub async fn get_connection( &mut self, connection_factory: &ConnectionFactory, - ) -> Result<&Connection> { + ) -> Result<&mut SinkConnection> { if self.connection.is_none() { self.connection = Some( connection_factory @@ -137,6 +134,10 @@ impl KafkaNode { .await?, ); } - Ok(self.connection.as_ref().unwrap()) + Ok(self.connection.as_mut().unwrap()) + } + + pub fn get_connection_if_open(&mut self) -> Option<&mut SinkConnection> { + self.connection.as_mut() } } diff --git a/shotover/src/transforms/kafka/sink_single.rs b/shotover/src/transforms/kafka/sink_single.rs index a9d035cfa..9cbd988c0 100644 --- a/shotover/src/transforms/kafka/sink_single.rs +++ b/shotover/src/transforms/kafka/sink_single.rs @@ -1,5 +1,5 @@ use crate::codec::{kafka::KafkaCodecBuilder, CodecBuilder, Direction}; -use crate::connection::Connection; +use crate::connection::SinkConnection; use crate::frame::kafka::{KafkaFrame, RequestBody, ResponseBody}; use crate::frame::Frame; use crate::message::Messages; @@ -96,7 +96,7 @@ impl TransformBuilder for KafkaSinkSingleBuilder { pub struct KafkaSinkSingle { address_port: u16, - connection: Option, + connection: Option, connect_timeout: Duration, read_timeout: Option, tls: Option, @@ -114,13 +114,12 @@ impl Transform for KafkaSinkSingle { let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkSingle".to_owned()); let address = (requests_wrapper.local_addr.ip(), self.address_port); self.connection = Some( - Connection::new( + SinkConnection::new( address, codec, &self.tls, self.connect_timeout, - Some(self.force_run_chain.clone()), - Direction::Sink, + self.force_run_chain.clone(), ) .await?, ); diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index b3de9aeec..f56c3aa3e 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -1,5 +1,5 @@ use crate::codec::{CodecBuilder, Direction}; -use crate::connection::Connection; +use crate::connection::SinkConnection; use crate::frame::{Frame, RedisFrame}; use crate::message::Messages; use crate::tls::{TlsConnector, TlsConnectorConfig}; @@ -94,7 +94,7 @@ impl TransformBuilder for RedisSinkSingleBuilder { pub struct RedisSinkSingle { address: String, tls: Option, - connection: Option, + connection: Option, failed_requests: Counter, connect_timeout: Duration, force_run_chain: Arc, @@ -110,13 +110,12 @@ impl Transform for RedisSinkSingle { if self.connection.is_none() { let codec = RedisCodecBuilder::new(Direction::Sink, "RedisSinkSingle".to_owned()); self.connection = Some( - Connection::new( + SinkConnection::new( &self.address, codec, &self.tls, self.connect_timeout, - Some(self.force_run_chain.clone()), - Direction::Sink, + self.force_run_chain.clone(), ) .await?, );