Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce TransformContextConfig #1490

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, RedisFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use shotover::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
Expand All @@ -15,7 +17,10 @@ const NAME: &str = "RedisGetRewrite";
#[typetag::serde(name = "RedisGetRewrite")]
#[async_trait(?Send)]
impl TransformConfig for RedisGetRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(RedisGetRewriteBuilder {
result: self.result.clone(),
}))
Expand Down
14 changes: 10 additions & 4 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use criterion::{criterion_group, BatchSize, Criterion};
use hex_literal::hex;
use shotover::codec::CodecState;
use shotover::frame::cassandra::{parse_statement_single, Tracing};
use shotover::frame::RedisFrame;
use shotover::frame::{CassandraFrame, CassandraOperation, Frame};
use shotover::frame::{MessageType, RedisFrame};
use shotover::message::{Message, MessageIdMap, QueryType};
use shotover::transforms::cassandra::peers_rewrite::CassandraPeersRewrite;
use shotover::transforms::chain::{TransformChain, TransformChainBuilder};
Expand All @@ -19,7 +19,7 @@ use shotover::transforms::protect::{KeyManagerConfig, ProtectConfig};
use shotover::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite;
use shotover::transforms::redis::timestamp_tagging::RedisTimestampTagger;
use shotover::transforms::throttling::RequestThrottlingConfig;
use shotover::transforms::{TransformConfig, Wrapper};
use shotover::transforms::{TransformConfig, TransformContextConfig, Wrapper};

fn criterion_benchmark(c: &mut Criterion) {
crate::init();
Expand Down Expand Up @@ -194,7 +194,10 @@ fn criterion_benchmark(c: &mut Criterion) {
// an absurdly large value is given so that all messages will pass through
max_requests_per_second: std::num::NonZeroU32::new(100_000_000).unwrap(),
}
.get_builder("".to_owned()),
.get_builder(TransformContextConfig {
chain_name: "".into(),
protocol: MessageType::Redis,
}),
)
.unwrap(),
Box::<NullSink>::default(),
Expand Down Expand Up @@ -303,7 +306,10 @@ fn criterion_benchmark(c: &mut Criterion) {
kek_id: "".to_string(),
},
}
.get_builder("".to_owned()),
.get_builder(TransformContextConfig {
chain_name: "".into(),
protocol: MessageType::Redis,
}),
)
.unwrap(),
Box::<NullSink>::default(),
Expand Down
14 changes: 10 additions & 4 deletions shotover/src/config/chain.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::transforms::chain::TransformChainBuilder;
use crate::transforms::{TransformBuilder, TransformConfig};
use crate::transforms::{TransformBuilder, TransformConfig, TransformContextConfig};
use anyhow::Result;
use serde::de::{DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor};
use serde::{Deserialize, Serialize};
Expand All @@ -14,12 +14,18 @@ pub struct TransformChainConfig(
);

