Skip to content

Commit

Permalink
Streams are iterable + receive_some doesn't require an explicit size
Browse files Browse the repository at this point in the history
This came out of discussion in python-triogh-959
  • Loading branch information
njsmith committed Jun 25, 2019
1 parent f7850e8 commit ee4cedb
Show file tree
Hide file tree
Showing 19 changed files with 145 additions and 119 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Abstract base classes
* - :class:`ReceiveStream`
- :class:`AsyncResource`
- :meth:`~ReceiveStream.receive_some`
-
- ``__aiter__``, ``__anext__``
- :class:`~trio.testing.MemoryReceiveStream`
* - :class:`Stream`
- :class:`SendStream`, :class:`ReceiveStream`
Expand Down
39 changes: 21 additions & 18 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -908,12 +908,10 @@ And the second task's job is to process the data the server sends back:
:lineno-match:
:pyobject: receiver

It repeatedly calls ``await client_stream.receive_some(...)`` to get
more data from the server (again, all Trio streams provide this
method), and then checks to see if the server has closed the
connection. ``receive_some`` only returns an empty bytestring if the
connection has been closed; otherwise, it waits until some data has
arrived, up to a maximum of ``BUFSIZE`` bytes.
It uses an ``async for`` loop to fetch data from the server.
Alternatively, it could use `~trio.abc.ReceiveStream.receive_some`,
which is the opposite of `~trio.abc.SendStream.send_all`, but using
``async for`` saves some boilerplate.

And now we're ready to look at the server.

Expand Down Expand Up @@ -974,11 +972,11 @@ functions we saw in the last section:

The argument ``server_stream`` is provided by :func:`serve_tcp`, and
is the other end of the connection we made in the client: so the data
that the client passes to ``send_all`` will come out of
``receive_some`` here, and vice-versa. Then we have a ``try`` block
discussed below, and finally the server loop which alternates between
reading some data from the socket and then sending it back out again
(unless the socket was closed, in which case we quit).
that the client passes to ``send_all`` will come out here. Then we
have a ``try`` block discussed below, and finally the server loop
which alternates between reading some data from the socket and then
sending it back out again (unless the socket was closed, in which case
we quit).

So what's that ``try`` block for? Remember that in Trio, like Python
in general, exceptions keep propagating until they're caught. Here we
Expand Down Expand Up @@ -1029,7 +1027,7 @@ our client could use a single task like::
while True:
data = ...
await client_stream.send_all(data)
received = await client_stream.receive_some(BUFSIZE)
received = await client_stream.receive_some()
if not received:
sys.exit()
await trio.sleep(1)
Expand All @@ -1046,18 +1044,23 @@ line, any time we're expecting more than one byte of data, we have to
be prepared to call ``receive_some`` multiple times.

And where this would go especially wrong is if we find ourselves in
the situation where ``len(data) > BUFSIZE``. On each pass through the
loop, we send ``len(data)`` bytes, but only read *at most* ``BUFSIZE``
bytes. The result is something like a memory leak: we'll end up with
more and more data backed up in the network, until eventually
something breaks.
the situation where ``data`` is big enough that it passes some
internal threshold, and the operating system or network decide to
always break it up into multiple pieces. Now on each pass through the
loop, we send ``len(data)`` bytes, but read less than that. The result
is something like a memory leak: we'll end up with more and more data
backed up in the network, until eventually something breaks.

.. note:: If you're curious *how* things break, then you can use
`~trio.abc.ReceiveStream.receive_some`\'s optional argument to put
a limit on how many bytes you read each time, and see what happens.

We could fix this by keeping track of how much data we're expecting at
each moment, and then keep calling ``receive_some`` until we get it all::

expected = len(data)
while expected > 0:
received = await client_stream.receive_some(BUFSIZE)
received = await client_stream.receive_some(expected)
if not received:
sys.exit(1)
expected -= len(received)
Expand Down
11 changes: 3 additions & 8 deletions docs/source/tutorial/echo-client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
# - can't be in use by some other program on your computer
# - must match what we set in our echo server
PORT = 12345
# How much memory to spend (at most) on each call to recv. Pretty arbitrary,
# but shouldn't be too big or too small.
BUFSIZE = 16384

async def sender(client_stream):
print("sender: started!")
Expand All @@ -22,12 +19,10 @@ async def sender(client_stream):

async def receiver(client_stream):
print("receiver: started!")
while True:
data = await client_stream.receive_some(BUFSIZE)
async for data in client_stream:
print("receiver: got data {!r}".format(data))
if not data:
print("receiver: connection closed")
sys.exit()
print("receiver: connection closed")
sys.exit()

async def parent():
print("parent: connecting to 127.0.0.1:{}".format(PORT))
Expand Down
12 changes: 3 additions & 9 deletions docs/source/tutorial/echo-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
# - can't be in use by some other program on your computer
# - must match what we set in our echo client
PORT = 12345
# How much memory to spend (at most) on each call to recv. Pretty arbitrary,
# but shouldn't be too big or too small.
BUFSIZE = 16384

