From 9348d6ca3dfc579df7f263b38957dc0ce0ff52ba Mon Sep 17 00:00:00 2001 From: Kuangda He Date: Fri, 18 Jun 2021 04:08:03 +1000 Subject: [PATCH] Refactor to fix NOAUTH propagation * Removed ConnectionManager which introduced unnecessary constraints. * Split RedisError into RedisError (upstream) and TransformError. * Replaced Subject+Credential with UsernamePasswordToken. * Re-enabled authenticated connection multiplexing. --- Cargo.toml | 1 - src/transforms/mod.rs | 11 +- src/transforms/redis_transforms/mod.rs | 65 ++- .../redis_transforms/redis_cluster.rs | 429 ++++++------------ .../util/cluster_connection_pool.rs | 44 +- tests/redis_int_tests/basic_driver_tests.rs | 13 + tests/redis_int_tests/support.rs | 20 +- 7 files changed, 217 insertions(+), 366 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ec7f425a6..ab07197f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,7 +78,6 @@ wasmer-runtime = "0.17.1" sodiumoxide = "0.2.5" rusoto_kms = "0.46.0" rusoto_signature = "0.46.0" -crossbeam-utils = "0.8.4" [dev-dependencies] criterion = "0.3" diff --git a/src/transforms/mod.rs b/src/transforms/mod.rs index aac66001b..c34c160dc 100644 --- a/src/transforms/mod.rs +++ b/src/transforms/mod.rs @@ -259,7 +259,7 @@ pub struct Wrapper<'a> { // pub next_transform: usize, transforms: Vec<&'a mut Transforms>, pub client_details: String, - chain_name: String + chain_name: String, } impl<'a> Clone for Wrapper<'a> { @@ -268,7 +268,7 @@ impl<'a> Clone for Wrapper<'a> { message: self.message.clone(), transforms: vec![], client_details: self.client_details.clone(), - chain_name: self.chain_name.clone() + chain_name: self.chain_name.clone(), } } } @@ -359,10 +359,5 @@ pub trait Transform: Send { } pub type ResponseFuturesOrdered = FuturesOrdered< - Pin< - Box< - dyn Future)>> - + std::marker::Send, - >, - >, + Pin)>> + std::marker::Send>>, >; diff --git a/src/transforms/redis_transforms/mod.rs b/src/transforms/redis_transforms/mod.rs index 8de9b14c8..162f4d06a 100644 --- a/src/transforms/redis_transforms/mod.rs +++ b/src/transforms/redis_transforms/mod.rs @@ -7,11 +7,8 @@ pub mod redis_cluster; pub mod redis_codec_destination; pub mod timestamp_tagging; -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Clone, Debug)] pub enum RedisError { - // TODO: Refactor to avoid mixing levels. - - // High level errors. #[error("authentication is required")] NotAuthenticated, @@ -21,26 +18,28 @@ pub enum RedisError { #[error("username or password is incorrect")] BadCredentials, - // TODO: Figure out how to capture context from errors. - // #[error("could not connect to {address} due to {source:?}")] - // ConnectionError { - // address: String, - // #[source] - // source: anyhow::Error, - // }, - - // Low level errors. - #[error("send error")] - SendError(String), + #[error("unknown error: {0}")] + Unknown(String), +} - #[error("receive error")] - ReceiveError(String), +impl RedisError { + fn from_message(error: &str) -> RedisError { + match error.splitn(2, ' ').next() { + Some("NOAUTH") => RedisError::NotAuthenticated, + Some("NOPERM") => RedisError::NotAuthorized, + Some("WRONGPASS") => RedisError::BadCredentials, + _ => RedisError::Unknown(error.to_string()), + } + } +} - #[error("protocol error")] - ProtocolError(String), +#[derive(thiserror::Error, Debug)] +pub enum TransformError { + #[error(transparent)] + Upstream(#[from] RedisError), - #[error("unknown: {0}")] - Unknown(String), + #[error("protocol error: {0}")] + Protocol(String), #[error("io error: {0}")] IO(io::Error), @@ -49,24 +48,22 @@ pub enum RedisError { Other(#[from] anyhow::Error), } -impl RedisError { - fn from_message(error: &str) -> RedisError { - match error.splitn(2, ' ').next() { - Some(name) => match name { - "NOAUTH" => RedisError::NotAuthenticated, - "NOPERM" => RedisError::NotAuthorized, - "WRONGPASS" => RedisError::BadCredentials, - _ => RedisError::Unknown(error.to_string()), - }, - None => RedisError::Unknown(error.to_string()), +impl TransformError { + fn choose_upstream(errors: Vec) -> Option { + match errors.iter().find_map(|e| match e { + TransformError::Upstream(e) => Some(e), + _ => None, + }) { + Some(e) => Some(TransformError::Upstream(e.clone())), + None => errors.into_iter().next(), } } } -impl From> for RedisError { - fn from(error: ConnectionError) -> Self { +impl From> for TransformError { + fn from(error: ConnectionError) -> Self { match error { - ConnectionError::IO(e) => RedisError::IO(e), + ConnectionError::IO(e) => TransformError::IO(e), ConnectionError::Authenticator(e) => e, } } diff --git a/src/transforms/redis_transforms/redis_cluster.rs b/src/transforms/redis_transforms/redis_cluster.rs index 967706a74..02f72f92f 100644 --- a/src/transforms/redis_transforms/redis_cluster.rs +++ b/src/transforms/redis_transforms/redis_cluster.rs @@ -3,12 +3,10 @@ use std::iter::FromIterator; use std::{ collections::{BTreeMap, HashMap, HashSet}, str::FromStr, - sync::Arc, }; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; -use crossbeam_utils::thread; use derivative::Derivative; use futures::stream::FuturesUnordered; use futures::{StreamExt, TryFutureExt}; @@ -20,7 +18,6 @@ use rand::SeedableRng; use redis_protocol::types::Frame; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::Mutex; use tokio::time::timeout; use tokio::time::Duration; use tracing::{debug, info, trace, warn}; @@ -30,6 +27,7 @@ use crate::message::{Message, MessageDetails, Messages, QueryResponse}; use crate::protocols::redis_codec::RedisCodec; use crate::protocols::RawFrame; use crate::transforms::redis_transforms::RedisError; +use crate::transforms::redis_transforms::TransformError; use crate::transforms::util::cluster_connection_pool::ConnectionPool; use crate::transforms::util::{Request, Response}; use crate::transforms::ResponseFuturesOrdered; @@ -40,14 +38,16 @@ use crate::{ const SLOT_SIZE: usize = 16384; -type RedisSubject = Option; - type ChannelMap = HashMap>>; -#[derive(Clone, Debug)] +#[derive(Clone, Derivative)] +#[derivative(Debug)] pub struct SlotMap { masters: BTreeMap, followers: BTreeMap, + + // Hide redundant information. + #[derivative(Debug = "ignore")] nodes: HashSet, } @@ -78,19 +78,14 @@ impl From for SlotMap { } } -#[derive(Clone)] -pub struct Credential { +#[derive(Clone, PartialEq, Eq, Hash, Derivative)] +#[derivative(Debug)] +pub struct UsernamePasswordToken { pub username: Option, - pub password: String, -} -impl fmt::Debug for Credential { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Credential") - .field("username", &self.username) - .field("password", &"**SECRET**".to_string()) - .finish() - } + // Reduce risk of logging passwords. + #[derivative(Debug = "ignore")] + pub password: String, } #[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] @@ -151,42 +146,28 @@ fn parse_contact_points(contact_points: &Vec) -> Result for RedisAuthenticator { - type Error = RedisError; +impl Authenticator for RedisAuthenticator { + type Error = TransformError; async fn authenticate( &self, sender: &mut UnboundedSender, - subject: &RedisSubject, - ) -> Result<(), RedisError> { - let credential = match &self.credential_manager.get(subject).await { - Some(credential) => credential.clone(), - None => { - // NOTE: Should be unreachable unless we add expirable credentials. - // TODO: If this becomes expected, it needs its own error? - return Err(RedisError::Other(anyhow!( - "no credentials for subject: {:?}", - subject - ))); - } - }; - + token: &UsernamePasswordToken, + ) -> Result<(), TransformError> { let auth_frame = { let mut args = Vec::new(); args.push(Frame::BulkString(Bytes::from("AUTH"))); // Support non-ACL / username-less. - if let Some(username) = &credential.username { + if let Some(username) = &token.username { args.push(Frame::BulkString(Bytes::from(username.clone()))); } - args.push(Frame::BulkString(Bytes::from(credential.password))); + args.push(Frame::BulkString(Bytes::from(token.password.clone()))); Frame::Array(args) }; @@ -196,19 +177,19 @@ impl Authenticator for RedisAuthenticator { match receive_frame_response(return_chan_rx).await? { Frame::SimpleString(s) => { if s != "OK" { - return Err(RedisError::ProtocolError("bad response value".to_string())); + return Err(TransformError::Protocol("bad response value".to_string())); } - trace!("authenticated upstream as user: {:?}", credential.username); + trace!("authenticated upstream as user: {:?}", token.username); Ok(()) } Frame::Error(e) => { if let RedisError::BadCredentials = RedisError::from_message(&e) { - return Err(RedisError::BadCredentials); + return Err(TransformError::Upstream(RedisError::BadCredentials)); } else { - return Err(RedisError::ProtocolError(format!("bad auth error: {}", e))); + return Err(TransformError::Protocol(format!("bad auth error: {}", e))); } } - _ => return Err(RedisError::ProtocolError("bad response type".to_string())), + _ => return Err(TransformError::Protocol("bad response type".to_string())), } } } @@ -216,6 +197,8 @@ impl Authenticator for RedisAuthenticator { #[async_trait] impl TransformsFromConfig for RedisClusterConfig { async fn get_source(&self, _topics: &TopicHolder) -> Result { + let authenticator = RedisAuthenticator {}; + let mut cluster = RedisCluster { name: "RedisCluster", slots: SlotMap::new(), @@ -224,23 +207,23 @@ impl TransformsFromConfig for RedisClusterConfig { rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), contact_points: parse_contact_points(&self.first_contact_points)?, connection_count: self.connection_count.unwrap_or(1), + connection_pool: ConnectionPool::new_with_auth(RedisCodec::new(true, 3), authenticator), rebuild_slots: false, - subject: None, - - // NOTE: Multiplexing is only safe for anonymous upstreams. - // TODO: Improve safety and allow multiplexing authenticated upstreams. - mux: RedisClusterMux::new(MuxShareMode::Anonymous), + token: None, }; - let credential = self.password.as_ref().map(|password| Credential { - username: self.username.clone(), - password: password.clone(), - }); + let token = self + .password + .as_ref() + .map(|password| UsernamePasswordToken { + username: self.username.clone(), + password: password.clone(), + }); - match cluster.rebuild_connections(credential.as_ref()).await { + match cluster.build_connections(token).await { Ok(()) => info!("eager connections established"), - Err(RedisError::NotAuthenticated) => { - info!("deferring connections due to auth") + Err(TransformError::Upstream(RedisError::NotAuthenticated)) => { + info!("deferring connections due to auth"); } Err(e) => { bail!("failed to connect to upstream: {}", e) @@ -251,112 +234,6 @@ impl TransformsFromConfig for RedisClusterConfig { } } -type CredentialMap = HashMap; - -#[derive(Clone)] -struct CredentialManager { - credentials: Arc>, -} - -impl CredentialManager { - fn new() -> Self { - Self { - credentials: Arc::new(Mutex::new(HashMap::new())), - } - } - - async fn insert(&mut self, credential: Credential) -> Option { - let subject = &credential.username; - self.credentials - .lock() - .await - .insert(subject.clone(), credential) - } - - async fn get(&self, subject: &RedisSubject) -> Option { - self.credentials.lock().await.get(subject).cloned() - } -} - -// TODO: Choose via config. -#[allow(dead_code)] -#[derive(Debug, Copy, Clone)] -enum MuxShareMode { - All, - None, - Anonymous, -} - -/// The part of RedisCluster that is conditionally shared. -#[derive(Derivative)] -#[derivative(Debug)] -struct RedisClusterMux { - share_mode: MuxShareMode, - connection_pool: ConnectionPool, - #[derivative(Debug = "ignore")] - credential_manager: CredentialManager, -} - -impl RedisClusterMux { - fn new(share_mode: MuxShareMode) -> Self { - let credential_manager = CredentialManager::new(); - - let authenticator = RedisAuthenticator { - credential_manager: credential_manager.clone(), - }; - - let connection_pool = - ConnectionPool::new_with_auth(RedisCodec::new(true, 3), authenticator); - - Self { - share_mode, - connection_pool, - credential_manager, - } - } - async fn clone_anonymous(&self) -> Self { - let credential_manager = CredentialManager::new(); - - let authenticator = RedisAuthenticator { - credential_manager: credential_manager.clone(), - }; - - let connection_pool = self.connection_pool.clone_anonymous(authenticator).await; - - Self { - share_mode: self.share_mode, - connection_pool, - credential_manager, - } - } -} - -impl Clone for RedisClusterMux { - fn clone(&self) -> Self { - match self.share_mode { - MuxShareMode::All => Self { - share_mode: self.share_mode.clone(), - connection_pool: self.connection_pool.clone(), - credential_manager: self.credential_manager.clone(), - }, - MuxShareMode::None => Self::new(self.share_mode), - MuxShareMode::Anonymous => { - let handle = tokio::runtime::Handle::current(); - thread::scope(|scope| { - scope - .spawn(move |_| { - // Clone using new thread while we block this one. - handle.block_on(self.clone_anonymous()) - }) - .join() - }) - .unwrap() - .unwrap() - } - } - } -} - pub struct Fmt(pub F) where F: Fn(&mut fmt::Formatter) -> fmt::Result; @@ -381,11 +258,9 @@ pub struct RedisCluster { rng: SmallRng, contact_points: Vec, connection_count: usize, + connection_pool: ConnectionPool, rebuild_slots: bool, - subject: Option, - - // TODO: Delegate to this object (RFC#2393). - mux: RedisClusterMux, + token: Option, } fn fmt_channels(channels: &ChannelMap, fmt: &mut fmt::Formatter) -> fmt::Result { @@ -408,7 +283,7 @@ impl RedisCluster { } } - async fn rebuild_slot_map(&mut self) -> Result<(), RedisError> { + async fn rebuild_slot_map(&mut self) -> Result<(), TransformError> { debug!("rebuilding slot map"); let contact_points = self.get_contact_points(); // IDEA: Retry with original contact points on failure? @@ -418,19 +293,18 @@ impl RedisCluster { async fn rebuild_slot_map_from_contacts( &mut self, contact_points: &[ContactPoint], - ) -> Result<(), RedisError> { + ) -> Result<(), TransformError> { let mut errors = Vec::new(); for contact in contact_points { match self - .mux .connection_pool - .new_connection(&contact.address, &self.subject) + .new_connection(&contact.address, &self.token) .await { Ok(sender) => match get_topology_from_node(&sender).await { Ok(mapping) => { - debug!("successfully updated map {:?}", mapping); + trace!("successfully updated map {:?}", mapping); self.slots = mapping.into(); return Ok(()); } @@ -444,63 +318,40 @@ impl RedisCluster { } debug!("failed to fetch slot map from all hosts"); - - if let Some(preferred_error) = errors.iter().find_map(|e| match e { - RedisError::NotAuthenticated => Some(RedisError::NotAuthenticated), - RedisError::NotAuthorized => Some(RedisError::NotAuthorized), - RedisError::BadCredentials => Some(RedisError::BadCredentials), - _ => None, - }) { - Err(preferred_error) - } else { - Err(errors.into_iter().next().unwrap()) - } + Err(TransformError::choose_upstream(errors).unwrap()) } - async fn rebuild_connections( + async fn build_connections( &mut self, - credential: Option<&Credential>, - ) -> Result<(), RedisError> { - // IDEA: Use state machine and store together? - // IDEA: Refactor to defer commit until success? Can transactional-memory help? - // TODO: Make this safe enough to multiplex authenticated connections. - - // Backup old connections to restore on failure. - let old_subject = self.subject.clone(); + token: Option, + ) -> Result<(), TransformError> { + // Backup existing state to restore on failure. let old_channels = self.channels.clone(); - let old_credential; + let old_token = self.token.clone(); - if let Some(credential) = credential { - old_credential = self.mux.credential_manager.insert(credential.clone()).await; - self.subject = Some(credential.username.clone()); - } else { - old_credential = None; - self.subject = None; - } + debug!("building connections with: {:?}", token); + self.token = token; - debug!("rebuilding connections as subject: {:?}", self.subject); + self.rebuild_connections().await.map_err(|e| { + self.channels = old_channels; + self.token = old_token; + e + }) + } + async fn rebuild_connections(&mut self) -> Result<(), TransformError> { if let Err(e) = self.rebuild_slot_map().await { debug!("failed to rebuild slot map: {}", e); if self.slots.nodes.is_empty() { - // Failed to fetch initial slot map. - // Restore subject and abort. - self.subject = old_subject; - if let Some(old_credential) = old_credential { - self.mux.credential_manager.insert(old_credential).await; - } return Err(e); } else { warn!("using cached slot map") } + } else { + debug!("mapped cluster: {:?}", self.slots); }; - debug!("mapped cluster: {:?}", self.slots); - - // Clear existing connections! - // TODO: Can we restore connections inside the pool on failure? - self.channels = ChannelMap::new(); - + let mut channels = ChannelMap::new(); let mut errors = Vec::new(); // TODO: How to make this DRY...? @@ -508,17 +359,16 @@ impl RedisCluster { debug!("building master connections"); for (_, node) in &self.slots.masters { match self - .mux .connection_pool - .get_connections(node.clone(), &self.subject, self.connection_count) + .get_connections(node.clone(), &self.token, self.connection_count) .await { Ok(connections) => { - self.channels.insert(node.clone(), connections); + channels.insert(node.clone(), connections); } Err(e) => { info!("Could not create connection to {} - {}", node, e); - errors.push(e); + errors.push(e.into()); } } } @@ -526,35 +376,32 @@ impl RedisCluster { debug!("building follower connections"); for (_, node) in &self.slots.followers { match self - .mux .connection_pool - .get_connections(node.clone(), &self.subject, self.connection_count) + .get_connections(node.clone(), &self.token, self.connection_count) .await { Ok(connections) => { - self.channels.insert(node.clone(), connections); + channels.insert(node.clone(), connections); } Err(e) => { info!("Could not create connection to {} - {}", node, e); - errors.push(e); + errors.push(e.into()); } } } - if self.channels.is_empty() && !errors.is_empty() { - self.subject = old_subject; - self.channels = old_channels; - if let Some(old_credential) = old_credential { - self.mux.credential_manager.insert(old_credential).await; - } - // Reject total failure by propagating connection error. - return Err(errors.into_iter().next().unwrap().into()); + if channels.is_empty() && !errors.is_empty() { + debug!("total failure trying to rebuild connections"); + return Err(TransformError::choose_upstream(errors).unwrap()); } + self.channels = channels; + debug!( "Connected to cluster: {:?}", Fmt(|f| fmt_channels(&self.channels, f)) ); + Ok(()) } @@ -570,10 +417,10 @@ impl RedisCluster { // let mut retry = true; // TODO: this is hard to read and may be bug prone - let chan = match self.channels.get_mut(&host) { - Some(chans) if chans.len() == 1 => chans.get_mut(0).unwrap(), - Some(chans) if chans.len() > 1 => { - let candidates = rand::seq::index::sample(&mut self.rng, chans.len(), 2); + let channel = match self.channels.get_mut(&host) { + Some(channels) if channels.len() == 1 => channels.get_mut(0).unwrap(), + Some(channels) if channels.len() > 1 => { + let candidates = rand::seq::index::sample(&mut self.rng, channels.len(), 2); let aidx = candidates.index(0); let bidx = candidates.index(1); @@ -581,7 +428,7 @@ impl RedisCluster { let aload = *self.load_scores.entry((host.clone(), aidx)).or_insert(0); let bload = *self.load_scores.entry((host.clone(), bidx)).or_insert(0); - chans + channels .get_mut(if aload <= bload { aidx } else { bidx }) .ok_or_else(|| anyhow!("Couldn't find host {}", host))? } @@ -589,9 +436,9 @@ impl RedisCluster { debug!("connection {} doesn't exist trying to connect", host); if let Ok(res) = timeout( Duration::from_millis(40), - self.mux.connection_pool.get_connections( + self.connection_pool.get_connections( host.clone(), - &self.subject, + &self.token, self.connection_count, ), ) @@ -616,7 +463,7 @@ impl RedisCluster { } } else { debug!( - "couldn't connect to {} - updating slot map from upstream cluster", + "timed out connecting to {} - updating slot map from upstream cluster", host ); self.rebuild_slots = true; @@ -626,7 +473,7 @@ impl RedisCluster { } }; - if let Err(e) = chan.send(Request { + if let Err(e) = channel.send(Request { messages: message, return_chan: Some(one_tx), message_id: None, @@ -657,25 +504,17 @@ impl RedisCluster { ChannelsResult::Channels(Vec::from_iter(self.slots.nodes.iter().cloned())) } Some(RoutingInfo::AllMasters) => { - // let mut conns = vec![]; - // self.slo - // for host in self.channels.masters.keys() { - // conns.push(host.clone()); - // } ChannelsResult::Channels(Vec::from_iter(self.slots.masters.values().cloned())) } - - Some(RoutingInfo::Random) => { - // TODO: Rename Random as RandomMaster? - let key = self - .slots + Some(RoutingInfo::Random) => ChannelsResult::Channels( + self.slots .masters .values() .next() - .unwrap_or(&"nothing".to_string()) - .clone(); - ChannelsResult::Channels(vec![key]) - } + .map(|key| vec![key.clone()]) + .unwrap_or(vec![]) + .clone(), + ), Some(RoutingInfo::Special(name)) => ChannelsResult::Command(name), None => ChannelsResult::Channels(vec![]), } @@ -707,16 +546,15 @@ impl RedisCluster { } }; let username = args.next(); + let token = UsernamePasswordToken { username, password }; - let credential = Credential { username, password }; - debug!("handling AUTH for {:?}", credential); - - match self.rebuild_connections(Some(&credential)).await { + match self.build_connections(Some(token)).await { Ok(()) => { debug!("successful AUTH"); send_simple_response(one_tx, "OK") } - Err(RedisError::BadCredentials) => { + Err(TransformError::Upstream(RedisError::BadCredentials)) => { + debug!("login attempt with bad credentials"); send_error_response(one_tx, "WRONGPASS invalid username-password") } Err(e) => { @@ -727,10 +565,14 @@ impl RedisCluster { } } -#[derive(Debug)] +#[derive(Derivative)] +#[derivative(Debug)] pub struct RawSlotMapping { pub masters: Vec<(String, u16, u16)>, pub followers: Vec<(String, u16, u16)>, + + // Hide redundant information. + #[derivative(Debug = "ignore")] pub nodes: HashSet, } @@ -865,8 +707,8 @@ fn build_slot_to_server( slots.push((format!("{}:{}", ip, port), start, end)); } -fn parse_slots(results: &Vec) -> Result { - let mut slots: Vec<(String, u16, u16)> = vec![]; +fn parse_slots(results: &Vec) -> Result { + let mut master_slots: Vec<(String, u16, u16)> = vec![]; let mut replica_slots: Vec<(String, u16, u16)> = vec![]; let mut nodes: HashSet = HashSet::new(); @@ -882,13 +724,13 @@ fn parse_slots(results: &Vec) -> Result { (0, Frame::Integer(i)) => start = *i as u16, (1, Frame::Integer(i)) => end = *i as u16, (2, Frame::Array(master)) => { - build_slot_to_server(master, &mut nodes, &mut slots, start, end) + build_slot_to_server(master, &mut nodes, &mut master_slots, start, end) } (n, Frame::Array(follow)) if n > 2 => { build_slot_to_server(&follow, &mut nodes, &mut replica_slots, start, end) } _ => { - return Err(RedisError::ProtocolError( + return Err(TransformError::Protocol( "unexpected value in slot map".to_string(), )) } @@ -896,11 +738,11 @@ fn parse_slots(results: &Vec) -> Result { } } - if slots.is_empty() { - Err(RedisError::Other(anyhow!("empty slot map!"))) + if master_slots.is_empty() { + Err(TransformError::Other(anyhow!("empty slot map!"))) } else { Ok(RawSlotMapping { - masters: slots, + masters: master_slots, followers: replica_slots, nodes, }) @@ -909,7 +751,7 @@ fn parse_slots(results: &Vec) -> Result { async fn get_topology_from_node( sender: &UnboundedSender, -) -> Result { +) -> Result { let return_chan_rx = send_frame_request( sender, Frame::Array(vec![ @@ -920,8 +762,10 @@ async fn get_topology_from_node( match receive_frame_response(return_chan_rx).await? { Frame::Array(results) => parse_slots(&results), - Frame::Error(message) => Err(RedisError::from_message(message.as_str())), - frame => Err(RedisError::ProtocolError(format!( + Frame::Error(message) => Err(TransformError::Upstream(RedisError::from_message( + message.as_str(), + ))), + frame => Err(TransformError::Protocol(format!( "unexpected response frame: {}", frame ))), @@ -999,17 +843,15 @@ fn send_frame_request( ) -> Result)>> { let (return_chan_tx, return_chan_rx) = tokio::sync::oneshot::channel(); - sender - .send(Request { - messages: Message { - details: MessageDetails::Unknown, - modified: false, - original: RawFrame::Redis(frame), - }, - return_chan: Some(return_chan_tx), - message_id: None, - }) - .map_err(|e| RedisError::SendError(format!("failed to send: {}", e)))?; + sender.send(Request { + messages: Message { + details: MessageDetails::Unknown, + modified: false, + original: RawFrame::Redis(frame), + }, + return_chan: Some(return_chan_tx), + message_id: None, + })?; Ok(return_chan_rx) } @@ -1018,9 +860,7 @@ fn send_frame_request( async fn receive_frame_response( receiver: tokio::sync::oneshot::Receiver<(Message, Result)>, ) -> Result { - let (_, result) = receiver - .await - .map_err(|e| RedisError::ReceiveError(format!("failed to receive: {}", e)))?; + let (_, result) = receiver.await?; // TODO: Is it possible for unwrap to panic here? let message = result?.messages.pop().unwrap().original; @@ -1046,6 +886,12 @@ fn response_sender( #[async_trait] impl Transform for RedisCluster { async fn transform<'a>(&'a mut self, qd: Wrapper<'a>) -> ChainResponse { + let unroutable_error = if self.channels.is_empty() { + self.rebuild_connections().await.err() + } else { + None + }; + if self.rebuild_slots { self.rebuild_slot_map().await?; self.rebuild_slots = false; @@ -1067,7 +913,7 @@ impl Transform for RedisCluster { }; let channels = match self.get_channels(command).await { - ChannelsResult::Channels(sender) => sender, + ChannelsResult::Channels(channels) => channels, ChannelsResult::Command(name) => { // Handle special command routing. match name { @@ -1080,10 +926,18 @@ impl Transform for RedisCluster { responses.push(match channels.len() { 0 => { let (one_tx, one_rx) = tokio::sync::oneshot::channel::(); - short_circuit(qd.chain_name.as_str(), one_tx); - Box::pin(one_rx.map_err(|e| { - anyhow!("0 Couldn't get short circuited for no channels - {}", e) - })) + + match unroutable_error { + Some(TransformError::Upstream(RedisError::NotAuthenticated)) => { + let _ = send_error_response( + one_tx, + "NOAUTH Authentication required (cached)", + ); + } + _ => short_circuit(qd.chain_name.as_str(), one_tx), + }; + + Box::pin(one_rx.map_err(anyhow::Error::msg)) } 1 => { let one_rx = self @@ -1093,7 +947,8 @@ impl Transform for RedisCluster { qd.chain_name.as_str(), ) .await?; - Box::pin(one_rx.map_err(|e| anyhow!("1 {}", e))) + + Box::pin(one_rx.map_err(anyhow::Error::msg)) } _ => { let futures: FuturesUnordered< @@ -1129,7 +984,7 @@ impl Transform for RedisCluster { }) .await; - std::result::Result::Ok(( + Ok(( orig, ChainResponse::Ok(Messages::new_from_message(Message { details: MessageDetails::Unknown, diff --git a/src/transforms/util/cluster_connection_pool.rs b/src/transforms/util/cluster_connection_pool.rs index beceea64b..a29c38df3 100644 --- a/src/transforms/util/cluster_connection_pool.rs +++ b/src/transforms/util/cluster_connection_pool.rs @@ -28,13 +28,13 @@ pub trait Authenticator { } // TODO: Replace with trait_alias (RFC#1733). -pub trait Subject: Send + Sync + std::hash::Hash + Eq + Clone {} -impl Subject for T {} +pub trait Token: Send + Sync + std::hash::Hash + Eq + Clone {} +impl Token for T {} #[derive(Clone, Derivative)] #[derivative(Debug)] -pub struct ConnectionPool, S: Subject> { - lanes: Arc, Lane>>>, +pub struct ConnectionPool, T: Token> { + lanes: Arc, Lane>>>, #[derivative(Debug = "ignore")] codec: C, @@ -43,7 +43,7 @@ pub struct ConnectionPool, S: Subject> { authenticator: A, } -impl, S: Subject> ConnectionPool { +impl, T: Token> ConnectionPool { // TODO: Support non-authenticated connection pools (with RFC#1216?). pub fn new_with_auth(codec: C, authenticator: A) -> Self { Self { @@ -53,37 +53,19 @@ impl, S: Subject> ConnectionPool Self { - let anonymous_lane = self.lanes.lock().await.get(&None).cloned(); - - let mut lanes = HashMap::new(); - - if let Some(anonymous_lane) = anonymous_lane { - // Carry-over a copy of the anonymous lane. - // NOTE: Future upstreams can use previous connections but NOT vice versa. - lanes.insert(None, anonymous_lane); - } - - Self { - lanes: Arc::new(Mutex::new(lanes)), - codec: self.codec.clone(), - authenticator, - } - } - /// Try and grab an existing connection, if it's closed (e.g. the listener on the other side /// has closed due to a TCP error), we'll try to reconnect and return the new connection while /// updating the connection map. Errors are returned when a connection can't be established. pub async fn get_connections( &self, address: Address, - subject: &Option, + token: &Option, connection_count: usize, ) -> Result, ConnectionError> { // TODO: Extract return type using generic associated types (RFC#1598). let mut lanes = self.lanes.lock().await; - let lane = lanes.entry(subject.clone()).or_insert_with(HashMap::new); + let lane = lanes.entry(token.clone()).or_insert_with(HashMap::new); if let Some(xs) = lane.get(&address) { // TODO: Reuse connections and not let one bad apple spoil the batch? @@ -93,7 +75,7 @@ impl, S: Subject> ConnectionPool, S: Subject> ConnectionPool, + token: &Option, connection_count: usize, ) -> Result, ConnectionError> { let mut connections: Vec = Vec::new(); let mut errors = Vec::new(); for i in 0..connection_count { - match self.new_connection(address, &subject).await { + match self.new_connection(address, &token).await { Ok(connection) => { connections.push(connection); } @@ -149,7 +131,7 @@ impl, S: Subject> ConnectionPool, + token: &Option, ) -> Result> { let stream: TcpStream = TcpStream::connect(address) .await @@ -157,9 +139,9 @@ impl, S: Subject> ConnectionPool Result<()> { panic!("bad password was ok"); } + let anonymous_context = TestContext::new_without_test(); + let mut anonymous_connection = anonymous_context.connection(); + + let no_auth_result: Result = redis::cmd("GET") + .arg("without authenticating") + .query(&mut anonymous_connection); + + if let Err(error) = no_auth_result { + assert!(error.to_string().starts_with("NOAUTH")) + } else { + panic!("should have got noauth") + } + Ok(()) } diff --git a/tests/redis_int_tests/support.rs b/tests/redis_int_tests/support.rs index 4451251fd..2d248c0ac 100644 --- a/tests/redis_int_tests/support.rs +++ b/tests/redis_int_tests/support.rs @@ -17,16 +17,20 @@ impl Default for TestContext { impl TestContext { pub fn new_auth() -> TestContext { - TestContext::new_internal("redis://default:shotover@127.0.0.1:6379/".to_string()) + TestContext::new_internal("redis://default:shotover@127.0.0.1:6379/", true) } pub fn new() -> TestContext { - TestContext::new_internal("redis://127.0.0.1:6379/".to_string()) + TestContext::new_internal("redis://127.0.0.1:6379/", true) } - pub fn new_internal(conn_string: String) -> TestContext { + pub fn new_without_test() -> TestContext { + TestContext::new_internal("redis://127.0.0.1:6379/", false) + } + + pub fn new_internal(conn_string: &str, test: bool) -> TestContext { info!("Using connection string: {}", conn_string); - let client = redis::Client::open(conn_string.as_str()).unwrap(); + let client = redis::Client::open(conn_string).unwrap(); let mut con; let attempts = 10; @@ -53,6 +57,9 @@ impl TestContext { } Ok(x) => { con = x; + if !test { + break; + } let result: RedisResult> = redis::cmd("GET").arg("nosdjkghsdjghsdkghj").query(&mut con); match result { @@ -70,7 +77,10 @@ impl TestContext { } } } - redis::cmd("FLUSHDB").execute(&mut con); + + if test { + redis::cmd("FLUSHDB").execute(&mut con); + } TestContext { client } }