Skip to content

Commit

Permalink
cli: fix exec server not reading all stdin with immediate close (micr…
Browse files Browse the repository at this point in the history
  • Loading branch information
connor4312 authored and Alex0007 committed Oct 26, 2023
1 parent 9692156 commit 446be0a
Showing 1 changed file with 68 additions and 7 deletions.
75 changes: 68 additions & 7 deletions cli/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
struct StreamRec {
write: Option<WriteHalf<DuplexStream>>,
q: Vec<Vec<u8>>,
ended: bool,
}

#[derive(Clone, Default)]
Expand All @@ -540,13 +541,24 @@ struct Streams {

impl Streams {
pub async fn remove(&self, id: u32) {
let stream = self.map.lock().unwrap().remove(&id);
if let Some(s) = stream {
// if there's no 'write' right now, it'll shut down in the write_loop
if let Some(mut w) = s.write {
let _ = w.shutdown().await;
let mut remove = None;

{
let mut map = self.map.lock().unwrap();
if let Some(s) = map.get_mut(&id) {
if let Some(w) = s.write.take() {
map.remove(&id);
remove = Some(w);
} else {
s.ended = true; // will shut down in write loop
}
}
}

// do this outside of the sync lock:
if let Some(mut w) = remove {
let _ = w.shutdown().await;
}
}

pub fn write(&self, id: u32, buf: Vec<u8>) {
Expand All @@ -566,6 +578,7 @@ impl Streams {
StreamRec {
write: Some(stream),
q: Vec::new(),
ended: false,
},
);
}
Expand Down Expand Up @@ -595,8 +608,13 @@ async fn write_loop(
};

if stream_rec.q.is_empty() {
stream_rec.write = Some(w);
return;
if stream_rec.ended {
lock.remove(&id);
break;
} else {
stream_rec.write = Some(w);
return;
}
}

std::mem::swap(&mut stream_rec.q, &mut items_vec);
Expand Down Expand Up @@ -691,3 +709,46 @@ pub enum MaybeSync {
Future(BoxFuture<'static, Option<Vec<u8>>>),
Sync(Option<Vec<u8>>),
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_remove() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1024);
streams.insert(1, tokio::io::split(writer).1);
streams.remove(1).await;

assert!(streams.map.lock().unwrap().get(&1).is_none());
let mut buffer = Vec::new();
assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 0);
}

#[tokio::test]
async fn test_write() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1024);
streams.insert(1, tokio::io::split(writer).1);
streams.write(1, vec![1, 2, 3]);

let mut buffer = [0; 3];
assert_eq!(reader.read_exact(&mut buffer).await.unwrap(), 3);
assert_eq!(buffer, [1, 2, 3]);
}

#[tokio::test]
async fn test_write_with_immediate_end() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1);
streams.insert(1, tokio::io::split(writer).1);
streams.write(1, vec![1, 2, 3]); // spawn write loop
streams.write(1, vec![4, 5, 6]); // enqueued while writing
streams.remove(1).await; // end stream

let mut buffer = Vec::new();
assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 6);
assert_eq!(buffer, vec![1, 2, 3, 4, 5, 6]);
}
}

0 comments on commit 446be0a

Please sign in to comment.