CONNECTION_COUNTER = count()

Expand All @@ -20,14 +17,11 @@ async def echo_server(server_stream):
ident = next(CONNECTION_COUNTER)
print("echo_server {}: started".format(ident))
try:
while True:
data = await server_stream.receive_some(BUFSIZE)
async for data in server_stream:
print("echo_server {}: received data {!r}".format(ident, data))
if not data:
print("echo_server {}: connection closed".format(ident))
return
print("echo_server {}: sending data {!r}".format(ident, data))
await server_stream.send_all(data)
print("echo_server {}: connection closed".format(ident))
return
# FIXME: add discussion of MultiErrors to the tutorial, and use
# MultiError.catch here. (Not important in this case, but important if the
# server code uses nurseries internally.)
Expand Down
8 changes: 8 additions & 0 deletions newsfragments/959.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
If you have a `~trio.abc.ReceiveStream` object, you can now use
``async for data in stream: ...`` instead of calling
`~trio.abc.ReceiveStream.receive_some` repeatedly. And the best part
is, it automatically checks for EOF for you, so you don't have to.
Also, you no longer have to choose a magic buffer size value before
calling `~trio.abc.ReceiveStream.receive_some`; you can now call
``await stream.receive_some()`` and the stream will automatically pick
a reasonable value for you.
2 changes: 1 addition & 1 deletion notes-to-self/graceful-shutdown-idea.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def shutting_down(self):
async def stream_handler(stream):
while True:
with gsm.cancel_on_graceful_shutdown():
data = await stream.receive_some(...)
data = await stream.receive_some()
if gsm.shutting_down:
break

Expand Down
23 changes: 16 additions & 7 deletions trio/_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,26 +378,26 @@ class ReceiveStream(AsyncResource):
If you want to receive Python objects rather than raw bytes, see
:class:`ReceiveChannel`.
`ReceiveStream` objects can be used in ``async for`` loops. Each iteration
will produce an arbitrary size
"""
__slots__ = ()

@abstractmethod
async def receive_some(self, max_bytes):
async def receive_some(self, max_bytes=None):
"""Wait until there is data available on this stream, and then return
at most ``max_bytes`` of it.
some of it.
A return value of ``b""`` (an empty bytestring) indicates that the
stream has reached end-of-file. Implementations should be careful that
they return ``b""`` if, and only if, the stream has reached
end-of-file!
This method will return as soon as any data is available, so it may
return fewer than ``max_bytes`` of data. But it will never return
more.
Args:
max_bytes (int): The maximum number of bytes to return. Must be
greater than zero.
greater than zero. Optional; if omitted, then the stream object
is free to pick a reasonable default.
Returns:
bytes or bytearray: The data received.
Expand All @@ -413,6 +413,15 @@ async def receive_some(self, max_bytes):
"""

def __aiter__(self):
return self

async def __anext__(self):
data = await self.receive_some()
if not data:
raise StopAsyncIteration
return data


class Stream(SendStream, ReceiveStream):
"""A standard interface for interacting with bidirectional byte streams.
Expand Down
4 changes: 2 additions & 2 deletions trio/_highlevel_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class StapledStream(HalfCloseableStream):
left, right = trio.testing.memory_stream_pair()
echo_stream = StapledStream(SocketStream(left), SocketStream(right))
await echo_stream.send_all(b"x")
assert await echo_stream.receive_some(1) == b"x"
assert await echo_stream.receive_some() == b"x"
:class:`StapledStream` objects implement the methods in the
:class:`~trio.abc.HalfCloseableStream` interface. They also have two
Expand Down Expand Up @@ -96,7 +96,7 @@ async def send_eof(self):
else:
return await self.send_stream.aclose()

async def receive_some(self, max_bytes):
async def receive_some(self, max_bytes=None):
"""Calls ``self.receive_stream.receive_some``.
"""
Expand Down
10 changes: 9 additions & 1 deletion trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

__all__ = ["SocketStream", "SocketListener"]

# XX TODO: this number was picked arbitrarily. We should do experiments to
# tune it. (Or make it dynamic -- one idea is to start small and increase it
# if we observe single reads filling up the whole buffer, at least within some
# limits.)
DEFAULT_RECEIVE_SIZE = 65536

_closed_stream_errnos = {
# Unix
errno.EBADF,
Expand Down Expand Up @@ -129,7 +135,9 @@ async def send_eof(self):
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)

async def receive_some(self, max_bytes):
async def receive_some(self, max_bytes=None):
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
with _translate_socket_errors_to_stream_errors():
Expand Down
51 changes: 26 additions & 25 deletions trio/_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,16 @@
from ._highlevel_generic import aclose_forcefully
from . import _sync
from ._util import ConflictDetector
from ._deprecate import warn_deprecated

