Skip to content

Commit

Permalink
Merge branch 'main' into remove-sasl-config
Browse files Browse the repository at this point in the history
  • Loading branch information
conorbros authored Mar 26, 2024
2 parents 471100a + e9a8907 commit 42960be
Show file tree
Hide file tree
Showing 9 changed files with 386 additions and 327 deletions.
62 changes: 27 additions & 35 deletions shotover/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This is the one true connection implementation that all other transforms + server.rs should be ported to.
//! All Sink transforms use SinkConnection for their outgoing connections.

use crate::codec::{CodecBuilder, CodecReadError, CodecWriteError, Direction};
use crate::codec::{CodecBuilder, CodecReadError, CodecWriteError};
use crate::frame::Frame;
use crate::message::{Message, MessageId, Messages};
use crate::tcp;
Expand All @@ -18,22 +18,21 @@ use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::error;
use tracing::Instrument;

pub struct Connection {
pub struct SinkConnection {
in_rx: mpsc::Receiver<Vec<Message>>,
out_tx: mpsc::UnboundedSender<Vec<Message>>,
connection_closed_rx: mpsc::Receiver<ConnectionError>,
error: Option<ConnectionError>,
dummy_response_inserter: Option<DummyResponseInserter>,
dummy_response_inserter: DummyResponseInserter,
}

impl Connection {
impl SinkConnection {
pub async fn new<A: ToSocketAddrs + ToHostname + std::fmt::Debug, C: CodecBuilder + 'static>(
host: A,
codec_builder: C,
tls: &Option<TlsConnector>,
connect_timeout: Duration,
force_run_chain: Option<Arc<Notify>>,
direction: Direction,
force_run_chain: Arc<Notify>,
) -> anyhow::Result<Self> {
let destination = tokio::net::lookup_host(&host).await?.next().unwrap();
let (in_tx, in_rx) = mpsc::channel::<Messages>(10_000);
Expand Down Expand Up @@ -68,15 +67,12 @@ impl Connection {
);
}

let dummy_response_inserter = match direction {
Direction::Source => None,
Direction::Sink => Some(DummyResponseInserter {
dummy_requests: vec![],
pending_requests_count: 0,
}),
let dummy_response_inserter = DummyResponseInserter {
dummy_requests: vec![],
pending_requests_count: 0,
};

Ok(Connection {
Ok(SinkConnection {
in_rx,
out_tx,
connection_closed_rx,
Expand All @@ -93,9 +89,8 @@ impl Connection {
/// Send messages.
/// If there is a problem with the connection an error is returned.
pub fn send(&mut self, mut messages: Vec<Message>) -> Result<(), ConnectionError> {
if let Some(dummy_response_inserter) = &mut self.dummy_response_inserter {
dummy_response_inserter.process_requests(&mut messages);
}
self.dummy_response_inserter.process_requests(&mut messages);

if let Some(error) = &self.error {
Err(error.clone())
} else {
Expand All @@ -110,20 +105,19 @@ impl Connection {
Err(error.clone())
} else {
// first process any immediately pending dummy responses
if let Some(dummy_response_inserter) = &mut self.dummy_response_inserter {
// ensure we include any received messages so we dont leave them hanging after using up a force_run_chain.
let mut messages = self.in_rx.try_recv().unwrap_or_default();
dummy_response_inserter.process_responses(&mut messages);
if !messages.is_empty() {
return Ok(messages);
}

// ensure we include any received messages so we dont leave them hanging after using up a force_run_chain.
let mut messages = self.in_rx.try_recv().unwrap_or_default();
self.dummy_response_inserter
.process_responses(&mut messages);
if !messages.is_empty() {
return Ok(messages);
}

match self.in_rx.recv().await {
Some(mut messages) => {
if let Some(dummy_response_inserter) = &mut self.dummy_response_inserter {
dummy_response_inserter.process_responses(&mut messages);
}
self.dummy_response_inserter
.process_responses(&mut messages);
Ok(messages)
}
None => Err(self.set_get_error()),
Expand All @@ -139,9 +133,8 @@ impl Connection {
} else {
match self.in_rx.try_recv() {
Ok(mut messages) => {
if let Some(dummy_response_inserter) = &mut self.dummy_response_inserter {
dummy_response_inserter.process_responses(&mut messages);
}
self.dummy_response_inserter
.process_responses(&mut messages);
Ok(messages)
}
Err(TryRecvError::Disconnected) => Err(self.set_get_error()),
Expand Down Expand Up @@ -179,7 +172,7 @@ fn spawn_read_write_tasks<
in_tx: mpsc::Sender<Messages>,
out_rx: UnboundedReceiver<Messages>,
out_tx: UnboundedSender<Messages>,
force_run_chain: Option<Arc<Notify>>,
force_run_chain: Arc<Notify>,
connection_closed_tx: mpsc::Sender<ConnectionError>,
) {
let (decoder, encoder) = codec.build();
Expand Down Expand Up @@ -243,7 +236,7 @@ async fn reader_task<C: CodecBuilder + 'static, R: AsyncRead + Unpin + Send + 's
mut reader: FramedRead<R, <C as CodecBuilder>::Decoder>,
in_tx: mpsc::Sender<Messages>,
out_tx: UnboundedSender<Messages>,
force_run_chain: Option<Arc<Notify>>,
force_run_chain: Arc<Notify>,
) -> Result<(), ConnectionError> {
loop {
tokio::select! {
Expand All @@ -260,9 +253,8 @@ async fn reader_task<C: CodecBuilder + 'static, R: AsyncRead + Unpin + Send + 's
// main task has shutdown, this task is no longer needed
return Ok(());
}
if let Some(force_run_chain) = force_run_chain.as_ref() {
force_run_chain.notify_one();
}

force_run_chain.notify_one();
}
Err(CodecReadError::RespondAndThenCloseConnection(messages)) => {
if let Err(err) = out_tx.send(messages) {
Expand Down
6 changes: 0 additions & 6 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,6 @@ pub fn spawn_read_write_tasks<
in_tx: mpsc::Sender<Messages>,
mut out_rx: UnboundedReceiver<Messages>,
out_tx: UnboundedSender<Messages>,
force_run_chain: Option<Arc<Notify>>,
) {
let (decoder, encoder) = codec.build();
let mut reader = FramedRead::new(rx, decoder);
Expand Down Expand Up @@ -468,9 +467,6 @@ pub fn spawn_read_write_tasks<
// main task has shutdown, this task is no longer needed
return;
}
if let Some(force_run_chain) = force_run_chain.as_ref() {
force_run_chain.notify_one();
}
}
Err(CodecReadError::RespondAndThenCloseConnection(messages)) => {
if let Err(err) = out_tx.send(messages) {
Expand Down Expand Up @@ -618,7 +614,6 @@ impl<C: CodecBuilder + 'static> Handler<C> {
in_tx,
out_rx,
out_tx.clone(),
None,
);
} else {
let (rx, tx) = stream.into_split();
Expand All @@ -629,7 +624,6 @@ impl<C: CodecBuilder + 'static> Handler<C> {
in_tx,
out_rx,
out_tx.clone(),
None,
);
};
}
Expand Down
9 changes: 4 additions & 5 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::codec::{cassandra::CassandraCodecBuilder, CodecBuilder, Direction};
use crate::connection::Connection;
use crate::connection::SinkConnection;
use crate::frame::cassandra::CassandraMetadata;
use crate::message::{Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
Expand Down Expand Up @@ -111,7 +111,7 @@ impl TransformBuilder for CassandraSinkSingleBuilder {
pub struct CassandraSinkSingle {
version: Option<Version>,
address: String,
connection: Option<Connection>,
connection: Option<SinkConnection>,
failed_requests: Counter,
tls: Option<TlsConnector>,
connect_timeout: Duration,
Expand Down Expand Up @@ -144,13 +144,12 @@ impl CassandraSinkSingle {
if self.connection.is_none() {
trace!("creating outbound connection {:?}", self.address);
self.connection = Some(
Connection::new(
SinkConnection::new(
self.address.clone(),
self.codec_builder.clone(),
&self.tls,
self.connect_timeout,
Some(self.force_run_chain.clone()),
Direction::Sink,
self.force_run_chain.clone(),
)
.await?,
);
Expand Down
24 changes: 0 additions & 24 deletions shotover/src/transforms/kafka/common.rs

This file was deleted.

1 change: 0 additions & 1 deletion shotover/src/transforms/kafka/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
mod common;
pub mod sink_cluster;
pub mod sink_single;
Loading

0 comments on commit 42960be

Please sign in to comment.