Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split redis codec #1051

Merged
merged 4 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 43 additions & 57 deletions shotover-proxy/src/codec/redis.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,44 @@
use crate::codec::{CodecBuilder, CodecReadError};
use crate::frame::RedisFrame;
use crate::frame::{Frame, MessageType};
use crate::message::{Encodable, Message, Messages, QueryType};
use crate::message::{Encodable, Message, Messages};
use anyhow::{anyhow, Result};
use bytes::{Buf, BytesMut};
use redis_protocol::resp2::prelude::decode_mut;
use redis_protocol::resp2::prelude::encode_bytes;
use tokio_util::codec::{Decoder, Encoder};

impl CodecBuilder for RedisCodec {
type Decoder = RedisCodec;
type Encoder = RedisCodec;
fn build(&self) -> (RedisCodec, RedisCodec) {
(RedisCodec::new(), RedisCodec::new())
}
}
#[derive(Default, Clone)]
pub struct RedisCodecBuilder {}

#[derive(Debug, Clone)]
pub struct RedisCodec {
messages: Messages,
impl CodecBuilder for RedisCodecBuilder {
type Decoder = RedisDecoder;
type Encoder = RedisEncoder;
fn build(&self) -> (RedisDecoder, RedisEncoder) {
(RedisDecoder::new(), RedisEncoder::new())
}
}

#[inline]
pub fn redis_query_type(frame: &RedisFrame) -> QueryType {
if let RedisFrame::Array(frames) = frame {
if let Some(RedisFrame::BulkString(bytes)) = frames.get(0) {
return match bytes.to_ascii_uppercase().as_slice() {
b"APPEND" | b"BITCOUNT" | b"STRLEN" | b"GET" | b"GETRANGE" | b"MGET"
| b"LRANGE" | b"LINDEX" | b"LLEN" | b"SCARD" | b"SISMEMBER" | b"SMEMBERS"
| b"SUNION" | b"SINTER" | b"ZCARD" | b"ZCOUNT" | b"ZRANGE" | b"ZRANK"
| b"ZSCORE" | b"ZRANGEBYSCORE" | b"HGET" | b"HGETALL" | b"HEXISTS" | b"HKEYS"
| b"HLEN" | b"HSTRLEN" | b"HVALS" | b"PFCOUNT" => QueryType::Read,
_ => QueryType::Write,
};
}
impl RedisCodecBuilder {
pub fn new() -> RedisCodecBuilder {
RedisCodecBuilder::default()
}
QueryType::Write
}

impl Default for RedisCodec {
fn default() -> Self {
Self::new()
}
#[derive(Default)]
pub struct RedisEncoder {}

#[derive(Default)]
pub struct RedisDecoder {
messages: Messages,
}

impl RedisCodec {
pub fn new() -> RedisCodec {
RedisCodec { messages: vec![] }
impl RedisDecoder {
pub fn new() -> Self {
Default::default()
}
}

impl Decoder for RedisCodec {
impl Decoder for RedisDecoder {
type Item = Messages;
type Error = CodecReadError;

Expand Down Expand Up @@ -79,7 +67,13 @@ impl Decoder for RedisCodec {
}
}

impl Encoder<Messages> for RedisCodec {
impl RedisEncoder {
pub fn new() -> Self {
Default::default()
}
}

impl Encoder<Messages> for RedisEncoder {
type Error = anyhow::Error;

fn encode(&mut self, item: Messages, dst: &mut BytesMut) -> Result<()> {
Expand Down Expand Up @@ -108,7 +102,7 @@ impl Encoder<Messages> for RedisCodec {

#[cfg(test)]
mod redis_tests {
use crate::codec::redis::RedisCodec;
use crate::codec::{redis::RedisCodecBuilder, CodecBuilder};
use bytes::BytesMut;
use hex_literal::hex;
use tokio_util::codec::{Decoder, Encoder};
Expand All @@ -135,68 +129,60 @@ mod redis_tests {

const HSET_MESSAGE: [u8; 75] = hex!("2a340d0a24340d0a485345540d0a2431380d0a6d797365743a5f5f72616e645f696e745f5f0d0a2432300d0a656c656d656e743a5f5f72616e645f696e745f5f0d0a24330d0a7878780d0a");

fn test_frame(codec: &mut RedisCodec, raw_frame: &[u8]) {
let message = codec
fn test_frame(raw_frame: &[u8]) {
let (mut decoder, mut encoder) = RedisCodecBuilder::new().build();
let message = decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();

let mut dest = BytesMut::new();
codec.encode(message, &mut dest).unwrap();
encoder.encode(message, &mut dest).unwrap();
assert_eq!(raw_frame, &dest);
}

#[test]
fn test_ok_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &OK_MESSAGE);
test_frame(&OK_MESSAGE);
}

#[test]
fn test_set_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &SET_MESSAGE);
test_frame(&SET_MESSAGE);
}

#[test]
fn test_get_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &GET_MESSAGE);
test_frame(&GET_MESSAGE);
}

#[test]
fn test_inc_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &INC_MESSAGE);
test_frame(&INC_MESSAGE);
}

#[test]
fn test_lpush_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &LPUSH_MESSAGE);
test_frame(&LPUSH_MESSAGE);
}

#[test]
fn test_rpush_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &RPUSH_MESSAGE);
test_frame(&RPUSH_MESSAGE);
}

#[test]
fn test_lpop_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &LPOP_MESSAGE);
test_frame(&LPOP_MESSAGE);
}

