Skip to content

Commit

Permalink
Introduce TransformContextConfig (#1490)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Feb 23, 2024
1 parent 74adab8 commit a8ba522
Show file tree
Hide file tree
Showing 30 changed files with 272 additions and 83 deletions.
9 changes: 7 additions & 2 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, RedisFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use shotover::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
Expand All @@ -15,7 +17,10 @@ const NAME: &str = "RedisGetRewrite";
#[typetag::serde(name = "RedisGetRewrite")]
#[async_trait(?Send)]
impl TransformConfig for RedisGetRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(RedisGetRewriteBuilder {
result: self.result.clone(),
}))
Expand Down
14 changes: 10 additions & 4 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use criterion::{criterion_group, BatchSize, Criterion};
use hex_literal::hex;
use shotover::codec::CodecState;
use shotover::frame::cassandra::{parse_statement_single, Tracing};
use shotover::frame::RedisFrame;
use shotover::frame::{CassandraFrame, CassandraOperation, Frame};
use shotover::frame::{MessageType, RedisFrame};
use shotover::message::{Message, MessageIdMap, QueryType};
use shotover::transforms::cassandra::peers_rewrite::CassandraPeersRewrite;
use shotover::transforms::chain::{TransformChain, TransformChainBuilder};
Expand All @@ -19,7 +19,7 @@ use shotover::transforms::protect::{KeyManagerConfig, ProtectConfig};
use shotover::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite;
use shotover::transforms::redis::timestamp_tagging::RedisTimestampTagger;
use shotover::transforms::throttling::RequestThrottlingConfig;
use shotover::transforms::{TransformConfig, Wrapper};
use shotover::transforms::{TransformConfig, TransformContextConfig, Wrapper};

fn criterion_benchmark(c: &mut Criterion) {
crate::init();
Expand Down Expand Up @@ -194,7 +194,10 @@ fn criterion_benchmark(c: &mut Criterion) {
// an absurdly large value is given so that all messages will pass through
max_requests_per_second: std::num::NonZeroU32::new(100_000_000).unwrap(),
}
.get_builder("".to_owned()),
.get_builder(TransformContextConfig {
chain_name: "".into(),
protocol: MessageType::Redis,
}),
)
.unwrap(),
Box::<NullSink>::default(),
Expand Down Expand Up @@ -303,7 +306,10 @@ fn criterion_benchmark(c: &mut Criterion) {
kek_id: "".to_string(),
},
}
.get_builder("".to_owned()),
.get_builder(TransformContextConfig {
chain_name: "".into(),
protocol: MessageType::Redis,
}),
)
.unwrap(),
Box::<NullSink>::default(),
Expand Down
14 changes: 10 additions & 4 deletions shotover/src/config/chain.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::transforms::chain::TransformChainBuilder;
use crate::transforms::{TransformBuilder, TransformConfig};
use crate::transforms::{TransformBuilder, TransformConfig, TransformContextConfig};
use anyhow::Result;
use serde::de::{DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor};
use serde::{Deserialize, Serialize};
Expand All @@ -14,12 +14,18 @@ pub struct TransformChainConfig(
);

