From 97f027ab153a7add8a186292b8420a1d30ea9cbc Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Thu, 14 Mar 2024 21:16:58 +1100 Subject: [PATCH] Add force_run_chain Notify (#1525) --- .../src/redis_get_rewrite.rs | 3 +- shotover/benches/benches/chain.rs | 24 +++++---- shotover/src/server.rs | 50 +++++++++++++------ .../src/transforms/cassandra/peers_rewrite.rs | 6 ++- .../transforms/cassandra/sink_cluster/mod.rs | 5 +- .../src/transforms/cassandra/sink_single.rs | 5 +- shotover/src/transforms/chain.rs | 25 +++++++--- shotover/src/transforms/coalesce.rs | 4 +- shotover/src/transforms/debug/force_parse.rs | 3 +- shotover/src/transforms/debug/log_to_file.rs | 4 +- shotover/src/transforms/debug/printer.rs | 5 +- shotover/src/transforms/debug/returner.rs | 5 +- .../tuneable_consistency_scatter.rs | 12 +++-- shotover/src/transforms/filter.rs | 4 +- .../src/transforms/kafka/sink_cluster/mod.rs | 4 +- shotover/src/transforms/kafka/sink_single.rs | 6 ++- shotover/src/transforms/load_balance.rs | 10 ++-- shotover/src/transforms/loopback.rs | 4 +- shotover/src/transforms/mod.rs | 27 +++++++++- shotover/src/transforms/null.rs | 4 +- shotover/src/transforms/opensearch/mod.rs | 4 +- shotover/src/transforms/parallel_map.rs | 10 ++-- shotover/src/transforms/protect/mod.rs | 4 +- shotover/src/transforms/query_counter.rs | 3 +- shotover/src/transforms/redis/cache.rs | 7 +-- .../transforms/redis/cluster_ports_rewrite.rs | 3 +- shotover/src/transforms/redis/sink_cluster.rs | 5 +- shotover/src/transforms/redis/sink_single.rs | 6 ++- .../src/transforms/redis/timestamp_tagging.rs | 5 +- shotover/src/transforms/tee.rs | 10 ++-- shotover/src/transforms/throttling.rs | 4 +- 31 files changed, 180 insertions(+), 91 deletions(-) diff --git a/custom-transforms-example/src/redis_get_rewrite.rs b/custom-transforms-example/src/redis_get_rewrite.rs index 55d757790..ed956374d 100644 --- a/custom-transforms-example/src/redis_get_rewrite.rs +++ b/custom-transforms-example/src/redis_get_rewrite.rs @@ -3,6 +3,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use shotover::frame::{Frame, RedisFrame}; use shotover::message::{MessageIdSet, Messages}; +use shotover::transforms::TransformContextBuilder; use shotover::transforms::{ Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, }; @@ -33,7 +34,7 @@ pub struct RedisGetRewriteBuilder { } impl TransformBuilder for RedisGetRewriteBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(RedisGetRewrite { get_requests: MessageIdSet::default(), result: self.result.clone(), diff --git a/shotover/benches/benches/chain.rs b/shotover/benches/benches/chain.rs index 140dc476d..d23256052 100644 --- a/shotover/benches/benches/chain.rs +++ b/shotover/benches/benches/chain.rs @@ -19,7 +19,9 @@ use shotover::transforms::protect::{KeyManagerConfig, ProtectConfig}; use shotover::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite; use shotover::transforms::redis::timestamp_tagging::RedisTimestampTagger; use shotover::transforms::throttling::RequestThrottlingConfig; -use shotover::transforms::{TransformConfig, TransformContextConfig, Wrapper}; +use shotover::transforms::{ + TransformConfig, TransformContextBuilder, TransformContextConfig, Wrapper, +}; fn criterion_benchmark(c: &mut Criterion) { crate::init(); @@ -38,7 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("loopback", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper.clone(), }, BenchInput::bench, @@ -57,7 +59,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("nullsink", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper.clone(), }, BenchInput::bench, @@ -95,7 +97,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("redis_filter", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper.clone(), }, BenchInput::bench, @@ -129,7 +131,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("redis_timestamp_tagger_untagged", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper_set.clone(), }, BenchInput::bench, @@ -148,7 +150,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("redis_timestamp_tagger_tagged", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper_get.clone(), }, BenchInput::bench, @@ -177,7 +179,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("redis_cluster_ports_rewrite", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper.clone(), }, BenchInput::bench, @@ -224,7 +226,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("cassandra_request_throttling_unparsed", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper.clone(), }, BenchInput::bench, @@ -278,7 +280,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("cassandra_rewrite_peers_passthrough", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper.clone(), }, BenchInput::bench, @@ -324,7 +326,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("cassandra_protect_unprotected", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper.clone(), }, BenchInput::bench, @@ -339,7 +341,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("cassandra_protect_protected", |b| { b.to_async(&rt).iter_batched( || BenchInput { - chain: chain.build(), + chain: chain.build(TransformContextBuilder::new()), wrapper: wrapper.clone(), }, BenchInput::bench, diff --git a/shotover/src/server.rs b/shotover/src/server.rs index 2eef16403..c28ebcec4 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -4,7 +4,7 @@ use crate::message::{Message, Messages}; use crate::sources::Transport; use crate::tls::{AcceptError, TlsAcceptor}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; -use crate::transforms::{TransformContextConfig, Wrapper}; +use crate::transforms::{TransformContextBuilder, TransformContextConfig, Wrapper}; use anyhow::{anyhow, Context, Result}; use bytes::BytesMut; use futures::future::join_all; @@ -16,7 +16,7 @@ use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; -use tokio::sync::{mpsc, watch, OwnedSemaphorePermit, Semaphore}; +use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore}; use tokio::task::JoinHandle; use tokio::time; use tokio::time::Duration; @@ -190,10 +190,15 @@ impl TcpCodecListener { let (pushed_messages_tx, pushed_messages_rx) = tokio::sync::mpsc::unbounded_channel::(); + let force_run_chain = Arc::new(Notify::new()); + let context = TransformContextBuilder { + force_run_chain: force_run_chain.clone(), + }; + let handler = Handler { chain: self .chain_builder - .build_with_pushed_messages(pushed_messages_tx), + .build_with_pushed_messages(pushed_messages_tx, context), codec: self.codec.clone(), shutdown: Shutdown::new(self.trigger_shutdown_rx.clone()), tls: self.tls.clone(), @@ -206,7 +211,7 @@ impl TcpCodecListener { self.connection_handles.push(tokio::spawn( async move { // Process the connection. If an error is encountered, log it. - if let Err(err) = handler.run(stream, transport).await { + if let Err(err) = handler.run(stream, transport, force_run_chain).await { error!( "{:?}", err.context("connection was unexpectedly terminated") @@ -576,7 +581,12 @@ impl Handler { /// /// When the shutdown signal is received, the connection is processed until /// it reaches a safe state, at which point it is terminated. - pub async fn run(mut self, stream: TcpStream, transport: Transport) -> Result<()> { + pub async fn run( + mut self, + stream: TcpStream, + transport: Transport, + force_run_chain: Arc, + ) -> Result<()> { stream.set_nodelay(true)?; let client_details = stream @@ -658,7 +668,7 @@ impl Handler { }; let result = self - .process_messages(&client_details, local_addr, in_rx, out_tx) + .process_messages(&client_details, local_addr, in_rx, out_tx, force_run_chain) .await; // Flush messages regardless of if we are shutting down due to a failure or due to application shutdown @@ -700,6 +710,7 @@ impl Handler { local_addr: SocketAddr, mut in_rx: mpsc::Receiver, out_tx: mpsc::UnboundedSender, + force_run_chain: Arc, ) -> Result<()> { // As long as the shutdown signal has not been received, try to read a // new request frame. @@ -707,6 +718,24 @@ impl Handler { // While reading a request frame, also listen for the shutdown signal debug!("Waiting for message {client_details}"); let responses = tokio::select! { + biased; + _ = self.shutdown.recv() => { + // If a shutdown signal is received, return from `run`. + // This will result in the task terminating. + return Ok(()); + } + Some(responses) = self.pushed_messages_rx.recv() => { + debug!("Received unrequested responses from destination {:?}", responses); + self.process_backward(client_details, local_addr, responses).await? + } + () = force_run_chain.notified() => { + let mut requests = vec!(); + while let Ok(x) = in_rx.try_recv() { + requests.extend(x); + } + debug!("A transform in the chain requested that a chain run occur, requests {:?}", requests); + self.process_forward(client_details, local_addr, &out_tx, requests).await? + }, requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => { match requests { Some(mut requests) => { @@ -722,15 +751,6 @@ impl Handler { } } }, - Some(responses) = self.pushed_messages_rx.recv() => { - debug!("Received unrequested responses from destination {:?}", responses); - self.process_backward(client_details, local_addr, responses).await? - } - _ = self.shutdown.recv() => { - // If a shutdown signal is received, return from `run`. - // This will result in the task terminating. - return Ok(()); - } }; debug!("sending response to client: {:?}", responses); diff --git a/shotover/src/transforms/cassandra/peers_rewrite.rs b/shotover/src/transforms/cassandra/peers_rewrite.rs index c887a0d16..fafbfa7f5 100644 --- a/shotover/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover/src/transforms/cassandra/peers_rewrite.rs @@ -1,6 +1,8 @@ use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, Wrapper, +}; use crate::{ frame::{ value::{GenericValue, IntSize}, @@ -52,7 +54,7 @@ impl CassandraPeersRewrite { } impl TransformBuilder for CassandraPeersRewrite { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/cassandra/sink_cluster/mod.rs b/shotover/src/transforms/cassandra/sink_cluster/mod.rs index 5fb4fd222..907f2bd20 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/mod.rs @@ -6,7 +6,8 @@ use crate::message::{Message, MessageIdMap, Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError}; use crate::transforms::{ - Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig, + Wrapper, }; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; @@ -160,7 +161,7 @@ impl CassandraSinkClusterBuilder { } impl TransformBuilder for CassandraSinkClusterBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(CassandraSinkCluster { contact_points: self.contact_points.clone(), message_rewriter: self.message_rewriter.clone(), diff --git a/shotover/src/transforms/cassandra/sink_single.rs b/shotover/src/transforms/cassandra/sink_single.rs index 527ca5db9..1bae4f9f6 100644 --- a/shotover/src/transforms/cassandra/sink_single.rs +++ b/shotover/src/transforms/cassandra/sink_single.rs @@ -5,7 +5,8 @@ use crate::message::{Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::Response; use crate::transforms::{ - Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig, + Wrapper, }; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -83,7 +84,7 @@ impl CassandraSinkSingleBuilder { } impl TransformBuilder for CassandraSinkSingleBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(CassandraSinkSingle { outbound: None, version: self.version, diff --git a/shotover/src/transforms/chain.rs b/shotover/src/transforms/chain.rs index dd5b798e0..507d00bb2 100644 --- a/shotover/src/transforms/chain.rs +++ b/shotover/src/transforms/chain.rs @@ -9,6 +9,8 @@ use tokio::sync::{mpsc, oneshot}; use tokio::time::{Duration, Instant}; use tracing::{debug, error, info, trace, Instrument}; +use super::TransformContextBuilder; + type InnerChain = Vec; #[derive(Debug)] @@ -233,9 +235,9 @@ pub struct TransformBuilderAndMetrics { } impl TransformBuilderAndMetrics { - fn build(&self) -> TransformAndMetrics { + fn build(&self, context: TransformContextBuilder) -> TransformAndMetrics { TransformAndMetrics { - transform: self.builder.build(), + transform: self.builder.build(context), transform_total: self.transform_total.clone(), transform_failures: self.transform_failures.clone(), transform_latency: self.transform_latency.clone(), @@ -331,7 +333,11 @@ impl TransformChainBuilder { errors } - pub fn build_buffered(&self, buffer_size: usize) -> BufferedChain { + pub fn build_buffered( + &self, + buffer_size: usize, + context: TransformContextBuilder, + ) -> BufferedChain { let (tx, mut rx) = mpsc::channel::(buffer_size); #[cfg(test)] @@ -341,7 +347,7 @@ impl TransformChainBuilder { // Even though we don't keep the join handle, this thread will wrap up once all corresponding senders have been dropped. - let mut chain = self.build(); + let mut chain = self.build(context); let _jh = tokio::spawn( async move { while let Some(BufferedChainMessages { @@ -398,8 +404,12 @@ impl TransformChainBuilder { } /// Clone the chain while adding a producer for the pushed messages channel - pub fn build(&self) -> TransformChain { - let chain = self.chain.iter().map(|x| x.build()).collect(); + pub fn build(&self, context: TransformContextBuilder) -> TransformChain { + let chain = self + .chain + .iter() + .map(|x| x.build(context.clone())) + .collect(); TransformChain { name: self.name, @@ -414,12 +424,13 @@ impl TransformChainBuilder { pub fn build_with_pushed_messages( &self, pushed_messages_tx: mpsc::UnboundedSender, + context: TransformContextBuilder, ) -> TransformChain { let chain = self .chain .iter() .map(|x| { - let mut transform = x.build(); + let mut transform = x.build(context.clone()); transform .transform .set_pushed_messages_tx(pushed_messages_tx.clone()); diff --git a/shotover/src/transforms/coalesce.rs b/shotover/src/transforms/coalesce.rs index b2a0add1f..5a87b4424 100644 --- a/shotover/src/transforms/coalesce.rs +++ b/shotover/src/transforms/coalesce.rs @@ -1,4 +1,4 @@ -use super::TransformContextConfig; +use super::{TransformContextBuilder, TransformContextConfig}; use crate::message::Messages; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; @@ -39,7 +39,7 @@ impl TransformConfig for CoalesceConfig { } impl TransformBuilder for Coalesce { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/debug/force_parse.rs b/shotover/src/transforms/debug/force_parse.rs index d71ff8e3f..28392be2e 100644 --- a/shotover/src/transforms/debug/force_parse.rs +++ b/shotover/src/transforms/debug/force_parse.rs @@ -7,6 +7,7 @@ use crate::message::Messages; /// It could also be used to ensure that messages round trip correctly when parsed. #[cfg(feature = "alpha-transforms")] use crate::transforms::TransformConfig; +use crate::transforms::TransformContextBuilder; #[cfg(feature = "alpha-transforms")] use crate::transforms::TransformContextConfig; use crate::transforms::{Transform, TransformBuilder, Wrapper}; @@ -76,7 +77,7 @@ pub struct DebugForceParse { } impl TransformBuilder for DebugForceParse { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/debug/log_to_file.rs b/shotover/src/transforms/debug/log_to_file.rs index 9a4b707ff..2d25a4b63 100644 --- a/shotover/src/transforms/debug/log_to_file.rs +++ b/shotover/src/transforms/debug/log_to_file.rs @@ -1,5 +1,5 @@ use crate::message::{Encodable, Message}; -use crate::transforms::{Transform, TransformBuilder, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformContextBuilder, Wrapper}; use anyhow::{Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -35,7 +35,7 @@ pub struct DebugLogToFileBuilder { } impl TransformBuilder for DebugLogToFileBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { self.connection_counter.fetch_add(1, Ordering::Relaxed); let connection_current = self.connection_counter.load(Ordering::Relaxed); diff --git a/shotover/src/transforms/debug/printer.rs b/shotover/src/transforms/debug/printer.rs index 2ba8e8f67..0fa509b22 100644 --- a/shotover/src/transforms/debug/printer.rs +++ b/shotover/src/transforms/debug/printer.rs @@ -1,6 +1,7 @@ use crate::message::Messages; use crate::transforms::{ - Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig, + Wrapper, }; use anyhow::Result; use async_trait::async_trait; @@ -41,7 +42,7 @@ impl DebugPrinter { } impl TransformBuilder for DebugPrinter { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/debug/returner.rs b/shotover/src/transforms/debug/returner.rs index 5f62acc68..4c9adff5c 100644 --- a/shotover/src/transforms/debug/returner.rs +++ b/shotover/src/transforms/debug/returner.rs @@ -1,6 +1,7 @@ use crate::message::{Message, Messages}; use crate::transforms::{ - Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig, + Wrapper, }; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -47,7 +48,7 @@ impl DebugReturner { } impl TransformBuilder for DebugReturner { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs index efffb0025..00c3b7383 100644 --- a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs +++ b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs @@ -3,7 +3,8 @@ use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; use crate::transforms::{ - Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig, + Wrapper, }; use anyhow::Result; use async_trait::async_trait; @@ -56,12 +57,12 @@ pub struct TuneableConsistencyScatterBuilder { } impl TransformBuilder for TuneableConsistencyScatterBuilder { - fn build(&self) -> Box { + fn build(&self, transform_context: TransformContextBuilder) -> Box { Box::new(TuneableConsistentencyScatter { route_map: self .route_map .iter() - .map(|x| x.build_buffered(10)) + .map(|x| x.build_buffered(10, transform_context.clone())) .collect(), write_consistency: self.write_consistency, read_consistency: self.read_consistency, @@ -301,7 +302,7 @@ mod scatter_transform_tests { TuneableConsistencyScatterBuilder, TuneableConsistentencyScatter, }; use crate::transforms::null::NullSink; - use crate::transforms::{Transform, TransformBuilder, Wrapper}; + use crate::transforms::{Transform, TransformBuilder, TransformContextBuilder, Wrapper}; use bytes::Bytes; use std::collections::HashMap; @@ -318,9 +319,10 @@ mod scatter_transform_tests { } fn build_chains(route_map: HashMap) -> Vec { + let context = TransformContextBuilder::new(); route_map .into_values() - .map(|x| x.build_buffered(10)) + .map(|x| x.build_buffered(10, context.clone())) .collect() } diff --git a/shotover/src/transforms/filter.rs b/shotover/src/transforms/filter.rs index 4c9e0b51f..51d9ac058 100644 --- a/shotover/src/transforms/filter.rs +++ b/shotover/src/transforms/filter.rs @@ -1,4 +1,4 @@ -use super::TransformContextConfig; +use super::{TransformContextBuilder, TransformContextConfig}; use crate::message::{Message, MessageIdMap, Messages, QueryType}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; @@ -41,7 +41,7 @@ impl TransformConfig for QueryTypeFilterConfig { } impl TransformBuilder for QueryTypeFilter { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index f936c1134..80ee06f1e 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -4,7 +4,7 @@ use crate::frame::Frame; use crate::message::{Message, Messages}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::util::{Request, Response}; -use crate::transforms::{Transform, TransformBuilder, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformContextBuilder, Wrapper}; use crate::transforms::{TransformConfig, TransformContextConfig}; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -143,7 +143,7 @@ impl KafkaSinkClusterBuilder { } impl TransformBuilder for KafkaSinkClusterBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(KafkaSinkCluster { first_contact_points: self.first_contact_points.clone(), shotover_nodes: self.shotover_nodes.clone(), diff --git a/shotover/src/transforms/kafka/sink_single.rs b/shotover/src/transforms/kafka/sink_single.rs index a3c3be49d..4eb41a995 100644 --- a/shotover/src/transforms/kafka/sink_single.rs +++ b/shotover/src/transforms/kafka/sink_single.rs @@ -7,7 +7,9 @@ use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::kafka::common::produce_channel; use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection}; use crate::transforms::util::{Request, Response}; -use crate::transforms::{Transform, TransformBuilder, TransformContextConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformContextBuilder, TransformContextConfig, Wrapper, +}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -76,7 +78,7 @@ impl KafkaSinkSingleBuilder { } impl TransformBuilder for KafkaSinkSingleBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(KafkaSinkSingle { outbound: None, address_port: self.address_port, diff --git a/shotover/src/transforms/load_balance.rs b/shotover/src/transforms/load_balance.rs index 9d82929f9..3e2597a29 100644 --- a/shotover/src/transforms/load_balance.rs +++ b/shotover/src/transforms/load_balance.rs @@ -1,4 +1,4 @@ -use super::TransformContextConfig; +use super::{TransformContextBuilder, TransformContextConfig}; use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; @@ -43,12 +43,13 @@ pub struct ConnectionBalanceAndPoolBuilder { } impl TransformBuilder for ConnectionBalanceAndPoolBuilder { - fn build(&self) -> Box { + fn build(&self, transform_context: TransformContextBuilder) -> Box { Box::new(ConnectionBalanceAndPool { active_connection: None, max_connections: self.max_connections, all_connections: self.all_connections.clone(), chain_to_clone: self.chain_to_clone.clone(), + transform_context, }) } @@ -69,6 +70,7 @@ pub struct ConnectionBalanceAndPool { max_connections: usize, all_connections: Arc>>, chain_to_clone: Arc, + transform_context: TransformContextBuilder, } #[async_trait] @@ -81,7 +83,9 @@ impl Transform for ConnectionBalanceAndPool { if self.active_connection.is_none() { let mut all_connections = self.all_connections.lock().await; if all_connections.len() < self.max_connections { - let chain = self.chain_to_clone.build_buffered(5); + let chain = self + .chain_to_clone + .build_buffered(5, self.transform_context.clone()); self.active_connection = Some(chain.clone()); all_connections.push(chain); } else { diff --git a/shotover/src/transforms/loopback.rs b/shotover/src/transforms/loopback.rs index 7345a2832..5ba587b85 100644 --- a/shotover/src/transforms/loopback.rs +++ b/shotover/src/transforms/loopback.rs @@ -3,13 +3,15 @@ use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::Result; use async_trait::async_trait; +use super::TransformContextBuilder; + const NAME: &str = "Loopback"; #[derive(Debug, Clone, Default)] pub struct Loopback {} impl TransformBuilder for Loopback { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/mod.rs b/shotover/src/transforms/mod.rs index 034d8cb09..1fec11dc5 100644 --- a/shotover/src/transforms/mod.rs +++ b/shotover/src/transforms/mod.rs @@ -12,7 +12,8 @@ use std::iter::Rev; use std::net::SocketAddr; use std::pin::Pin; use std::slice::IterMut; -use tokio::sync::mpsc; +use std::sync::Arc; +use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; #[cfg(feature = "cassandra")] @@ -42,8 +43,30 @@ pub mod tee; pub mod throttling; pub mod util; +/// Provides extra context that may be needed when creating a Transform +#[derive(Clone, Debug)] +pub struct TransformContextBuilder { + /// The chain is run naturally whenever messages are received from the client. + /// However, for various reasons, a transform may want to force a chain run. + /// This can be done by calling `notify_one` on this field. + /// + /// For example: + /// * This must be used when a sink transform has asynchronously received responses in the background + /// * This should be used when a transform needs to generate or flush messages after some kind of timeout or background process completes. + pub force_run_chain: Arc, +} + +#[allow(clippy::new_without_default)] +impl TransformContextBuilder { + pub fn new() -> Self { + TransformContextBuilder { + force_run_chain: Arc::new(Notify::new()), + } + } +} + pub trait TransformBuilder: Send + Sync { - fn build(&self) -> Box; + fn build(&self, transform_context: TransformContextBuilder) -> Box; fn get_name(&self) -> &'static str; diff --git a/shotover/src/transforms/null.rs b/shotover/src/transforms/null.rs index 0c46e6a09..cb2f7aaab 100644 --- a/shotover/src/transforms/null.rs +++ b/shotover/src/transforms/null.rs @@ -1,4 +1,4 @@ -use super::TransformContextConfig; +use super::{TransformContextBuilder, TransformContextConfig}; use crate::message::Messages; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; @@ -25,7 +25,7 @@ impl TransformConfig for NullSinkConfig { pub struct NullSink {} impl TransformBuilder for NullSink { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/opensearch/mod.rs b/shotover/src/transforms/opensearch/mod.rs index 6464d29c5..0682d1693 100644 --- a/shotover/src/transforms/opensearch/mod.rs +++ b/shotover/src/transforms/opensearch/mod.rs @@ -1,4 +1,4 @@ -use super::TransformContextConfig; +use super::{TransformContextBuilder, TransformContextConfig}; use crate::tcp; use crate::transforms::{Messages, Transform, TransformBuilder, TransformConfig, Wrapper}; use crate::{ @@ -56,7 +56,7 @@ impl OpenSearchSinkSingleBuilder { } impl TransformBuilder for OpenSearchSinkSingleBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(OpenSearchSinkSingle { address: self.address.clone(), connect_timeout: self.connect_timeout, diff --git a/shotover/src/transforms/parallel_map.rs b/shotover/src/transforms/parallel_map.rs index a716fd7e6..ef89bb512 100644 --- a/shotover/src/transforms/parallel_map.rs +++ b/shotover/src/transforms/parallel_map.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use std::future::Future; use std::pin::Pin; -use super::TransformContextConfig; +use super::{TransformContextBuilder, TransformContextConfig}; #[derive(Debug)] pub struct ParallelMapBuilder { @@ -132,9 +132,13 @@ impl Transform for ParallelMap { } impl TransformBuilder for ParallelMapBuilder { - fn build(&self) -> Box { + fn build(&self, transform_context: TransformContextBuilder) -> Box { Box::new(ParallelMap { - chains: self.chains.iter().map(|x| x.build()).collect(), + chains: self + .chains + .iter() + .map(|x| x.build(transform_context.clone())) + .collect(), ordered: self.ordered, }) } diff --git a/shotover/src/transforms/protect/mod.rs b/shotover/src/transforms/protect/mod.rs index f2041daac..35005f879 100644 --- a/shotover/src/transforms/protect/mod.rs +++ b/shotover/src/transforms/protect/mod.rs @@ -14,6 +14,8 @@ use cql3_parser::select::SelectElement; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use super::TransformContextBuilder; + mod aws_kms; mod crypto; mod key_management; @@ -73,7 +75,7 @@ pub struct Protect { } impl TransformBuilder for Protect { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/query_counter.rs b/shotover/src/transforms/query_counter.rs index 575ee9d72..f77f274ca 100644 --- a/shotover/src/transforms/query_counter.rs +++ b/shotover/src/transforms/query_counter.rs @@ -1,6 +1,7 @@ use crate::frame::Frame; use crate::message::Messages; use crate::transforms::TransformConfig; +use crate::transforms::TransformContextBuilder; use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::Result; use async_trait::async_trait; @@ -30,7 +31,7 @@ impl QueryCounter { } impl TransformBuilder for QueryCounter { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/redis/cache.rs b/shotover/src/transforms/redis/cache.rs index b05643f5e..337387099 100644 --- a/shotover/src/transforms/redis/cache.rs +++ b/shotover/src/transforms/redis/cache.rs @@ -3,7 +3,8 @@ use crate::frame::{CassandraFrame, CassandraOperation, Frame, MessageType, Redis use crate::message::{Message, MessageIdMap, Messages, Metadata}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; use crate::transforms::{ - Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig, + Wrapper, }; use anyhow::{bail, Result}; use async_trait::async_trait; @@ -119,9 +120,9 @@ pub struct SimpleRedisCacheBuilder { } impl TransformBuilder for SimpleRedisCacheBuilder { - fn build(&self) -> Box { + fn build(&self, transform_context: TransformContextBuilder) -> Box { Box::new(SimpleRedisCache { - cache_chain: self.cache_chain.build(), + cache_chain: self.cache_chain.build(transform_context.clone()), caching_schema: self.caching_schema.clone(), missed_requests: self.missed_requests.clone(), pending_cache_requests: Default::default(), diff --git a/shotover/src/transforms/redis/cluster_ports_rewrite.rs b/shotover/src/transforms/redis/cluster_ports_rewrite.rs index 0e704cfcf..2e74233fc 100644 --- a/shotover/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover/src/transforms/redis/cluster_ports_rewrite.rs @@ -1,6 +1,7 @@ use crate::frame::Frame; use crate::frame::RedisFrame; use crate::message::{MessageIdMap, Messages}; +use crate::transforms::TransformContextBuilder; use crate::transforms::TransformContextConfig; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, bail, Context, Result}; @@ -28,7 +29,7 @@ impl TransformConfig for RedisClusterPortsRewriteConfig { } impl TransformBuilder for RedisClusterPortsRewrite { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/redis/sink_cluster.rs b/shotover/src/transforms/redis/sink_cluster.rs index 5e423dcf8..6f0cff019 100644 --- a/shotover/src/transforms/redis/sink_cluster.rs +++ b/shotover/src/transforms/redis/sink_cluster.rs @@ -8,7 +8,8 @@ use crate::transforms::redis::TransformError; use crate::transforms::util::cluster_connection_pool::{Authenticator, ConnectionPool}; use crate::transforms::util::{Request, Response}; use crate::transforms::{ - ResponseFuture, Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, + ResponseFuture, Transform, TransformBuilder, TransformConfig, TransformContextBuilder, + TransformContextConfig, Wrapper, }; use anyhow::{anyhow, bail, ensure, Context, Result}; use async_trait::async_trait; @@ -81,7 +82,7 @@ pub struct RedisSinkClusterBuilder { } impl TransformBuilder for RedisSinkClusterBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(RedisSinkCluster::new( self.first_contact_points.clone(), self.direct_destination.clone(), diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index 5329127bf..78af0788a 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -2,7 +2,9 @@ use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages}; use crate::tcp; use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, Wrapper, +}; use crate::{ codec::{ redis::{RedisCodecBuilder, RedisDecoder, RedisEncoder}, @@ -78,7 +80,7 @@ impl RedisSinkSingleBuilder { } impl TransformBuilder for RedisSinkSingleBuilder { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(RedisSinkSingle { address: self.address.clone(), tls: self.tls.clone(), diff --git a/shotover/src/transforms/redis/timestamp_tagging.rs b/shotover/src/transforms/redis/timestamp_tagging.rs index 2716c5591..a370c2f2a 100644 --- a/shotover/src/transforms/redis/timestamp_tagging.rs +++ b/shotover/src/transforms/redis/timestamp_tagging.rs @@ -2,7 +2,8 @@ use crate::frame::redis::redis_query_type; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages, QueryType}; use crate::transforms::{ - Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, + Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig, + Wrapper, }; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -39,7 +40,7 @@ impl RedisTimestampTagger { } impl TransformBuilder for RedisTimestampTagger { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) } diff --git a/shotover/src/transforms/tee.rs b/shotover/src/transforms/tee.rs index 4ebaedaaa..97fff9f85 100644 --- a/shotover/src/transforms/tee.rs +++ b/shotover/src/transforms/tee.rs @@ -1,4 +1,4 @@ -use super::TransformContextConfig; +use super::{TransformContextBuilder, TransformContextConfig}; use crate::config::chain::TransformChainConfig; use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; @@ -68,9 +68,11 @@ impl TeeBuilder { } impl TransformBuilder for TeeBuilder { - fn build(&self) -> Box { + fn build(&self, transform_context: TransformContextBuilder) -> Box { Box::new(Tee { - tx: self.tx.build_buffered(self.buffer_size), + tx: self + .tx + .build_buffered(self.buffer_size, transform_context.clone()), behavior: match &self.behavior { ConsistencyBehaviorBuilder::Ignore => ConsistencyBehavior::Ignore, ConsistencyBehaviorBuilder::LogWarningOnMismatch => { @@ -79,7 +81,7 @@ impl TransformBuilder for TeeBuilder { ConsistencyBehaviorBuilder::FailOnMismatch => ConsistencyBehavior::FailOnMismatch, ConsistencyBehaviorBuilder::SubchainOnMismatch(chain) => { ConsistencyBehavior::SubchainOnMismatch( - chain.build_buffered(self.buffer_size), + chain.build_buffered(self.buffer_size, transform_context), Default::default(), ) } diff --git a/shotover/src/transforms/throttling.rs b/shotover/src/transforms/throttling.rs index 1d0f4845a..99d3422f9 100644 --- a/shotover/src/transforms/throttling.rs +++ b/shotover/src/transforms/throttling.rs @@ -1,4 +1,4 @@ -use super::TransformContextConfig; +use super::{TransformContextBuilder, TransformContextConfig}; use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; @@ -46,7 +46,7 @@ pub struct RequestThrottling { } impl TransformBuilder for RequestThrottling { - fn build(&self) -> Box { + fn build(&self, _transform_context: TransformContextBuilder) -> Box { Box::new(self.clone()) }