Skip to content

Commit

Permalink
fix(cmorizer): better constructor for parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
pgierz committed Nov 12, 2024
1 parent e88b52a commit cc4f159
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
26 changes: 14 additions & 12 deletions src/pymorize/cmorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@
from prefect_dask import DaskTaskRunner
from rich.progress import track

from .data_request import (
DataRequest,
DataRequestTable,
DataRequestVariable,
IgnoreTableFiles,
)
from .data_request import (DataRequest, DataRequestTable, DataRequestVariable,
IgnoreTableFiles)
from .filecache import fc
from .logging import logger
from .pipeline import Pipeline
Expand Down Expand Up @@ -45,8 +41,11 @@ def __init__(
self.rules = rules_cfg or []
self.pipelines = pipelines_cfg or []

self._post_init_configure_dask()
self._post_init_create_dask_cluster()
self._cluster = None # Dask Cluster, might be set up later
if self._pymorize_cfg.get("parallel", True):
if pymorize_cfg.get("parallel_backend") == "dask":
self._post_init_configure_dask()
self._post_init_create_dask_cluster()
self._post_init_create_pipelines()
self._post_init_create_rules()
self._post_init_read_bare_tables()
Expand Down Expand Up @@ -235,7 +234,8 @@ def _post_init_create_pipelines(self):
pipelines.append(p)
elif isinstance(p, dict):
pl = Pipeline.from_dict(p)
pl.assign_cluster(self._cluster)
if self._cluster is not None:
pl.assign_cluster(self._cluster)
pipelines.append(Pipeline.from_dict(p))
else:
raise ValueError(f"Invalid pipeline configuration for {p}")
Expand Down Expand Up @@ -329,8 +329,9 @@ def add_rule(self, rule):
def add_pipeline(self, pipeline):
if not isinstance(pipeline, Pipeline):
raise TypeError("pipeline must be an instance of Pipeline")
# Assign the cluster to this pipeline:
pipeline.assign_cluster(self._cluster)
if self._cluster is not None:
# Assign the cluster to this pipeline:
pipeline.assign_cluster(self._cluster)
self.pipelines.append(pipeline)

def _rule_for_filepath(self, filepath):
Expand Down Expand Up @@ -387,7 +388,8 @@ def process(self, parallel=None):
if parallel is None:
parallel = self._pymorize_cfg.get("parallel", True)
if parallel:
return self.parallel_process()
parallel_backend = self._pymorize_cfg.get("parallel_backend", "prefect")
return self.parallel_process(backend=parallel_backend)
else:
return self.serial_process()

Expand Down
9 changes: 8 additions & 1 deletion src/pymorize/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,18 @@ def _run_prefect(self, data, rule_spec):
)
cmor_name = rule_spec.get("cmor_name")
rule_name = rule_spec.get("name", cmor_name)
if self._cluster is None:
logger.warning(
"No cluster assigned to this pipeline. Using local Dask cluster."
)
dask_scheduler_address = None
else:
dask_scheduler_address = self._cluster.scheduler

@flow(
flow_run_name=f"{self.name} - {rule_name}",
description=f"{rule_spec.get('description', '')}",
task_runner=DaskTaskRunner(address=self._cluster.scheduler_address),
task_runner=DaskTaskRunner(address=dask_scheduler_address),
on_completion=[self.on_completion],
on_failure=[self.on_failure],
)
Expand Down

0 comments on commit cc4f159

Please sign in to comment.