impl TransformChainConfig {
pub async fn get_builder(&self, name: String) -> Result<TransformChainBuilder> {
pub async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<TransformChainBuilder> {
let mut transforms: Vec<Box<dyn TransformBuilder>> = Vec::new();
for tc in &self.0 {
transforms.push(tc.get_builder(name.clone()).await?)
transforms.push(tc.get_builder(transform_context.clone()).await?)
}
Ok(TransformChainBuilder::new(transforms, name.leak()))
Ok(TransformChainBuilder::new(
transforms,
transform_context.chain_name.leak(),
))
}
}

Expand Down
8 changes: 6 additions & 2 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::message::{Message, Messages};
use crate::sources::Transport;
use crate::tls::{AcceptError, TlsAcceptor};
use crate::transforms::chain::{TransformChain, TransformChainBuilder};
use crate::transforms::Wrapper;
use crate::transforms::{TransformContextConfig, Wrapper};
use anyhow::{anyhow, Context, Result};
use bytes::BytesMut;
use futures::future::join_all;
Expand Down Expand Up @@ -92,8 +92,12 @@ impl<C: CodecBuilder + 'static> TcpCodecListener<C> {
gauge!("shotover_available_connections_count", "source" => source_name.clone());
available_connections_gauge.set(limit_connections.available_permits() as f64);

let chain_usage_config = TransformContextConfig {
chain_name: source_name.clone(),
protocol: codec.protocol(),
};
let chain_builder = chain_config
.get_builder(source_name.clone())
.get_builder(chain_usage_config)
.await
.map_err(|x| vec![format!("{x:?}")])?;

Expand Down
16 changes: 11 additions & 5 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::frame::{
value::{GenericValue, IntSize},
CassandraOperation, CassandraResult, Frame,
};
use crate::message::{Message, Messages};
use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::{
frame::{
value::{GenericValue, IntSize},
CassandraOperation, CassandraResult, Frame,
},
transforms::TransformContextConfig,
};
use anyhow::Result;
use async_trait::async_trait;
use cassandra_protocol::frame::events::{ServerEvent, StatusChange};
Expand All @@ -23,7 +26,10 @@ const NAME: &str = "CassandraPeersRewrite";
#[typetag::serde(name = "CassandraPeersRewrite")]
#[async_trait(?Send)]
impl TransformConfig for CassandraPeersRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(CassandraPeersRewrite::new(self.port)))
}
}
Expand Down
11 changes: 8 additions & 3 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame};
use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use cassandra_protocol::events::ServerEvent;
Expand Down Expand Up @@ -66,7 +68,10 @@ const NAME: &str = "CassandraSinkCluster";
#[typetag::serde(name = "CassandraSinkCluster")]
#[async_trait(?Send)]
impl TransformConfig for CassandraSinkClusterConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let tls = self.tls.clone().map(TlsConnector::new).transpose()?;
let mut shotover_nodes = self.shotover_nodes.clone();
let index = self
Expand All @@ -84,7 +89,7 @@ impl TransformConfig for CassandraSinkClusterConfig {
Ok(Box::new(CassandraSinkClusterBuilder::new(
self.first_contact_points.clone(),
shotover_nodes,
chain_name,
transform_context.chain_name,
local_node,
tls,
self.connect_timeout_ms,
Expand Down
11 changes: 8 additions & 3 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use crate::frame::cassandra::CassandraMetadata;
use crate::message::{Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::Response;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use cassandra_protocol::frame::Version;
Expand All @@ -29,11 +31,14 @@ const NAME: &str = "CassandraSinkSingle";
#[typetag::serde(name = "CassandraSinkSingle")]
#[async_trait(?Send)]
impl TransformConfig for CassandraSinkSingleConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let tls = self.tls.clone().map(TlsConnector::new).transpose()?;
Ok(Box::new(CassandraSinkSingleBuilder::new(
self.address.clone(),
chain_name,
transform_context.chain_name,
tls,
self.connect_timeout_ms,
self.read_timeout,
Expand Down
6 changes: 5 additions & 1 deletion shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::TransformContextConfig;
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
Expand All @@ -24,7 +25,10 @@ const NAME: &str = "Coalesce";
#[typetag::serde(name = "Coalesce")]
#[async_trait(?Send)]
impl TransformConfig for CoalesceConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(Coalesce {
buffer: Vec::with_capacity(self.flush_when_buffered_message_count.unwrap_or(0)),
flush_when_buffered_message_count: self.flush_when_buffered_message_count,
Expand Down
12 changes: 10 additions & 2 deletions shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use crate::message::Messages;
/// It could also be used to ensure that messages round trip correctly when parsed.
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformConfig;
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformContextConfig;
use crate::transforms::{Transform, TransformBuilder, Wrapper};
use anyhow::Result;
use async_trait::async_trait;
Expand All @@ -25,7 +27,10 @@ pub struct DebugForceParseConfig {
#[typetag::serde(name = "DebugForceParse")]
#[async_trait(?Send)]
impl TransformConfig for DebugForceParseConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugForceParse {
parse_requests: self.parse_requests,
parse_responses: self.parse_responses,
Expand All @@ -49,7 +54,10 @@ const NAME: &str = "DebugForceEncode";
#[typetag::serde(name = "DebugForceEncode")]
#[async_trait(?Send)]
impl TransformConfig for DebugForceEncodeConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugForceParse {
parse_requests: self.encode_requests,
parse_responses: self.encode_responses,
Expand Down
5 changes: 4 additions & 1 deletion shotover/src/transforms/debug/log_to_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ const NAME: &str = "DebugLogToFile";
#[typetag::serde(name = "DebugLogToFile")]
#[async_trait(?Send)]
impl crate::transforms::TransformConfig for DebugLogToFileConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: crate::transforms::TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
// This transform is used for debugging a specific run, so we clean out any logs left over from the previous run
std::fs::remove_dir_all("message-log").ok();

Expand Down
9 changes: 7 additions & 2 deletions shotover/src/transforms/debug/printer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
Expand All @@ -13,7 +15,10 @@ const NAME: &str = "DebugPrinter";
#[typetag::serde(name = "DebugPrinter")]
#[async_trait(?Send)]
impl TransformConfig for DebugPrinterConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugPrinter::new()))
}
}
Expand Down
9 changes: 7 additions & 2 deletions shotover/src/transforms/debug/returner.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::message::{Message, Messages};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
Expand All @@ -15,7 +17,10 @@ const NAME: &str = "DebugReturner";
#[typetag::serde(name = "DebugReturner")]
#[async_trait(?Send)]
impl TransformConfig for DebugReturnerConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugReturner::new(self.response.clone())))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::config::chain::TransformChainConfig;
use crate::frame::{Frame, RedisFrame};
use crate::message::{Message, Messages, QueryType};
use crate::transforms::chain::{BufferedChain, TransformChainBuilder};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::FuturesUnordered;
Expand All @@ -23,12 +25,19 @@ const NAME: &str = "TuneableConsistencyScatter";
#[typetag::serde(name = "TuneableConsistencyScatter")]
#[async_trait(?Send)]
impl TransformConfig for TuneableConsistencyScatterConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let mut route_map = Vec::with_capacity(self.route_map.len());
warn!("Using this transform is considered unstable - Does not work with REDIS pipelines");

for (key, value) in &self.route_map {
route_map.push(value.get_builder(key.clone()).await?);
let chain_config = TransformContextConfig {
chain_name: key.clone(),
protocol: transform_context.protocol,
};
route_map.push(value.get_builder(chain_config).await?);
}
route_map.sort_by_key(|x| x.name);

Expand Down
6 changes: 5 additions & 1 deletion shotover/src/transforms/filter.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::TransformContextConfig;
use crate::message::{Message, MessageIdMap, Messages, QueryType};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
Expand Down Expand Up @@ -28,7 +29,10 @@ const NAME: &str = "QueryTypeFilter";
#[typetag::serde(name = "QueryTypeFilter")]
#[async_trait(?Send)]
impl TransformConfig for QueryTypeFilterConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(QueryTypeFilter {
filter: self.filter.clone(),
filtered_requests: MessageIdMap::default(),
Expand Down
9 changes: 6 additions & 3 deletions shotover/src/transforms/kafka/sink_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::message::{Message, Messages};
use crate::tcp;
use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection};
use crate::transforms::util::{Request, Response};
use crate::transforms::TransformConfig;
use crate::transforms::{Transform, TransformBuilder, Wrapper};
use crate::transforms::{TransformConfig, TransformContextConfig};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use dashmap::DashMap;
Expand Down Expand Up @@ -45,11 +45,14 @@ const NAME: &str = "KafkaSinkCluster";
#[typetag::serde(name = "KafkaSinkCluster")]
#[async_trait(?Send)]
impl TransformConfig for KafkaSinkClusterConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(KafkaSinkClusterBuilder::new(
self.first_contact_points.clone(),
self.shotover_nodes.clone(),
chain_name,
transform_context.chain_name,
self.connect_timeout_ms,
self.read_timeout,
)))
Expand Down
Loading

0 comments on commit a8ba522

Please sign in to comment.