Skip to content

Commit

Permalink
TLS: Use destination as hostname instead of hardcoded localhost
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Nov 4, 2022
1 parent 25f2621 commit 3eec28e
Show file tree
Hide file tree
Showing 18 changed files with 132 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ chain_config:
certificate_authority_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost_CA.crt"
certificate_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost.crt"
private_key_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost.key"
verify_hostname: true
verify_hostname: false
source_to_chain_mapping:
cassandra_prod: main_chain
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ sources:
chain_config:
main_chain:
- CassandraSinkSingle:
remote_address: "127.0.0.1:9042"
remote_address: "localhost:9042"
tls:
certificate_authority_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost_CA.crt"
certificate_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost.crt"
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/example-configs/cassandra-tls/topology.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ sources:
chain_config:
main_chain:
- CassandraSinkSingle:
remote_address: "127.0.0.1:9042"
remote_address: "localhost:9042"
tls:
certificate_authority_path: "example-configs/docker-images/cassandra-tls-4.0.6/certs/localhost_CA.crt"
verify_hostname: false
verify_hostname: true
source_to_chain_mapping:
cassandra_prod: main_chain
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ chain_config:
certificate_authority_path: "example-configs/redis-tls/certs/ca.crt"
certificate_path: "example-configs/redis-tls/certs/redis.crt"
private_key_path: "example-configs/redis-tls/certs/redis.key"
verify_hostname: true
verify_hostname: false
source_to_chain_mapping:
redis_prod: redis_chain
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ chain_config:
first_contact_points: ["127.0.0.1:2220", "127.0.0.1:2221", "127.0.0.1:2222", "127.0.0.1:2223", "127.0.0.1:2224", "127.0.0.1:2225"]
tls:
certificate_authority_path: "example-configs/redis-tls/certs/ca.crt"
verify_hostname: true
verify_hostname: false
source_to_chain_mapping:
redis_prod: redis_chain
4 changes: 2 additions & 2 deletions shotover-proxy/example-configs/redis-tls/topology.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ sources:
chain_config:
redis_chain_tls:
- RedisSinkSingle:
remote_address: "127.0.0.1:1111"
remote_address: "localhost:1111"
tls:
certificate_authority_path: "example-configs/redis-tls/certs/ca.crt"
certificate_path: "example-configs/redis-tls/certs/redis.crt"
private_key_path: "example-configs/redis-tls/certs/redis.key"
verify_hostname: false
verify_hostname: true
source_to_chain_mapping:
redis_prod: redis_chain_tls
redis_prod_tls: redis_chain_tls
1 change: 1 addition & 0 deletions shotover-proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ mod observability;
pub mod runner;
mod server;
pub mod sources;
pub mod tcp;
pub mod tls;
pub mod transforms;
19 changes: 19 additions & 0 deletions shotover-proxy/src/tcp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use anyhow::{anyhow, Result};
use std::time::Duration;
use tokio::{
net::{TcpStream, ToSocketAddrs},
time::timeout,
};

pub async fn tcp_stream<A: ToSocketAddrs + std::fmt::Debug>(destination: A) -> Result<TcpStream> {
timeout(Duration::from_secs(3), TcpStream::connect(&destination))
.await
.map_err(|_| {
anyhow!(
"destination {destination:?} did not respond to connection attempt within 3 seconds"
)
})?
.map_err(|e| {
anyhow!(e).context(format!("Failed to connect to destination {destination:?}"))
})
}
74 changes: 70 additions & 4 deletions shotover-proxy/src/tls.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::tcp;
use anyhow::{anyhow, Result};
use openssl::ssl::Ssl;
use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio_openssl::SslStream;

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down Expand Up @@ -127,15 +129,19 @@ impl TlsConnector {
})
}

