Skip to content

Commit

Permalink
Merge branch 'scylladb:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
macher259 authored Nov 17, 2021
2 parents a819920 + 83199fa commit 012fc6a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 24 deletions.
2 changes: 1 addition & 1 deletion scylla/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ futures = "0.3.6"
histogram = "0.6.9"
num_enum = "0.5"
compress = "0.2.1"
tokio = { version = "1.1.0", features = ["net", "time", "io-util", "sync", "rt", "macros"] }
tokio = { version = "1.12", features = ["net", "time", "io-util", "sync", "rt", "macros"] }
snap = "1.0"
uuid = "0.8.1"
rand = "0.8.3"
Expand Down
69 changes: 46 additions & 23 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use bytes::Bytes;
use futures::{future::RemoteHandle, FutureExt};
use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpSocket, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tracing::{error, warn};
use tracing::{error, trace, warn};
use uuid::Uuid;

#[cfg(feature = "ssl")]
Expand Down Expand Up @@ -241,7 +241,7 @@ impl Connection {
stream.set_nodelay(config.tcp_nodelay)?;

// TODO: What should be the size of the channel?
let (sender, receiver) = mpsc::channel(128);
let (sender, receiver) = mpsc::channel(1024);

let (error_sender, error_receiver) = tokio::sync::oneshot::channel();

Expand Down Expand Up @@ -747,8 +747,16 @@ impl Connection {
// across .await points. Therefore, it should not be too expensive.
let handler_map = StdMutex::new(ResponseHandlerMap::new());

let r = Self::reader(read_half, &handler_map, config);
let w = Self::writer(write_half, &handler_map, receiver);
let r = Self::reader(
BufReader::with_capacity(8192, read_half),
&handler_map,
config,
);
let w = Self::writer(
BufWriter::with_capacity(8192, write_half),
&handler_map,
receiver,
);

let result = futures::try_join!(r, w);

Expand Down Expand Up @@ -821,6 +829,23 @@ impl Connection {
}
}

fn alloc_stream_id(
handler_map: &StdMutex<ResponseHandlerMap>,
response_handler: ResponseHandler,
) -> Option<i16> {
// We are guaranteed here that handler_map will not be locked
// by anybody else, so we can do try_lock().unwrap()
let mut lock = handler_map.try_lock().unwrap();
if let Some(stream_id) = lock.allocate(response_handler) {
Some(stream_id)
} else {
// TODO: Handle this error better, for now we drop this
// request and return an error to the receiver
error!("Could not allocate stream id");
None
}
}

async fn writer(
mut write_half: (impl AsyncWrite + Unpin),
handler_map: &StdMutex<ResponseHandlerMap>,
Expand All @@ -829,25 +854,23 @@ impl Connection {
// When the Connection object is dropped, the sender half
// of the channel will be dropped, this task will return an error
// and the whole worker will be stopped
while let Some(task) = task_receiver.recv().await {
let stream_id = {
// We are guaranteed here that handler_map will not be locked
// by anybody else, so we can do try_lock().unwrap()
let mut lock = handler_map.try_lock().unwrap();

if let Some(stream_id) = lock.allocate(task.response_handler) {
stream_id
} else {
// TODO: Handle this error better, for now we drop this
// request and return an error to the receiver
error!("Could not allocate stream id");
continue;
while let Some(mut task) = task_receiver.recv().await {
let mut num_requests = 0;
let mut total_sent = 0;
while let Some(stream_id) = Self::alloc_stream_id(handler_map, task.response_handler) {
let mut req = task.serialized_request;
req.set_stream(stream_id);
let req_data: &[u8] = req.get_data();
total_sent += req_data.len();
num_requests += 1;
write_half.write_all(req_data).await?;
task = match task_receiver.try_recv() {
Ok(t) => t,
Err(_) => break,
}
};

let mut req = task.serialized_request;
req.set_stream(stream_id);
write_half.write_all(req.get_data()).await?;
}
trace!("Sending {} requests; {} bytes", num_requests, total_sent);
write_half.flush().await?;
}

Ok(())
Expand Down

0 comments on commit 012fc6a

Please sign in to comment.