Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

fix mt test #536

Merged
merged 1 commit into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions libsqlx-server/src/allocation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ use tokio::time::timeout;
use crate::hrana;
use crate::hrana::http::handle_pipeline;
use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody};
use crate::linc::bus::{Bus, Dispatch};
use crate::linc::bus::{Dispatch};
use crate::linc::proto::{
BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value,
};
use crate::linc::{Inbound, NodeId, Outbound};
use crate::manager::Manager;
use crate::meta::DatabaseId;

use self::config::{AllocConfig, DbConfig};
Expand Down Expand Up @@ -505,7 +504,7 @@ impl Database {
next_req_id: 0,
primary_id: *primary_id,
database_id: DatabaseId::from_name(&alloc.db_name),
dispatcher: alloc.bus.clone(),
dispatcher: alloc.dispatcher.clone(),
}),
}
}
Expand Down Expand Up @@ -687,7 +686,7 @@ pub struct Allocation {

pub hrana_server: Arc<hrana::http::Server>,
/// handle to the message bus
pub bus: Arc<Bus<Arc<Manager>>>,
pub dispatcher: Arc<dyn Dispatch>,
pub db_name: String,
}

Expand Down Expand Up @@ -770,7 +769,7 @@ impl Allocation {
next_frame_no,
req_no,
seq_no: 0,
dipatcher: self.bus.clone() as _,
dipatcher: self.dispatcher.clone() as _,
notifier: frame_notifier.clone(),
buffer: Vec::new(),
};
Expand Down Expand Up @@ -818,7 +817,7 @@ impl Allocation {
Message::ProxyResponse(ref r) => {
if let Some(conn) = self
.connections
.get(&self.bus.node_id())
.get(&self.dispatcher.node_id())
.and_then(|m| m.get(&r.connection_id).cloned())
{
conn.inbound.send(msg).await.unwrap();
Expand All @@ -837,7 +836,7 @@ impl Allocation {
req_id: u32,
program: Program,
) {
let dispatcher = self.bus.clone();
let dispatcher = self.dispatcher.clone();
let database_id = DatabaseId::from_name(&self.db_name);
let exec = |conn: ConnectionHandle| async move {
let _ = conn
Expand Down Expand Up @@ -878,7 +877,7 @@ impl Allocation {
let conn = block_in_place(|| self.database.connect(conn_id, self));
let (exec_sender, exec_receiver) = mpsc::channel(1);
let (inbound_sender, inbound_receiver) = mpsc::channel(1);
let id = remote.unwrap_or((self.bus.node_id(), conn_id));
let id = remote.unwrap_or((self.dispatcher.node_id(), conn_id));
let conn = Connection {
id,
conn,
Expand All @@ -903,7 +902,7 @@ impl Allocation {
self.next_conn_id = self.next_conn_id.wrapping_add(1);
if self
.connections
.get(&self.bus.node_id())
.get(&self.dispatcher.node_id())
.and_then(|m| m.get(&self.next_conn_id))
.is_none()
{
Expand Down
5 changes: 5 additions & 0 deletions libsqlx-server/src/linc/bus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ impl<H: Handler> Bus<H> {
#[async_trait::async_trait]
pub trait Dispatch: Send + Sync + 'static {
async fn dispatch(&self, msg: Outbound);
fn node_id(&self) -> NodeId;
}

#[async_trait::async_trait]
Expand All @@ -62,4 +63,8 @@ impl<H: Handler> Dispatch for Bus<H> {
// This message is outbound.
self.send_queue.enqueue(msg).await;
}

fn node_id(&self) -> NodeId {
self.node_id
}
}
228 changes: 37 additions & 191 deletions libsqlx-server/src/linc/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,161 +273,60 @@ where
mod test {
use std::sync::Arc;

use futures::{future, pin_mut};
use tokio::sync::Notify;
use turmoil::net::{TcpListener, TcpStream};
use uuid::Uuid;

use super::*;

#[test]
fn invalid_handshake() {
let mut sim = turmoil::Builder::new().build();

let host_node_id = NodeId::new_v4();
sim.host("host", move || async move {
let bus = Bus::new(host_node_id);
let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234")
.await
.unwrap();
let (s, _) = listener.accept().await.unwrap();
let mut connection = Connection::new_acceptor(s, bus);
connection.tick().await;

Ok(())
let host_node_id = 0;
let done = Arc::new(Notify::new());
let done_clone = done.clone();
sim.host("host", move || {
let done_clone = done_clone.clone();
async move {
let bus = Arc::new(Bus::new(host_node_id, |_, _| async {}));
let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234")
.await
.unwrap();
let (s, _) = listener.accept().await.unwrap();
let connection = Connection::new_acceptor(s, bus);
let done = done_clone.notified();
let run = connection.run();
pin_mut!(done);
pin_mut!(run);
future::select(run, done).await;

Ok(())
}
});

sim.client("client", async move {
let s = TcpStream::connect("host:1234").await.unwrap();
let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async();

s.send(Message::Node(NodeMessage::Handshake {
protocol_version: 1234,
node_id: Uuid::new_v4(),
}))
.await
.unwrap();
let mut s = AsyncBincodeStream::<_, Enveloppe, Enveloppe, _>::from(s).for_async();

let msg = Enveloppe {
database_id: None,
message: Message::Handshake {
protocol_version: 1234,
node_id: 1,
},
};
s.send(msg).await.unwrap();
let m = s.next().await.unwrap().unwrap();

assert!(matches!(
m,
Message::Node(NodeMessage::Error(
NodeError::HandshakeVersionMismatch { .. }
))
m.message,
Message::Error(
ProtoError::HandshakeVersionMismatch { .. }
)
));

Ok(())
});

sim.run().unwrap();
}

#[test]
fn stream_closed() {
let mut sim = turmoil::Builder::new().build();

let database_id = DatabaseId::new_v4();
let host_node_id = NodeId::new_v4();
let notify = Arc::new(Notify::new());
sim.host("host", {
let notify = notify.clone();
move || {
let notify = notify.clone();
async move {
let bus = Bus::new(host_node_id);
let mut sub = bus.subscribe(database_id).unwrap();
let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234")
.await
.unwrap();
let (s, _) = listener.accept().await.unwrap();
let connection = Connection::new_acceptor(s, bus);
tokio::task::spawn_local(connection.run());
let mut streams = Vec::new();
loop {
tokio::select! {
Some(mut stream) = sub.next() => {
let m = stream.next().await.unwrap();
stream.send(m).await.unwrap();
streams.push(stream);
}
_ = notify.notified() => {
break;
}
}
}

Ok(())
}
}
});

sim.client("client", async move {
let stream_id = StreamId::new(1);
let node_id = NodeId::new_v4();
let s = TcpStream::connect("host:1234").await.unwrap();
let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async();

s.send(Message::Node(NodeMessage::Handshake {
protocol_version: CURRENT_PROTO_VERSION,
node_id,
}))
.await
.unwrap();
let m = s.next().await.unwrap().unwrap();
assert!(matches!(m, Message::Node(NodeMessage::Handshake { .. })));

// send message to unexisting stream:
s.send(Message::Stream {
stream_id,
payload: StreamMessage::Dummy,
})
.await
.unwrap();
let m = s.next().await.unwrap().unwrap();
assert_eq!(
m,
Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id)))
);

// open stream then send message
s.send(Message::Node(NodeMessage::OpenStream {
stream_id,
database_id,
}))
.await
.unwrap();
s.send(Message::Stream {
stream_id,
payload: StreamMessage::Dummy,
})
.await
.unwrap();
let m = s.next().await.unwrap().unwrap();
assert_eq!(
m,
Message::Stream {
stream_id,
payload: StreamMessage::Dummy
}
);

s.send(Message::Node(NodeMessage::CloseStream {
stream_id: StreamId::new(1),
}))
.await
.unwrap();
s.send(Message::Stream {
stream_id,
payload: StreamMessage::Dummy,
})
.await
.unwrap();
let m = s.next().await.unwrap().unwrap();
assert_eq!(
m,
Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id)))
);

notify.notify_waiters();
done.notify_waiters();

Ok(())
});
Expand Down Expand Up @@ -459,7 +358,7 @@ mod test {

sim.client("client", async move {
let stream = TcpStream::connect("host:1234").await.unwrap();
let bus = Bus::new(NodeId::new_v4());
let bus = Arc::new(Bus::new(1, |_, _| async {}));
let mut conn = Connection::new_acceptor(stream, bus);

notify.notify_waiters();
Expand All @@ -473,57 +372,4 @@ mod test {

sim.run().unwrap();
}

#[test]
fn zero_stream_id() {
let mut sim = turmoil::Builder::new().build();

let notify = Arc::new(Notify::new());
sim.host("host", {
let notify = notify.clone();
move || {
let notify = notify.clone();
async move {
let listener = TcpListener::bind("0.0.0.0:1234").await.unwrap();
let (stream, _) = listener.accept().await.unwrap();
let (connection_messages_sender, connection_messages) = mpsc::channel(1);
let conn = Connection {
peer: Some(NodeId::new_v4()),
state: ConnectionState::Connected,
conn: AsyncBincodeStream::from(stream).for_async(),
streams: HashMap::new(),
connection_messages,
connection_messages_sender,
is_initiator: false,
bus: Bus::new(NodeId::new_v4()),
stream_id_allocator: StreamIdAllocator::new(false),
registration: None,
};

conn.run().await;

Ok(())
}
}
});

sim.client("client", async move {
let stream = TcpStream::connect("host:1234").await.unwrap();
let mut stream = AsyncBincodeStream::<_, Message, Message, _>::from(stream).for_async();

stream
.send(Message::Stream {
stream_id: StreamId::new_unchecked(0),
payload: StreamMessage::Dummy,
})
.await
.unwrap();

assert!(stream.next().await.is_none());

Ok(())
});

sim.run().unwrap();
}
}
Loading