Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to close a queue. #573

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions newsfragments/573.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add the ability to close a :class:`trio.Queue`, cancelling all waiting getters and putters, and
preventing anyone else from getting or putting onto it.
67 changes: 66 additions & 1 deletion trio/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"StrictFIFOLock",
"Condition",
"Queue",
"QueueClosed",
]


Expand Down Expand Up @@ -802,6 +803,11 @@ class _QueueStats:
tasks_waiting_get = attr.ib()


class QueueClosed(Exception):
"""Raised on waiters for the queue when a queue is closed.
"""


# Like queue.Queue, with the notable difference that the capacity argument is
# mandatory.
class Queue:
Expand Down Expand Up @@ -842,6 +848,9 @@ def __init__(self, capacity):
# if len(self._data) < self.capacity, then self._put_wait is empty
# if len(self._data) > 0, then self._get_wait is empty
self._data = deque()
# closed state
self._put_close = False
self._all_closed = False

def __repr__(self):
return (
Expand Down Expand Up @@ -887,6 +896,9 @@ def put_nowait(self, obj):
WouldBlock: if the queue is full.

"""
if self._put_close or self._all_closed:
raise QueueClosed

if self._get_wait:
assert not self._data
task, _ = self._get_wait.popitem(last=False)
Expand All @@ -905,6 +917,9 @@ async def put(self, obj):

"""
await _core.checkpoint_if_cancelled()
if self._put_close or self._all_closed:
raise QueueClosed

try:
self.put_nowait(obj)
except _core.WouldBlock:
Expand Down Expand Up @@ -933,6 +948,9 @@ def get_nowait(self):
WouldBlock: if the queue is empty.

"""
if self._all_closed:
raise QueueClosed

if self._put_wait:
task, value = self._put_wait.popitem(last=False)
# No need to check max_size, b/c we'll pop an item off again right
Expand All @@ -942,6 +960,16 @@ def get_nowait(self):
if self._data:
value = self._data.popleft()
return value
if self._put_close:
# this confused me a bit so its bound to confuse somebody else as to why this is here
# 1) there's no put waiters, so we skip that branch
# 2) there's no data so we skip that branch
# that means that if there's no data at all, and the put side is closed
# we cannot ever have more data, so we close this side and raise QueueClosed so that
# any getters from here on close early
self._all_closed = True
raise QueueClosed

raise _core.WouldBlock()

@_core.enable_ki_protection
Expand All @@ -953,6 +981,9 @@ async def get(self):

"""
await _core.checkpoint_if_cancelled()
if self._all_closed:
raise QueueClosed

try:
value = self.get_nowait()
except _core.WouldBlock:
Expand All @@ -972,12 +1003,46 @@ def abort_fn(_):
value = await _core.wait_task_rescheduled(abort_fn)
return value

def close_put(self):
"""Closes one side of this queue, preventing any putters from putting data onto the queue.

If this queue is empty, it will also cancel all getters.
"""
if self.empty():
# pointless to let the getters wait on closed data
self.close_both_sides()
else:
self._put_close = True
for task in self._put_wait.values():
_core.reschedule(task, outcome.Error(QueueClosed))

self._put_wait.clear()

def close_both_sides(self):
"""Closes both the getter and putter sides of the queue, discarding all data.
"""
self._put_close, self._all_closed = True, True
for task in self._get_wait.values():
_core.reschedule(task, outcome.Error(QueueClosed))

self._get_wait.clear()

for task in self._put_wait.values():
_core.reschedule(task, outcome.Error(QueueClosed))

self._put_wait.clear()

self._data.clear()

@aiter_compat
def __aiter__(self):
return self

async def __anext__(self):
return await self.get()
try:
return await self.get()
except QueueClosed:
raise StopAsyncIteration from None

def statistics(self):
"""Returns an object containing debugging information.
Expand Down
21 changes: 20 additions & 1 deletion trio/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ..testing import wait_all_tasks_blocked, assert_checkpoints

from .. import _core
from .. import _timeouts
from .._timeouts import sleep_forever, move_on_after
from .._sync import *

Expand Down Expand Up @@ -542,6 +541,26 @@ async def do_put(q, v):
q.get_nowait()


async def test_Queue_close():
q1 = Queue(capacity=1)

await q1.put(1)
q1.close_put()
with pytest.raises(QueueClosed):
await q1.put(2)

assert (await q1.get()) == 1
with pytest.raises(QueueClosed):
await q1.get()

q2 = Queue(capacity=1)
await q2.put(1)
q2.close_both_sides()

with pytest.raises(QueueClosed):
await q2.get()


# Two ways of implementing a Lock in terms of a Queue. Used to let us put the
# Queue through the generic lock tests.

Expand Down