Skip to content

Commit

Permalink
Introduce config for default task duration (#3642)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel Sailer authored Mar 27, 2020
1 parent 3fceec6 commit eda27be
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
1 change: 1 addition & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ distributed:
pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings
preload: []
preload-argv: []
unknown-task-duration: 500ms # Default duration for all tasks with unknown durations ("15m", "2h")
default-task-durations: # How long we expect function names to run ("1h", "1s") (helps for long tasks)
rechunk-split: 1us
shuffle-split: 1us
Expand Down
6 changes: 5 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3700,14 +3700,18 @@ def get_comm_cost(self, ts, ws):
"""
return sum(dts.nbytes for dts in ts.dependencies - ws.has_what) / self.bandwidth

def get_task_duration(self, ts, default=0.5):
def get_task_duration(self, ts, default=None):
"""
Get the estimated computation cost of the given task
(not including any communication cost).
"""
duration = ts.prefix.duration_average
if duration is None:
self.unknown_durations[ts.prefix.name].add(ts)
if default is None:
default = parse_timedelta(
dask.config.get("distributed.scheduler.unknown-task-duration")
)
return default

return duration
Expand Down
11 changes: 11 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,3 +2071,14 @@ async def test_worker_name_collision(s, a):
s.validate_state()
assert set(s.workers) == {a.address}
assert s.aliases == {a.name: a.address}


@gen_cluster(client=True, config={"distributed.scheduler.unknown-task-duration": "1h"})
async def test_unknown_task_duration_config(client, s, a, b):
future = client.submit(slowinc, 1)
while not s.tasks:
await asyncio.sleep(0.001)
assert sum(s.get_task_duration(ts) for ts in s.tasks.values()) == 3600
assert len(s.unknown_durations) == 1
await wait(future)
assert len(s.unknown_durations) == 0

0 comments on commit eda27be

Please sign in to comment.