Skip to content

Commit

Permalink
Add flow control in AsyncSSH redirection to StreamWriter objects
Browse files Browse the repository at this point in the history
This commit adds support for flow control when an asyncio.StreamWriter is
passed to AsyncSSH's process redirect feature. Previously, writes were
allowed indefinitely, potentially building up a large buffer in memory if
data was being sent faster than it was being consumed. Thanks go to Benjy
Wiener for reporting this issue and supplying a test script to reproduce it.
  • Loading branch information
ronf committed Jun 22, 2024
1 parent c3dc869 commit f2020ed
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
39 changes: 34 additions & 5 deletions asyncssh/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,24 +571,52 @@ def close(self) -> None:
class _StreamWriter(_UnicodeWriter[AnyStr]):
"""Forward data to an asyncio stream"""

def __init__(self, writer: asyncio.StreamWriter,
def __init__(self, process: 'SSHProcess[AnyStr]',
writer: asyncio.StreamWriter, recv_eof: bool,
encoding: Optional[str], errors: str):
super().__init__(encoding, errors)

self._process: 'SSHProcess[AnyStr]' = process
self._writer = writer
self._recv_eof = recv_eof
self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue()
self._write_task: Optional[asyncio.Task[None]] = \
process.channel.get_connection().create_task(self._feed())

async def _feed(self) -> None:
"""Feed data to the stream"""

while True:
data = await self._queue.get()

if data is None:
self._queue.task_done()
break

self._writer.write(self.encode(data))
await self._writer.drain()
self._queue.task_done()

if self._recv_eof:
self._writer.write_eof()

def write(self, data: AnyStr) -> None:
"""Write data to the stream"""

self._writer.write(self.encode(data))
self._queue.put_nowait(data)

def write_eof(self) -> None:
"""Write EOF to the stream"""

self._writer.write_eof()
self.close()

def close(self) -> None:
"""Ignore close -- the caller must clean up the associated transport"""
"""Stop forwarding data to the stream"""

if self._write_task:
self._write_task = None
self._queue.put_nowait(None)
self._process.add_cleanup_task(self._queue.join())


class _DevNullWriter(_WriterProtocol[AnyStr]):
Expand Down Expand Up @@ -925,7 +953,8 @@ def pipe_factory() -> _PipeWriter:
writer_process.set_reader(reader, send_eof, writer_datatype)
writer = _ProcessWriter[AnyStr](writer_process, writer_datatype)
elif isinstance(target, asyncio.StreamWriter):
writer = _StreamWriter(target, self._encoding, self._errors)
writer = _StreamWriter(self, target, recv_eof,
self._encoding, self._errors)
else:
file: _File
needs_close = True
Expand Down
22 changes: 22 additions & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,28 @@ async def test_stdout_stream(self):

self.assertEqual(stdout_data, data.encode('ascii'))

@unittest.skipIf(sys.platform == 'win32',
'skip asyncio.subprocess tests on Windows')
@asynctest
async def test_stdout_stream_keep_open(self):
"""Test with stdout redirected to asyncio stream which remains open"""

data = str(id(self))

async with self.connect() as conn:
proc2 = await asyncio.create_subprocess_shell(
'cat', stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE)

await conn.run('echo', input=data, stdout=proc2.stdin,
stderr=asyncssh.DEVNULL, recv_eof=False)
await conn.run('echo', input=data, stdout=proc2.stdin,
stderr=asyncssh.DEVNULL)

stdout_data = await proc2.stdout.read()

self.assertEqual(stdout_data, 2*data.encode('ascii'))

@asynctest
async def test_change_stdout(self):
"""Test changing stdout of an open process"""
Expand Down

0 comments on commit f2020ed

Please sign in to comment.