Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support asynchronous tasks #5151

Merged
merged 5 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,9 @@ def submit(
Parameters
----------
func : callable
Callable to be scheduled as ``func(*args **kwargs)``. If ``func`` is a coroutine,
it will be run on the main event loop of a worker. Otherwise ``func`` will be run
in a worker's task executor pool (see ``Worker.executors`` for more information.)
*args
**kwargs
pure : bool (defaults to True)
Expand Down Expand Up @@ -1674,6 +1677,9 @@ def map(
Parameters
----------
func : callable
Callable to be scheduled for execution. If ``func`` is a coroutine,
it will be run on the main event loop of a worker. Otherwise ``func`` will be run
in a worker's task executor pool (see ``Worker.executors`` for more information.)
iterables : Iterables
List-like objects to map over. They should have the same length.
key : str, list
Expand Down
17 changes: 14 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from distributed.utils import is_valid_xml, mp_context, sync, tmp_text, tmpfile
from distributed.utils_test import (
TaskStateMetadataPlugin,
_UnhashableCallable,
async_wait_for,
asyncinc,
captured_logger,
Expand Down Expand Up @@ -5595,9 +5596,9 @@ async def test_warn_when_submitting_large_values(c, s, a, b):

@gen_cluster(client=True)
async def test_unhashable_function(c, s, a, b):
d = {"a": 1}
result = await c.submit(d.get, "a")
assert result == 1
func = _UnhashableCallable()
result = await c.submit(func, 1)
Comment on lines +5599 to +5600
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is because dict.get became hashable starting with Python 3.8. This is why test_unhashable_function was only failing in our Python 3.7 CI builds

assert result == 2


@gen_cluster()
Expand Down Expand Up @@ -6923,3 +6924,13 @@ def f():
assert results[n.worker_address] == 123

assert files == set(os.listdir()) # no change


@gen_cluster(client=True)
async def test_async_task(c, s, a, b):
async def f(x):
return x + 1

future = c.submit(f, 10)
result = await future
assert result == 11
16 changes: 15 additions & 1 deletion distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_traceback,
is_kernel,
is_valid_xml,
iscoroutinefunction,
nbytes,
offload,
open_port,
Expand All @@ -41,7 +42,15 @@
truncate_exception,
warn_on_duration,
)
from distributed.utils_test import captured_logger, div, gen_test, has_ipv6, inc, throws
from distributed.utils_test import (
_UnhashableCallable,
captured_logger,
div,
gen_test,
has_ipv6,
inc,
throws,
)


def test_All(loop):
Expand Down Expand Up @@ -584,3 +593,8 @@ def test_parse_timedelta_deprecated():
with pytest.warns(FutureWarning, match="parse_timedelta is deprecated"):
from distributed.utils import parse_timedelta
assert parse_timedelta is dask.utils.parse_timedelta


def test_iscoroutinefunction_unhashable_input():
# Ensure iscoroutinefunction can handle unhashable callables
assert not iscoroutinefunction(_UnhashableCallable())
8 changes: 8 additions & 0 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from distributed.metrics import time
from distributed.utils import get_ip
from distributed.utils_test import (
_UnhashableCallable,
cluster,
gen_cluster,
gen_test,
Expand Down Expand Up @@ -269,3 +270,10 @@ async def test_tls_scheduler(security, cleanup):
security=security, host="localhost", dashboard_address=":0"
) as s:
assert s.address.startswith("tls")


def test__UnhashableCallable():
func = _UnhashableCallable()
assert func(1) == 2
with pytest.raises(TypeError, match="unhashable"):
hash(func)
14 changes: 13 additions & 1 deletion distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,9 +1126,21 @@ def color_of(x, palette=palette):
return palette[n % len(palette)]


def _iscoroutinefunction(f):
return inspect.iscoroutinefunction(f) or gen.is_coroutine_function(f)


@functools.lru_cache(None)
def _iscoroutinefunction_cached(f):
return _iscoroutinefunction(f)


def iscoroutinefunction(f):
return inspect.iscoroutinefunction(f) or gen.is_coroutine_function(f)
# Attempt to use lru_cache version and fall back to non-cached version if needed
try:
return _iscoroutinefunction_cached(f)
except TypeError: # unhashable type
return _iscoroutinefunction(f)


@contextmanager
Expand Down
7 changes: 7 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,13 @@ def slowidentity(*args, **kwargs):
return args


class _UnhashableCallable:
__hash__ = None

def __call__(self, x):
return x + 1


def run_for(duration, timer=time):
"""
Burn CPU for *duration* seconds.
Expand Down
45 changes: 44 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2910,7 +2910,14 @@ async def execute(self, key):
try:
e = self.executors[executor]
ts.start_time = time()
if "ThreadPoolExecutor" in str(type(e)):
if iscoroutinefunction(function):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of this check revealed that since our iscoroutinefunction utility function is wrapped by functool.lru_cache for performance reasons, iscoroutinefunction raised an error when it encountered an unhashable function (which occurred in one of our tests). I pushed a commit which adds some fallback logic to make sure iscoroutinefunction can handle unhashable inputs

result = await apply_function_async(
function,
args2,
kwargs2,
self.scheduler_delay,
)
elif "ThreadPoolExecutor" in str(type(e)):
result = await self.loop.run_in_executor(
e,
apply_function,
Expand Down Expand Up @@ -3885,6 +3892,42 @@ def apply_function_simple(
return msg


async def apply_function_async(
function,
args,
kwargs,
time_delay,
):
"""Run a function, collect information

Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
start = time()
try:
result = await function(*args, **kwargs)
except Exception as e:
msg = error_message(e)
msg["op"] = "task-erred"
msg["actual-exception"] = e
else:
msg = {
"op": "task-finished",
"status": "OK",
"result": result,
"nbytes": sizeof(result),
"type": type(result) if result is not None else None,
}
finally:
end = time()
msg["start"] = start + time_delay
msg["stop"] = end + time_delay
msg["thread"] = ident
return msg


def apply_function_actor(
function, args, kwargs, execution_state, key, active_threads, active_threads_lock
):
Expand Down