Skip to content

Commit

Permalink
Split redis codec (#1051)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Feb 21, 2023
1 parent d3b0300 commit 4e914ee
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 78 deletions.
100 changes: 43 additions & 57 deletions shotover-proxy/src/codec/redis.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,44 @@
use crate::codec::{CodecBuilder, CodecReadError};
use crate::frame::RedisFrame;
use crate::frame::{Frame, MessageType};
use crate::message::{Encodable, Message, Messages, QueryType};
use crate::message::{Encodable, Message, Messages};
use anyhow::{anyhow, Result};
use bytes::{Buf, BytesMut};
use redis_protocol::resp2::prelude::decode_mut;
use redis_protocol::resp2::prelude::encode_bytes;
use tokio_util::codec::{Decoder, Encoder};

impl CodecBuilder for RedisCodec {
type Decoder = RedisCodec;
type Encoder = RedisCodec;
fn build(&self) -> (RedisCodec, RedisCodec) {
(RedisCodec::new(), RedisCodec::new())
}
}
#[derive(Default, Clone)]
pub struct RedisCodecBuilder {}

#[derive(Debug, Clone)]
pub struct RedisCodec {
messages: Messages,
impl CodecBuilder for RedisCodecBuilder {
type Decoder = RedisDecoder;
type Encoder = RedisEncoder;
fn build(&self) -> (RedisDecoder, RedisEncoder) {
(RedisDecoder::new(), RedisEncoder::new())
}
}

#[inline]
pub fn redis_query_type(frame: &RedisFrame) -> QueryType {
if let RedisFrame::Array(frames) = frame {
if let Some(RedisFrame::BulkString(bytes)) = frames.get(0) {
return match bytes.to_ascii_uppercase().as_slice() {
b"APPEND" | b"BITCOUNT" | b"STRLEN" | b"GET" | b"GETRANGE" | b"MGET"
| b"LRANGE" | b"LINDEX" | b"LLEN" | b"SCARD" | b"SISMEMBER" | b"SMEMBERS"
| b"SUNION" | b"SINTER" | b"ZCARD" | b"ZCOUNT" | b"ZRANGE" | b"ZRANK"
| b"ZSCORE" | b"ZRANGEBYSCORE" | b"HGET" | b"HGETALL" | b"HEXISTS" | b"HKEYS"
| b"HLEN" | b"HSTRLEN" | b"HVALS" | b"PFCOUNT" => QueryType::Read,
_ => QueryType::Write,
};
}
impl RedisCodecBuilder {
pub fn new() -> RedisCodecBuilder {
RedisCodecBuilder::default()
}
QueryType::Write
}

impl Default for RedisCodec {
fn default() -> Self {
Self::new()
}
#[derive(Default)]
pub struct RedisEncoder {}

#[derive(Default)]
pub struct RedisDecoder {
messages: Messages,
}

impl RedisCodec {
pub fn new() -> RedisCodec {
RedisCodec { messages: vec![] }
impl RedisDecoder {
pub fn new() -> Self {
Default::default()
}
}

impl Decoder for RedisCodec {
impl Decoder for RedisDecoder {
type Item = Messages;
type Error = CodecReadError;

Expand Down Expand Up @@ -79,7 +67,13 @@ impl Decoder for RedisCodec {
}
}

impl Encoder<Messages> for RedisCodec {
impl RedisEncoder {
pub fn new() -> Self {
Default::default()
}
}

impl Encoder<Messages> for RedisEncoder {
type Error = anyhow::Error;

fn encode(&mut self, item: Messages, dst: &mut BytesMut) -> Result<()> {
Expand Down Expand Up @@ -108,7 +102,7 @@ impl Encoder<Messages> for RedisCodec {

#[cfg(test)]
mod redis_tests {
use crate::codec::redis::RedisCodec;
use crate::codec::{redis::RedisCodecBuilder, CodecBuilder};
use bytes::BytesMut;
use hex_literal::hex;
use tokio_util::codec::{Decoder, Encoder};
Expand All @@ -135,68 +129,60 @@ mod redis_tests {

const HSET_MESSAGE: [u8; 75] = hex!("2a340d0a24340d0a485345540d0a2431380d0a6d797365743a5f5f72616e645f696e745f5f0d0a2432300d0a656c656d656e743a5f5f72616e645f696e745f5f0d0a24330d0a7878780d0a");

fn test_frame(codec: &mut RedisCodec, raw_frame: &[u8]) {
let message = codec
fn test_frame(raw_frame: &[u8]) {
let (mut decoder, mut encoder) = RedisCodecBuilder::new().build();
let message = decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();

let mut dest = BytesMut::new();
codec.encode(message, &mut dest).unwrap();
encoder.encode(message, &mut dest).unwrap();
assert_eq!(raw_frame, &dest);
}

#[test]
fn test_ok_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &OK_MESSAGE);
test_frame(&OK_MESSAGE);
}

#[test]
fn test_set_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &SET_MESSAGE);
test_frame(&SET_MESSAGE);
}

#[test]
fn test_get_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &GET_MESSAGE);
test_frame(&GET_MESSAGE);
}

#[test]
fn test_inc_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &INC_MESSAGE);
test_frame(&INC_MESSAGE);
}

#[test]
fn test_lpush_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &LPUSH_MESSAGE);
test_frame(&LPUSH_MESSAGE);
}

#[test]
fn test_rpush_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &RPUSH_MESSAGE);
test_frame(&RPUSH_MESSAGE);
}

#[test]
fn test_lpop_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &LPOP_MESSAGE);
test_frame(&LPOP_MESSAGE);
}

#[test]
fn test_sadd_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &SADD_MESSAGE);
test_frame(&SADD_MESSAGE);
}

