Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TLS support for DiameterClient #8

Merged
merged 4 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use diameter::avp::Unsigned32;
use diameter::dictionary;
use diameter::flags;
use diameter::transport::DiameterClient;
use diameter::transport::DiameterClientConfig;
use diameter::{ApplicationId, CommandCode, DiameterMessage};
use std::fs;
use std::net::Ipv4Addr;
Expand All @@ -26,7 +27,11 @@ async fn main() {
}

// Initialize a Diameter client and connect it to the server
let mut client = DiameterClient::new("localhost:3868");
let client_config = DiameterClientConfig {
use_tls: false,
verify_cert: false,
};
let mut client = DiameterClient::new("localhost:3868", client_config);
let mut handler = client.connect().await.unwrap();
tokio::spawn(async move {
DiameterClient::handle(&mut handler).await;
Expand Down
7 changes: 6 additions & 1 deletion examples/load_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use diameter::avp::Unsigned32;
use diameter::dictionary;
use diameter::flags;
use diameter::transport::DiameterClient;
use diameter::transport::DiameterClientConfig;
use diameter::{ApplicationId, CommandCode, DiameterMessage};
use std::fs;
use std::io::Write;
Expand Down Expand Up @@ -53,7 +54,11 @@ async fn main() {
local
.run_until(async move {
// Initialize a Diameter client and connect it to the server
let mut client = DiameterClient::new("localhost:3868");
let client_config = DiameterClientConfig {
use_tls: false,
verify_cert: false,
};
let mut client = DiameterClient::new("localhost:3868", client_config);
let mut handler = client.connect().await.unwrap();
task::spawn_local(async move {
DiameterClient::handle(&mut handler).await;
Expand Down
4 changes: 2 additions & 2 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use diameter::transport::DiameterServerConfig;
use diameter::CommandCode;
use diameter::DiameterMessage;
use std::fs;
use std::fs::File;
use std::io::Read;
// use std::fs::File;
// use std::io::Read;
use std::io::Write;
use std::thread;

Expand Down
65 changes: 51 additions & 14 deletions src/transport/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@ use crate::transport::Codec;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::net::tcp::OwnedReadHalf;
use tokio::net::tcp::OwnedWriteHalf;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::sync::oneshot::Receiver;
use tokio::sync::oneshot::Sender;
use tokio::sync::Mutex;

/// Configuration for a Diameter protocol client.
///
pub struct DiameterClientConfig {
pub use_tls: bool,
pub verify_cert: bool,
// pub native_tls: Option<native_tls::Identity>, // Future Implementation
}

/// A Diameter protocol client for sending and receiving Diameter messages.
///
/// The client maintains a connection to a Diameter server and provides
Expand All @@ -25,8 +32,9 @@ use tokio::sync::Mutex;
/// seq_num: The next sequence number to use for a message.

pub struct DiameterClient {
config: DiameterClientConfig,
address: String,
writer: Option<Arc<Mutex<OwnedWriteHalf>>>,
writer: Option<Arc<Mutex<dyn AsyncWrite + Send + Unpin>>>,
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
seq_num: u32,
}
Expand All @@ -42,8 +50,9 @@ impl DiameterClient {
///
/// Returns:
/// A new instance of `DiameterClient`.
pub fn new(addr: &str) -> DiameterClient {
pub fn new(addr: &str, config: DiameterClientConfig) -> DiameterClient {
DiameterClient {
config,
address: addr.into(),
writer: None,
msg_caches: Arc::new(Mutex::new(HashMap::new())),
Expand All @@ -58,13 +67,39 @@ impl DiameterClient {
pub async fn connect(&mut self) -> Result<ClientHandler> {
let stream = TcpStream::connect(self.address.clone()).await?;

let (reader, writer) = stream.into_split();
let writer = Arc::new(Mutex::new(writer));
if self.config.use_tls {
let tls_connector = tokio_native_tls::TlsConnector::from(
native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(!self.config.verify_cert)
.build()?,
);
let tls_stream = tls_connector.connect(&self.address.clone(), stream).await?;
let (reader, writer) = tokio::io::split(tls_stream);

self.writer = Some(writer);
// writer
let writer = Arc::new(Mutex::new(writer));
self.writer = Some(writer);

let msg_caches = Arc::clone(&self.msg_caches);
Ok(ClientHandler { reader, msg_caches })
// reader
let msg_caches = Arc::clone(&self.msg_caches);
Ok(ClientHandler {
reader: Box::new(reader),
msg_caches,
})
} else {
let (reader, writer) = tokio::io::split(stream);

// writer
let writer = Arc::new(Mutex::new(writer));
self.writer = Some(writer);

// reader
let msg_caches = Arc::clone(&self.msg_caches);
Ok(ClientHandler {
reader: Box::new(reader),
msg_caches,
})
}
}

/// Handles incoming Diameter messages.
Expand All @@ -77,11 +112,12 @@ impl DiameterClient {
///
/// Example:
/// ```no_run
/// use diameter::transport::client::{ClientHandler, DiameterClient};
/// use diameter::transport::client::{ClientHandler, DiameterClient, DiameterClientConfig};
///
/// #[tokio::main]
/// async fn main() {
/// let mut client = DiameterClient::new("localhost:3868");
/// let config = DiameterClientConfig { use_tls: false, verify_cert: false };
/// let mut client = DiameterClient::new("localhost:3868", config);
/// let mut handler = client.connect().await.unwrap();
/// tokio::spawn(async move {
/// DiameterClient::handle(&mut handler).await;
Expand Down Expand Up @@ -183,7 +219,8 @@ impl DiameterClient {
/// A Diameter protocol client handler for receiving Diameter messages.
///
pub struct ClientHandler {
reader: OwnedReadHalf,
// reader: ReadHalf<TcpStream>,
reader: Box<dyn AsyncRead + Send + Unpin>,
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
}

Expand All @@ -199,7 +236,7 @@ pub struct ClientHandler {
pub struct DiameterRequest {
request: DiameterMessage,
receiver: Arc<Mutex<Option<Receiver<DiameterMessage>>>>,
writer: Arc<Mutex<OwnedWriteHalf>>,
writer: Arc<Mutex<dyn AsyncWrite + Send + Unpin>>,
}

impl DiameterRequest {
Expand All @@ -215,7 +252,7 @@ impl DiameterRequest {
pub fn new(
request: DiameterMessage,
receiver: Receiver<DiameterMessage>,
writer: Arc<Mutex<OwnedWriteHalf>>,
writer: Arc<Mutex<dyn AsyncWrite + Send + Unpin>>,
) -> Self {
DiameterRequest {
request,
Expand Down
8 changes: 7 additions & 1 deletion src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod experimental;
pub mod server;

pub use crate::transport::client::DiameterClient;
pub use crate::transport::client::DiameterClientConfig;
pub use crate::transport::server::DiameterServer;
pub use crate::transport::server::DiameterServerConfig;

Expand Down Expand Up @@ -84,6 +85,7 @@ mod tests {
use crate::diameter::flags;
use crate::diameter::{ApplicationId, CommandCode, DiameterMessage};
use crate::transport::DiameterClient;
use crate::transport::DiameterClientConfig;
use crate::transport::DiameterServer;
use crate::transport::DiameterServerConfig;

Expand Down Expand Up @@ -120,7 +122,11 @@ mod tests {
});

// Diameter Client
let mut client = DiameterClient::new("localhost:3868");
let client_config = DiameterClientConfig {
use_tls: false,
verify_cert: false,
};
let mut client = DiameterClient::new("localhost:3868", client_config);
let mut handler = client.connect().await.unwrap();
tokio::spawn(async move {
DiameterClient::handle(&mut handler).await;
Expand Down
1 change: 1 addition & 0 deletions src/transport/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;

/// Configuration for the Diameter server.
///
pub struct DiameterServerConfig {
pub native_tls: Option<native_tls::Identity>,
}
Expand Down
Loading