pub async fn connect(&self, tcp_stream: TcpStream) -> Result<SslStream<TcpStream>> {
pub async fn connect<A: ToSocketAddrs + ToHostname + std::fmt::Debug>(
&self,
address: A,
) -> Result<SslStream<TcpStream>> {
let ssl = self
.connector
.configure()
.map_err(openssl_stack_error_to_anyhow)?
.verify_hostname(self.verify_hostname)
.into_ssl("localhost")
.into_ssl(&address.to_hostname())
.map_err(openssl_stack_error_to_anyhow)?;

let tcp_stream = tcp::tcp_stream(address).await?;
let mut ssl_stream =
SslStream::new(ssl, tcp_stream).map_err(openssl_stack_error_to_anyhow)?;
Pin::new(&mut ssl_stream).connect().await.map_err(|e| {
Expand All @@ -147,7 +153,7 @@ impl TlsConnector {
}
}

// Always use these openssl_* conversion methods instead of directly directly converting to anyhow
// Always use these openssl_* conversion methods instead of directly converting to anyhow

fn openssl_ssl_error_to_anyhow(error: openssl::ssl::Error) -> anyhow::Error {
if let Some(stack) = error.ssl_error() {
Expand Down Expand Up @@ -205,3 +211,63 @@ pub trait AsyncStream: AsyncRead + AsyncWrite {}
/// We need to tell rust that these types implement AsyncStream even though they already implement AsyncRead and AsyncWrite
impl AsyncStream for tokio_openssl::SslStream<TcpStream> {}
impl AsyncStream for TcpStream {}

/// Allows retrieving the hostname from any ToSocketAddrs type
pub trait ToHostname {
fn to_hostname(&self) -> String;
}

/// Implement for all reference types
impl<T: ToHostname + ?Sized> ToHostname for &T {
fn to_hostname(&self) -> String {
(**self).to_hostname()
}
}

impl ToHostname for String {
fn to_hostname(&self) -> String {
self.split(':').next().unwrap_or("").to_owned()
}
}

impl ToHostname for &str {
fn to_hostname(&self) -> String {
self.split(':').next().unwrap_or("").to_owned()
}
}

impl ToHostname for (&str, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for (String, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for (IpAddr, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for (Ipv4Addr, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for (Ipv6Addr, u16) {
fn to_hostname(&self) -> String {
self.0.to_string()
}
}

impl ToHostname for SocketAddr {
fn to_hostname(&self) -> String {
self.ip().to_string()
}
}
23 changes: 6 additions & 17 deletions shotover-proxy/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use crate::codec::cassandra::CassandraCodec;
use crate::frame::cassandra::CassandraMetadata;
use crate::message::{Message, Metadata};
use crate::server::CodecReadError;
use crate::tls::TlsConnector;
use crate::tcp;
use crate::tls::{TlsConnector, ToHostname};
use crate::transforms::util::Response;
use crate::transforms::Messages;
use anyhow::{anyhow, Result};
Expand All @@ -13,7 +14,7 @@ use futures::{SinkExt, StreamExt};
use halfbrown::HashMap;
use std::time::Duration;
use tokio::io::{split, AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::net::ToSocketAddrs;
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
use tokio_util::codec::{FramedRead, FramedWrite};
Expand All @@ -34,31 +35,18 @@ pub struct CassandraConnection {
}

impl CassandraConnection {
pub async fn new<A: ToSocketAddrs + std::fmt::Debug>(
pub async fn new<A: ToSocketAddrs + ToHostname + std::fmt::Debug>(
host: A,
codec: CassandraCodec,
mut tls: Option<TlsConnector>,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
) -> Result<Self> {
let tcp_stream = timeout(Duration::from_secs(3), TcpStream::connect(&host))
.await
.map_err(|_| {
anyhow!(
"Cassandra node at {:?} did not respond to connection attempt within 3 seconds",
host
)
})?
.map_err(|e| {
anyhow::Error::new(e)
.context(format!("Failed to connect to cassandra node: {:?}", host))
})?;

let (out_tx, out_rx) = mpsc::unbounded_channel::<Request>();
let (return_tx, return_rx) = mpsc::unbounded_channel::<Request>();
let (rx_process_has_shutdown_tx, rx_process_has_shutdown_rx) = oneshot::channel::<()>();

if let Some(tls) = tls.as_mut() {
let tls_stream = tls.connect(tcp_stream).await?;
let tls_stream = tls.connect(host).await?;
let (read, write) = split(tls_stream);
tokio::spawn(
tx_process(
Expand All @@ -81,6 +69,7 @@ impl CassandraConnection {
.in_current_span(),
);
} else {
let tcp_stream = tcp::tcp_stream(host).await?;
let (read, write) = split(tcp_stream);
tokio::spawn(
tx_process(
Expand Down
4 changes: 2 additions & 2 deletions shotover-proxy/src/transforms/cassandra/sink_cluster/node.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::codec::cassandra::CassandraCodec;
use crate::frame::Frame;
use crate::message::{Message, Messages};
use crate::tls::TlsConnector;
use crate::tls::{TlsConnector, ToHostname};
use crate::transforms::cassandra::connection::CassandraConnection;
use anyhow::{anyhow, Result};
use cassandra_protocol::frame::Version;
Expand Down Expand Up @@ -95,7 +95,7 @@ impl ConnectionFactory {
}
}

pub async fn new_connection<A: ToSocketAddrs + std::fmt::Debug>(
pub async fn new_connection<A: ToSocketAddrs + ToHostname + std::fmt::Debug>(
&self,
address: A,
) -> Result<CassandraConnection> {
Expand Down
4 changes: 1 addition & 3 deletions shotover-proxy/src/transforms/redis/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::io;

use crate::transforms::util::ConnectionError;

pub mod cache;
Expand Down Expand Up @@ -43,7 +41,7 @@ pub enum TransformError {
Protocol(String),

#[error("io error: {0}")]
IO(io::Error),
IO(anyhow::Error),

#[error("TLS error: {0}")]
TLS(anyhow::Error),
Expand Down
16 changes: 5 additions & 11 deletions shotover-proxy/src/transforms/redis/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::frame::Frame;
use crate::frame::RedisFrame;
use crate::message::{Message, Messages};
use crate::server::CodecReadError;
use crate::tcp;
use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig};
use crate::transforms::{Transform, Transforms, Wrapper};
use anyhow::{anyhow, Context, Result};
Expand All @@ -14,10 +15,9 @@ use metrics::{register_counter, Counter};
use serde::Deserialize;
use std::fmt::Debug;
use std::pin::Pin;
use std::time::Duration;
use tokio::net::TcpStream;

use tokio::sync::mpsc;
use tokio::time::timeout;

use tokio_util::codec::Framed;
use tracing::Instrument;

Expand Down Expand Up @@ -98,17 +98,11 @@ impl Transform for RedisSinkSingle {
}

if self.connection.is_none() {
let tcp_stream = timeout(
Duration::from_secs(3),
TcpStream::connect(self.address.clone()),
)
.await?
.map_err(|e| anyhow::Error::new(e).context("Failed to connect to upstream"))?;

let generic_stream = if let Some(tls) = self.tls.as_mut() {
let tls_stream = tls.connect(tcp_stream).await?;
let tls_stream = tls.connect(self.address.clone()).await?;
Box::pin(tls_stream) as Pin<Box<dyn AsyncStream + Send + Sync>>
} else {
let tcp_stream = tcp::tcp_stream(self.address.clone()).await?;
Box::pin(tcp_stream) as Pin<Box<dyn AsyncStream + Send + Sync>>
};

Expand Down
19 changes: 9 additions & 10 deletions shotover-proxy/src/transforms/util/cluster_connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::Response;
use crate::server::Codec;
use crate::server::CodecReadHalf;
use crate::server::CodecWriteHalf;
use crate::tcp;
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::util::{ConnectionError, Request};
use anyhow::{anyhow, Result};
Expand All @@ -11,12 +12,12 @@ use futures::StreamExt;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;

use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::Mutex;
use tokio::time::timeout;

use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, trace, warn, Instrument};
Expand Down Expand Up @@ -163,17 +164,15 @@ impl<C: Codec + 'static, A: Authenticator<T>, T: Token> ConnectionPool<C, A, T>
address: &str,
token: &Option<T>,
) -> Result<Connection, ConnectionError<A::Error>> {
let stream = timeout(Duration::from_secs(3), TcpStream::connect(address))
.await
.map_err(|e| ConnectionError::IO(e.into()))?
.map_err(ConnectionError::IO)?;

let mut connection = if let Some(tls) = &self.tls {
let tls_stream = tls.connect(stream).await.map_err(ConnectionError::TLS)?;
let tls_stream = tls.connect(address).await.map_err(ConnectionError::TLS)?;
let (rx, tx) = tokio::io::split(tls_stream);
spawn_read_write_tasks(&self.codec, rx, tx)
} else {
let (rx, tx) = stream.into_split();
let tcp_stream = tcp::tcp_stream(address)
.await
.map_err(ConnectionError::IO)?;
let (rx, tx) = tcp_stream.into_split();
spawn_read_write_tasks(&self.codec, rx, tx)
};

Expand Down
3 changes: 1 addition & 2 deletions shotover-proxy/src/transforms/util/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use anyhow::{Error, Result};
use std::fmt;
use std::io;

use crate::message::Message;

Expand All @@ -23,7 +22,7 @@ pub struct Response {
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError<E: fmt::Debug + fmt::Display> {
#[error("io error: {0}")]
IO(io::Error),
IO(Error),

#[error("TLS error: {0}")]
TLS(Error),
Expand Down
Loading

0 comments on commit 3eec28e

Please sign in to comment.