Skip to content

Commit

Permalink
RedisClusterPortsRewrite port to MessageId (#1485)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Feb 19, 2024
1 parent 203691a commit 022529e
Showing 1 changed file with 45 additions and 29 deletions.
74 changes: 45 additions & 29 deletions shotover/src/transforms/redis/cluster_ports_rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::collections::HashMap;

use crate::frame::Frame;
use crate::frame::RedisFrame;
use crate::message::MessageId;
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::{anyhow, bail, Context, Result};
Expand All @@ -19,9 +22,7 @@ const NAME: &str = "RedisClusterPortsRewrite";
#[async_trait(?Send)]
impl TransformConfig for RedisClusterPortsRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(RedisClusterPortsRewrite {
new_port: self.new_port,
}))
Ok(Box::new(RedisClusterPortsRewrite::new(self.new_port)))
}
}

Expand All @@ -38,11 +39,21 @@ impl TransformBuilder for RedisClusterPortsRewrite {
#[derive(Clone)]
pub struct RedisClusterPortsRewrite {
new_port: u16,
request_type: HashMap<MessageId, RequestType>,
}

#[derive(Clone)]
enum RequestType {
ClusterSlot,
ClusterNodes,
}

impl RedisClusterPortsRewrite {
pub fn new(new_port: u16) -> Self {
RedisClusterPortsRewrite { new_port }
RedisClusterPortsRewrite {
new_port,
request_type: HashMap::new(),
}
}
}

Expand All @@ -53,43 +64,48 @@ impl Transform for RedisClusterPortsRewrite {
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
// Find the indices of cluster slot messages
let mut cluster_slots_indices = vec![];
let mut cluster_nodes_indices = vec![];

for (i, message) in requests_wrapper.requests.iter_mut().enumerate() {
for message in requests_wrapper.requests.iter_mut() {
let message_id = message.id();
if let Some(frame) = message.frame() {
if is_cluster_slots(frame) {
cluster_slots_indices.push(i);
self.request_type
.insert(message_id, RequestType::ClusterSlot);
}

if is_cluster_nodes(frame) {
cluster_nodes_indices.push(i);
self.request_type
.insert(message_id, RequestType::ClusterNodes);
}
}
}

let mut response = requests_wrapper.call_next_transform().await?;

// Rewrite the ports in the cluster slots responses
for i in cluster_slots_indices {
if let Some(frame) = response[i].frame() {
rewrite_port_slot(frame, self.new_port)
.context("failed to rewrite CLUSTER SLOTS port")?;
}
response[i].invalidate_cache();
}

// Rewrite the ports in the cluster nodes responses
for i in cluster_nodes_indices {
if let Some(frame) = response[i].frame() {
rewrite_port_node(frame, self.new_port)
.context("failed to rewrite CLUSTER NODES port")?;
let mut responses = requests_wrapper.call_next_transform().await?;

for response in &mut responses {
if let Some(request_id) = response.request_id() {
match self.request_type.remove(&request_id) {
// Rewrite the ports in the cluster slots responses
Some(RequestType::ClusterSlot) => {
if let Some(frame) = response.frame() {
rewrite_port_slot(frame, self.new_port)
.context("failed to rewrite CLUSTER SLOTS port")?;
}
response.invalidate_cache();
}
// Rewrite the ports in the cluster nodes responses
Some(RequestType::ClusterNodes) => {
if let Some(frame) = response.frame() {
rewrite_port_node(frame, self.new_port)
.context("failed to rewrite CLUSTER NODES port")?;
}
response.invalidate_cache();
}
None => {}
}
}
response[i].invalidate_cache();
}

Ok(response)
Ok(responses)
}
}

Expand Down

0 comments on commit 022529e

Please sign in to comment.