diff --git a/shotover-proxy/src/codec/redis.rs b/shotover-proxy/src/codec/redis.rs index c80ee0735..9a9066edf 100644 --- a/shotover-proxy/src/codec/redis.rs +++ b/shotover-proxy/src/codec/redis.rs @@ -1,56 +1,44 @@ use crate::codec::{CodecBuilder, CodecReadError}; -use crate::frame::RedisFrame; use crate::frame::{Frame, MessageType}; -use crate::message::{Encodable, Message, Messages, QueryType}; +use crate::message::{Encodable, Message, Messages}; use anyhow::{anyhow, Result}; use bytes::{Buf, BytesMut}; use redis_protocol::resp2::prelude::decode_mut; use redis_protocol::resp2::prelude::encode_bytes; use tokio_util::codec::{Decoder, Encoder}; -impl CodecBuilder for RedisCodec { - type Decoder = RedisCodec; - type Encoder = RedisCodec; - fn build(&self) -> (RedisCodec, RedisCodec) { - (RedisCodec::new(), RedisCodec::new()) - } -} +#[derive(Default, Clone)] +pub struct RedisCodecBuilder {} -#[derive(Debug, Clone)] -pub struct RedisCodec { - messages: Messages, +impl CodecBuilder for RedisCodecBuilder { + type Decoder = RedisDecoder; + type Encoder = RedisEncoder; + fn build(&self) -> (RedisDecoder, RedisEncoder) { + (RedisDecoder::new(), RedisEncoder::new()) + } } -#[inline] -pub fn redis_query_type(frame: &RedisFrame) -> QueryType { - if let RedisFrame::Array(frames) = frame { - if let Some(RedisFrame::BulkString(bytes)) = frames.get(0) { - return match bytes.to_ascii_uppercase().as_slice() { - b"APPEND" | b"BITCOUNT" | b"STRLEN" | b"GET" | b"GETRANGE" | b"MGET" - | b"LRANGE" | b"LINDEX" | b"LLEN" | b"SCARD" | b"SISMEMBER" | b"SMEMBERS" - | b"SUNION" | b"SINTER" | b"ZCARD" | b"ZCOUNT" | b"ZRANGE" | b"ZRANK" - | b"ZSCORE" | b"ZRANGEBYSCORE" | b"HGET" | b"HGETALL" | b"HEXISTS" | b"HKEYS" - | b"HLEN" | b"HSTRLEN" | b"HVALS" | b"PFCOUNT" => QueryType::Read, - _ => QueryType::Write, - }; - } +impl RedisCodecBuilder { + pub fn new() -> RedisCodecBuilder { + RedisCodecBuilder::default() } - QueryType::Write } -impl Default for RedisCodec { - fn default() -> Self { - Self::new() - } +#[derive(Default)] +pub struct RedisEncoder {} + +#[derive(Default)] +pub struct RedisDecoder { + messages: Messages, } -impl RedisCodec { - pub fn new() -> RedisCodec { - RedisCodec { messages: vec![] } +impl RedisDecoder { + pub fn new() -> Self { + Default::default() } } -impl Decoder for RedisCodec { +impl Decoder for RedisDecoder { type Item = Messages; type Error = CodecReadError; @@ -79,7 +67,13 @@ impl Decoder for RedisCodec { } } -impl Encoder for RedisCodec { +impl RedisEncoder { + pub fn new() -> Self { + Default::default() + } +} + +impl Encoder for RedisEncoder { type Error = anyhow::Error; fn encode(&mut self, item: Messages, dst: &mut BytesMut) -> Result<()> { @@ -108,7 +102,7 @@ impl Encoder for RedisCodec { #[cfg(test)] mod redis_tests { - use crate::codec::redis::RedisCodec; + use crate::codec::{redis::RedisCodecBuilder, CodecBuilder}; use bytes::BytesMut; use hex_literal::hex; use tokio_util::codec::{Decoder, Encoder}; @@ -135,68 +129,60 @@ mod redis_tests { const HSET_MESSAGE: [u8; 75] = hex!("2a340d0a24340d0a485345540d0a2431380d0a6d797365743a5f5f72616e645f696e745f5f0d0a2432300d0a656c656d656e743a5f5f72616e645f696e745f5f0d0a24330d0a7878780d0a"); - fn test_frame(codec: &mut RedisCodec, raw_frame: &[u8]) { - let message = codec + fn test_frame(raw_frame: &[u8]) { + let (mut decoder, mut encoder) = RedisCodecBuilder::new().build(); + let message = decoder .decode(&mut BytesMut::from(raw_frame)) .unwrap() .unwrap(); let mut dest = BytesMut::new(); - codec.encode(message, &mut dest).unwrap(); + encoder.encode(message, &mut dest).unwrap(); assert_eq!(raw_frame, &dest); } #[test] fn test_ok_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &OK_MESSAGE); + test_frame(&OK_MESSAGE); } #[test] fn test_set_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &SET_MESSAGE); + test_frame(&SET_MESSAGE); } #[test] fn test_get_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &GET_MESSAGE); + test_frame(&GET_MESSAGE); } #[test] fn test_inc_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &INC_MESSAGE); + test_frame(&INC_MESSAGE); } #[test] fn test_lpush_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &LPUSH_MESSAGE); + test_frame(&LPUSH_MESSAGE); } #[test] fn test_rpush_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &RPUSH_MESSAGE); + test_frame(&RPUSH_MESSAGE); } #[test] fn test_lpop_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &LPOP_MESSAGE); + test_frame(&LPOP_MESSAGE); } #[test] fn test_sadd_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &SADD_MESSAGE); + test_frame(&SADD_MESSAGE); } #[test] fn test_hset_codec() { - let mut codec = RedisCodec::new(); - test_frame(&mut codec, &HSET_MESSAGE); + test_frame(&HSET_MESSAGE); } } diff --git a/shotover-proxy/src/frame/mod.rs b/shotover-proxy/src/frame/mod.rs index a1f57b9ba..44470527b 100644 --- a/shotover-proxy/src/frame/mod.rs +++ b/shotover-proxy/src/frame/mod.rs @@ -5,7 +5,9 @@ pub use cassandra::{CassandraFrame, CassandraOperation, CassandraResult}; use cassandra_protocol::compression::Compression; pub use redis_protocol::resp2::types::Frame as RedisFrame; use std::fmt::{Display, Formatter, Result as FmtResult}; + pub mod cassandra; +pub mod redis; #[derive(PartialEq, Debug, Clone, Copy)] pub enum MessageType { diff --git a/shotover-proxy/src/frame/redis.rs b/shotover-proxy/src/frame/redis.rs new file mode 100644 index 000000000..a75e234fa --- /dev/null +++ b/shotover-proxy/src/frame/redis.rs @@ -0,0 +1,19 @@ +use crate::frame::RedisFrame; +use crate::message::QueryType; + +#[inline] +pub fn redis_query_type(frame: &RedisFrame) -> QueryType { + if let RedisFrame::Array(frames) = frame { + if let Some(RedisFrame::BulkString(bytes)) = frames.get(0) { + return match bytes.to_ascii_uppercase().as_slice() { + b"APPEND" | b"BITCOUNT" | b"STRLEN" | b"GET" | b"GETRANGE" | b"MGET" + | b"LRANGE" | b"LINDEX" | b"LLEN" | b"SCARD" | b"SISMEMBER" | b"SMEMBERS" + | b"SUNION" | b"SINTER" | b"ZCARD" | b"ZCOUNT" | b"ZRANGE" | b"ZRANK" + | b"ZSCORE" | b"ZRANGEBYSCORE" | b"HGET" | b"HGETALL" | b"HEXISTS" | b"HKEYS" + | b"HLEN" | b"HSTRLEN" | b"HVALS" | b"PFCOUNT" => QueryType::Read, + _ => QueryType::Write, + }; + } + } + QueryType::Write +} diff --git a/shotover-proxy/src/message/mod.rs b/shotover-proxy/src/message/mod.rs index 269c5c93f..83fa4b027 100644 --- a/shotover-proxy/src/message/mod.rs +++ b/shotover-proxy/src/message/mod.rs @@ -1,6 +1,6 @@ -use crate::codec::redis::redis_query_type; use crate::codec::CodecState; use crate::frame::cassandra::Tracing; +use crate::frame::redis::redis_query_type; use crate::frame::{ cassandra, cassandra::{CassandraMetadata, CassandraOperation}, diff --git a/shotover-proxy/src/sources/redis_source.rs b/shotover-proxy/src/sources/redis_source.rs index e0147baf4..e98cff251 100644 --- a/shotover-proxy/src/sources/redis_source.rs +++ b/shotover-proxy/src/sources/redis_source.rs @@ -1,4 +1,4 @@ -use crate::codec::redis::RedisCodec; +use crate::codec::redis::RedisCodecBuilder; use crate::server::TcpCodecListener; use crate::sources::Sources; use crate::tls::{TlsAcceptor, TlsAcceptorConfig}; @@ -65,7 +65,7 @@ impl RedisSource { name.to_string(), listen_addr.clone(), hard_connection_limit.unwrap_or(false), - RedisCodec::new(), + RedisCodecBuilder::new(), Arc::new(Semaphore::new(connection_limit.unwrap_or(512))), trigger_shutdown_rx.clone(), tls.map(TlsAcceptor::new).transpose()?, diff --git a/shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs b/shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs index 4520562df..62f9f732f 100644 --- a/shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs @@ -207,7 +207,7 @@ fn is_cluster_slots(frame: &Frame) -> bool { #[cfg(test)] mod test { use super::*; - use crate::codec::redis::RedisCodec; + use crate::codec::redis::RedisDecoder; use crate::transforms::redis::sink_cluster::parse_slots; use tokio_util::codec::Decoder; @@ -270,7 +270,7 @@ mod test { #[test] fn test_rewrite_port_slots() { let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n"; - let mut codec = RedisCodec::new(); + let mut codec = RedisDecoder::new(); let mut message = codec .decode(&mut slots_pcap.into()) .unwrap() diff --git a/shotover-proxy/src/transforms/redis/sink_cluster.rs b/shotover-proxy/src/transforms/redis/sink_cluster.rs index 57793617e..5a74f63b2 100644 --- a/shotover-proxy/src/transforms/redis/sink_cluster.rs +++ b/shotover-proxy/src/transforms/redis/sink_cluster.rs @@ -1,4 +1,4 @@ -use crate::codec::redis::RedisCodec; +use crate::codec::redis::RedisCodecBuilder; use crate::error::ChainResponse; use crate::frame::{Frame, RedisFrame}; use crate::message::Message; @@ -77,7 +77,7 @@ pub struct RedisSinkCluster { load_scores: HashMap<(String, usize), usize>, rng: SmallRng, connection_count: usize, - connection_pool: ConnectionPool, + connection_pool: ConnectionPool, reason_for_no_nodes: Option<&'static str>, rebuild_connections: bool, first_contact_points: Vec, @@ -97,8 +97,12 @@ impl RedisSinkCluster { let authenticator = RedisAuthenticator {}; let connect_timeout = Duration::from_millis(connect_timeout_ms); - let connection_pool = - ConnectionPool::new_with_auth(connect_timeout, RedisCodec::new(), authenticator, tls)?; + let connection_pool = ConnectionPool::new_with_auth( + connect_timeout, + RedisCodecBuilder::new(), + authenticator, + tls, + )?; let sink_cluster = RedisSinkCluster { first_contact_points, @@ -1052,6 +1056,7 @@ impl Authenticator for RedisAuthenticator { #[cfg(test)] mod test { use super::*; + use crate::codec::redis::RedisDecoder; use tokio_util::codec::Decoder; #[test] @@ -1059,7 +1064,7 @@ mod test { // Wireshark capture from a Redis cluster with 3 masters and 3 replicas. let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n"; - let mut codec = RedisCodec::new(); + let mut codec = RedisDecoder::new(); let mut message = codec .decode(&mut slots_pcap.into()) diff --git a/shotover-proxy/src/transforms/redis/sink_single.rs b/shotover-proxy/src/transforms/redis/sink_single.rs index 0c7d609e1..67f855d96 100644 --- a/shotover-proxy/src/transforms/redis/sink_single.rs +++ b/shotover-proxy/src/transforms/redis/sink_single.rs @@ -1,4 +1,7 @@ -use crate::codec::redis::RedisCodec; +use crate::codec::redis::RedisCodecBuilder; +use crate::codec::redis::RedisDecoder; +use crate::codec::redis::RedisEncoder; +use crate::codec::CodecBuilder; use crate::codec::CodecReadError; use crate::error::ChainResponse; use crate::frame::Frame; @@ -9,15 +12,17 @@ use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig}; use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; -use futures::stream::{SplitSink, SplitStream}; use futures::{FutureExt, SinkExt, StreamExt}; use metrics::{register_counter, Counter}; use serde::Deserialize; use std::fmt::Debug; use std::pin::Pin; use std::time::Duration; +use tokio::io::ReadHalf; +use tokio::io::WriteHalf; use tokio::sync::mpsc; -use tokio_util::codec::Framed; +use tokio_util::codec::FramedRead; +use tokio_util::codec::FramedWrite; use tracing::Instrument; #[derive(Deserialize, Debug, Clone)] @@ -88,10 +93,10 @@ impl RedisSinkSingleBuilder { } } -type RedisFramed = Framed>, RedisCodec>; +type PinStream = Pin>; struct Connection { - outbound_tx: SplitSink, + outbound_tx: FramedWrite, RedisEncoder>, response_messages_rx: mpsc::UnboundedReceiver, sent_message_type_tx: mpsc::UnboundedSender, } @@ -130,7 +135,10 @@ impl Transform for RedisSinkSingle { Box::pin(tcp_stream) as Pin> }; - let (outbound_tx, outbound_rx) = Framed::new(generic_stream, RedisCodec::new()).split(); + let (decoder, encoder) = RedisCodecBuilder::new().build(); + let (stream_rx, stream_tx) = tokio::io::split(generic_stream); + let outbound_tx = FramedWrite::new(stream_tx, encoder); + let outbound_rx = FramedRead::new(stream_rx, decoder); let (response_messages_tx, response_messages_rx) = mpsc::unbounded_channel(); let (sent_message_type_tx, sent_message_type_rx) = mpsc::unbounded_channel(); @@ -211,7 +219,7 @@ impl Transform for RedisSinkSingle { /// /// The task will end silently if either the RedisSinkSingle transform is dropped or the server closes the connection. async fn server_response_processing_task( - mut outbound_rx: SplitStream, + mut outbound_rx: FramedRead, RedisDecoder>, subscribe_tx: Option>, response_messages_tx: mpsc::UnboundedSender, mut sent_message_type: mpsc::UnboundedReceiver, diff --git a/shotover-proxy/src/transforms/redis/timestamp_tagging.rs b/shotover-proxy/src/transforms/redis/timestamp_tagging.rs index 6af8dc3d3..4dcb92921 100644 --- a/shotover-proxy/src/transforms/redis/timestamp_tagging.rs +++ b/shotover-proxy/src/transforms/redis/timestamp_tagging.rs @@ -1,5 +1,5 @@ -use crate::codec::redis::redis_query_type; use crate::error::ChainResponse; +use crate::frame::redis::redis_query_type; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, QueryType}; use crate::transforms::{Transform, Wrapper}; diff --git a/shotover-proxy/src/transforms/util/cluster_connection_pool.rs b/shotover-proxy/src/transforms/util/cluster_connection_pool.rs index 5a4346e4d..c5c9ce932 100644 --- a/shotover-proxy/src/transforms/util/cluster_connection_pool.rs +++ b/shotover-proxy/src/transforms/util/cluster_connection_pool.rs @@ -287,7 +287,7 @@ async fn rx_process( #[cfg(test)] mod test { use super::spawn_read_write_tasks; - use crate::codec::redis::RedisCodec; + use crate::codec::redis::RedisCodecBuilder; use std::mem; use std::time::Duration; use tokio::io::AsyncReadExt; @@ -318,7 +318,7 @@ mod test { let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap(); let (rx, tx) = stream.into_split(); - let codec = RedisCodec::new(); + let codec = RedisCodecBuilder::new(); let sender = spawn_read_write_tasks(&codec, rx, tx); assert!(remote.await.unwrap()); @@ -358,7 +358,7 @@ mod test { let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap(); let (rx, tx) = stream.into_split(); - let codec = RedisCodec::new(); + let codec = RedisCodecBuilder::new(); // Drop sender immediately. std::mem::drop(spawn_read_write_tasks(&codec, rx, tx));