Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(asgi.BoundedStream): Mixing iteration and read() can lead to subtle errors #1692

Merged
merged 2 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions falcon/asgi/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"""ASGI BoundedStream class."""


from falcon.errors import OperationNotAllowed


__all__ = ['BoundedStream']


Expand Down Expand Up @@ -87,12 +90,14 @@ class BoundedStream:
'_buffer',
'_bytes_remaining',
'_closed',
'_iteration_started',
'_pos',
'_receive',
]

def __init__(self, receive, content_length=None):
self._closed = False
self._iteration_started = False

self._receive = receive
self._buffer = b''
Expand All @@ -108,9 +113,9 @@ def __init__(self, receive, content_length=None):
self._pos = 0

def __aiter__(self):
# NOTE(kgriffs): Technically we should be returning an async iterator
# here instead of an async generator, but in practice the caller
# should be happy as long as the returned object is iterable.
# NOTE(kgriffs): This returns an async generator, but that's OK because
# it also implements the iterator protocol defined in PEP 492, albeit
# in a more efficient way than a regular async iterator.
return self._iter_content()

# -------------------------------------------------------------------------
Expand Down Expand Up @@ -224,7 +229,7 @@ async def readall(self):
"""

if self._closed:
raise ValueError(
raise OperationNotAllowed(
'This stream is closed; no further operations on it are permitted.'
)

Expand Down Expand Up @@ -298,7 +303,7 @@ async def read(self, size=None):
"""

if self._closed:
raise ValueError(
raise OperationNotAllowed(
'This stream is closed; no further operations on it are permitted.'
)

Expand Down Expand Up @@ -360,23 +365,27 @@ async def read(self, size=None):

return data

# NOTE: In docs, tell people to not mix reading different modes - make
# sure you exhaust in the finally if you are reading something
# in middleware, or a chance something else might read it. Don't want someone
# to end up trying to read a half-read thing anyway!
async def _iter_content(self):
if self._closed:
raise ValueError(
raise OperationNotAllowed(
'This stream is closed; no further operations on it are permitted.'
)

if self.eof:
yield b''
return

# TODO(kgriffs): Should we check for any buffered data and return
# that first? Or simply raise an error if any data has already
# been read?
if self._iteration_started:
raise OperationNotAllowed('This stream is already being iterated over.')

self._iteration_started = True

if self._buffer:
next_chunk = self._buffer
self._buffer = b''

self._pos += len(next_chunk)
yield next_chunk

while self._bytes_remaining > 0:
event = await self._receive()
Expand Down
6 changes: 6 additions & 0 deletions falcon/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ class UnsupportedError(RuntimeError):
"""The method or operation is not supported."""


# NOTE(kgriffs): This inherits from ValueError to be consistent with the type
# raised by Python's built-in file-like objects.
class OperationNotAllowed(ValueError):
"""The requested operation is not allowed."""


class HTTPBadRequest(HTTPError):
"""400 Bad Request.

Expand Down
34 changes: 33 additions & 1 deletion tests/asgi/test_boundedstream_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,15 @@ async def test_mixed_b():

assert b''.join(chunks) == body

for t in (test_nonmixed, test_mixed_a, test_mixed_b):
async def test_mixed_iter():
s = stream()

chunks = [await s.read(chunk_size)]
chunks += [data async for data in s]

assert b''.join(chunks) == body

for t in (test_nonmixed, test_mixed_a, test_mixed_b, test_mixed_iter):
testing.invoke_coroutine_sync(t)
testing.invoke_coroutine_sync(t)

Expand All @@ -219,6 +227,30 @@ async def t():
testing.invoke_coroutine_sync(t)


def test_iteration_already_started():
body = testing.rand_string(1, 2048).encode()
s = _stream(body)

async def t():
stream_iter = s.__aiter__()

chunks = [await stream_iter.__anext__()]

with pytest.raises(ValueError):
stream_iter2 = s.__aiter__()
await stream_iter2.__anext__()

while True:
try:
chunks.append(await stream_iter.__anext__())
except StopAsyncIteration:
break

assert b''.join(chunks) == body

testing.invoke_coroutine_sync(t)


def _stream(body, content_length=None):
emitter = testing.ASGIRequestEventEmitter(body)
return asgi.BoundedStream(emitter, content_length=content_length)