impl TransformChainConfig {
pub async fn get_builder(&self, name: String) -> Result<TransformChainBuilder> {
pub async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<TransformChainBuilder> {
let mut transforms: Vec<Box<dyn TransformBuilder>> = Vec::new();
for tc in &self.0 {
transforms.push(tc.get_builder(name.clone()).await?)
transforms.push(tc.get_builder(transform_context.clone()).await?)
}
Ok(TransformChainBuilder::new(transforms, name.leak()))
Ok(TransformChainBuilder::new(
transforms,
transform_context.chain_name.leak(),
))
}
}

Expand Down
8 changes: 6 additions & 2 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::message::{Message, Messages};
use crate::sources::Transport;
use crate::tls::{AcceptError, TlsAcceptor};
use crate::transforms::chain::{TransformChain, TransformChainBuilder};
use crate::transforms::Wrapper;
use crate::transforms::{TransformContextConfig, Wrapper};
use anyhow::{anyhow, Context, Result};
use bytes::BytesMut;
use futures::future::join_all;
Expand Down Expand Up @@ -92,8 +92,12 @@ impl<C: CodecBuilder + 'static> TcpCodecListener<C> {
gauge!("shotover_available_connections_count", "source" => source_name.clone());
available_connections_gauge.set(limit_connections.available_permits() as f64);

let chain_usage_config = TransformContextConfig {
chain_name: source_name.clone(),
protocol: codec.protocol(),
};
let chain_builder = chain_config
.get_builder(source_name.clone())
.get_builder(chain_usage_config)
.await
.map_err(|x| vec![format!("{x:?}")])?;

Expand Down
16 changes: 11 additions & 5 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::frame::{
value::{GenericValue, IntSize},
CassandraOperation, CassandraResult, Frame,
};
use crate::message::{Message, Messages};
use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::{
frame::{
value::{GenericValue, IntSize},
CassandraOperation, CassandraResult, Frame,
},
transforms::TransformContextConfig,
};
use anyhow::Result;
use async_trait::async_trait;
use cassandra_protocol::frame::events::{ServerEvent, StatusChange};
Expand All @@ -23,7 +26,10 @@ const NAME: &str = "CassandraPeersRewrite";
#[typetag::serde(name = "CassandraPeersRewrite")]
#[async_trait(?Send)]
impl TransformConfig for CassandraPeersRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(CassandraPeersRewrite::new(self.port)))
}
}
Expand Down
11 changes: 8 additions & 3 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame};
use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use cassandra_protocol::events::ServerEvent;
Expand Down Expand Up @@ -66,7 +68,10 @@ const NAME: &str = "CassandraSinkCluster";
#[typetag::serde(name = "CassandraSinkCluster")]
#[async_trait(?Send)]
impl TransformConfig for CassandraSinkClusterConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let tls = self.tls.clone().map(TlsConnector::new).transpose()?;
let mut shotover_nodes = self.shotover_nodes.clone();
let index = self
Expand All @@ -84,7 +89,7 @@ impl TransformConfig for CassandraSinkClusterConfig {
Ok(Box::new(CassandraSinkClusterBuilder::new(
self.first_contact_points.clone(),
shotover_nodes,
chain_name,
transform_context.chain_name,
local_node,
tls,
self.connect_timeout_ms,
Expand Down
11 changes: 8 additions & 3 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use crate::frame::cassandra::CassandraMetadata;
use crate::message::{Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::Response;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use cassandra_protocol::frame::Version;
Expand All @@ -29,11 +31,14 @@ const NAME: &str = "CassandraSinkSingle";
#[typetag::serde(name = "CassandraSinkSingle")]
#[async_trait(?Send)]
impl TransformConfig for CassandraSinkSingleConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let tls = self.tls.clone().map(TlsConnector::new).transpose()?;
Ok(Box::new(CassandraSinkSingleBuilder::new(
self.address.clone(),
chain_name,
transform_context.chain_name,
tls,
self.connect_timeout_ms,
self.read_timeout,
Expand Down
6 changes: 5 additions & 1 deletion shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::TransformContextConfig;
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
Expand All @@ -24,7 +25,10 @@ const NAME: &str = "Coalesce";
#[typetag::serde(name = "Coalesce")]
#[async_trait(?Send)]
impl TransformConfig for CoalesceConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(Coalesce {
buffer: Vec::with_capacity(self.flush_when_buffered_message_count.unwrap_or(0)),
flush_when_buffered_message_count: self.flush_when_buffered_message_count,
Expand Down
12 changes: 10 additions & 2 deletions shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use crate::message::Messages;
/// It could also be used to ensure that messages round trip correctly when parsed.
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformConfig;
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformContextConfig;
use crate::transforms::{Transform, TransformBuilder, Wrapper};
use anyhow::Result;
use async_trait::async_trait;
Expand All @@ -25,7 +27,10 @@ pub struct DebugForceParseConfig {
#[typetag::serde(name = "DebugForceParse")]
#[async_trait(?Send)]
impl TransformConfig for DebugForceParseConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugForceParse {
parse_requests: self.parse_requests,
parse_responses: self.parse_responses,
Expand All @@ -49,7 +54,10 @@ const NAME: &str = "DebugForceEncode";
#[typetag::serde(name = "DebugForceEncode")]
#[async_trait(?Send)]
impl TransformConfig for DebugForceEncodeConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugForceParse {
parse_requests: self.encode_requests,
parse_responses: self.encode_responses,
Expand Down
5 changes: 4 additions & 1 deletion shotover/src/transforms/debug/log_to_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ const NAME: &str = "DebugLogToFile";
#[typetag::serde(name = "DebugLogToFile")]
#[async_trait(?Send)]
impl crate::transforms::TransformConfig for DebugLogToFileConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: crate::transforms::TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
// This transform is used for debugging a specific run, so we clean out any logs left over from the previous run
std::fs::remove_dir_all("message-log").ok();

Expand Down
9 changes: 7 additions & 2 deletions shotover/src/transforms/debug/printer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
Expand All @@ -13,7 +15,10 @@ const NAME: &str = "DebugPrinter";
#[typetag::serde(name = "DebugPrinter")]
#[async_trait(?Send)]
impl TransformConfig for DebugPrinterConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugPrinter::new()))
}
}
Expand Down
9 changes: 7 additions & 2 deletions shotover/src/transforms/debug/returner.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::message::{Message, Messages};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
Expand All @@ -15,7 +17,10 @@ const NAME: &str = "DebugReturner";
#[typetag::serde(name = "DebugReturner")]
#[async_trait(?Send)]
impl TransformConfig for DebugReturnerConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugReturner::new(self.response.clone())))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::config::chain::TransformChainConfig;
use crate::frame::{Frame, RedisFrame};
use crate::message::{Message, Messages, QueryType};
use crate::transforms::chain::{BufferedChain, TransformChainBuilder};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::FuturesUnordered;
Expand All @@ -23,12 +25,19 @@ const NAME: &str = "TuneableConsistencyScatter";
#[typetag::serde(name = "TuneableConsistencyScatter")]
#[async_trait(?Send)]
impl TransformConfig for TuneableConsistencyScatterConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let mut route_map = Vec::with_capacity(self.route_map.len());
warn!("Using this transform is considered unstable - Does not work with REDIS pipelines");

for (key, value) in &self.route_map {
route_map.push(value.get_builder(key.clone()).await?);
let chain_config = TransformContextConfig {
chain_name: key.clone(),
protocol: transform_context.protocol,
};
route_map.push(value.get_builder(chain_config).await?);
}
route_map.sort_by_key(|x| x.name);

Expand Down
6 changes: 5 additions & 1 deletion shotover/src/transforms/filter.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::TransformContextConfig;
use crate::message::{Message, MessageIdMap, Messages, QueryType};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
Expand Down Expand Up @@ -28,7 +29,10 @@ const NAME: &str = "QueryTypeFilter";
#[typetag::serde(name = "QueryTypeFilter")]
#[async_trait(?Send)]
impl TransformConfig for QueryTypeFilterConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(QueryTypeFilter {
filter: self.filter.clone(),
filtered_requests: MessageIdMap::default(),
Expand Down
9 changes: 6 additions & 3 deletions shotover/src/transforms/kafka/sink_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::message::{Message, Messages};
use crate::tcp;
use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection};
use crate::transforms::util::{Request, Response};
use crate::transforms::TransformConfig;
use crate::transforms::{Transform, TransformBuilder, Wrapper};
use crate::transforms::{TransformConfig, TransformContextConfig};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use dashmap::DashMap;
Expand Down Expand Up @@ -45,11 +45,14 @@ const NAME: &str = "KafkaSinkCluster";
#[typetag::serde(name = "KafkaSinkCluster")]
#[async_trait(?Send)]
impl TransformConfig for KafkaSinkClusterConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(KafkaSinkClusterBuilder::new(
self.first_contact_points.clone(),
self.shotover_nodes.clone(),
chain_name,
transform_context.chain_name,
self.connect_timeout_ms,
self.read_timeout,
)))
Expand Down
Loading
Loading