Skip to content

Commit

Permalink
Merge branch 'main' into keyspace
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Nov 9, 2022
2 parents 493e239 + c8c8772 commit 032404f
Show file tree
Hide file tree
Showing 14 changed files with 105 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ chain_config:
certificate_authority_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost_CA.crt"
certificate_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost.crt"
private_key_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost.key"
verify_hostname: true
verify_hostname: false
source_to_chain_mapping:
cassandra_prod: main_chain
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ sources:
chain_config:
main_chain:
- CassandraSinkSingle:
remote_address: "127.0.0.1:9042"
remote_address: "localhost:9042"
connect_timeout_ms: 3000
tls:
certificate_authority_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost_CA.crt"
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/example-configs/cassandra-tls/topology.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ sources:
chain_config:
main_chain:
- CassandraSinkSingle:
remote_address: "127.0.0.1:9042"
remote_address: "localhost:9042"
connect_timeout_ms: 3000
tls:
certificate_authority_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost_CA.crt"
verify_hostname: false
verify_hostname: true
source_to_chain_mapping:
cassandra_prod: main_chain
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ chain_config:
certificate_authority_path: "example-configs/redis-tls/certs/ca.crt"
certificate_path: "example-configs/redis-tls/certs/redis.crt"
private_key_path: "example-configs/redis-tls/certs/redis.key"
verify_hostname: true
verify_hostname: false
source_to_chain_mapping:
redis_prod: redis_chain
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ chain_config:
connect_timeout_ms: 3000
tls:
certificate_authority_path: "example-configs/redis-tls/certs/ca.crt"
verify_hostname: true
verify_hostname: false
source_to_chain_mapping:
redis_prod: redis_chain
4 changes: 2 additions & 2 deletions shotover-proxy/example-configs/redis-tls/topology.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ sources:
chain_config:
redis_chain_tls:
- RedisSinkSingle:
remote_address: "127.0.0.1:1111"
remote_address: "localhost:1111"
connect_timeout_ms: 3000
tls:
certificate_authority_path: "example-configs/redis-tls/certs/ca.crt"
certificate_path: "example-configs/redis-tls/certs/redis.crt"
private_key_path: "example-configs/redis-tls/certs/redis.key"
verify_hostname: false
verify_hostname: true
source_to_chain_mapping:
redis_prod: redis_chain_tls
redis_prod_tls: redis_chain_tls
76 changes: 72 additions & 4 deletions shotover-proxy/src/tls.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use crate::tcp;
use anyhow::{anyhow, Result};
use openssl::ssl::Ssl;
use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio_openssl::SslStream;

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down Expand Up @@ -127,15 +130,20 @@ impl TlsConnector {
})
}