################################################################
# SSLStream
################################################################

# XX TODO: this number was pulled out of a hat. We should tune it with
# science.
DEFAULT_RECEIVE_SIZE = 65536


class NeedHandshakeError(Exception):
"""Some :class:`SSLStream` methods can't return any meaningful data until
Expand Down Expand Up @@ -197,8 +202,6 @@ def done(self):

_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])

_default_max_refill_bytes = 32 * 1024


class SSLStream(Stream):
r"""Encrypted communication using SSL/TLS.
Expand Down Expand Up @@ -269,15 +272,6 @@ class SSLStream(Stream):
that :class:`~ssl.SSLSocket` implements the
``https_compatible=True`` behavior by default.
max_refill_bytes (int): :class:`~ssl.SSLSocket` maintains an internal
buffer of incoming data, and when it runs low then it calls
:meth:`receive_some` on the underlying transport stream to refill
it. This argument lets you set the ``max_bytes`` argument passed to
the *underlying* :meth:`receive_some` call. It doesn't affect calls
to *this* class's :meth:`receive_some`, or really anything else
user-observable except possibly performance. You probably don't need
to worry about this.
Attributes:
transport_stream (trio.abc.Stream): The underlying transport stream
that was passed to ``__init__``. An example of when this would be
Expand Down Expand Up @@ -313,11 +307,14 @@ def __init__(
server_hostname=None,
server_side=False,
https_compatible=False,
max_refill_bytes=_default_max_refill_bytes
max_refill_bytes="unused and deprecated",
):
self.transport_stream = transport_stream
self._state = _State.OK
self._max_refill_bytes = max_refill_bytes
if max_refill_bytes != "unused and deprecated":
warn_deprecated(
"max_refill_bytes=...", "0.12.0", issue=959, instead=None
)
self._https_compatible = https_compatible
self._outgoing = _stdlib_ssl.MemoryBIO()
self._incoming = _stdlib_ssl.MemoryBIO()
Expand Down Expand Up @@ -536,9 +533,7 @@ async def _retry(self, fn, *args, ignore_want_read=False):
async with self._inner_recv_lock:
yielded = True
if recv_count == self._inner_recv_count:
data = await self.transport_stream.receive_some(
self._max_refill_bytes
)
data = await self.transport_stream.receive_some()
if not data:
self._incoming.write_eof()
else:
Expand Down Expand Up @@ -590,7 +585,7 @@ async def do_handshake(self):
# https://bugs.python.org/issue30141
# So we *definitely* have to make sure that do_handshake is called
# before doing anything else.
async def receive_some(self, max_bytes):
async def receive_some(self, max_bytes=None):
"""Read some data from the underlying transport, decrypt it, and
return it.
Expand Down Expand Up @@ -621,9 +616,15 @@ async def receive_some(self, max_bytes):
return b""
else:
raise
max_bytes = _operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
if max_bytes is None:
# Heuristic: normally we use DEFAULT_RECEIVE_SIZE, but if
# the transport gave us a bunch of data last time then we'll
# try to decrypt and pass it all back at once.
max_bytes = max(DEFAULT_RECEIVE_SIZE, self._incoming.pending)
else:
max_bytes = _operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
try:
return await self._retry(self._ssl_object.read, max_bytes)
except trio.BrokenResourceError as exc:
Expand Down Expand Up @@ -837,8 +838,6 @@ class SSLListener(Listener[SSLStream]):
https_compatible (bool): Passed on to :class:`SSLStream`.
max_refill_bytes (int): Passed on to :class:`SSLStream`.
Attributes:
transport_listener (trio.abc.Listener): The underlying listener that was
passed to ``__init__``.
Expand All @@ -851,12 +850,15 @@ def __init__(
ssl_context,
*,
https_compatible=False,
max_refill_bytes=_default_max_refill_bytes
max_refill_bytes="unused and deprecated",
):
if max_refill_bytes != "unused and deprecated":
warn_deprecated(
"max_refill_bytes=...", "0.12.0", issue=959, instead=None
)
self.transport_listener = transport_listener
self._ssl_context = ssl_context
self._https_compatible = https_compatible
self._max_refill_bytes = max_refill_bytes

async def accept(self):
"""Accept the next connection and wrap it in an :class:`SSLStream`.
Expand All @@ -870,7 +872,6 @@ async def accept(self):
self._ssl_context,
server_side=True,
https_compatible=self._https_compatible,
max_refill_bytes=self._max_refill_bytes,
)

async def aclose(self):
Expand Down
5 changes: 1 addition & 4 deletions trio/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,7 @@ async def feed_input():

async def read_output(stream, chunks):
async with stream:
while True:
chunk = await stream.receive_some(32768)
if not chunk:
break
async for chunk in stream:
chunks.append(chunk)

async with trio.open_nursery() as nursery:
Expand Down
Loading

0 comments on commit ee4cedb

Please sign in to comment.