diff --git a/shotover-proxy/src/transforms/cassandra/sink_cluster.rs b/shotover-proxy/src/transforms/cassandra/sink_cluster.rs index 2acccd092..3056217bb 100644 --- a/shotover-proxy/src/transforms/cassandra/sink_cluster.rs +++ b/shotover-proxy/src/transforms/cassandra/sink_cluster.rs @@ -16,7 +16,7 @@ use cql3_parser::common::{FQName, Identifier}; use metrics::{register_counter, Counter}; use rand::prelude::*; use serde::Deserialize; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; use tokio::net::ToSocketAddrs; @@ -45,9 +45,9 @@ impl CassandraSinkClusterConfig { pub struct CassandraSinkCluster { contact_points: Vec, - contact_point_connection: Option, + init_handshake_connection: Option, init_handshake: Vec, - init_handshake_address: Option, + init_handshake_address: Option, init_handshake_complete: bool, chain_name: String, failed_requests: Counter, @@ -70,7 +70,7 @@ impl Clone for CassandraSinkCluster { fn clone(&self) -> Self { CassandraSinkCluster { contact_points: self.contact_points.clone(), - contact_point_connection: None, + init_handshake_connection: None, init_handshake: vec![], init_handshake_address: None, init_handshake_complete: false, @@ -112,7 +112,7 @@ impl CassandraSinkCluster { CassandraSinkCluster { contact_points, - contact_point_connection: None, + init_handshake_connection: None, init_handshake: vec![], init_handshake_address: None, init_handshake_complete: false, @@ -133,12 +133,24 @@ impl CassandraSinkCluster { impl CassandraSinkCluster { async fn send_message(&mut self, messages: Messages) -> ChainResponse { + // Attempt to populate nodes list if we still dont have one yet + if self.local_nodes.is_empty() { + let nodes_shared = self.topology_task_nodes.read().await; + self.local_nodes = nodes_shared.clone(); + } + // Create the initial connection. - // Messages will be sent through this connection until we have extracted the handshake and list of nodes - // TODO: initial connection should come from node list too - if self.contact_point_connection.is_none() { - let random_point = self.contact_points.choose(&mut self.rng).unwrap(); - self.contact_point_connection = Some( + // Messages will be sent through this connection until we have extracted the handshake. + if self.init_handshake_connection.is_none() { + let random_point = if let Some(random_point) = self.local_nodes.choose(&mut self.rng) { + SocketAddr::new(random_point.address, 9042) + } else { + tokio::net::lookup_host(self.contact_points.choose(&mut self.rng).unwrap()) + .await? + .next() + .unwrap() + }; + self.init_handshake_connection = Some( CassandraConnection::new( random_point, CassandraCodec::new(), @@ -147,13 +159,7 @@ impl CassandraSinkCluster { ) .await?, ); - self.init_handshake_address = Some(random_point.clone()); - } - - // Attempt to populate nodes list if we still dont have one yet - if self.local_nodes.is_empty() { - let nodes_shared = self.topology_task_nodes.read().await; - self.local_nodes.extend(nodes_shared.iter().cloned()); + self.init_handshake_address = Some(random_point); } if !self.init_handshake_complete { @@ -169,7 +175,7 @@ impl CassandraSinkCluster { if let Ok(permit) = self.task_handshake_tx.try_reserve() { permit.send(TaskHandshake { handshake: self.init_handshake.clone(), - address: self.init_handshake_address.as_ref().unwrap().clone(), + address: self.init_handshake_address.unwrap(), }) } self.init_handshake_complete = true; @@ -184,8 +190,11 @@ impl CassandraSinkCluster { for message in messages { let (return_chan_tx, return_chan_rx) = oneshot::channel(); if self.local_nodes.is_empty() || !self.init_handshake_complete { - self.contact_point_connection.as_mut().unwrap() + // If the handshake is incomplete then we need to keep sending down this connection until we have formed a complete handshake. + // If the handshake is complete but the nodes list isnt ready yet then this connection will make do until we have a nodes list. + self.init_handshake_connection.as_mut().unwrap() } else { + // We have a full nodes list and handshake, so we can do proper routing now. let random_node = self.local_nodes.choose_mut(&mut self.rng).unwrap(); random_node .get_connection(&self.init_handshake, &self.tls, &self.pushed_messages_tx) @@ -362,7 +371,7 @@ pub struct CassandraNode { #[derive(Debug)] pub struct TaskHandshake { pub handshake: Vec, - pub address: String, + pub address: SocketAddr, } impl CassandraNode { diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster.rs b/shotover-proxy/tests/cassandra_int_tests/cluster.rs index ae47961db..2fc36d8b6 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster.rs @@ -21,7 +21,7 @@ pub async fn test() { // Normally the handshake is the handshake that the client gave shotover. task_handshake_tx .send(TaskHandshake { - address: "172.16.1.2:9042".to_string(), + address: "172.16.1.2:9042".parse().unwrap(), handshake: create_handshake(), }) .await