diff --git a/examples/client.rs b/examples/client.rs index 8c5589f..47e453a 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -10,6 +10,7 @@ use diameter::avp::Unsigned32; use diameter::dictionary; use diameter::flags; use diameter::transport::DiameterClient; +use diameter::transport::DiameterClientConfig; use diameter::{ApplicationId, CommandCode, DiameterMessage}; use std::fs; use std::net::Ipv4Addr; @@ -26,7 +27,11 @@ async fn main() { } // Initialize a Diameter client and connect it to the server - let mut client = DiameterClient::new("localhost:3868"); + let client_config = DiameterClientConfig { + use_tls: false, + verify_cert: false, + }; + let mut client = DiameterClient::new("localhost:3868", client_config); let mut handler = client.connect().await.unwrap(); tokio::spawn(async move { DiameterClient::handle(&mut handler).await; diff --git a/examples/load_generator.rs b/examples/load_generator.rs index 04b417d..7c42c8e 100644 --- a/examples/load_generator.rs +++ b/examples/load_generator.rs @@ -11,6 +11,7 @@ use diameter::avp::Unsigned32; use diameter::dictionary; use diameter::flags; use diameter::transport::DiameterClient; +use diameter::transport::DiameterClientConfig; use diameter::{ApplicationId, CommandCode, DiameterMessage}; use std::fs; use std::io::Write; @@ -53,7 +54,11 @@ async fn main() { local .run_until(async move { // Initialize a Diameter client and connect it to the server - let mut client = DiameterClient::new("localhost:3868"); + let client_config = DiameterClientConfig { + use_tls: false, + verify_cert: false, + }; + let mut client = DiameterClient::new("localhost:3868", client_config); let mut handler = client.connect().await.unwrap(); task::spawn_local(async move { DiameterClient::handle(&mut handler).await; diff --git a/examples/server.rs b/examples/server.rs index 518510b..11f22b9 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -14,8 +14,8 @@ use diameter::transport::DiameterServerConfig; use diameter::CommandCode; use diameter::DiameterMessage; use std::fs; -use std::fs::File; -use std::io::Read; +// use std::fs::File; +// use std::io::Read; use std::io::Write; use std::thread; diff --git a/src/transport/client.rs b/src/transport/client.rs index 591c47c..30957cd 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -5,14 +5,21 @@ use crate::transport::Codec; use std::collections::HashMap; use std::ops::DerefMut; use std::sync::Arc; -use tokio::net::tcp::OwnedReadHalf; -use tokio::net::tcp::OwnedWriteHalf; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::sync::oneshot; use tokio::sync::oneshot::Receiver; use tokio::sync::oneshot::Sender; use tokio::sync::Mutex; +/// Configuration for a Diameter protocol client. +/// +pub struct DiameterClientConfig { + pub use_tls: bool, + pub verify_cert: bool, + // pub native_tls: Option, // Future Implementation +} + /// A Diameter protocol client for sending and receiving Diameter messages. /// /// The client maintains a connection to a Diameter server and provides @@ -25,8 +32,9 @@ use tokio::sync::Mutex; /// seq_num: The next sequence number to use for a message. pub struct DiameterClient { + config: DiameterClientConfig, address: String, - writer: Option>>, + writer: Option>>, msg_caches: Arc>>>, seq_num: u32, } @@ -42,8 +50,9 @@ impl DiameterClient { /// /// Returns: /// A new instance of `DiameterClient`. - pub fn new(addr: &str) -> DiameterClient { + pub fn new(addr: &str, config: DiameterClientConfig) -> DiameterClient { DiameterClient { + config, address: addr.into(), writer: None, msg_caches: Arc::new(Mutex::new(HashMap::new())), @@ -58,13 +67,39 @@ impl DiameterClient { pub async fn connect(&mut self) -> Result { let stream = TcpStream::connect(self.address.clone()).await?; - let (reader, writer) = stream.into_split(); - let writer = Arc::new(Mutex::new(writer)); + if self.config.use_tls { + let tls_connector = tokio_native_tls::TlsConnector::from( + native_tls::TlsConnector::builder() + .danger_accept_invalid_certs(!self.config.verify_cert) + .build()?, + ); + let tls_stream = tls_connector.connect(&self.address.clone(), stream).await?; + let (reader, writer) = tokio::io::split(tls_stream); - self.writer = Some(writer); + // writer + let writer = Arc::new(Mutex::new(writer)); + self.writer = Some(writer); - let msg_caches = Arc::clone(&self.msg_caches); - Ok(ClientHandler { reader, msg_caches }) + // reader + let msg_caches = Arc::clone(&self.msg_caches); + Ok(ClientHandler { + reader: Box::new(reader), + msg_caches, + }) + } else { + let (reader, writer) = tokio::io::split(stream); + + // writer + let writer = Arc::new(Mutex::new(writer)); + self.writer = Some(writer); + + // reader + let msg_caches = Arc::clone(&self.msg_caches); + Ok(ClientHandler { + reader: Box::new(reader), + msg_caches, + }) + } } /// Handles incoming Diameter messages. @@ -77,11 +112,12 @@ impl DiameterClient { /// /// Example: /// ```no_run - /// use diameter::transport::client::{ClientHandler, DiameterClient}; + /// use diameter::transport::client::{ClientHandler, DiameterClient, DiameterClientConfig}; /// /// #[tokio::main] /// async fn main() { - /// let mut client = DiameterClient::new("localhost:3868"); + /// let config = DiameterClientConfig { use_tls: false, verify_cert: false }; + /// let mut client = DiameterClient::new("localhost:3868", config); /// let mut handler = client.connect().await.unwrap(); /// tokio::spawn(async move { /// DiameterClient::handle(&mut handler).await; @@ -183,7 +219,8 @@ impl DiameterClient { /// A Diameter protocol client handler for receiving Diameter messages. /// pub struct ClientHandler { - reader: OwnedReadHalf, + // reader: ReadHalf, + reader: Box, msg_caches: Arc>>>, } @@ -199,7 +236,7 @@ pub struct ClientHandler { pub struct DiameterRequest { request: DiameterMessage, receiver: Arc>>>, - writer: Arc>, + writer: Arc>, } impl DiameterRequest { @@ -215,7 +252,7 @@ impl DiameterRequest { pub fn new( request: DiameterMessage, receiver: Receiver, - writer: Arc>, + writer: Arc>, ) -> Self { DiameterRequest { request, diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 2ad8e8a..709aa43 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -5,6 +5,7 @@ pub mod experimental; pub mod server; pub use crate::transport::client::DiameterClient; +pub use crate::transport::client::DiameterClientConfig; pub use crate::transport::server::DiameterServer; pub use crate::transport::server::DiameterServerConfig; @@ -84,6 +85,7 @@ mod tests { use crate::diameter::flags; use crate::diameter::{ApplicationId, CommandCode, DiameterMessage}; use crate::transport::DiameterClient; + use crate::transport::DiameterClientConfig; use crate::transport::DiameterServer; use crate::transport::DiameterServerConfig; @@ -120,7 +122,11 @@ mod tests { }); // Diameter Client - let mut client = DiameterClient::new("localhost:3868"); + let client_config = DiameterClientConfig { + use_tls: false, + verify_cert: false, + }; + let mut client = DiameterClient::new("localhost:3868", client_config); let mut handler = client.connect().await.unwrap(); tokio::spawn(async move { DiameterClient::handle(&mut handler).await; diff --git a/src/transport/server.rs b/src/transport/server.rs index 39bba5d..793dc5f 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -9,6 +9,7 @@ use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; /// Configuration for the Diameter server. +/// pub struct DiameterServerConfig { pub native_tls: Option, }