Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dask to run tasks #1714

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
63b032a
Use dask to run preprocessing tasks (and disable other tasks for now)
bouweandela Sep 2, 2022
4792193
Improve use of dask client as suggested by @zklaus
bouweandela Sep 13, 2022
11b12f1
Restore sequential and parallel task run
bouweandela Sep 13, 2022
6faa8f6
Add missing compute argument to save function for multimodel statisti…
bouweandela Sep 23, 2022
49e3589
Add distributed as a dependency
bouweandela Sep 23, 2022
54edb73
Restore previous API to fix tests
bouweandela Sep 23, 2022
021b3e0
Add xarray as a dependency
bouweandela Sep 23, 2022
53aa585
Merge branch 'main' into dask-distributed
bouweandela Sep 23, 2022
d5cf4f6
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela Oct 14, 2022
0f5bda8
Support configuring dask
bouweandela Oct 14, 2022
52af816
Add a suffix to output_directory if it exists instead of stopping
bouweandela Oct 16, 2022
823c731
Fix tests
bouweandela Oct 18, 2022
6f5a6bf
single call to compute
fnattino Nov 7, 2022
b455bbb
Only start cluster if necessary and fix filename-future mapping
bouweandela Nov 10, 2022
05d69f7
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela Nov 21, 2022
be991f8
Use iris (https://github.com/SciTools/iris/pull/5031) for saving
bouweandela Nov 21, 2022
93b6da1
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela Nov 22, 2022
37bb757
Point to iris branch
bouweandela Dec 5, 2022
2b60264
Merge branch 'main' into dask-distributed
bouweandela Dec 5, 2022
66bdf2e
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela Mar 30, 2023
065b8a4
Work in progress
bouweandela Apr 4, 2023
f47d6da
Update branch name
bouweandela Apr 4, 2023
6e25901
Remove type hint
bouweandela Apr 4, 2023
e176745
Get iris from main branch
bouweandela Apr 21, 2023
d6787fa
Merge branch 'main' into dask-distributed
bouweandela Apr 21, 2023
222c7e5
Add default scheduler
bouweandela Apr 28, 2023
b760abb
Use release candidate
bouweandela May 11, 2023
6b35638
Try to install iris from PyPI
bouweandela May 11, 2023
cf79d20
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela May 12, 2023
4348953
Update dependencies
bouweandela May 19, 2023
6bf6328
Remove pip install of ESMValTool_sample_data
bouweandela May 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- cftime
- compilers
- dask
- distributed
- esgf-pyclient>=0.3.1
- esmpy!=8.1.0
- filelock
Expand All @@ -18,7 +19,7 @@ dependencies:
- geopy
- humanfriendly
- importlib_resources
- iris>=3.4.0
- iris>=3.6.0
- iris-esmf-regrid >=0.6.0 # to work with latest esmpy
- isodate
- jinja2
Expand Down
5 changes: 3 additions & 2 deletions esmvalcore/_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
from functools import total_ordering
from pathlib import Path

from netCDF4 import Dataset
from PIL import Image
Expand Down Expand Up @@ -104,15 +105,15 @@ class TrackedFile:
"""File with provenance tracking."""

def __init__(self,
filename,
filename: Path,
attributes=None,
ancestors=None,
prov_filename=None):
"""Create an instance of a file with provenance tracking.

Arguments
---------
filename: str
filename:
Path to the file on disk.
attributes: dict
Dictionary with facets describing the file. If set to None, this
Expand Down
10 changes: 8 additions & 2 deletions esmvalcore/_recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ def _get_default_settings(dataset):
settings['remove_supplementary_variables'] = {}

# Configure saving cubes to file
settings['save'] = {'compress': session['compress_netcdf']}
settings['save'] = {
'compress': session['compress_netcdf'],
'compute': session['max_parallel_tasks'] != 0,
}
if facets['short_name'] != facets['original_short_name']:
settings['save']['alias'] = facets['short_name']

Expand Down Expand Up @@ -537,6 +540,9 @@ def _get_downstream_settings(step, order, products):
if key in remaining_steps:
if all(p.settings.get(key, object()) == value for p in products):
settings[key] = value
save = dict(some_product.settings.get('save', {}))
save.pop('filename', None)
settings['save'] = save
return settings


Expand Down Expand Up @@ -1305,7 +1311,7 @@ def run(self):
if self.session['search_esgf'] != 'never':
esgf.download(self._download_files, self.session['download_dir'])

self.tasks.run(max_parallel_tasks=self.session['max_parallel_tasks'])
self.tasks.run(self.session)
logger.info(
"Wrote recipe with version numbers and wildcards "
"to:\nfile://%s", filled_recipe)
Expand Down
198 changes: 181 additions & 17 deletions esmvalcore/_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""ESMValtool task definition."""
from __future__ import annotations

import abc
import contextlib
import datetime
import importlib
import logging
import numbers
import os
Expand All @@ -15,15 +18,21 @@
from multiprocessing import Pool
from pathlib import Path, PosixPath
from shutil import which
from typing import Optional
from typing import TYPE_CHECKING

import dask
import dask.distributed
import psutil
import yaml

from ._citation import _write_citation_files
from ._provenance import TrackedFile, get_task_provenance
from .config import Session
from .config._diagnostics import DIAGNOSTICS, TAGS

if TYPE_CHECKING:
from esmvalcore.preprocessor import PreprocessingTask


def path_representer(dumper, data):
"""For printing pathlib.Path objects in yaml files."""
Expand Down Expand Up @@ -191,7 +200,9 @@ def _ncl_type(value):
lines = []

# ignore some settings for NCL diagnostic
ignore_settings = ['profile_diagnostic', ]
ignore_settings = [
'profile_diagnostic',
]
for sett in ignore_settings:
settings_copy = dict(settings)
if 'diag_script_info' not in settings_copy:
Expand Down Expand Up @@ -414,7 +425,9 @@ def write_settings(self):
run_dir.mkdir(parents=True, exist_ok=True)

# ignore some settings for diagnostic
ignore_settings = ['profile_diagnostic', ]
ignore_settings = [
'profile_diagnostic',
]
for sett in ignore_settings:
settings_copy = dict(self.settings)
settings_copy.pop(sett, None)
Expand Down Expand Up @@ -694,6 +707,54 @@ def __repr__(self):
return string


@contextlib.contextmanager
def get_distributed_client(session):
"""Get a Dask distributed client."""
dask_args = session.get('dask', {})
client_args = dask_args.get('client', {}).copy()
cluster_args = dask_args.get('cluster', {}).copy()

# Start a cluster, if requested
if 'address' in client_args:
# Use an externally managed cluster.
cluster = None
if cluster_args:
logger.warning(
"Not using 'dask: cluster' settings because a cluster "
"'address' is already provided in 'dask: client'.")
elif cluster_args:
# Start cluster.
cluster_type = cluster_args.pop(
'type',
'dask.distributed.LocalCluster',
)
cluster_module_name, cluster_cls_name = cluster_type.rsplit('.', 1)
cluster_module = importlib.import_module(cluster_module_name)
cluster_cls = getattr(cluster_module, cluster_cls_name)
cluster = cluster_cls(**cluster_args)
client_args['address'] = cluster.scheduler_address
else:
# No cluster configured, use Dask default scheduler, or a LocalCluster
# managed through Client.
cluster = None

# Start a client, if requested
if dask_args:
client = dask.distributed.Client(**client_args)
logger.info(f"Dask dashboard: {client.dashboard_link}")
else:
logger.info("Using the Dask default scheduler.")
client = None

try:
yield client
finally:
if client is not None:
client.close()
if cluster is not None:
cluster.close()


class TaskSet(set):
"""Container for tasks."""

Expand All @@ -710,18 +771,101 @@ def get_independent(self) -> 'TaskSet':
independent_tasks.add(task)
return independent_tasks

def run(self, max_parallel_tasks: Optional[int] = None) -> None:
def run(self, session: Session) -> None:
"""Run tasks.

Parameters
----------
max_parallel_tasks : int
Number of processes to run. If `1`, run the tasks sequentially.
session : esmvalcore.config.Session
Session.
"""
if max_parallel_tasks == 1:
self._run_sequential()
else:
self._run_parallel(max_parallel_tasks)
with get_distributed_client(session) as client:
if client is None:
scheduler_address = None
else:
scheduler_address = client.scheduler.address
for task in self.flatten():
if (isinstance(task, DiagnosticTask)
and Path(task.script).suffix.lower() == '.py'):
# Only use the scheduler address if running a
# Python script.
task.settings['scheduler_address'] = scheduler_address

max_parallel_tasks = session['max_parallel_tasks']
if max_parallel_tasks == 0:
if client is None:
raise ValueError(
"Unable to run tasks using Dask distributed without a "
"configured dask client. Please edit config-user.yml "
"to configure dask.")
self._run_distributed(client)
elif max_parallel_tasks == 1:
self._run_sequential()
else:
self._run_parallel(scheduler_address, max_parallel_tasks)

def _run_distributed(self, client: dask.distributed.Client) -> None:
"""Run tasks using Dask Distributed."""
client.forward_logging()
tasks = sorted((t for t in self.flatten()), key=lambda t: t.priority)

# Create a graph for dask.array operations in PreprocessingTasks
preprocessing_tasks = [t for t in tasks if hasattr(t, 'delayeds')]

futures_to_preproc_tasks: dict[dask.distributed.Future,
PreprocessingTask] = {}
for task in preprocessing_tasks:
future = client.submit(_run_preprocessing_task,
task,
priority=-task.priority)
futures_to_preproc_tasks[future] = task

for future in dask.distributed.as_completed(futures_to_preproc_tasks):
task = futures_to_preproc_tasks[future]
_copy_preprocessing_results(task, future)

# Launch dask.array compute operations for PreprocessingTasks
futures_to_files: dict[dask.distributed.Future, Path] = {}
for task in preprocessing_tasks:
logger.info(f"Computing task {task.name}")
futures = client.compute(
list(task.delayeds.values()),
priority=-task.priority,
)
futures_to_files.update(zip(futures, task.delayeds))

# Start computing DiagnosticTasks as soon as the relevant
# PreprocessingTasks complete
waiting = [t for t in tasks if t not in preprocessing_tasks]
futures_to_tasks: dict[dask.distributed.Future, BaseTask] = {}
done_files = set()
done_tasks = set()
iterator = dask.distributed.as_completed(futures_to_files)
for future in iterator:
if future in futures_to_files:
filename = futures_to_files[future]
logger.info(f"Wrote (delayed) {filename}")
done_files.add(filename)
# Check if a PreprocessingTask has finished
for preproc_task in preprocessing_tasks:
filenames = set(preproc_task.delayeds)
if filenames.issubset(done_files):
done_tasks.add(preproc_task)
elif future in futures_to_tasks:
# Check if a ResumeTask or DiagnosticTask has finished
task = futures_to_tasks[future]
_copy_distributed_results(task, future)
done_tasks.add(task)

# Schedule any new tasks that can be scheduled
for task in waiting:
if set(task.ancestors).issubset(done_tasks):
future = client.submit(_run_task,
task,
priority=-task.priority)
iterator.add(future)
futures_to_tasks[future] = task
waiting.pop(waiting.index(task))

def _run_sequential(self) -> None:
"""Run tasks sequentially."""
Expand All @@ -732,7 +876,7 @@ def _run_sequential(self) -> None:
for task in sorted(tasks, key=lambda t: t.priority):
task.run()

def _run_parallel(self, max_parallel_tasks=None):
def _run_parallel(self, scheduler_address, max_parallel_tasks=None):
"""Run tasks in parallel."""
scheduled = self.flatten()
running = {}
Expand All @@ -757,14 +901,15 @@ def done(task):
if len(running) >= max_parallel_tasks:
break
if all(done(t) for t in task.ancestors):
future = pool.apply_async(_run_task, [task])
future = pool.apply_async(_run_task,
[task, scheduler_address])
running[task] = future
scheduled.remove(task)

# Handle completed tasks
ready = {t for t in running if running[t].ready()}
for task in ready:
_copy_results(task, running[task])
_copy_multiprocessing_results(task, running[task])
running.pop(task)

# Wait if there are still tasks running
Expand All @@ -785,12 +930,31 @@ def done(task):
pool.join()


def _copy_results(task, future):
def _run_task(task, scheduler_address=None):
"""Run task and return the result."""
if scheduler_address is None:
output_files = task.run()
else:
with dask.distributed.Client(scheduler_address):
output_files = task.run()
return output_files, task.products


def _copy_distributed_results(task, future):
"""Update task with the results from the dask worker."""
task.output_files, task.products = future.result()


def _copy_multiprocessing_results(task, future):
"""Update task with the results from the remote process."""
task.output_files, task.products = future.get()


def _run_task(task):
"""Run task and return the result."""
def _run_preprocessing_task(task):
output_files = task.run()
return output_files, task.products
return output_files, task.products, task.delayeds


def _copy_preprocessing_results(task, future):
"""Update task with the results from the dask worker."""
task.output_files, task.products, task.delayeds = future.result()
1 change: 1 addition & 0 deletions esmvalcore/config/_config_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def validate_diagnostics(
'auxiliary_data_dir': validate_path,
'compress_netcdf': validate_bool,
'config_developer_file': validate_config_developer,
'dask': validate_dict,
'download_dir': validate_path,
'drs': validate_drs,
'exit_on_warning': validate_bool,
Expand Down
Loading