22
22
import warnings
23
23
import weakref
24
24
from collections import defaultdict
25
+ from collections .abc import Callable
25
26
from contextlib import contextmanager , nullcontext , suppress
26
27
from glob import glob
27
28
from itertools import count
54
55
from .diagnostics .plugin import WorkerPlugin
55
56
from .metrics import time
56
57
from .nanny import Nanny
58
+ from .node import ServerNode
57
59
from .proctitle import enable_proctitle_on_children
58
60
from .security import Security
59
61
from .utils import (
@@ -770,7 +772,7 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):
770
772
await asyncio .gather (* (disconnect (addr , timeout , rpc_kwargs ) for addr in addresses ))
771
773
772
774
773
- def gen_test (timeout = _TEST_TIMEOUT ):
775
+ def gen_test (timeout : float = _TEST_TIMEOUT ) -> Callable [[ Callable ], Callable ] :
774
776
"""Coroutine test
775
777
776
778
@gen_test(timeout=5)
@@ -797,14 +799,14 @@ def test_func():
797
799
798
800
799
801
async def start_cluster (
800
- nthreads ,
801
- scheduler_addr ,
802
- loop ,
803
- security = None ,
804
- Worker = Worker ,
805
- scheduler_kwargs = {},
806
- worker_kwargs = {},
807
- ):
802
+ nthreads : list [ tuple [ str , int ] | tuple [ str , int , dict ]] ,
803
+ scheduler_addr : str ,
804
+ loop : IOLoop ,
805
+ security : Security | dict [ str , Any ] | None = None ,
806
+ Worker : type [ ServerNode ] = Worker ,
807
+ scheduler_kwargs : dict [ str , Any ] = {},
808
+ worker_kwargs : dict [ str , Any ] = {},
809
+ ) -> tuple [ Scheduler , list [ ServerNode ]] :
808
810
s = await Scheduler (
809
811
loop = loop ,
810
812
validate = True ,
@@ -813,6 +815,7 @@ async def start_cluster(
813
815
host = scheduler_addr ,
814
816
** scheduler_kwargs ,
815
817
)
818
+
816
819
workers = [
817
820
Worker (
818
821
s .address ,
@@ -822,7 +825,11 @@ async def start_cluster(
822
825
loop = loop ,
823
826
validate = True ,
824
827
host = ncore [0 ],
825
- ** (merge (worker_kwargs , ncore [2 ]) if len (ncore ) > 2 else worker_kwargs ),
828
+ ** (
829
+ merge (worker_kwargs , ncore [2 ]) # type: ignore
830
+ if len (ncore ) > 2
831
+ else worker_kwargs
832
+ ),
826
833
)
827
834
for i , ncore in enumerate (nthreads )
828
835
]
@@ -854,21 +861,24 @@ async def end_worker(w):
854
861
855
862
856
863
def gen_cluster (
857
- nthreads = [("127.0.0.1" , 1 ), ("127.0.0.1" , 2 )],
858
- ncores = None ,
864
+ nthreads : list [tuple [str , int ] | tuple [str , int , dict ]] = [
865
+ ("127.0.0.1" , 1 ),
866
+ ("127.0.0.1" , 2 ),
867
+ ],
868
+ ncores : None = None , # deprecated
859
869
scheduler = "127.0.0.1" ,
860
- timeout = _TEST_TIMEOUT ,
861
- security = None ,
862
- Worker = Worker ,
863
- client = False ,
864
- scheduler_kwargs = {},
865
- worker_kwargs = {},
866
- client_kwargs = {},
867
- active_rpc_timeout = 1 ,
868
- config = {},
869
- clean_kwargs = {},
870
- allow_unclosed = False ,
871
- ):
870
+ timeout : float = _TEST_TIMEOUT ,
871
+ security : Security | dict [ str , Any ] | None = None ,
872
+ Worker : type [ ServerNode ] = Worker ,
873
+ client : bool = False ,
874
+ scheduler_kwargs : dict [ str , Any ] = {},
875
+ worker_kwargs : dict [ str , Any ] = {},
876
+ client_kwargs : dict [ str , Any ] = {},
877
+ active_rpc_timeout : float = 1 ,
878
+ config : dict [ str , Any ] = {},
879
+ clean_kwargs : dict [ str , Any ] = {},
880
+ allow_unclosed : bool = False ,
881
+ ) -> Callable [[ Callable ], Callable ] :
872
882
from distributed import Client
873
883
874
884
""" Coroutine test with small cluster
0 commit comments