diff --git a/shotover-proxy/src/server.rs b/shotover-proxy/src/server.rs index df669584c..68494f0c3 100644 --- a/shotover-proxy/src/server.rs +++ b/shotover-proxy/src/server.rs @@ -3,7 +3,7 @@ use crate::tls::TlsAcceptor; use crate::transforms::chain::TransformChain; use crate::transforms::Wrapper; use anyhow::{anyhow, Context, Result}; -use futures::StreamExt; +use futures::{SinkExt, StreamExt}; use metrics::{register_gauge, Gauge}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; @@ -13,7 +13,6 @@ use tokio::sync::{mpsc, watch, Semaphore}; use tokio::time; use tokio::time::timeout; use tokio::time::Duration; -use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{FramedRead, FramedWrite}; use tracing::Instrument; @@ -372,12 +371,12 @@ fn spawn_read_write_tasks< rx: R, tx: W, in_tx: UnboundedSender, - out_rx: UnboundedReceiver, + mut out_rx: UnboundedReceiver, out_tx: UnboundedSender, mut terminate_tasks_rx: watch::Receiver<()>, ) { let mut reader = FramedRead::new(rx, codec.clone()); - let writer = FramedWrite::new(tx, codec); + let mut writer = FramedWrite::new(tx, codec); tokio::spawn( async move { @@ -399,19 +398,54 @@ fn spawn_read_write_tasks< } } } + error!("tx task end"); } .in_current_span(), ); tokio::spawn( async move { - let rx_stream = UnboundedReceiverStream::new(out_rx).map(Ok); - tokio::select! { - Err(err) = rx_stream.forward(writer) => { - error!("failed to send or encode message: {:?}", err); + loop { + tokio::select! { + result = out_rx.recv() => { + if let Some(message) = result { + if let Err(err) = writer.send(message).await { + error!("failed to send or encode message: {:?}", err); + } + } else { + error!("tx task ending out_rx closed"); + break; + } + } + _ = terminate_tasks_rx.changed() => { + error!("terminate_tasks_rx received"); + while let Ok(message) = out_rx.try_recv() { + error!("tx task message flushed"); + if let Err(err) = writer.send(message).await { + error!("failed to send or encode message: {:?}", err); + } + } + error!("tx task end flushing finished"); + break; + } } - _ = terminate_tasks_rx.changed() => { } } + error!("tx task end"); + //tokio::time::sleep(std::time::Duration::from_millis(1000)).await; + + // The cassandra protocol needs to: + // 1. receive bad version init + // 2. reply with error + // 3. receive another message + // 4. kill the connection: + // 1. codec returns Err + // 2. rx task receives Err, logging it and returning + // 3. rx task ends dropping the in_tx + // 4. main task receives None from in_rx causing it to return + // 5. main task ends resulting in drop running terminate_tasks_tx.send(()) + // + // I suspect that: + // Sender is backlogged and gets killed before it can process everything so 2 never occurs } .in_current_span(), ); @@ -436,6 +470,7 @@ impl Handler { let mut idle_time_seconds: u64 = 1; let (terminate_tx, terminate_rx) = watch::channel::<()>(()); + tracing::error!("{:?}", terminate_rx.has_changed()); self.terminate_tasks = Some(terminate_tx); let (in_tx, mut in_rx) = mpsc::unbounded_channel::(); @@ -478,7 +513,9 @@ impl Handler { idle_time_seconds = 1; match maybe_message { Some(m) => m, - None => return Ok(()) + None => { + error!("main task ending due to in_rx shutdown"); + return Ok(())} } }, Err(_) => { @@ -528,6 +565,7 @@ impl Handler { // send the result of the process up stream out_tx.send(modified_messages)?; } + error!("main task end"); Ok(()) } }