pub async fn connect(&self, tcp_stream: TcpStream) -> Result<SslStream<TcpStream>> {
pub async fn connect<A: ToSocketAddrs + ToHostname + std::fmt::Debug>(
&self,
connect_timeout: Duration,
address: A,
) -> Result<SslStream<TcpStream>> {
let ssl = self
.connector
.configure()
.map_err(openssl_stack_error_to_anyhow)?
.verify_hostname(self.verify_hostname)
.into_ssl("localhost")
.into_ssl(&address.to_hostname())
.map_err(openssl_stack_error_to_anyhow)?;

let tcp_stream = tcp::tcp_stream(connect_timeout, address).await?;
let mut ssl_stream =
SslStream::new(ssl, tcp_stream).map_err(openssl_stack_error_to_anyhow)?;
Pin::new(&mut ssl_stream).connect().await.map_err(|e| {
Expand All @@ -147,7 +155,7 @@ impl TlsConnector {
}
}

// Always use these openssl_* conversion methods instead of directly directly converting to anyhow
// Always use these openssl_* conversion methods instead of directly converting to anyhow

fn openssl_ssl_error_to_anyhow(error: openssl::ssl::Error) -> anyhow::Error {
if let Some(stack) = error.ssl_error() {
Expand Down Expand Up @@ -205,3 +213,63 @@ pub trait AsyncStream: AsyncRead + AsyncWrite {}
/// We need to tell rust that these types implement AsyncStream even though they already implement AsyncRead and AsyncWrite
impl AsyncStream for tokio_openssl::SslStream<TcpStream> {}
impl AsyncStream for TcpStream {}

/// Allows retrieving the hostname from any ToSocketAddrs type
pub trait ToHostname {
fn to_hostname(&self) -> String;
}

/// Implement for all reference types
impl<T: ToHostname + ?Sized> ToHostname for &T {
fn to_hostname(&self) -> String {
(**self).to_hostname()
}
}

impl ToHostname for String {
fn to_hostname(&self) -> String {
self.split(':').next().unwrap_or("").to_owned()
}
}

impl ToHostname for &str {
fn to_hostname(&self) -> String {
self.split(':').next().unwrap_or("").to_owned()
}
}

impl ToHostname for (&str, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for (String, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for (IpAddr, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for (Ipv4Addr, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for (Ipv6Addr, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for SocketAddr {
fn to_hostname(&self) -> String {
self.ip().to_string()
}
}
9 changes: 4 additions & 5 deletions shotover-proxy/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::frame::cassandra::CassandraMetadata;
use crate::message::{Message, Metadata};
use crate::server::CodecReadError;
use crate::tcp;
use crate::tls::TlsConnector;
use crate::tls::{TlsConnector, ToHostname};
use crate::transforms::util::Response;
use crate::transforms::Messages;
use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -35,21 +35,19 @@ pub struct CassandraConnection {
}

impl CassandraConnection {
pub async fn new<A: ToSocketAddrs + std::fmt::Debug>(
pub async fn new<A: ToSocketAddrs + ToHostname + std::fmt::Debug>(
connect_timeout: Duration,
host: A,
codec: CassandraCodec,
mut tls: Option<TlsConnector>,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
) -> Result<Self> {
let tcp_stream = tcp::tcp_stream(connect_timeout, host).await?;

let (out_tx, out_rx) = mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = mpsc::unbounded_channel::<Request>();
let (rx_process_has_shutdown_tx, rx_process_has_shutdown_rx) = oneshot::channel::<()>();

if let Some(tls) = tls.as_mut() {
let tls_stream = tls.connect(tcp_stream).await?;
let tls_stream = tls.connect(connect_timeout, host).await?;
let (read, write) = split(tls_stream);
tokio::spawn(
tx_process(
Expand All @@ -72,6 +70,7 @@ impl CassandraConnection {
.in_current_span(),
);
} else {
let tcp_stream = tcp::tcp_stream(connect_timeout, host).await?;
let (read, write) = split(tcp_stream);
tokio::spawn(
tx_process(
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::codec::cassandra::CassandraCodec;
use crate::frame::Frame;
use crate::message::{Message, Messages};
use crate::tls::TlsConnector;
use crate::tls::{TlsConnector, ToHostname};
use crate::transforms::cassandra::connection::CassandraConnection;
use anyhow::{anyhow, Result};
use cassandra_protocol::frame::Version;
Expand Down Expand Up @@ -100,7 +100,7 @@ impl ConnectionFactory {
}
}

pub async fn new_connection<A: ToSocketAddrs + std::fmt::Debug>(
pub async fn new_connection<A: ToSocketAddrs + ToHostname + std::fmt::Debug>(
&self,
address: A,
) -> Result<CassandraConnection> {
Expand Down
8 changes: 5 additions & 3 deletions shotover-proxy/src/transforms/redis/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ impl Transform for RedisSinkSingle {
}

if self.connection.is_none() {
let tcp_stream = tcp::tcp_stream(self.connect_timeout, self.address.clone()).await?;

let generic_stream = if let Some(tls) = self.tls.as_mut() {
let tls_stream = tls.connect(tcp_stream).await?;
let tls_stream = tls
.connect(self.connect_timeout, self.address.clone())
.await?;
Box::pin(tls_stream) as Pin<Box<dyn AsyncStream + Send + Sync>>
} else {
let tcp_stream =
tcp::tcp_stream(self.connect_timeout, self.address.clone()).await?;
Box::pin(tcp_stream) as Pin<Box<dyn AsyncStream + Send + Sync>>
};

Expand Down
9 changes: 4 additions & 5 deletions shotover-proxy/src/transforms/util/cluster_connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,17 @@ impl<C: Codec + 'static, A: Authenticator<T>, T: Token> ConnectionPool<C, A, T>
address: &str,
token: &Option<T>,
) -> Result<Connection, ConnectionError<A::Error>> {
let tcp_stream = tcp::tcp_stream(self.connect_timeout, address)
.await
.map_err(ConnectionError::IO)?;

let mut connection = if let Some(tls) = &self.tls {
let tls_stream = tls
.connect(tcp_stream)
.connect(self.connect_timeout, address)
.await
.map_err(ConnectionError::TLS)?;
let (rx, tx) = tokio::io::split(tls_stream);
spawn_read_write_tasks(&self.codec, rx, tx)
} else {
let tcp_stream = tcp::tcp_stream(self.connect_timeout, address)
.await
.map_err(ConnectionError::IO)?;
let (rx, tx) = tcp_stream.into_split();
spawn_read_write_tasks(&self.codec, rx, tx)
};
Expand Down
8 changes: 6 additions & 2 deletions shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use shotover_proxy::transforms::cassandra::sink_cluster::{
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tokio::time::timeout;

pub mod multi_rack;
pub mod single_rack_v3;
Expand All @@ -24,7 +25,7 @@ pub async fn run_topology_task(ca_path: Option<&str>, port: Option<u32>) -> Vec<
certificate_authority_path: ca_path.into(),
certificate_path: None,
private_key_path: None,
verify_hostname: true,
verify_hostname: false,
})
.unwrap()
});
Expand All @@ -46,7 +47,10 @@ pub async fn run_topology_task(ca_path: Option<&str>, port: Option<u32>) -> Vec<
.await
.unwrap();

nodes_rx.changed().await.unwrap();
timeout(Duration::from_secs(30), nodes_rx.changed())
.await
.unwrap()
.unwrap();
let nodes = nodes_rx.borrow().clone();
nodes
}
Expand Down
6 changes: 3 additions & 3 deletions shotover-proxy/tests/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ impl ShotoverManager {
let address = "127.0.0.1";
test_helpers::wait_for_socket_to_open(address, port);

let tcp_stream = tokio::net::TcpStream::connect((address, port))
let connector = TlsConnector::new(config).unwrap();
let tls_stream = connector
.connect(Duration::from_secs(3), (address, port))
.await
.unwrap();
let connector = TlsConnector::new(config).unwrap();
let tls_stream = connector.connect(tcp_stream).await.unwrap();
ShotoverManager::redis_connection_async_inner(
Box::pin(tls_stream) as Pin<Box<dyn AsyncStream + Send + Sync>>
)
Expand Down
2 changes: 1 addition & 1 deletion shotover-proxy/tests/redis_int_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async fn source_tls_and_single_tls() {
certificate_authority_path: "example-configs/redis-tls/certs/ca.crt".into(),
certificate_path: Some("example-configs/redis-tls/certs/redis.crt".into()),
private_key_path: Some("example-configs/redis-tls/certs/redis.key".into()),
verify_hostname: true,
verify_hostname: false,
};

let mut connection = shotover_manager
Expand Down

0 comments on commit 032404f

Please sign in to comment.