Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up handling of Handles #76

Merged
merged 11 commits into from
May 11, 2020
4 changes: 2 additions & 2 deletions trio_asyncio/_async.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import trio
import asyncio

from ._base import BaseTrioEventLoop
from ._handles import Handle


class TrioEventLoop(BaseTrioEventLoop):
Expand Down Expand Up @@ -69,7 +69,7 @@ def stop_me():
if self._stopped.is_set():
waiter.set()
else:
self._queue_handle(Handle(stop_me, (), self, context=None, is_sync=True))
self._queue_handle(asyncio.Handle(stop_me, (), self))
return waiter

def _close(self):
Expand Down
124 changes: 37 additions & 87 deletions trio_asyncio/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
import concurrent.futures

from ._handles import Handle, TimerHandle
from ._handles import ScopedHandle, AsyncHandle
from ._util import run_aio_future, run_aio_generator
from ._deprecate import deprecated, deprecated_alias
from . import _util
Expand Down Expand Up @@ -39,28 +39,6 @@ def clear(self):
pass


def _h_raise(handle, exc):
"""
Convince a handle to raise an error.

trio-asyncio enhanced handles have a method to do this
but asyncio's native handles don't. Thus we need to fudge things.
smurfix marked this conversation as resolved.
Show resolved Hide resolved
"""
if hasattr(handle, '_raise'):
handle._raise(exc)
return

def _raise(exc):
raise exc

cb, handle._callback = handle._callback, _raise
ar, handle._args = handle._args, (exc,)
try:
handle._run()
finally:
handle._callback, handle._args = cb, ar


class _TrioSelector(_BaseSelectorImpl):
"""A selector that hooks into a ``TrioEventLoop``.

Expand Down Expand Up @@ -241,25 +219,6 @@ async def run_aio_coroutine(self, coro):
finally:
sniffio.current_async_library_cvar.reset(t)

async def __run_trio(self, h):
"""Helper for copying the result of a Trio task to an asyncio future"""
f, proc, *args = h._args
if f.cancelled(): # pragma: no cover
return
try:
with trio.CancelScope() as scope:
h._scope = scope
res = await proc(*args)
if scope.cancelled_caught:
f.cancel()
return
except BaseException as exc:
if not f.cancelled(): # pragma: no branch
f.set_exception(exc)
else:
if not f.cancelled(): # pragma: no branch
f.set_result(res)

def trio_as_future(self, proc, *args):
"""Start a new Trio task to run ``await proc(*args)`` asynchronously.
Return an `asyncio.Future` that will resolve to the value or exception
Expand Down Expand Up @@ -292,14 +251,7 @@ def trio_as_future(self, proc, *args):
an `asyncio.Future` which will resolve to the result of the call to *proc*
"""
f = asyncio.Future(loop=self)
h = Handle(
self.__run_trio, (
f,
proc,
) + args, self, context=None, is_sync=None
)
self._queue_handle(h)
f.add_done_callback(h._cb_future_cancel)
self._queue_handle(AsyncHandle(proc, args, self, result_future=f))
return f

def run_trio_task(self, proc, *args):
Expand All @@ -314,7 +266,7 @@ def run_trio_task(self, proc, *args):
Returns:
an `asyncio.Handle` which can be used to cancel the background task
"""
return self._queue_handle(Handle(proc, args, self, is_sync=False))
return self._queue_handle(AsyncHandle(proc, args, self))

# Callback handling #

Expand All @@ -331,7 +283,7 @@ def _queue_handle(self, handle):
def _call_soon(self, *arks, **kwargs):
raise RuntimeError("_call_soon() should not have been called")

def call_later(self, delay, callback, *args, context=None):
def call_later(self, delay, callback, *args, **context):
"""asyncio's timer-based delay

Note that the callback is a sync function.
Expand All @@ -342,36 +294,36 @@ def call_later(self, delay, callback, *args, context=None):
"""
self._check_callback(callback, 'call_later')
assert delay >= 0, delay
h = TimerHandle(delay + self.time(), callback, args, self, context=context, is_sync=True)
h = asyncio.TimerHandle(delay + self.time(), callback, args, self, **context)
self._queue_handle(h)
return h

def call_at(self, when, callback, *args, context=None):
def call_at(self, when, callback, *args, **context):
"""asyncio's time-based delay

Note that the callback is a sync function.
"""
self._check_callback(callback, 'call_at')
return self._queue_handle(
TimerHandle(when, callback, args, self, context=context, is_sync=True)
asyncio.TimerHandle(when, callback, args, self, **context)
)

def call_soon(self, callback, *args, context=None):
def call_soon(self, callback, *args, **context):
"""asyncio's defer-to-mainloop callback executor.

Note that the callback is a sync function.
"""
self._check_callback(callback, 'call_soon')
return self._queue_handle(Handle(callback, args, self, context=context, is_sync=True))
return self._queue_handle(asyncio.Handle(callback, args, self, **context))

