diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0e7c706 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,50 @@ +on: [push, pull_request] + +name: CI + +jobs: + check: + name: Check + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v3 + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + + - name: Run cargo check + run: cargo check + + test: + name: Test Suite + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v3 + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + + - name: Run cargo test + run: cargo test + + lints: + name: Lints + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v3 + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: rustfmt, clippy + + - name: Run cargo fmt + run: cargo fmt --all -- --check \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 5d4de66..e2b67c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "anyhow" -version = "1.0.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" - [[package]] name = "bitflags" version = "1.3.2" @@ -297,6 +291,14 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "threadsafe_zmq" +version = "1.0.0" +dependencies = [ + "crossbeam-channel", + "zmq", +] + [[package]] name = "toml" version = "0.8.19" @@ -475,12 +477,3 @@ dependencies = [ "system-deps", "zeromq-src", ] - -[[package]] -name = "zmq_rchan" -version = "1.0.0" -dependencies = [ - "anyhow", - "crossbeam-channel", - "zmq", -] diff --git a/Cargo.toml b/Cargo.toml index 4750a6d..882ea32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,14 @@ [package] -name = "zmq_rchan" +name = "threadsafe_zmq" version = "1.0.0" edition = "2021" +authors = ["Elvis Sabanovic "] +description = "Threadsafe zeromq" +readme = "README.md" +keywords = ["threadsafe", "zeromq", "sockets", "unix-domain-sockets", "ipc"] +categories = ["concurrency"] +repository = "https://github.com/Elvis339/threadsafe_zmq" [dependencies] zmq = "0.10.0" -crossbeam-channel = "0.5.13" -anyhow = "1.0.89" \ No newline at end of file +crossbeam-channel = "0.5.13" \ No newline at end of file diff --git a/README.md b/README.md index 228f1be..5ccdf90 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ This implementation is based on Golang's [zmqchan](https://github.com/abligh/zmq This is just a tweaked implementation in Rust +# Message Flow + ``` +-------------------+ | Client | diff --git a/example/Cargo.toml b/example/Cargo.toml new file mode 100644 index 0000000..3fd0c78 --- /dev/null +++ b/example/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "threadsafe_zmq_example" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "server" +path = "src/server.rs" + +[[bin]] +name = "client" +path = "src/client.rs" + +[dependencies] +threadsafe_zmq = { path = "../src", version = "1.0.0" } \ No newline at end of file diff --git a/example/README.md b/example/README.md new file mode 100644 index 0000000..7231e03 --- /dev/null +++ b/example/README.md @@ -0,0 +1,6 @@ +# Example +In one terminal window run: +`RUST_LOG=trace cargo run --bin server` + +In another run +`RUST_LOG=info cargo run --bin client` \ No newline at end of file diff --git a/example/src/client.rs b/example/src/client.rs new file mode 100644 index 0000000..e465df2 --- /dev/null +++ b/example/src/client.rs @@ -0,0 +1,67 @@ +use env_logger; +use log::{error, info}; +use rand::Rng; +use zmq::Context; + +fn main() { + env_logger::init(); + + let clients = 4; + let mut handles = Vec::with_capacity(clients); + + for i in 0..clients { + let client_id = i; + let handle = std::thread::spawn(move || { + let addr = "tcp://localhost:5555"; + let ctx = Context::new(); + let socket = ctx + .socket(zmq::DEALER) + .expect("Failed to create PAIR socket"); + + let rand_id = client_id as u8 + generate_random_number(); + let id = format!("client-{}", rand_id); + socket + .set_identity(id.clone().as_bytes()) + .expect("Failed to set identity"); + socket.connect(addr).expect("Failed to connect to server"); + + info!("{} connected to: {}", id, addr); + loop { + let rand_num = generate_random_number(); + let rand_num_bytes = rand_num.to_le_bytes().to_vec(); + + match socket.send_multipart(vec![rand_num_bytes], 0) { + Ok(_) => info!("{}, sent number: {}", id, rand_num), + Err(snd_err) => { + error!("{}, failed to send message: {:?}", id, snd_err); + continue; + } + } + + match socket.recv_multipart(0) { + Ok(message) => { + info!("Client {}, received result: {:?}", client_id, message); + } + Err(rcv_err) => { + error!( + "Client {}, failed to receive message: {:?}", + client_id, rcv_err + ); + } + } + + std::thread::sleep(std::time::Duration::from_millis(100)); + } + }); + handles.push(handle); + } + + loop { + std::thread::sleep(std::time::Duration::from_secs(1)); + } +} + +fn generate_random_number() -> u8 { + let mut rng = rand::thread_rng(); + rng.gen_range(0..=30) +} diff --git a/example/src/server.rs b/example/src/server.rs new file mode 100644 index 0000000..6cae4fc --- /dev/null +++ b/example/src/server.rs @@ -0,0 +1,84 @@ +use env_logger; +use log::{debug, error, info}; +use threadsafe_zmq::{ChannelPair, Sender, ZmqByteStream}; +use zmq::Context; + +fn main() { + env_logger::init(); + + let addr = "tcp://*:5555"; + let ctx = Context::new(); + let socket = ctx + .socket(zmq::ROUTER) + .expect("Failed to create ROUTER socket"); + socket.bind(addr).expect("Failed to bind to address"); + + let channel_pair = ChannelPair::new(socket).expect("Failed to create channel pair"); + info!("Server listening on {}", addr); + + loop { + debug!("Waiting to receive messages..."); + + match channel_pair.rx_chan().recv() { + Ok(message) => { + if message.len() < 2 { + error!("Received malformed message: {:?}", message); + continue; + } + + println!("Received message: {:?}", message); + let cp = channel_pair.clone(); + std::thread::spawn(move || { + calculate_fib(message, cp.tx_chan()); + }); + } + Err(rcv_err) => { + error!("Failed to receive message: {:?}", rcv_err); + } + } + } +} + +fn calculate_fib(messages: ZmqByteStream, sender: &Sender) { + // The first part is the identity, and the second part is the actual message + let identity = messages[0].clone(); + let payload = messages[1].clone(); + + let id_str = String::from_utf8_lossy(&identity); + + if payload.is_empty() { + error!("Received an empty payload, skipping Fibonacci calculation."); + return; + } + + info!("Received message from: {:?}", id_str); + + // Deserialize the message into u32 + // let number = match payload.as_slice().try_into() { + // Ok(bytes) => u32::from_le_bytes(bytes), + // Err(_) => { + // error!("Failed to deserialize payload, skipping."); + // return; + // } + // }; + let number = 13; + + info!("Calculating Fibonacci for number: {}", number); + let result = fibonacci_recursive(number); + let result_bytes = result.to_le_bytes().to_vec(); + + // The response must include the identity frame, followed by the result + let response = vec![identity.clone(), result_bytes]; + match sender.send(response) { + Ok(_) => info!("Successfully sent response: {:?} to: {:?}", result, id_str), + Err(err) => error!("Failed to send response: {:?}", err), + } +} + +fn fibonacci_recursive(n: u32) -> u32 { + if n <= 1 { + n + } else { + fibonacci_recursive(n - 1) + fibonacci_recursive(n - 2) + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..cf06126 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,29 @@ +use std::fmt; +use zmq; + +#[derive(Debug)] +pub enum ChannelPairError { + Zmq(zmq::Error), + ChannelError(String), + ConfigurationError(String), + Other(String), +} + +impl From for ChannelPairError { + fn from(error: zmq::Error) -> Self { + ChannelPairError::Zmq(error) + } +} + +impl fmt::Display for ChannelPairError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ChannelPairError::Zmq(e) => write!(f, "ZeroMQ Error: {}", e), + ChannelPairError::ChannelError(msg) => write!(f, "Channel Error: {}", msg), + ChannelPairError::ConfigurationError(msg) => write!(f, "Configuration Error: {}", msg), + ChannelPairError::Other(msg) => write!(f, "Other Error: {}", msg), + } + } +} + +impl std::error::Error for ChannelPairError {} diff --git a/src/lib.rs b/src/lib.rs index ffba19e..2efaaf7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,89 +1,42 @@ -use anyhow::anyhow; -use crossbeam_channel::{select, unbounded, Receiver, Sender}; +mod error; + +use crossbeam_channel::{ + select, unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender, +}; +use error::ChannelPairError; use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::Arc; -use zmq::{Context, PollEvents, Socket, SocketType}; +use zmq::{Context, PollEvents, Socket}; static UNIQUE_INDEX: AtomicU16 = AtomicU16::new(0); const IN: usize = 0; const OUT: usize = 1; -const DEFAULT_LINGER_TIME: i32 = 30; -type ZmqByteStream = Vec>; +pub type ZmqByteStream = Vec>; +pub type Sender = CrossbeamSender; +pub type Receiver = CrossbeamReceiver; -struct ChannelPair { - ctx: Context, +pub struct ChannelPair { z_sock: Socket, z_tx: Vec, z_control: Vec, - tx_chan: Sender, - rx_chan: Receiver, - error_chan: (Sender, Receiver), - control_chan: (Sender, Receiver), + tx_chan: Sender, + rx_chan: Receiver, + error_chan: ( + CrossbeamSender, + CrossbeamReceiver, + ), + control_chan: (CrossbeamSender, CrossbeamReceiver), } -impl Clone for ChannelPair { - // Avoid calling clone!!! - fn clone(&self) -> Self { - let socket_type = self.z_sock.get_socket_type().unwrap(); - let last_endpoint = self.z_sock.get_last_endpoint().unwrap().unwrap(); - Self::disconnect_socket(&self.z_sock).unwrap(); - - for z_tx_socket in self.z_tx.iter() { - Self::disconnect_socket(z_tx_socket).unwrap(); - } - - for z_control_socket in self.z_control.iter() { - Self::disconnect_socket(z_control_socket).unwrap(); - } - - let z_tx = Self::new_pair(&self.ctx).unwrap(); - let z_control = Self::new_pair(&self.ctx).unwrap(); - - let new_socket = self.ctx.socket(socket_type).unwrap(); - - new_socket.set_rcvtimeo(0).unwrap(); - new_socket.set_sndtimeo(0).unwrap(); - - match socket_type { - SocketType::PUB - | SocketType::REP - | SocketType::PUSH - | SocketType::XPUB - | SocketType::ROUTER => { - new_socket.bind(&last_endpoint).unwrap(); - } - SocketType::SUB - | SocketType::REQ - | SocketType::DEALER - | SocketType::PULL - | SocketType::XSUB - | SocketType::STREAM => { - new_socket.connect(&last_endpoint).unwrap(); - } - // Special handling for PAIR socket (try bind first, then connect if it fails) - SocketType::PAIR => match new_socket.bind(&last_endpoint) { - Ok(_) => {} - Err(_) => { - new_socket.connect(&last_endpoint).unwrap(); - } - }, - } - - let mut cp = Self { - ctx: self.ctx.clone(), - z_sock: new_socket, - tx_chan: self.tx_chan.clone(), - rx_chan: self.rx_chan.clone(), - error_chan: self.error_chan.clone(), - control_chan: self.control_chan.clone(), - z_tx, - z_control, - }; - - Self::configure_socket(&mut cp).unwrap(); +enum SocketState { + Idle, + ReadyToSend(ZmqByteStream), +} - cp +impl SocketState { + fn reset(&mut self) { + *self = SocketState::Idle; } } @@ -91,28 +44,25 @@ unsafe impl Send for ChannelPair {} unsafe impl Sync for ChannelPair {} impl ChannelPair { - pub fn new(ctx: Context, socket: Socket) -> anyhow::Result> { - let z_tx = Self::new_pair(&ctx)?; - let z_control = Self::new_pair(&ctx)?; + pub fn new(socket: Socket) -> Result, ChannelPairError> { + let z_tx = Self::new_pair()?; + let z_control = Self::new_pair()?; let (tx_chan, rx_chan) = unbounded::(); let mut channel_pair = Self { - ctx, z_tx, z_control, tx_chan, rx_chan, z_sock: socket, - error_chan: unbounded::(), + error_chan: unbounded::(), control_chan: unbounded::(), }; Self::configure_socket(&mut channel_pair)?; - - // let tmp = channel_pair.clone(); let channel_pair = Arc::new(channel_pair); - // === run sockets ==={ + // === run sockets === let channel_pair_clone = Arc::clone(&channel_pair); std::thread::spawn(move || channel_pair_clone.run_sockets()); @@ -124,114 +74,87 @@ impl ChannelPair { } fn run_sockets(&self) { - let mut to_xmit: ZmqByteStream = vec![]; + let mut state = SocketState::Idle; - // Create poll items array. let mut items = [ - self.z_sock.as_poll_item(PollEvents::empty()), - self.z_tx[OUT].as_poll_item(PollEvents::empty()), - self.z_control[OUT].as_poll_item(zmq::POLLIN), + self.z_sock.as_poll_item(PollEvents::empty()), // z_sock for reading incoming messages + self.z_tx[OUT].as_poll_item(PollEvents::empty()), // z_tx[OUT] for receiving messages from `z_tx[IN]` + self.z_control[OUT].as_poll_item(PollEvents::empty()), // z_control for handling control messages ]; loop { - let mut z_sock_flag = PollEvents::empty(); - let mut tx_sock_flag = PollEvents::empty(); - - if to_xmit.is_empty() { - tx_sock_flag |= zmq::POLLIN; - } else { - z_sock_flag |= zmq::POLLOUT; - } + // Set events to monitor based on the state + items[0].set_events(match state { + SocketState::ReadyToSend(_) => zmq::POLLOUT, // If we have data to send, poll for writable events + _ => zmq::POLLIN, // If we have no data to send, poll for readable events + }); - items[0].set_events(z_sock_flag); - items[1].set_events(tx_sock_flag); + items[1].set_events(match state { + SocketState::Idle => zmq::POLLIN, // Poll for messages from `z_tx[OUT]` only when idle + _ => PollEvents::empty(), // No need to poll if we're in a different state + }); - // Pass the whole `items` array as a mutable reference. match zmq::poll(&mut items, -1) { Ok(_) => { + // Check if `z_sock` is readable or writable if items[0].is_readable() { match self.z_sock.recv_multipart(0) { - Ok(zmq_byte_stream) => self.tx_chan().send(zmq_byte_stream).unwrap(), + Ok(zmq_byte_stream) => { + if let Err(err) = self.tx_chan().send(zmq_byte_stream) { + self.on_err(ChannelPairError::ChannelError(format!( + "Failed to send message to channel: {:?}", + err + ))); + return; + } + } Err(recv_err) => { - self.on_err(anyhow!( - "Failed to receive message on the socket: {:?}", - recv_err - )); + self.on_err(ChannelPairError::Zmq(recv_err)); return; } } } - if items[0].is_writable() && !to_xmit.is_empty() { - match self.z_sock.send_multipart(to_xmit.clone(), 0) { - Ok(_) => { - to_xmit = vec![]; - } - Err(snd_err) => { - self.on_err(anyhow!( - "Failed to send message on the socket: {:?}", - snd_err - )); - return; + if items[0].is_writable() { + if let SocketState::ReadyToSend(message) = &state { + match self.z_sock.send_multipart(message.clone(), 0) { + Ok(_) => { + state.reset(); + } + Err(snd_err) => { + self.on_err(ChannelPairError::Zmq(snd_err)); + return; + } } } } - if items[1].is_readable() && to_xmit.is_empty() { + // Check if there's a message in the transmit socket (`z_tx[OUT]`) + if items[1].is_readable() { match self.z_tx[OUT].recv_multipart(0) { Ok(zmq_byte_stream) => { - to_xmit = zmq_byte_stream; + state = SocketState::ReadyToSend(zmq_byte_stream); } Err(z_tx_recv_err) => { - self.on_err(anyhow!( - "Failed to receive message on the tx socket: {:?}", - z_tx_recv_err - )); + self.on_err(ChannelPairError::Zmq(z_tx_recv_err)); return; } } } - if items[2].is_writable() { + // Handle control messages if there are any + if items[2].is_readable() { match self.z_control[OUT].recv_multipart(0) { Ok(_) => {} Err(ctrl_err) => { - self.on_err(anyhow!( - "Failed to receive message on the control socket: {:?}", - ctrl_err - )); + self.on_err(ChannelPairError::Zmq(ctrl_err)); return; } } - - let linger = self.z_sock.get_linger().unwrap_or(DEFAULT_LINGER_TIME); - if let Err(err) = self.z_sock.set_linger(linger) { - self.on_err(anyhow!("Failed to set linger on the socket: {:?}", err)); - return; - } - - if !to_xmit.is_empty() { - match self.z_sock.send_multipart(to_xmit.clone(), 0) { - Ok(_) => {} - Err(snd_err) => { - self.on_err(anyhow!( - "Failed to send pending message on the socket: {:?}", - snd_err - )); - return; - } - } - } else { - to_xmit = vec![]; - } - - items[2].set_events(PollEvents::empty()); - items[0].set_events(PollEvents::empty()); - items[1].set_events(zmq::POLLIN); } } Err(poll_err) => { - self.on_err(anyhow!("Polling error: {:?}", poll_err)); + self.on_err(ChannelPairError::Zmq(poll_err)); return; } } @@ -245,14 +168,13 @@ impl ChannelPair { recv(self.rx_chan()) -> msg => { match msg { Ok(msg) => { - // If a message is received, try to send it through zTx. if let Err(err) = self.z_tx[IN].send_multipart(&msg, 0) { - self.on_err(anyhow!("Failed to send message on tx socket: {:?}", err)); + self.on_err(ChannelPairError::Zmq(err)); return; } }, Err(_) => { - self.on_err(anyhow!("ZMQ tx channel closed unexpectedly")); + self.on_err(ChannelPairError::ChannelError("ZMQ tx channel closed unexpectedly".into())); return; } } @@ -261,15 +183,14 @@ impl ChannelPair { match control { Ok(control) => { if control { - // Send an empty message as a control signal. if let Err(err) = self.z_control[IN].send("", 0) { - self.on_err(anyhow!("Failed to send message on the control socket: {:?}", err)); + self.on_err(ChannelPairError::Zmq(err)); } } return; }, Err(_) => { - self.on_err(anyhow!("ZMQ control channel closed unexpectedly")); + self.on_err(ChannelPairError::ChannelError("ZMQ control channel closed unexpectedly".into())); return; } } @@ -278,32 +199,36 @@ impl ChannelPair { } } - pub fn rx_chan(&self) -> &Receiver { + pub fn rx_chan(&self) -> &Receiver { &self.rx_chan } - pub fn tx_chan(&self) -> &Sender { + pub fn tx_chan(&self) -> &Sender { &self.tx_chan } - fn tx_control_chan(&self) -> &Sender { + pub fn rx_err_chan(&self) -> &CrossbeamReceiver { + &self.error_chan.1 + } + + fn tx_control_chan(&self) -> &CrossbeamSender { &self.control_chan.0 } - fn rx_control_chan(&self) -> &Receiver { + fn rx_control_chan(&self) -> &CrossbeamReceiver { &self.control_chan.1 } - fn tx_err_chan(&self) -> &Sender { + fn tx_err_chan(&self) -> &CrossbeamSender { &self.error_chan.0 } - fn on_err(&self, error: anyhow::Error) { + fn on_err(&self, error: ChannelPairError) { let _ = self.tx_err_chan().send(error); let _ = self.tx_control_chan().send(false); } - fn configure_socket(&mut self) -> anyhow::Result<()> { + fn configure_socket(&mut self) -> Result<(), ChannelPairError> { self.z_sock.set_rcvtimeo(0)?; self.z_sock.set_sndtimeo(0)?; @@ -320,7 +245,8 @@ impl ChannelPair { Ok(()) } - fn new_pair(context: &Context) -> anyhow::Result> { + fn new_pair() -> Result, ChannelPairError> { + let context = Context::new(); let addr = format!("inproc://_channelpair_internal-{}", get_unique_id()); let server_pair = context.socket(zmq::PAIR)?; server_pair.bind(&addr)?; @@ -330,177 +256,101 @@ impl ChannelPair { Ok(vec![server_pair, client_pair]) } - - fn disconnect_socket(socket: &Socket) -> Result<(), zmq::Error> { - // Retrieve the last endpoint, handling both the outer and inner Result. - match socket.get_last_endpoint()? { - // Handle the case where the endpoint is a valid UTF-8 string. - Ok(endpoint) => socket.disconnect(&endpoint), - - // Handle the case where the endpoint is raw bytes (non-UTF-8). - Err(bytes) => { - let endpoint = String::from_utf8_lossy(&bytes).into_owned(); - socket.disconnect(&endpoint) - } - } - } } fn get_unique_id() -> u16 { UNIQUE_INDEX.fetch_add(1, Ordering::SeqCst) } -// // Helper function to create a new pair socket. -// fn new_pair_socket() -> (Socket, Context, Socket, Context) { -// let ctx1 = Context::new(); -// let socket1 = ctx1.socket(zmq::PAIR).unwrap(); -// socket1.bind("tcp://127.0.0.1:9737").unwrap(); -// -// let ctx2 = Context::new(); -// let socket2 = ctx2.socket(zmq::PAIR).unwrap(); -// socket2.connect("tcp://127.0.0.1:9737").unwrap(); -// -// (socket1, ctx1, socket2, ctx2) -// } -// -// // Helper function to compare two ZMQ byte streams. -// fn msg_equal(a: &ZmqByteStream, b: &ZmqByteStream) -> bool { -// if a.len() != b.len() { -// return false; -// } -// for (i, msg) in a.iter().enumerate() { -// if msg.len() != b[i].len() { -// return false; -// } -// for (j, byte) in msg.iter().enumerate() { -// if *byte != b[i][j] { -// return false; -// } -// } -// } -// true -// } -// -// // Helper function to run the echo logic. -// fn run_echo(num: usize, c: &ChannelPair) { -// let mut remaining = num; -// -// while remaining > 0 { -// select! { -// recv(c.rx_chan()) -> msg => { -// match msg { -// Ok(msg) => { -// // Send the received message back to the Tx channel. -// c.tx_chan().send(msg).expect("Failed to send message"); -// remaining -= 1; -// -// if remaining <= 0 { -// println!("ECHO: done"); -// return; -// } -// }, -// Err(_) => { -// panic!("Cannot read from echo channel"); -// } -// } -// }, -// default(Duration::from_secs(5)) => { -// panic!("Timeout in run_echo"); -// } -// } -// } -// } -// -// // Helper function to run the write logic. -// fn run_write(num: usize, c: &ChannelPair) { -// let mut tx = 0; -// let mut rx = 0; -// let src_msg: ZmqByteStream = vec![b"Hello".to_vec(), b"World".to_vec()]; -// let mut txchan = Some(c.tx_chan()); -// -// while rx < num { -// select! { -// recv(c.rx_chan()) -> msg => { -// match msg { -// Ok(received_msg) => { -// rx += 1; -// if !msg_equal(&received_msg, &src_msg) { -// panic!("Messages do not match"); -// } -// if rx >= num { -// println!("MAIN: done"); -// return; -// } -// }, -// Err(_) => { -// panic!("Cannot read from main channel"); -// } -// } -// }, -// send(txchan.unwrap_or(&c.tx_chan()), src_msg.clone()) -> result => { -// if let Ok(_) = result { -// tx += 1; -// if tx >= num { -// txchan = None; // Disable the txchan (like setting to `nil` in Go) -// } -// } -// }, -// default(Duration::from_secs(1)) => { -// panic!("Timeout in runWrite"); -// } -// } -// } -// } - #[cfg(test)] mod tests { use super::*; + use std::sync::{Arc, Mutex}; + use std::time::Duration; #[test] - fn channel_pair() { - let (sb, ctx1, sc, ctx2) = new_pair_socket(); - - let num: usize = 10; - let mut handles = Vec::with_capacity(2); - - { - let echo = std::thread::spawn(move || { - let cp = ChannelPair::new(ctx1, sb).unwrap(); - run_echo(num, &cp) - }); - handles.push(echo); - } - - { - let write = std::thread::spawn(move || { - let cc = ChannelPair::new(ctx2, sc).unwrap(); - run_write(num, &cc); - }); - handles.push(write); - } - - for handle in handles { - handle.join().expect("Thread panicked!"); - } - } - - - #[test] - fn test_channel_pair_basic() { + fn channel_pair_test() { let ctx = Context::new(); - let server_socket = ctx.socket(zmq::PAIR).unwrap(); - let client_socket = ctx.socket(zmq::PAIR).unwrap(); - server_socket.bind("tcp://127.0.0.1:5555").unwrap(); - client_socket.connect("tcp://127.0.0.1:5555").unwrap(); - - let cp = ChannelPair::new(ctx, client_socket).unwrap(); - - let msg = vec![b"Hello".to_vec(), b"World".to_vec()]; - cp.tx_chan().send(msg.clone()).unwrap(); + let server_socket = ctx + .socket(zmq::ROUTER) + .expect("Failed to create ROUTER socket"); + server_socket + .bind("tcp://127.0.0.1:5555") + .expect("Failed to bind server socket"); + + let processed_messages = Arc::new(Mutex::new(20)); + let channel_pair = ChannelPair::new(server_socket).expect("Failed to create ChannelPair"); + + // Spawn the server thread + let server_processed_messages = Arc::clone(&processed_messages); + std::thread::spawn(move || { + while *server_processed_messages.lock().unwrap() > 0 { + match channel_pair.rx_chan().recv() { + Ok(message) => { + let message_clone = message.clone(); + + let cp_clone = Arc::clone(&channel_pair); + let server_processed_messages = Arc::clone(&server_processed_messages); + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(2)); // Simulate processing time + + // Send the same message back as the response + cp_clone + .tx_chan() + .send(message_clone) + .expect("Failed to send response"); + + // Decrement the shared counter safely + let mut counter = server_processed_messages.lock().unwrap(); + *counter -= 1; + }); + } + Err(err) => { + panic!("Server: Timeout receive message: {:?}", err); + } + } + } + }); + + // Client logic + let client_handle = std::thread::spawn(move || { + let client_ctx = Context::new(); + let client_socket = client_ctx + .socket(zmq::DEALER) + .expect("Failed to create DEALER socket"); + client_socket + .connect("tcp://127.0.0.1:5555") + .expect("Failed to connect client socket"); + client_socket + .set_identity("client-1".as_bytes()) + .expect("Failed to set client identity"); + + let mut num_of_messages_client = 20; + while num_of_messages_client >= 0 { + // Send a message to the server + let msg = vec![b"Hello".to_vec()]; + client_socket + .send_multipart(&msg, 0) + .expect("Failed to send message from client"); + + // Wait for a response from the server + match client_socket.recv_multipart(0) { + Ok(_) => { + num_of_messages_client -= 1; + } + Err(e) => panic!("Client: Failed to receive response: {:?}", e), + } + } + }); - let received_msg = server_socket.recv_multipart(0).unwrap(); + client_handle.join().expect("Client thread failed"); - assert!(msg_equal(&msg, &received_msg), "Sent and received messages do not match"); + // Check that all messages have been processed + assert_eq!( + *processed_messages.lock().unwrap(), + 0, + "Not all messages were processed." + ); } }