diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 311eeaae829..457fd864454 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -24,6 +24,7 @@ distributed: pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] preload-argv: [] + unknown-task-duration: 0.5s # Default duration for all tasks with unknown durations ("15m", "2h") default-task-durations: # How long we expect function names to run ("1h", "1s") (helps for long tasks) rechunk-split: 1us shuffle-split: 1us diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f99c26d9aba..8bd8ce45a71 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3693,7 +3693,7 @@ def get_comm_cost(self, ts, ws): """ return sum(dts.nbytes for dts in ts.dependencies - ws.has_what) / self.bandwidth - def get_task_duration(self, ts, default=0.5): + def get_task_duration(self, ts, default=None): """ Get the estimated computation cost of the given task (not including any communication cost). @@ -3701,6 +3701,10 @@ def get_task_duration(self, ts, default=0.5): duration = ts.prefix.duration_average if duration is None: self.unknown_durations[ts.prefix.name].add(ts) + if default is None: + default = parse_timedelta( + dask.config.get("distributed.scheduler.unknown-task-duration") + ) return default return duration diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 1068169b200..d7ad5ec9388 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2057,3 +2057,14 @@ async def test_worker_name_collision(s, a): s.validate_state() assert set(s.workers) == {a.address} assert s.aliases == {a.name: a.address} + + +@gen_cluster(client=True, config={"distributed.scheduler.unknown-task-duration": "1h"}) +async def test_unknown_task_duration_config(client, s, a, b): + future = client.submit(inc, 1) + while not s.tasks: + await asyncio.sleep(0.001) + assert sum(s.get_task_duration(ts) for ts in s.tasks.values()) == 3600 + assert len(s.unknown_durations) == 1 + await wait(future) + assert len(s.unknown_durations) == 0