Skip to content

Commit

Permalink
Use universal _chain from asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 12, 2024
1 parent 6184c11 commit e32ffa2
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 66 deletions.
2 changes: 1 addition & 1 deletion src/plumpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from .process_listener import *
from .process_states import *
from .processes import *
from .rmq import *
from .utils import *
from .workchains import *
from .rmq import *

__all__ = (
events.__all__
Expand Down
1 change: 0 additions & 1 deletion src/plumpy/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,3 @@ class PersistenceError(Exception):

class ClosedError(Exception):
"""Raised when an mutable operation is attempted on a closed process"""

6 changes: 3 additions & 3 deletions src/plumpy/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import asyncio
import contextlib
from typing import Any, Awaitable, Callable, Optional
from typing import Any, Awaitable, Callable, Generator, Optional

__all__ = ['create_task', 'CancellableAction', 'create_task']
__all__ = ['CancellableAction', 'create_task', 'create_task']


class InvalidFutureError(Exception):
Expand All @@ -18,7 +18,7 @@ class InvalidFutureError(Exception):


@contextlib.contextmanager
def capture_exceptions(future: Future[Any], ignore: tuple[type[BaseException], ...] = ()):
def capture_exceptions(future: Future[Any], ignore: tuple[type[BaseException], ...] = ()) -> Generator[None, Any, None]:
"""
Capture any exceptions in the context and set them as the result of the given future
Expand Down
4 changes: 1 addition & 3 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,10 @@ def execute_process(
:param no_reply: if True, this call will be fire-and-forget, i.e. no return value
:return: the result of executing the process
"""
from plumpy.rmq.futures import unwrap_kiwi_future

message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader)

execute_future = kiwipy.Future()
create_future = unwrap_kiwi_future(self._communicator.task_send(message))
create_future = self._communicator.task_send(message)

def on_created(_: Any) -> None:
with kiwipy.capture_exceptions(execute_future):
Expand Down
3 changes: 2 additions & 1 deletion src/plumpy/rmq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from .exceptions import *
# mypy: disable-error-code=name-defined
from .communications import *
from .exceptions import *

__all__ = exceptions.__all__ + communications.__all__
29 changes: 2 additions & 27 deletions src/plumpy/rmq/communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import kiwipy

from plumpy import futures
from plumpy.rmq.futures import wrap_to_kiwi_future
from plumpy.utils import ensure_coroutine

__all__ = [
'Communicator',
'DeliveryFailed',
'RemoteException',
'TaskRejected',
'plum_to_kiwi_future',
'wrap_communicator',
]

Expand All @@ -36,31 +36,6 @@
BroadcastSubscriber = Callable[[kiwipy.Communicator, Any, Any, Any, ID_TYPE], Any]


def plum_to_kiwi_future(plum_future: futures.Future) -> kiwipy.Future:
"""
Return a kiwi future that resolves to the outcome of the plum future
:param plum_future: the plum future
:return: the kiwipy future
"""
kiwi_future = kiwipy.Future()

def on_done(_plum_future: futures.Future) -> None:
with kiwipy.capture_exceptions(kiwi_future):
if plum_future.cancelled():
kiwi_future.cancel()
else:
result = plum_future.result()
# Did we get another future? In which case convert it too
if isinstance(result, futures.Future):
result = plum_to_kiwi_future(result)
kiwi_future.set_result(result)

plum_future.add_done_callback(on_done)
return kiwi_future


def convert_to_comm(
callback: 'Subscriber', loop: Optional[asyncio.AbstractEventLoop] = None
) -> Callable[..., kiwipy.Future]:
Expand Down Expand Up @@ -97,7 +72,7 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k

msg_fn = functools.partial(coro, communicator, *args, **kwargs)
task_future = futures.create_task(msg_fn, loop)
return plum_to_kiwi_future(task_future)
return wrap_to_kiwi_future(task_future)

return converted

Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/rmq/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed
import kiwipy
from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed

__all__ = [
'CommunicatorChannelInvalidStateError',
Expand Down
45 changes: 16 additions & 29 deletions src/plumpy/rmq/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,25 @@
Module containing future related methods and classes
"""

import kiwipy

__all__ = ['chain', 'copy_future', 'unwrap_kiwi_future']

copy_future = kiwipy.copy_future
chain = kiwipy.chain
import asyncio
import concurrent.futures
from asyncio.futures import _chain_future, _copy_future_state # type: ignore[attr-defined]
from typing import Any

import kiwipy

def unwrap_kiwi_future(future: kiwipy.Future) -> kiwipy.Future:
"""
Create a kiwi future that represents the final results of a nested series of futures,
meaning that if the futures provided itself resolves to a future the returned
future will not resolve to a value until the final chain of futures is not a future
but a concrete value. If at any point in the chain a future resolves to an exception
then the returned future will also resolve to that exception.
__all__ = ['chain', 'copy_future', 'wrap_to_kiwi_future']

:param future: the future to unwrap
:return: the unwrapping future
copy_future = _copy_future_state
chain = _chain_future

"""
unwrapping = kiwipy.Future()

def unwrap(fut: kiwipy.Future) -> None:
if fut.cancelled():
unwrapping.cancel()
else:
with kiwipy.capture_exceptions(unwrapping):
result = fut.result()
if isinstance(result, kiwipy.Future):
result.add_done_callback(unwrap)
else:
unwrapping.set_result(result)
def wrap_to_kiwi_future(future: asyncio.Future[Any]) -> kiwipy.Future:
"""Wrap to concurrent.futures.Future object."""
if isinstance(future, concurrent.futures.Future):
return future
assert asyncio.isfuture(future), f'concurrent.futures.Future is expected, got {future!r}'

future.add_done_callback(unwrap)
return unwrapping
new_future = kiwipy.Future()
_chain_future(future, new_future)
return new_future

0 comments on commit e32ffa2

Please sign in to comment.