#[test]
fn test_hset_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &HSET_MESSAGE);
test_frame(&HSET_MESSAGE);
}
}
2 changes: 2 additions & 0 deletions shotover-proxy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ pub use cassandra::{CassandraFrame, CassandraOperation, CassandraResult};
use cassandra_protocol::compression::Compression;
pub use redis_protocol::resp2::types::Frame as RedisFrame;
use std::fmt::{Display, Formatter, Result as FmtResult};

pub mod cassandra;
pub mod redis;

#[derive(PartialEq, Debug, Clone, Copy)]
pub enum MessageType {
Expand Down
19 changes: 19 additions & 0 deletions shotover-proxy/src/frame/redis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::frame::RedisFrame;
use crate::message::QueryType;

#[inline]
pub fn redis_query_type(frame: &RedisFrame) -> QueryType {
if let RedisFrame::Array(frames) = frame {
if let Some(RedisFrame::BulkString(bytes)) = frames.get(0) {
return match bytes.to_ascii_uppercase().as_slice() {
b"APPEND" | b"BITCOUNT" | b"STRLEN" | b"GET" | b"GETRANGE" | b"MGET"
| b"LRANGE" | b"LINDEX" | b"LLEN" | b"SCARD" | b"SISMEMBER" | b"SMEMBERS"
| b"SUNION" | b"SINTER" | b"ZCARD" | b"ZCOUNT" | b"ZRANGE" | b"ZRANK"
| b"ZSCORE" | b"ZRANGEBYSCORE" | b"HGET" | b"HGETALL" | b"HEXISTS" | b"HKEYS"
| b"HLEN" | b"HSTRLEN" | b"HVALS" | b"PFCOUNT" => QueryType::Read,
_ => QueryType::Write,
};
}
}
QueryType::Write
}
2 changes: 1 addition & 1 deletion shotover-proxy/src/message/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::codec::redis::redis_query_type;
use crate::codec::CodecState;
use crate::frame::cassandra::Tracing;
use crate::frame::redis::redis_query_type;
use crate::frame::{
cassandra,
cassandra::{CassandraMetadata, CassandraOperation},
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/src/sources/redis_source.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisCodecBuilder;
use crate::server::TcpCodecListener;
use crate::sources::Sources;
use crate::tls::{TlsAcceptor, TlsAcceptorConfig};
Expand Down Expand Up @@ -65,7 +65,7 @@ impl RedisSource {
name.to_string(),
listen_addr.clone(),
hard_connection_limit.unwrap_or(false),
RedisCodec::new(),
RedisCodecBuilder::new(),
Arc::new(Semaphore::new(connection_limit.unwrap_or(512))),
trigger_shutdown_rx.clone(),
tls.map(TlsAcceptor::new).transpose()?,
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ fn is_cluster_slots(frame: &Frame) -> bool {
#[cfg(test)]
mod test {
use super::*;
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisDecoder;
use crate::transforms::redis::sink_cluster::parse_slots;
use tokio_util::codec::Decoder;

Expand Down Expand Up @@ -270,7 +270,7 @@ mod test {
#[test]
fn test_rewrite_port_slots() {
let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n";
let mut codec = RedisCodec::new();
let mut codec = RedisDecoder::new();
let mut message = codec
.decode(&mut slots_pcap.into())
.unwrap()
Expand Down
15 changes: 10 additions & 5 deletions shotover-proxy/src/transforms/redis/sink_cluster.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisCodecBuilder;
use crate::error::ChainResponse;
use crate::frame::{Frame, RedisFrame};
use crate::message::Message;
Expand Down Expand Up @@ -77,7 +77,7 @@ pub struct RedisSinkCluster {
load_scores: HashMap<(String, usize), usize>,
rng: SmallRng,
connection_count: usize,
connection_pool: ConnectionPool<RedisCodec, RedisAuthenticator, UsernamePasswordToken>,
connection_pool: ConnectionPool<RedisCodecBuilder, RedisAuthenticator, UsernamePasswordToken>,
reason_for_no_nodes: Option<&'static str>,
rebuild_connections: bool,
first_contact_points: Vec<String>,
Expand All @@ -97,8 +97,12 @@ impl RedisSinkCluster {
let authenticator = RedisAuthenticator {};

let connect_timeout = Duration::from_millis(connect_timeout_ms);
let connection_pool =
ConnectionPool::new_with_auth(connect_timeout, RedisCodec::new(), authenticator, tls)?;
let connection_pool = ConnectionPool::new_with_auth(
connect_timeout,
RedisCodecBuilder::new(),
authenticator,
tls,
)?;

let sink_cluster = RedisSinkCluster {
first_contact_points,
Expand Down Expand Up @@ -1052,14 +1056,15 @@ impl Authenticator<UsernamePasswordToken> for RedisAuthenticator {
#[cfg(test)]
mod test {
use super::*;
use crate::codec::redis::RedisDecoder;
use tokio_util::codec::Decoder;

#[test]
fn test_parse_slots() {
// Wireshark capture from a Redis cluster with 3 masters and 3 replicas.
let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n";

let mut codec = RedisCodec::new();
let mut codec = RedisDecoder::new();

let mut message = codec
.decode(&mut slots_pcap.into())
Expand Down
22 changes: 15 additions & 7 deletions shotover-proxy/src/transforms/redis/sink_single.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisCodecBuilder;
use crate::codec::redis::RedisDecoder;
use crate::codec::redis::RedisEncoder;
use crate::codec::CodecBuilder;
use crate::codec::CodecReadError;
use crate::error::ChainResponse;
use crate::frame::Frame;
Expand All @@ -9,15 +12,17 @@ use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig};
use crate::transforms::{Transform, TransformBuilder, Wrapper};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use futures::stream::{SplitSink, SplitStream};
use futures::{FutureExt, SinkExt, StreamExt};
use metrics::{register_counter, Counter};
use serde::Deserialize;
use std::fmt::Debug;
use std::pin::Pin;
use std::time::Duration;
use tokio::io::ReadHalf;
use tokio::io::WriteHalf;
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
use tokio_util::codec::FramedRead;
use tokio_util::codec::FramedWrite;
use tracing::Instrument;

#[derive(Deserialize, Debug, Clone)]
Expand Down Expand Up @@ -88,10 +93,10 @@ impl RedisSinkSingleBuilder {
}
}

type RedisFramed = Framed<Pin<Box<dyn AsyncStream + Send + Sync>>, RedisCodec>;
type PinStream = Pin<Box<dyn AsyncStream + Send + Sync>>;

struct Connection {
outbound_tx: SplitSink<RedisFramed, Messages>,
outbound_tx: FramedWrite<WriteHalf<PinStream>, RedisEncoder>,
response_messages_rx: mpsc::UnboundedReceiver<Message>,
sent_message_type_tx: mpsc::UnboundedSender<MessageType>,
}
Expand Down Expand Up @@ -130,7 +135,10 @@ impl Transform for RedisSinkSingle {
Box::pin(tcp_stream) as Pin<Box<dyn AsyncStream + Send + Sync>>
};

let (outbound_tx, outbound_rx) = Framed::new(generic_stream, RedisCodec::new()).split();
let (decoder, encoder) = RedisCodecBuilder::new().build();
let (stream_rx, stream_tx) = tokio::io::split(generic_stream);
let outbound_tx = FramedWrite::new(stream_tx, encoder);
let outbound_rx = FramedRead::new(stream_rx, decoder);
let (response_messages_tx, response_messages_rx) = mpsc::unbounded_channel();
let (sent_message_type_tx, sent_message_type_rx) = mpsc::unbounded_channel();

Expand Down Expand Up @@ -211,7 +219,7 @@ impl Transform for RedisSinkSingle {
///
/// The task will end silently if either the RedisSinkSingle transform is dropped or the server closes the connection.
async fn server_response_processing_task(
mut outbound_rx: SplitStream<RedisFramed>,
mut outbound_rx: FramedRead<ReadHalf<PinStream>, RedisDecoder>,
subscribe_tx: Option<mpsc::UnboundedSender<Messages>>,
response_messages_tx: mpsc::UnboundedSender<Message>,
mut sent_message_type: mpsc::UnboundedReceiver<MessageType>,
Expand Down
2 changes: 1 addition & 1 deletion shotover-proxy/src/transforms/redis/timestamp_tagging.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::codec::redis::redis_query_type;
use crate::error::ChainResponse;
use crate::frame::redis::redis_query_type;
use crate::frame::{Frame, RedisFrame};
use crate::message::{Message, QueryType};
use crate::transforms::{Transform, Wrapper};
Expand Down
6 changes: 3 additions & 3 deletions shotover-proxy/src/transforms/util/cluster_connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async fn rx_process<C: DecoderHalf, R: AsyncRead + Unpin + Send + 'static>(
#[cfg(test)]
mod test {
use super::spawn_read_write_tasks;
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisCodecBuilder;
use std::mem;
use std::time::Duration;
use tokio::io::AsyncReadExt;
Expand Down Expand Up @@ -318,7 +318,7 @@ mod test {

let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let (rx, tx) = stream.into_split();
let codec = RedisCodec::new();
let codec = RedisCodecBuilder::new();
let sender = spawn_read_write_tasks(&codec, rx, tx);

assert!(remote.await.unwrap());
Expand Down Expand Up @@ -358,7 +358,7 @@ mod test {

let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let (rx, tx) = stream.into_split();
let codec = RedisCodec::new();
let codec = RedisCodecBuilder::new();

// Drop sender immediately.
std::mem::drop(spawn_read_write_tasks(&codec, rx, tx));
Expand Down

0 comments on commit 4e914ee

Please sign in to comment.