Skip to content

Commit 1585f85

Browse files
authoredOct 19, 2021
Type annotations for Worker and gen_cluster (#5438)
1 parent 7d2516a commit 1585f85

File tree

6 files changed

+254
-129
lines changed

6 files changed

+254
-129
lines changed
 

‎distributed/profile.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424
'children': {...}}}
2525
}
2626
"""
27+
from __future__ import annotations
28+
2729
import bisect
2830
import linecache
2931
import sys
3032
import threading
3133
from collections import defaultdict, deque
3234
from time import sleep
35+
from typing import Any
3336

3437
import tlz as toolz
3538

@@ -152,7 +155,7 @@ def merge(*args):
152155
}
153156

154157

155-
def create():
158+
def create() -> dict[str, Any]:
156159
return {
157160
"count": 0,
158161
"children": {},

‎distributed/tests/test_scheduler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1947,7 +1947,7 @@ class NoSchedulerDelayWorker(Worker):
19471947
comparisons using times reported from workers.
19481948
"""
19491949

1950-
@property
1950+
@property # type: ignore
19511951
def scheduler_delay(self):
19521952
return 0
19531953

‎distributed/tests/test_stress.py

-5
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,6 @@ async def test_stress_creation_and_deletion(c, s):
9999
# Assertions are handled by the validate mechanism in the scheduler
100100
da = pytest.importorskip("dask.array")
101101

102-
def _disable_suspicious_counter(dask_worker):
103-
dask_worker._suspicious_count_limit = None
104-
105102
rng = da.random.RandomState(0)
106103
x = rng.random(size=(2000, 2000), chunks=(100, 100))
107104
y = ((x + 1).T + (x * 2) - x.mean(axis=1)).sum().round(2)
@@ -111,14 +108,12 @@ async def create_and_destroy_worker(delay):
111108
start = time()
112109
while time() < start + 5:
113110
async with Nanny(s.address, nthreads=2) as n:
114-
await c.run(_disable_suspicious_counter, workers=[n.worker_address])
115111
await asyncio.sleep(delay)
116112
print("Killed nanny")
117113

118114
await asyncio.gather(*(create_and_destroy_worker(0.1 * i) for i in range(20)))
119115

120116
async with Nanny(s.address, nthreads=2):
121-
await c.run(_disable_suspicious_counter)
122117
assert await c.compute(z) == 8000884.93
123118

124119

‎distributed/tests/test_worker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2968,7 +2968,7 @@ async def test_who_has_consistent_remove_replica(c, s, *workers):
29682968

29692969
await f2
29702970

2971-
assert ("missing-dep", f1.key) in a.story(f1.key)
2971+
assert (f1.key, "missing-dep") in a.story(f1.key)
29722972
assert a.tasks[f1.key].suspicious_count == 0
29732973
assert s.tasks[f1.key].suspicious == 0
29742974

‎distributed/utils_test.py

+34-24
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
import weakref
2424
from collections import defaultdict
25+
from collections.abc import Callable
2526
from contextlib import contextmanager, nullcontext, suppress
2627
from glob import glob
2728
from itertools import count
@@ -54,6 +55,7 @@
5455
from .diagnostics.plugin import WorkerPlugin
5556
from .metrics import time
5657
from .nanny import Nanny
58+
from .node import ServerNode
5759
from .proctitle import enable_proctitle_on_children
5860
from .security import Security
5961
from .utils import (
@@ -770,7 +772,7 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):
770772
await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses))
771773

772774

773-
def gen_test(timeout=_TEST_TIMEOUT):
775+
def gen_test(timeout: float = _TEST_TIMEOUT) -> Callable[[Callable], Callable]:
774776
"""Coroutine test
775777
776778
@gen_test(timeout=5)
@@ -797,14 +799,14 @@ def test_func():
797799

798800

799801
async def start_cluster(
800-
nthreads,
801-
scheduler_addr,
802-
loop,
803-
security=None,
804-
Worker=Worker,
805-
scheduler_kwargs={},
806-
worker_kwargs={},
807-
):
802+
nthreads: list[tuple[str, int] | tuple[str, int, dict]],
803+
scheduler_addr: str,
804+
loop: IOLoop,
805+
security: Security | dict[str, Any] | None = None,
806+
Worker: type[ServerNode] = Worker,
807+
scheduler_kwargs: dict[str, Any] = {},
808+
worker_kwargs: dict[str, Any] = {},
809+
) -> tuple[Scheduler, list[ServerNode]]:
808810
s = await Scheduler(
809811
loop=loop,
810812
validate=True,
@@ -813,6 +815,7 @@ async def start_cluster(
813815
host=scheduler_addr,
814816
**scheduler_kwargs,
815817
)
818+
816819
workers = [
817820
Worker(
818821
s.address,
@@ -822,7 +825,11 @@ async def start_cluster(
822825
loop=loop,
823826
validate=True,
824827
host=ncore[0],
825-
**(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs),
828+
**(
829+
merge(worker_kwargs, ncore[2]) # type: ignore
830+
if len(ncore) > 2
831+
else worker_kwargs
832+
),
826833
)
827834
for i, ncore in enumerate(nthreads)
828835
]
@@ -854,21 +861,24 @@ async def end_worker(w):
854861

855862

856863
def gen_cluster(
857-
nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)],
858-
ncores=None,
864+
nthreads: list[tuple[str, int] | tuple[str, int, dict]] = [
865+
("127.0.0.1", 1),
866+
("127.0.0.1", 2),
867+
],
868+
ncores: None = None, # deprecated
859869
scheduler="127.0.0.1",
860-
timeout=_TEST_TIMEOUT,
861-
security=None,
862-
Worker=Worker,
863-
client=False,
864-
scheduler_kwargs={},
865-
worker_kwargs={},
866-
client_kwargs={},
867-
active_rpc_timeout=1,
868-
config={},
869-
clean_kwargs={},
870-
allow_unclosed=False,
871-
):
870+
timeout: float = _TEST_TIMEOUT,
871+
security: Security | dict[str, Any] | None = None,
872+
Worker: type[ServerNode] = Worker,
873+
client: bool = False,
874+
scheduler_kwargs: dict[str, Any] = {},
875+
worker_kwargs: dict[str, Any] = {},
876+
client_kwargs: dict[str, Any] = {},
877+
active_rpc_timeout: float = 1,
878+
config: dict[str, Any] = {},
879+
clean_kwargs: dict[str, Any] = {},
880+
allow_unclosed: bool = False,
881+
) -> Callable[[Callable], Callable]:
872882
from distributed import Client
873883

874884
""" Coroutine test with small cluster

0 commit comments

Comments
 (0)
Please sign in to comment.