diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 7d1e8fe6..fbe0d98e 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -24,12 +24,11 @@ use tokio::time::timeout; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; -use crate::linc::bus::{Bus, Dispatch}; +use crate::linc::bus::{Dispatch}; use crate::linc::proto::{ BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value, }; use crate::linc::{Inbound, NodeId, Outbound}; -use crate::manager::Manager; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; @@ -505,7 +504,7 @@ impl Database { next_req_id: 0, primary_id: *primary_id, database_id: DatabaseId::from_name(&alloc.db_name), - dispatcher: alloc.bus.clone(), + dispatcher: alloc.dispatcher.clone(), }), } } @@ -687,7 +686,7 @@ pub struct Allocation { pub hrana_server: Arc, /// handle to the message bus - pub bus: Arc>>, + pub dispatcher: Arc, pub db_name: String, } @@ -770,7 +769,7 @@ impl Allocation { next_frame_no, req_no, seq_no: 0, - dipatcher: self.bus.clone() as _, + dipatcher: self.dispatcher.clone() as _, notifier: frame_notifier.clone(), buffer: Vec::new(), }; @@ -818,7 +817,7 @@ impl Allocation { Message::ProxyResponse(ref r) => { if let Some(conn) = self .connections - .get(&self.bus.node_id()) + .get(&self.dispatcher.node_id()) .and_then(|m| m.get(&r.connection_id).cloned()) { conn.inbound.send(msg).await.unwrap(); @@ -837,7 +836,7 @@ impl Allocation { req_id: u32, program: Program, ) { - let dispatcher = self.bus.clone(); + let dispatcher = self.dispatcher.clone(); let database_id = DatabaseId::from_name(&self.db_name); let exec = |conn: ConnectionHandle| async move { let _ = conn @@ -878,7 +877,7 @@ impl Allocation { let conn = block_in_place(|| self.database.connect(conn_id, self)); let (exec_sender, exec_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); - let id = remote.unwrap_or((self.bus.node_id(), conn_id)); + let id = remote.unwrap_or((self.dispatcher.node_id(), conn_id)); let conn = Connection { id, conn, @@ -903,7 +902,7 @@ impl Allocation { self.next_conn_id = self.next_conn_id.wrapping_add(1); if self .connections - .get(&self.bus.node_id()) + .get(&self.dispatcher.node_id()) .and_then(|m| m.get(&self.next_conn_id)) .is_none() { diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 4707c989..a31c3368 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -50,6 +50,7 @@ impl Bus { #[async_trait::async_trait] pub trait Dispatch: Send + Sync + 'static { async fn dispatch(&self, msg: Outbound); + fn node_id(&self) -> NodeId; } #[async_trait::async_trait] @@ -62,4 +63,8 @@ impl Dispatch for Bus { // This message is outbound. self.send_queue.enqueue(msg).await; } + + fn node_id(&self) -> NodeId { + self.node_id + } } diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 09e2ec44..5f5d9f24 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -273,9 +273,9 @@ where mod test { use std::sync::Arc; + use futures::{future, pin_mut}; use tokio::sync::Notify; use turmoil::net::{TcpListener, TcpStream}; - use uuid::Uuid; use super::*; @@ -283,151 +283,50 @@ mod test { fn invalid_handshake() { let mut sim = turmoil::Builder::new().build(); - let host_node_id = NodeId::new_v4(); - sim.host("host", move || async move { - let bus = Bus::new(host_node_id); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - let (s, _) = listener.accept().await.unwrap(); - let mut connection = Connection::new_acceptor(s, bus); - connection.tick().await; - - Ok(()) + let host_node_id = 0; + let done = Arc::new(Notify::new()); + let done_clone = done.clone(); + sim.host("host", move || { + let done_clone = done_clone.clone(); + async move { + let bus = Arc::new(Bus::new(host_node_id, |_, _| async {})); + let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + let (s, _) = listener.accept().await.unwrap(); + let connection = Connection::new_acceptor(s, bus); + let done = done_clone.notified(); + let run = connection.run(); + pin_mut!(done); + pin_mut!(run); + future::select(run, done).await; + + Ok(()) + } }); sim.client("client", async move { let s = TcpStream::connect("host:1234").await.unwrap(); - let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async(); - - s.send(Message::Node(NodeMessage::Handshake { - protocol_version: 1234, - node_id: Uuid::new_v4(), - })) - .await - .unwrap(); + let mut s = AsyncBincodeStream::<_, Enveloppe, Enveloppe, _>::from(s).for_async(); + + let msg = Enveloppe { + database_id: None, + message: Message::Handshake { + protocol_version: 1234, + node_id: 1, + }, + }; + s.send(msg).await.unwrap(); let m = s.next().await.unwrap().unwrap(); assert!(matches!( - m, - Message::Node(NodeMessage::Error( - NodeError::HandshakeVersionMismatch { .. } - )) + m.message, + Message::Error( + ProtoError::HandshakeVersionMismatch { .. } + ) )); - Ok(()) - }); - - sim.run().unwrap(); - } - - #[test] - fn stream_closed() { - let mut sim = turmoil::Builder::new().build(); - - let database_id = DatabaseId::new_v4(); - let host_node_id = NodeId::new_v4(); - let notify = Arc::new(Notify::new()); - sim.host("host", { - let notify = notify.clone(); - move || { - let notify = notify.clone(); - async move { - let bus = Bus::new(host_node_id); - let mut sub = bus.subscribe(database_id).unwrap(); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - let (s, _) = listener.accept().await.unwrap(); - let connection = Connection::new_acceptor(s, bus); - tokio::task::spawn_local(connection.run()); - let mut streams = Vec::new(); - loop { - tokio::select! { - Some(mut stream) = sub.next() => { - let m = stream.next().await.unwrap(); - stream.send(m).await.unwrap(); - streams.push(stream); - } - _ = notify.notified() => { - break; - } - } - } - - Ok(()) - } - } - }); - - sim.client("client", async move { - let stream_id = StreamId::new(1); - let node_id = NodeId::new_v4(); - let s = TcpStream::connect("host:1234").await.unwrap(); - let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async(); - - s.send(Message::Node(NodeMessage::Handshake { - protocol_version: CURRENT_PROTO_VERSION, - node_id, - })) - .await - .unwrap(); - let m = s.next().await.unwrap().unwrap(); - assert!(matches!(m, Message::Node(NodeMessage::Handshake { .. }))); - - // send message to unexisting stream: - s.send(Message::Stream { - stream_id, - payload: StreamMessage::Dummy, - }) - .await - .unwrap(); - let m = s.next().await.unwrap().unwrap(); - assert_eq!( - m, - Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id))) - ); - - // open stream then send message - s.send(Message::Node(NodeMessage::OpenStream { - stream_id, - database_id, - })) - .await - .unwrap(); - s.send(Message::Stream { - stream_id, - payload: StreamMessage::Dummy, - }) - .await - .unwrap(); - let m = s.next().await.unwrap().unwrap(); - assert_eq!( - m, - Message::Stream { - stream_id, - payload: StreamMessage::Dummy - } - ); - - s.send(Message::Node(NodeMessage::CloseStream { - stream_id: StreamId::new(1), - })) - .await - .unwrap(); - s.send(Message::Stream { - stream_id, - payload: StreamMessage::Dummy, - }) - .await - .unwrap(); - let m = s.next().await.unwrap().unwrap(); - assert_eq!( - m, - Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id))) - ); - - notify.notify_waiters(); + done.notify_waiters(); Ok(()) }); @@ -459,7 +358,7 @@ mod test { sim.client("client", async move { let stream = TcpStream::connect("host:1234").await.unwrap(); - let bus = Bus::new(NodeId::new_v4()); + let bus = Arc::new(Bus::new(1, |_, _| async {})); let mut conn = Connection::new_acceptor(stream, bus); notify.notify_waiters(); @@ -473,57 +372,4 @@ mod test { sim.run().unwrap(); } - - #[test] - fn zero_stream_id() { - let mut sim = turmoil::Builder::new().build(); - - let notify = Arc::new(Notify::new()); - sim.host("host", { - let notify = notify.clone(); - move || { - let notify = notify.clone(); - async move { - let listener = TcpListener::bind("0.0.0.0:1234").await.unwrap(); - let (stream, _) = listener.accept().await.unwrap(); - let (connection_messages_sender, connection_messages) = mpsc::channel(1); - let conn = Connection { - peer: Some(NodeId::new_v4()), - state: ConnectionState::Connected, - conn: AsyncBincodeStream::from(stream).for_async(), - streams: HashMap::new(), - connection_messages, - connection_messages_sender, - is_initiator: false, - bus: Bus::new(NodeId::new_v4()), - stream_id_allocator: StreamIdAllocator::new(false), - registration: None, - }; - - conn.run().await; - - Ok(()) - } - } - }); - - sim.client("client", async move { - let stream = TcpStream::connect("host:1234").await.unwrap(); - let mut stream = AsyncBincodeStream::<_, Message, Message, _>::from(stream).for_async(); - - stream - .send(Message::Stream { - stream_id: StreamId::new_unchecked(0), - payload: StreamMessage::Dummy, - }) - .await - .unwrap(); - - assert!(stream.next().await.is_none()); - - Ok(()) - }); - - sim.run().unwrap(); - } } diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index b6113a80..3415dee4 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -81,136 +81,3 @@ impl ConnectionPool { self.connections.spawn(fut); } } - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use futures::SinkExt; - use tokio::sync::Notify; - use tokio_stream::StreamExt; - - use crate::linc::{server::Server, AllocId}; - - use super::*; - - #[test] - fn manage_connections() { - let mut sim = turmoil::Builder::new().build(); - let database_id = AllocId::new_v4(); - let notify = Arc::new(Notify::new()); - - let expected_msg = crate::linc::proto::StreamMessage::Proxy( - crate::linc::proto::ProxyMessage::ProxyRequest { - connection_id: 42, - req_id: 42, - program: "foobar".into(), - }, - ); - - let spawn_host = |node_id| { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - move || { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - async move { - let bus = Bus::new(node_id); - let mut sub = bus.subscribe(database_id).unwrap(); - let mut server = Server::new(bus.clone()); - let mut listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - - let mut has_closed = false; - let mut streams = Vec::new(); - loop { - tokio::select! { - _ = notify.notified() => { - if !has_closed { - streams.clear(); - server.close_connections().await; - has_closed = true; - } else { - break; - } - }, - _ = server.tick(&mut listener) => (), - Some(mut stream) = sub.next() => { - stream - .send(expected_msg.clone()) - .await - .unwrap(); - streams.push(stream); - } - } - } - - Ok(()) - } - } - }; - - let host1_id = NodeId::new_v4(); - sim.host("host1", spawn_host(host1_id)); - - let host2_id = NodeId::new_v4(); - sim.host("host2", spawn_host(host2_id)); - - let host3_id = NodeId::new_v4(); - sim.host("host3", spawn_host(host3_id)); - - sim.client("client", async move { - let bus = Bus::new(NodeId::new_v4()); - let pool = ConnectionPool::new( - bus.clone(), - vec![ - (host1_id, "host1:1234".into()), - (host2_id, "host2:1234".into()), - (host3_id, "host3:1234".into()), - ], - ); - - tokio::task::spawn_local(pool.run::()); - - // all three hosts are reachable: - let mut stream1 = bus.new_stream(host1_id, database_id).await.unwrap(); - let m = stream1.next().await.unwrap(); - assert_eq!(m, expected_msg); - - let mut stream2 = bus.new_stream(host2_id, database_id).await.unwrap(); - let m = stream2.next().await.unwrap(); - assert_eq!(m, expected_msg); - - let mut stream3 = bus.new_stream(host3_id, database_id).await.unwrap(); - let m = stream3.next().await.unwrap(); - assert_eq!(m, expected_msg); - - // sever connections - notify.notify_waiters(); - - assert!(stream1.next().await.is_none()); - assert!(stream2.next().await.is_none()); - assert!(stream3.next().await.is_none()); - - let mut stream1 = bus.new_stream(host1_id, database_id).await.unwrap(); - let m = stream1.next().await.unwrap(); - assert_eq!(m, expected_msg); - - let mut stream2 = bus.new_stream(host2_id, database_id).await.unwrap(); - let m = stream2.next().await.unwrap(); - assert_eq!(m, expected_msg); - - let mut stream3 = bus.new_stream(host3_id, database_id).await.unwrap(); - let m = stream3.next().await.unwrap(); - assert_eq!(m, expected_msg); - - // terminate test - notify.notify_waiters(); - - Ok(()) - }); - - sim.run().unwrap(); - } -} diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index 6403906e..828c8bb6 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -1,10 +1,21 @@ use std::sync::Arc; -use super::bus::Bus; +use super::bus::{Dispatch}; use super::Inbound; #[async_trait::async_trait] pub trait Handler: Sized + Send + Sync + 'static { /// Handle inbound message - async fn handle(&self, bus: Arc>, msg: Inbound); + async fn handle(&self, bus: Arc, msg: Inbound); +} + +#[cfg(test)] +#[async_trait::async_trait] +impl Handler for F +where F: Fn(Arc, Inbound) -> Fut + Send + Sync + 'static, + Fut: std::future::Future + Send, +{ + async fn handle(&self, bus: Arc, msg: Inbound) { + (self)(bus, msg).await + } } diff --git a/libsqlx-server/src/linc/net.rs b/libsqlx-server/src/linc/net.rs index 2123c041..a7fa87af 100644 --- a/libsqlx-server/src/linc/net.rs +++ b/libsqlx-server/src/linc/net.rs @@ -74,6 +74,10 @@ mod test { fn accept(&self) -> Self::Future<'_> { Box::pin(self.accept()) } + + fn local_addr(&self) -> color_eyre::Result { + Ok(self.local_addr()?) + } } impl Connector for TcpStream { diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs index f3eacec2..6371a059 100644 --- a/libsqlx-server/src/linc/server.rs +++ b/libsqlx-server/src/linc/server.rs @@ -75,24 +75,20 @@ impl Server { mod test { use std::sync::Arc; - use crate::linc::{proto::ProxyMessage, AllocId, NodeId}; - use super::*; - use futures::{SinkExt, StreamExt}; - use tokio::sync::Notify; use turmoil::net::TcpStream; #[test] fn server_respond_to_handshake() { let mut sim = turmoil::Builder::new().build(); - let host_node_id = NodeId::new_v4(); + let host_node_id = 0; let notify = Arc::new(tokio::sync::Notify::new()); sim.host("host", move || { let notify = notify.clone(); async move { - let bus = Bus::new(host_node_id); + let bus = Arc::new(Bus::new(host_node_id, |_, _| async {})); let mut server = Server::new(bus); let mut listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") .await @@ -105,10 +101,10 @@ mod test { }); sim.client("client", async move { - let node_id = NodeId::new_v4(); + let node_id = 1; let mut c = Connection::new_initiator( TcpStream::connect("host:1234").await.unwrap(), - Bus::new(node_id), + Arc::new(Bus::new(node_id, |_, _| async {})), ); c.tick().await; @@ -121,229 +117,4 @@ mod test { sim.run().unwrap(); } - - #[test] - fn client_create_stream_client_close() { - let mut sim = turmoil::Builder::new().build(); - - let host_node_id = NodeId::new_v4(); - let stream_db_id = AllocId::new_v4(); - let notify = Arc::new(Notify::new()); - let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { - connection_id: 12, - req_id: 1, - program: "hello".to_string(), - }); - - sim.host("host", { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - move || { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - async move { - let bus = Bus::new(host_node_id); - let server = Server::new(bus.clone()); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - let mut subs = bus.subscribe(stream_db_id).unwrap(); - tokio::task::spawn_local(server.run(listener)); - - let mut stream = subs.next().await.unwrap(); - - let msg = stream.next().await.unwrap(); - - assert_eq!(msg, expected_msg); - - notify.notify_waiters(); - - assert!(stream.next().await.is_none()); - - notify.notify_waiters(); - - Ok(()) - } - } - }); - - sim.client("client", async move { - let node_id = NodeId::new_v4(); - let bus = Bus::new(node_id); - let mut c = Connection::new_initiator( - TcpStream::connect("host:1234").await.unwrap(), - bus.clone(), - ); - c.tick().await; - c.tick().await; - let _h = tokio::spawn(c.run()); - let mut stream = bus.new_stream(host_node_id, stream_db_id).await.unwrap(); - stream.send(expected_msg).await.unwrap(); - - notify.notified().await; - - drop(stream); - - notify.notified().await; - - Ok(()) - }); - - sim.run().unwrap(); - } - - #[test] - fn client_create_stream_server_close() { - let mut sim = turmoil::Builder::new().build(); - - let host_node_id = NodeId::new_v4(); - let database_id = AllocId::new_v4(); - let notify = Arc::new(Notify::new()); - - sim.host("host", { - let notify = notify.clone(); - move || { - let notify = notify.clone(); - async move { - let bus = Bus::new(host_node_id); - let server = Server::new(bus.clone()); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - let mut subs = bus.subscribe(database_id).unwrap(); - tokio::task::spawn_local(server.run(listener)); - - let stream = subs.next().await.unwrap(); - drop(stream); - - notify.notify_waiters(); - notify.notified().await; - - Ok(()) - } - } - }); - - sim.client("client", async move { - let node_id = NodeId::new_v4(); - let bus = Bus::new(node_id); - let mut c = Connection::new_initiator( - TcpStream::connect("host:1234").await.unwrap(), - bus.clone(), - ); - c.tick().await; - c.tick().await; - let _h = tokio::spawn(c.run()); - let mut stream = bus.new_stream(host_node_id, database_id).await.unwrap(); - - notify.notified().await; - assert!(stream.next().await.is_none()); - notify.notify_waiters(); - - Ok(()) - }); - - sim.run().unwrap(); - } - - #[test] - fn server_create_stream_server_close() { - let mut sim = turmoil::Builder::new().build(); - - let host_node_id = NodeId::new_v4(); - let notify = Arc::new(Notify::new()); - let client_id = NodeId::new_v4(); - let database_id = AllocId::new_v4(); - let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { - connection_id: 12, - req_id: 1, - program: "hello".to_string(), - }); - - sim.host("host", { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - move || { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - async move { - let bus = Bus::new(host_node_id); - let server = Server::new(bus.clone()); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - tokio::task::spawn_local(server.run(listener)); - - let mut stream = bus.new_stream(client_id, database_id).await.unwrap(); - stream.send(expected_msg).await.unwrap(); - notify.notified().await; - drop(stream); - - Ok(()) - } - } - }); - - sim.client("client", async move { - let bus = Bus::new(client_id); - let mut subs = bus.subscribe(database_id).unwrap(); - let c = Connection::new_initiator( - TcpStream::connect("host:1234").await.unwrap(), - bus.clone(), - ); - let _h = tokio::spawn(c.run()); - - let mut stream = subs.next().await.unwrap(); - let msg = stream.next().await.unwrap(); - assert_eq!(msg, expected_msg); - notify.notify_waiters(); - assert!(stream.next().await.is_none()); - - Ok(()) - }); - - sim.run().unwrap(); - } - - #[test] - fn server_create_stream_client_close() { - let mut sim = turmoil::Builder::new().build(); - - let host_node_id = NodeId::new_v4(); - let client_id = NodeId::new_v4(); - let database_id = AllocId::new_v4(); - - sim.host("host", { - move || async move { - let bus = Bus::new(host_node_id); - let server = Server::new(bus.clone()); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - tokio::task::spawn_local(server.run(listener)); - - let mut stream = bus.new_stream(client_id, database_id).await.unwrap(); - assert!(stream.next().await.is_none()); - - Ok(()) - } - }); - - sim.client("client", async move { - let bus = Bus::new(client_id); - let mut subs = bus.subscribe(database_id).unwrap(); - let c = Connection::new_initiator( - TcpStream::connect("host:1234").await.unwrap(), - bus.clone(), - ); - let _h = tokio::spawn(c.run()); - - let stream = subs.next().await.unwrap(); - drop(stream); - - Ok(()) - }); - - sim.run().unwrap(); - } } diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 01870144..414e17bf 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -8,7 +8,7 @@ use tokio::task::JoinSet; use crate::allocation::{Allocation, AllocationMessage, Database}; use crate::hrana; -use crate::linc::bus::Bus; +use crate::linc::bus::Dispatch; use crate::linc::handler::Handler; use crate::linc::Inbound; use crate::meta::{DatabaseId, Store}; @@ -34,7 +34,7 @@ impl Manager { pub async fn alloc( self: &Arc, database_id: DatabaseId, - bus: Arc>>, + dispatcher: Arc, ) -> Option> { if let Some(sender) = self.cache.get(&database_id) { return Some(sender.clone()); @@ -46,12 +46,12 @@ impl Manager { let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { inbox, - database: Database::from_config(&config, path, bus.clone()), + database: Database::from_config(&config, path, dispatcher.clone()), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, hrana_server: Arc::new(hrana::http::Server::new(None)), - bus, // TODO: handle self URL? + dispatcher, // TODO: handle self URL? db_name: config.db_name, connections: HashMap::new(), }; @@ -69,7 +69,7 @@ impl Manager { #[async_trait::async_trait] impl Handler for Arc { - async fn handle(&self, bus: Arc>, msg: Inbound) { + async fn handle(&self, bus: Arc, msg: Inbound) { if let Some(sender) = self .clone() .alloc(msg.enveloppe.database_id.unwrap(), bus.clone()) diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index b71b33eb..0167497b 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -27,6 +27,11 @@ impl DatabaseId { reader.read(&mut out); Self(out) } + + #[cfg(test)] + pub fn random() -> Self { + Self(uuid::Uuid::new_v4().into_bytes()) + } } impl fmt::Display for DatabaseId { diff --git a/libsqlx/Cargo.toml b/libsqlx/Cargo.toml index 85fd7a9d..519339ec 100644 --- a/libsqlx/Cargo.toml +++ b/libsqlx/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] async-trait = "0.1.68" bytesize = "1.2.0" -serde = "1.0.164" +serde = { version = "1", features = ["rc"] } serde_json = "1.0.99" rusqlite = { workspace = true } anyhow = "1.0.71" diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index c0aaed79..5c060e4d 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -196,6 +196,7 @@ mod test { use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::Relaxed; + use parking_lot::Mutex; use rusqlite::types::Value; use crate::connection::Connection; @@ -205,14 +206,14 @@ mod test { use super::*; - struct ReadRowBuilder(Vec); + struct ReadRowBuilder(Arc>>); impl ResultBuilder for ReadRowBuilder { fn add_row_value( &mut self, v: rusqlite::types::ValueRef, ) -> Result<(), QueryResultBuilderError> { - self.0.push(v.into()); + self.0.lock().push(v.into()); Ok(()) } } @@ -227,22 +228,26 @@ mod test { let mut db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); let mut conn = db.connect().unwrap(); - let mut builder = ReadRowBuilder(Vec::new()); - conn.execute_program(Program::seq(&["select count(*) from test"]), &mut builder) + let row: Arc>> = Default::default(); + let builder = Box::new(ReadRowBuilder(row.clone())); + conn.execute_program(&Program::seq(&["select count(*) from test"]), builder) .unwrap(); - assert!(builder.0.is_empty()); + assert!(row.lock().is_empty()); let file = File::open("assets/test/simple_wallog").unwrap(); let log = LogFile::new(file).unwrap(); let mut injector = db.injector().unwrap(); log.frames_iter() .unwrap() - .for_each(|f| injector.inject(f.unwrap()).unwrap()); + .for_each(|f| { + injector.inject(f.unwrap()).unwrap(); + }); - let mut builder = ReadRowBuilder(Vec::new()); - conn.execute_program(Program::seq(&["select count(*) from test"]), &mut builder) + let row: Arc>> = Default::default(); + let builder = Box::new(ReadRowBuilder(row.clone())); + conn.execute_program(&Program::seq(&["select count(*) from test"]), builder) .unwrap(); - assert_eq!(builder.0[0], Value::Integer(5)); + assert_eq!(row.lock()[0], Value::Integer(5)); } #[test] @@ -253,7 +258,7 @@ mod test { let primary = LibsqlDatabase::new( temp_primary.path().to_path_buf(), PrimaryType { - logger: Arc::new(ReplicationLogger::open(temp_primary.path(), false, ()).unwrap()), + logger: Arc::new(ReplicationLogger::open(temp_primary.path(), false, (), Box::new(|_| ())).unwrap()), }, ); @@ -268,8 +273,8 @@ mod test { let mut primary_conn = primary.connect().unwrap(); primary_conn .execute_program( - Program::seq(&["create table test (x)", "insert into test values (42)"]), - &mut (), + &Program::seq(&["create table test (x)", "insert into test values (42)"]), + Box::new(()), ) .unwrap(); @@ -282,13 +287,14 @@ mod test { } let mut replica_conn = replica.connect().unwrap(); - let mut builder = ReadRowBuilder(Vec::new()); + let row: Arc>> = Default::default(); + let builder = Box::new(ReadRowBuilder(row.clone())); replica_conn - .execute_program(Program::seq(&["select * from test limit 1"]), &mut builder) + .execute_program(&Program::seq(&["select * from test limit 1"]), builder) .unwrap(); - assert_eq!(builder.0.len(), 1); - assert_eq!(builder.0[0], Value::Integer(42)); + assert_eq!(row.lock().len(), 1); + assert_eq!(row.lock()[0], Value::Integer(42)); } #[test] @@ -317,13 +323,14 @@ mod test { temp.path().to_path_buf(), Compactor(compactor_called.clone()), false, + Box::new(|_| ()), ) .unwrap(); let mut conn = db.connect().unwrap(); conn.execute_program( - Program::seq(&["create table test (x)", "insert into test values (12)"]), - &mut (), + &Program::seq(&["create table test (x)", "insert into test values (12)"]), + Box::new(()), ) .unwrap(); assert!(compactor_called.load(Relaxed)); @@ -356,22 +363,23 @@ mod test { temp.path().to_path_buf(), Compactor(compactor_called.clone()), false, + Box::new(|_| ()) ) .unwrap(); let mut conn = db.connect().unwrap(); conn.execute_program( - Program::seq(&[ + &Program::seq(&[ "begin", "create table test (x)", "insert into test values (12)", ]), - &mut (), + Box::new(()) ) .unwrap(); conn.inner_connection().cache_flush().unwrap(); assert!(!compactor_called.load(Relaxed)); - conn.execute_program(Program::seq(&["commit"]), &mut ()) + conn.execute_program(&Program::seq(&["commit"]), Box::new(())) .unwrap(); assert!(compactor_called.load(Relaxed)); } diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 187f3b25..7bcfb0bf 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -958,7 +958,7 @@ mod test { #[test] fn write_and_read_from_frame_log() { let dir = tempfile::tempdir().unwrap(); - let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, (), Box::new(|_| ())).unwrap(); let frames = (0..10) .map(|i| WalPage { @@ -986,7 +986,7 @@ mod test { #[test] fn index_out_of_bounds() { let dir = tempfile::tempdir().unwrap(); - let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, (), Box::new(|_| ())).unwrap(); let log_file = logger.log_file.write(); assert!(matches!(log_file.frame(1), Err(LogReadError::Ahead))); } @@ -995,7 +995,7 @@ mod test { #[should_panic] fn incorrect_frame_size() { let dir = tempfile::tempdir().unwrap(); - let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, (), Box::new(|_| ())).unwrap(); let entry = WalPage { page_no: 0, size_after: 0, diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index a06b6620..2d576387 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -250,36 +250,37 @@ impl ResultBuilder for ExtractFrameNoBuilder { #[cfg(test)] mod test { - use std::cell::Cell; - use std::rc::Rc; use std::sync::Arc; + use parking_lot::Mutex; + + use crate::Connection; use crate::database::test_utils::MockDatabase; use crate::database::{proxy::database::WriteProxyDatabase, Database}; use crate::program::Program; #[test] fn simple_write_proxied() { - let write_called = Rc::new(Cell::new(false)); + let write_called = Arc::new(Mutex::new(false)); let write_db = MockDatabase::new().with_execute({ let write_called = write_called.clone(); - move |_, b| { + move |_, mut b| { b.finnalize(false, Some(42)).unwrap(); - write_called.set(true); + *write_called.lock() =true; Ok(()) } }); - let read_called = Rc::new(Cell::new(false)); + let read_called = Arc::new(Mutex::new(false)); let read_db = MockDatabase::new().with_execute({ let read_called = read_called.clone(); move |_, _| { - read_called.set(true); + *read_called.lock() = true; Ok(()) } }); - let wait_called = Rc::new(Cell::new(false)); + let wait_called = Arc::new(Mutex::new(false)); let db = WriteProxyDatabase::new( read_db, write_db, @@ -287,23 +288,23 @@ mod test { let wait_called = wait_called.clone(); move |fno| { assert_eq!(fno, 42); - wait_called.set(true); + *wait_called.lock() = true; } }), ); let mut conn = db.connect().unwrap(); - conn.execute_program(Program::seq(&["insert into test values (12)"]), &mut ()) + conn.execute_program(&Program::seq(&["insert into test values (12)"]), Box::new(())) .unwrap(); - assert!(!wait_called.get()); - assert!(!read_called.get()); - assert!(write_called.get()); + assert!(!*wait_called.lock()); + assert!(!*read_called.lock()); + assert!(*write_called.lock()); - conn.execute_program(Program::seq(&["select * from test"]), &mut ()) + conn.execute_program(&Program::seq(&["select * from test"]), Box::new(())) .unwrap(); - assert!(read_called.get()); - assert!(wait_called.get()); + assert!(*read_called.lock()); + assert!(*wait_called.lock()); } } diff --git a/libsqlx/src/database/test_utils.rs b/libsqlx/src/database/test_utils.rs index a46aa2ac..93bf3b1d 100644 --- a/libsqlx/src/database/test_utils.rs +++ b/libsqlx/src/database/test_utils.rs @@ -10,16 +10,17 @@ use super::Database; pub struct MockDatabase { #[allow(clippy::type_complexity)] - describe_fn: Arc crate::Result>, + describe_fn: Arc crate::Result +Send +Sync>, #[allow(clippy::type_complexity)] - execute_fn: Arc crate::Result<()>>, + execute_fn: Arc) -> crate::Result<()> +Send +Sync>, } +#[derive(Clone)] pub struct MockConnection { #[allow(clippy::type_complexity)] - describe_fn: Arc crate::Result>, + describe_fn: Arc crate::Result + Send +Sync>, #[allow(clippy::type_complexity)] - execute_fn: Arc crate::Result<()>>, + execute_fn: Arc) -> crate::Result<()> + Send +Sync>, } impl MockDatabase { @@ -32,7 +33,7 @@ impl MockDatabase { pub fn with_execute( mut self, - f: impl Fn(Program, &mut dyn ResultBuilder) -> crate::Result<()> + 'static, + f: impl Fn(&Program, Box) -> crate::Result<()> + Send + Sync +'static, ) -> Self { self.execute_fn = Arc::new(f); self @@ -53,8 +54,8 @@ impl Database for MockDatabase { impl Connection for MockConnection { fn execute_program( &mut self, - pgm: crate::program::Program, - reponse_builder: &mut dyn ResultBuilder, + pgm: &crate::program::Program, + reponse_builder: Box, ) -> crate::Result<()> { (self.execute_fn)(pgm, reponse_builder)?; Ok(()) diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index fed13fd3..458b50cc 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -650,9 +650,10 @@ pub mod test { &mut self, _is_txn: bool, _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { + ) -> Result { self.maybe_inject_error()?; - self.transition(Finish) + self.transition(Finish)?; + Ok(true) } }