From 446be0a245a6f54889a9d1edbf73213e530fda29 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Tue, 10 Oct 2023 08:39:53 -0700 Subject: [PATCH] cli: fix exec server not reading all stdin with immediate close (#195257) Fixes https://github.com/microsoft/vscode-remote-tunnels/issues/691 --- cli/src/rpc.rs | 75 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/cli/src/rpc.rs b/cli/src/rpc.rs index a9a66153735dc..0972ad054757b 100644 --- a/cli/src/rpc.rs +++ b/cli/src/rpc.rs @@ -531,6 +531,7 @@ impl RpcDispatcher { struct StreamRec { write: Option>, q: Vec>, + ended: bool, } #[derive(Clone, Default)] @@ -540,13 +541,24 @@ struct Streams { impl Streams { pub async fn remove(&self, id: u32) { - let stream = self.map.lock().unwrap().remove(&id); - if let Some(s) = stream { - // if there's no 'write' right now, it'll shut down in the write_loop - if let Some(mut w) = s.write { - let _ = w.shutdown().await; + let mut remove = None; + + { + let mut map = self.map.lock().unwrap(); + if let Some(s) = map.get_mut(&id) { + if let Some(w) = s.write.take() { + map.remove(&id); + remove = Some(w); + } else { + s.ended = true; // will shut down in write loop + } } } + + // do this outside of the sync lock: + if let Some(mut w) = remove { + let _ = w.shutdown().await; + } } pub fn write(&self, id: u32, buf: Vec) { @@ -566,6 +578,7 @@ impl Streams { StreamRec { write: Some(stream), q: Vec::new(), + ended: false, }, ); } @@ -595,8 +608,13 @@ async fn write_loop( }; if stream_rec.q.is_empty() { - stream_rec.write = Some(w); - return; + if stream_rec.ended { + lock.remove(&id); + break; + } else { + stream_rec.write = Some(w); + return; + } } std::mem::swap(&mut stream_rec.q, &mut items_vec); @@ -691,3 +709,46 @@ pub enum MaybeSync { Future(BoxFuture<'static, Option>>), Sync(Option>), } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_remove() { + let streams = Streams::default(); + let (writer, mut reader) = tokio::io::duplex(1024); + streams.insert(1, tokio::io::split(writer).1); + streams.remove(1).await; + + assert!(streams.map.lock().unwrap().get(&1).is_none()); + let mut buffer = Vec::new(); + assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 0); + } + + #[tokio::test] + async fn test_write() { + let streams = Streams::default(); + let (writer, mut reader) = tokio::io::duplex(1024); + streams.insert(1, tokio::io::split(writer).1); + streams.write(1, vec![1, 2, 3]); + + let mut buffer = [0; 3]; + assert_eq!(reader.read_exact(&mut buffer).await.unwrap(), 3); + assert_eq!(buffer, [1, 2, 3]); + } + + #[tokio::test] + async fn test_write_with_immediate_end() { + let streams = Streams::default(); + let (writer, mut reader) = tokio::io::duplex(1); + streams.insert(1, tokio::io::split(writer).1); + streams.write(1, vec![1, 2, 3]); // spawn write loop + streams.write(1, vec![4, 5, 6]); // enqueued while writing + streams.remove(1).await; // end stream + + let mut buffer = Vec::new(); + assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 6); + assert_eq!(buffer, vec![1, 2, 3, 4, 5, 6]); + } +}