Skip to content

Commit

Permalink
Merge pull request #8 from lwlee2608/feature/add-tls-support-to-client
Browse files Browse the repository at this point in the history
Add TLS support for DiameterClient
  • Loading branch information
lwlee2608 authored Apr 1, 2024
2 parents f9ac06f + b7cdee7 commit acafaeb
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 19 deletions.
7 changes: 6 additions & 1 deletion examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use diameter::avp::Unsigned32;
use diameter::dictionary;
use diameter::flags;
use diameter::transport::DiameterClient;
use diameter::transport::DiameterClientConfig;
use diameter::{ApplicationId, CommandCode, DiameterMessage};
use std::fs;
use std::net::Ipv4Addr;
Expand All @@ -26,7 +27,11 @@ async fn main() {
}

// Initialize a Diameter client and connect it to the server
let mut client = DiameterClient::new("localhost:3868");
let client_config = DiameterClientConfig {
use_tls: false,
verify_cert: false,
};
let mut client = DiameterClient::new("localhost:3868", client_config);
let mut handler = client.connect().await.unwrap();
tokio::spawn(async move {
DiameterClient::handle(&mut handler).await;
Expand Down
7 changes: 6 additions & 1 deletion examples/load_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use diameter::avp::Unsigned32;
use diameter::dictionary;
use diameter::flags;
use diameter::transport::DiameterClient;
use diameter::transport::DiameterClientConfig;
use diameter::{ApplicationId, CommandCode, DiameterMessage};
use std::fs;
use std::io::Write;
Expand Down Expand Up @@ -53,7 +54,11 @@ async fn main() {
local
.run_until(async move {
// Initialize a Diameter client and connect it to the server
let mut client = DiameterClient::new("localhost:3868");
let client_config = DiameterClientConfig {
use_tls: false,
verify_cert: false,
};
let mut client = DiameterClient::new("localhost:3868", client_config);
let mut handler = client.connect().await.unwrap();
task::spawn_local(async move {
DiameterClient::handle(&mut handler).await;
Expand Down
4 changes: 2 additions & 2 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use diameter::transport::DiameterServerConfig;
use diameter::CommandCode;
use diameter::DiameterMessage;
use std::fs;
use std::fs::File;
use std::io::Read;
// use std::fs::File;
// use std::io::Read;
use std::io::Write;
use std::thread;

Expand Down
65 changes: 51 additions & 14 deletions src/transport/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@ use crate::transport::Codec;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::net::tcp::OwnedReadHalf;
use tokio::net::tcp::OwnedWriteHalf;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::sync::oneshot::Receiver;
use tokio::sync::oneshot::Sender;
use tokio::sync::Mutex;

/// Configuration for a Diameter protocol client.
///
pub struct DiameterClientConfig {
pub use_tls: bool,
pub verify_cert: bool,
// pub native_tls: Option<native_tls::Identity>, // Future Implementation
}

/// A Diameter protocol client for sending and receiving Diameter messages.
///
/// The client maintains a connection to a Diameter server and provides
Expand All @@ -25,8 +32,9 @@ use tokio::sync::Mutex;
/// seq_num: The next sequence number to use for a message.
pub struct DiameterClient {
config: DiameterClientConfig,
address: String,
writer: Option<Arc<Mutex<OwnedWriteHalf>>>,
writer: Option<Arc<Mutex<dyn AsyncWrite + Send + Unpin>>>,
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
seq_num: u32,
}
Expand All @@ -42,8 +50,9 @@ impl DiameterClient {
///
/// Returns:
/// A new instance of `DiameterClient`.
pub fn new(addr: &str) -> DiameterClient {
pub fn new(addr: &str, config: DiameterClientConfig) -> DiameterClient {
DiameterClient {
config,
address: addr.into(),
writer: None,
msg_caches: Arc::new(Mutex::new(HashMap::new())),
Expand All @@ -58,13 +67,39 @@ impl DiameterClient {
pub async fn connect(&mut self) -> Result<ClientHandler> {
let stream = TcpStream::connect(self.address.clone()).await?;

let (reader, writer) = stream.into_split();
let writer = Arc::new(Mutex::new(writer));
if self.config.use_tls {
let tls_connector = tokio_native_tls::TlsConnector::from(
native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(!self.config.verify_cert)
.build()?,
);
let tls_stream = tls_connector.connect(&self.address.clone(), stream).await?;
let (reader, writer) = tokio::io::split(tls_stream);

self.writer = Some(writer);
// writer
let writer = Arc::new(Mutex::new(writer));
self.writer = Some(writer);

let msg_caches = Arc::clone(&self.msg_caches);
Ok(ClientHandler { reader, msg_caches })
// reader
let msg_caches = Arc::clone(&self.msg_caches);
Ok(ClientHandler {
reader: Box::new(reader),
msg_caches,
})
} else {
let (reader, writer) = tokio::io::split(stream);

// writer
let writer = Arc::new(Mutex::new(writer));
self.writer = Some(writer);

// reader
let msg_caches = Arc::clone(&self.msg_caches);
Ok(ClientHandler {
reader: Box::new(reader),
msg_caches,
})
}
}

/// Handles incoming Diameter messages.
Expand All @@ -77,11 +112,12 @@ impl DiameterClient {
///
/// Example:
/// ```no_run
/// use diameter::transport::client::{ClientHandler, DiameterClient};
/// use diameter::transport::client::{ClientHandler, DiameterClient, DiameterClientConfig};
///
/// #[tokio::main]
/// async fn main() {
/// let mut client = DiameterClient::new("localhost:3868");
/// let config = DiameterClientConfig { use_tls: false, verify_cert: false };
/// let mut client = DiameterClient::new("localhost:3868", config);
/// let mut handler = client.connect().await.unwrap();
/// tokio::spawn(async move {
/// DiameterClient::handle(&mut handler).await;
Expand Down Expand Up @@ -183,7 +219,8 @@ impl DiameterClient {
/// A Diameter protocol client handler for receiving Diameter messages.
///
pub struct ClientHandler {
reader: OwnedReadHalf,
// reader: ReadHalf<TcpStream>,
reader: Box<dyn AsyncRead + Send + Unpin>,
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
}

Expand All @@ -199,7 +236,7 @@ pub struct ClientHandler {
pub struct DiameterRequest {
request: DiameterMessage,
receiver: Arc<Mutex<Option<Receiver<DiameterMessage>>>>,
writer: Arc<Mutex<OwnedWriteHalf>>,
writer: Arc<Mutex<dyn AsyncWrite + Send + Unpin>>,
}

impl DiameterRequest {
Expand All @@ -215,7 +252,7 @@ impl DiameterRequest {
pub fn new(
request: DiameterMessage,
receiver: Receiver<DiameterMessage>,
writer: Arc<Mutex<OwnedWriteHalf>>,
writer: Arc<Mutex<dyn AsyncWrite + Send + Unpin>>,
) -> Self {
DiameterRequest {
request,
Expand Down
8 changes: 7 additions & 1 deletion src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod experimental;
pub mod server;

pub use crate::transport::client::DiameterClient;
pub use crate::transport::client::DiameterClientConfig;
pub use crate::transport::server::DiameterServer;
pub use crate::transport::server::DiameterServerConfig;

Expand Down Expand Up @@ -84,6 +85,7 @@ mod tests {
use crate::diameter::flags;
use crate::diameter::{ApplicationId, CommandCode, DiameterMessage};
use crate::transport::DiameterClient;
use crate::transport::DiameterClientConfig;
use crate::transport::DiameterServer;
use crate::transport::DiameterServerConfig;

Expand Down Expand Up @@ -120,7 +122,11 @@ mod tests {
});

// Diameter Client
let mut client = DiameterClient::new("localhost:3868");
let client_config = DiameterClientConfig {
use_tls: false,
verify_cert: false,
};
let mut client = DiameterClient::new("localhost:3868", client_config);
let mut handler = client.connect().await.unwrap();
tokio::spawn(async move {
DiameterClient::handle(&mut handler).await;
Expand Down
1 change: 1 addition & 0 deletions src/transport/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;

/// Configuration for the Diameter server.
///
pub struct DiameterServerConfig {
pub native_tls: Option<native_tls::Identity>,
}
Expand Down

0 comments on commit acafaeb

Please sign in to comment.