#[test]
fn test_sadd_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &SADD_MESSAGE);
test_frame(&SADD_MESSAGE);
}

#[test]
fn test_hset_codec() {
let mut codec = RedisCodec::new();
test_frame(&mut codec, &HSET_MESSAGE);
test_frame(&HSET_MESSAGE);
}
}
2 changes: 2 additions & 0 deletions shotover-proxy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ pub use cassandra::{CassandraFrame, CassandraOperation, CassandraResult};
use cassandra_protocol::compression::Compression;
pub use redis_protocol::resp2::types::Frame as RedisFrame;
use std::fmt::{Display, Formatter, Result as FmtResult};

pub mod cassandra;
pub mod redis;

#[derive(PartialEq, Debug, Clone, Copy)]
pub enum MessageType {
Expand Down
19 changes: 19 additions & 0 deletions shotover-proxy/src/frame/redis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::frame::RedisFrame;
use crate::message::QueryType;

#[inline]
pub fn redis_query_type(frame: &RedisFrame) -> QueryType {
if let RedisFrame::Array(frames) = frame {
if let Some(RedisFrame::BulkString(bytes)) = frames.get(0) {
return match bytes.to_ascii_uppercase().as_slice() {
b"APPEND" | b"BITCOUNT" | b"STRLEN" | b"GET" | b"GETRANGE" | b"MGET"
| b"LRANGE" | b"LINDEX" | b"LLEN" | b"SCARD" | b"SISMEMBER" | b"SMEMBERS"
| b"SUNION" | b"SINTER" | b"ZCARD" | b"ZCOUNT" | b"ZRANGE" | b"ZRANK"
| b"ZSCORE" | b"ZRANGEBYSCORE" | b"HGET" | b"HGETALL" | b"HEXISTS" | b"HKEYS"
| b"HLEN" | b"HSTRLEN" | b"HVALS" | b"PFCOUNT" => QueryType::Read,
_ => QueryType::Write,
};
}
}
QueryType::Write
}
2 changes: 1 addition & 1 deletion shotover-proxy/src/message/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::codec::redis::redis_query_type;
use crate::codec::CodecState;
use crate::frame::cassandra::Tracing;
use crate::frame::redis::redis_query_type;
use crate::frame::{
cassandra,
cassandra::{CassandraMetadata, CassandraOperation},
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/src/sources/redis_source.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisCodecBuilder;
use crate::server::TcpCodecListener;
use crate::sources::Sources;
use crate::tls::{TlsAcceptor, TlsAcceptorConfig};
Expand Down Expand Up @@ -65,7 +65,7 @@ impl RedisSource {
name.to_string(),
listen_addr.clone(),
hard_connection_limit.unwrap_or(false),
RedisCodec::new(),
RedisCodecBuilder::new(),
Arc::new(Semaphore::new(connection_limit.unwrap_or(512))),
trigger_shutdown_rx.clone(),
tls.map(TlsAcceptor::new).transpose()?,
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/src/transforms/redis/cluster_ports_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ fn is_cluster_slots(frame: &Frame) -> bool {
#[cfg(test)]
mod test {
use super::*;
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisDecoder;
use crate::transforms::redis::sink_cluster::parse_slots;
use tokio_util::codec::Decoder;

Expand Down Expand Up @@ -270,7 +270,7 @@ mod test {
#[test]
fn test_rewrite_port_slots() {
let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n";
let mut codec = RedisCodec::new();
let mut codec = RedisDecoder::new();
let mut message = codec
.decode(&mut slots_pcap.into())
.unwrap()
Expand Down
15 changes: 10 additions & 5 deletions shotover-proxy/src/transforms/redis/sink_cluster.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisCodecBuilder;
use crate::error::ChainResponse;
use crate::frame::{Frame, RedisFrame};
use crate::message::Message;
Expand Down Expand Up @@ -77,7 +77,7 @@ pub struct RedisSinkCluster {
load_scores: HashMap<(String, usize), usize>,
rng: SmallRng,
connection_count: usize,
connection_pool: ConnectionPool<RedisCodec, RedisAuthenticator, UsernamePasswordToken>,
connection_pool: ConnectionPool<RedisCodecBuilder, RedisAuthenticator, UsernamePasswordToken>,
reason_for_no_nodes: Option<&'static str>,
rebuild_connections: bool,
first_contact_points: Vec<String>,
Expand All @@ -97,8 +97,12 @@ impl RedisSinkCluster {
let authenticator = RedisAuthenticator {};

let connect_timeout = Duration::from_millis(connect_timeout_ms);
let connection_pool =
ConnectionPool::new_with_auth(connect_timeout, RedisCodec::new(), authenticator, tls)?;
let connection_pool = ConnectionPool::new_with_auth(
connect_timeout,
RedisCodecBuilder::new(),
authenticator,
tls,
)?;

let sink_cluster = RedisSinkCluster {
first_contact_points,
Expand Down Expand Up @@ -1052,14 +1056,15 @@ impl Authenticator<UsernamePasswordToken> for RedisAuthenticator {
#[cfg(test)]
mod test {
use super::*;
use crate::codec::redis::RedisDecoder;
use tokio_util::codec::Decoder;

#[test]
fn test_parse_slots() {
// Wireshark capture from a Redis cluster with 3 masters and 3 replicas.
let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n";

let mut codec = RedisCodec::new();
let mut codec = RedisDecoder::new();

let mut message = codec
.decode(&mut slots_pcap.into())
Expand Down
22 changes: 15 additions & 7 deletions shotover-proxy/src/transforms/redis/sink_single.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisCodecBuilder;
use crate::codec::redis::RedisDecoder;
use crate::codec::redis::RedisEncoder;
use crate::codec::CodecBuilder;
use crate::codec::CodecReadError;
use crate::error::ChainResponse;
use crate::frame::Frame;
Expand All @@ -9,15 +12,17 @@ use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig};
use crate::transforms::{Transform, TransformBuilder, Wrapper};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use futures::stream::{SplitSink, SplitStream};
use futures::{FutureExt, SinkExt, StreamExt};
use metrics::{register_counter, Counter};
use serde::Deserialize;
use std::fmt::Debug;
use std::pin::Pin;
use std::time::Duration;
use tokio::io::ReadHalf;
use tokio::io::WriteHalf;
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
use tokio_util::codec::FramedRead;
use tokio_util::codec::FramedWrite;
use tracing::Instrument;

#[derive(Deserialize, Debug, Clone)]
Expand Down Expand Up @@ -88,10 +93,10 @@ impl RedisSinkSingleBuilder {
}
}

type RedisFramed = Framed<Pin<Box<dyn AsyncStream + Send + Sync>>, RedisCodec>;
type PinStream = Pin<Box<dyn AsyncStream + Send + Sync>>;

struct Connection {
outbound_tx: SplitSink<RedisFramed, Messages>,
outbound_tx: FramedWrite<WriteHalf<PinStream>, RedisEncoder>,
response_messages_rx: mpsc::UnboundedReceiver<Message>,
sent_message_type_tx: mpsc::UnboundedSender<MessageType>,
}
Expand Down Expand Up @@ -130,7 +135,10 @@ impl Transform for RedisSinkSingle {
Box::pin(tcp_stream) as Pin<Box<dyn AsyncStream + Send + Sync>>
};

let (outbound_tx, outbound_rx) = Framed::new(generic_stream, RedisCodec::new()).split();
let (decoder, encoder) = RedisCodecBuilder::new().build();
let (stream_rx, stream_tx) = tokio::io::split(generic_stream);
let outbound_tx = FramedWrite::new(stream_tx, encoder);
let outbound_rx = FramedRead::new(stream_rx, decoder);
let (response_messages_tx, response_messages_rx) = mpsc::unbounded_channel();
let (sent_message_type_tx, sent_message_type_rx) = mpsc::unbounded_channel();

Expand Down Expand Up @@ -211,7 +219,7 @@ impl Transform for RedisSinkSingle {
///
/// The task will end silently if either the RedisSinkSingle transform is dropped or the server closes the connection.
async fn server_response_processing_task(
mut outbound_rx: SplitStream<RedisFramed>,
mut outbound_rx: FramedRead<ReadHalf<PinStream>, RedisDecoder>,
subscribe_tx: Option<mpsc::UnboundedSender<Messages>>,
response_messages_tx: mpsc::UnboundedSender<Message>,
mut sent_message_type: mpsc::UnboundedReceiver<MessageType>,
Expand Down
2 changes: 1 addition & 1 deletion shotover-proxy/src/transforms/redis/timestamp_tagging.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::codec::redis::redis_query_type;
use crate::error::ChainResponse;
use crate::frame::redis::redis_query_type;
use crate::frame::{Frame, RedisFrame};
use crate::message::{Message, QueryType};
use crate::transforms::{Transform, Wrapper};
Expand Down
6 changes: 3 additions & 3 deletions shotover-proxy/src/transforms/util/cluster_connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async fn rx_process<C: DecoderHalf, R: AsyncRead + Unpin + Send + 'static>(
#[cfg(test)]
mod test {
use super::spawn_read_write_tasks;
use crate::codec::redis::RedisCodec;
use crate::codec::redis::RedisCodecBuilder;
use std::mem;
use std::time::Duration;
use tokio::io::AsyncReadExt;
Expand Down Expand Up @@ -318,7 +318,7 @@ mod test {

let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let (rx, tx) = stream.into_split();
let codec = RedisCodec::new();
let codec = RedisCodecBuilder::new();
let sender = spawn_read_write_tasks(&codec, rx, tx);

assert!(remote.await.unwrap());
Expand Down Expand Up @@ -358,7 +358,7 @@ mod test {

let stream = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let (rx, tx) = stream.into_split();
let codec = RedisCodec::new();
let codec = RedisCodecBuilder::new();

// Drop sender immediately.
std::mem::drop(spawn_read_write_tasks(&codec, rx, tx));
Expand Down