def call_soon_threadsafe(self, callback, *args, context=None):
def call_soon_threadsafe(self, callback, *args, **context):
"""asyncio's thread-safe defer-to-mainloop

Note that the callback is a sync function.
"""
self._check_callback(callback, 'call_soon_threadsafe')
self._check_closed()
h = Handle(callback, args, self, context=context, is_sync=True)
h = asyncio.Handle(callback, args, self, **context)
self._token.run_sync_soon(self._q_send.send_nowait, h)

# drop all timers
Expand Down Expand Up @@ -471,7 +423,7 @@ async def synchronize(self):

"""
w = trio.Event()
self._queue_handle(Handle(w.set, (), self, is_sync=True))
self._queue_handle(asyncio.Handle(w.set, (), self))
await w.wait()

# Signal handling #
Expand All @@ -488,7 +440,7 @@ def add_signal_handler(self, sig, callback, *args):
self._check_signal(sig)
if sig == signal.SIGKILL:
raise RuntimeError("SIGKILL cannot be caught")
h = Handle(callback, args, self, context=None, is_sync=True)
h = asyncio.Handle(callback, args, self)
assert sig not in self._signal_handlers, \
"Signal %d is already being caught" % (sig,)
self._orig_signals[sig] = signal.signal(sig, self._handle_sig)
Expand Down Expand Up @@ -528,7 +480,7 @@ def add_reader(self, fd, callback, *args):

def _add_reader(self, fd, callback, *args):
self._check_closed()
handle = Handle(callback, args, self, context=None, is_sync=True)
handle = ScopedHandle(callback, args, self)
reader = self._set_read_handle(fd, handle)
if reader is not None:
reader.cancel()
Expand All @@ -547,20 +499,17 @@ def _set_read_handle(self, fd, handle):
self._selector.modify(fd, mask | EVENT_READ, (handle, writer))
return reader

async def _reader_loop(self, fd, handle, task_status=trio.TASK_STATUS_IGNORED):
task_status.started()
with trio.CancelScope() as scope:
handle._scope = scope
async def _reader_loop(self, fd, handle):
with handle._scope:
try:
while not handle._cancelled: # pragma: no branch
while True:
await _wait_readable(fd)
handle._call_sync()
if handle._cancelled:
break
handle._run()
await self.synchronize()
except Exception as exc:
_h_raise(handle, exc)
return
finally:
handle._scope = None
handle._raise(exc)

# writing to a file descriptor

Expand All @@ -583,7 +532,7 @@ def add_writer(self, fd, callback, *args):

def _add_writer(self, fd, callback, *args):
self._check_closed()
handle = Handle(callback, args, self, context=None, is_sync=True)
handle = ScopedHandle(callback, args, self)
writer = self._set_write_handle(fd, handle)
if writer is not None:
writer.cancel()
Expand All @@ -601,20 +550,17 @@ def _set_write_handle(self, fd, handle):
self._selector.modify(fd, mask | EVENT_WRITE, (reader, handle))
return writer

async def _writer_loop(self, fd, handle, task_status=trio.TASK_STATUS_IGNORED):
with trio.CancelScope() as scope:
handle._scope = scope
task_status.started()
async def _writer_loop(self, fd, handle):
with handle._scope:
try:
while not handle._cancelled: # pragma: no branch
while True:
await _wait_writable(fd)
handle._call_sync()
if handle._cancelled:
break
handle._run()
await self.synchronize()
except Exception as exc:
_h_raise(handle, exc)
return
finally:
handle._scope = None
handle._raise(exc)

def autoclose(self, fd):
"""
Expand Down Expand Up @@ -717,7 +663,7 @@ async def _main_loop_one(self, no_wait=False):
# so restart from the beginning.
return

if isinstance(obj, TimerHandle):
if isinstance(obj, asyncio.TimerHandle):
# A TimerHandle is added to the list of timers.
heapq.heappush(self._timers, obj)
return
Expand All @@ -732,13 +678,17 @@ async def _main_loop_one(self, no_wait=False):

# Don't go through the expensive nursery dance
# if this is a sync function.
if getattr(obj, '_is_sync', True):
if isinstance(obj, AsyncHandle):
if hasattr(obj, '_context'):
obj._context.run(self._nursery.start_soon, obj._run, name=obj._callback)
else:
self._nursery.start_soon(obj._run, name=obj._callback)
await obj._started.wait()
else:
if hasattr(obj, '_context'):
obj._context.run(obj._callback, *obj._args)
else:
obj._callback(*obj._args)
else:
await self._nursery.start(obj._call_async)

async def _main_loop_exit(self):
"""Finalize the loop. It may not be re-entered."""
Expand Down
Loading