Skip to content

Commit

Permalink
Merge pull request #79 from esm-tools/cleanup/config
Browse files Browse the repository at this point in the history
Config Cleanup
  • Loading branch information
pgierz authored Nov 29, 2024
2 parents d1bea94 + eb5fb02 commit a6ac277
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 98 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/CI-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ jobs:
run: |
export HDF5_DEBUG=1
export NETCDF_DEBUG=1
export XARRAY_BACKEND=h5netcdf
export XARRAY_ENGINE=h5netcdf
export PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=300
pytest -vvv -s --cov tests/meta/*.py
- name: Test with pytest (Unit)
run: |
export XARRAY_BACKEND=h5netcdf
export XARRAY_ENGINE=h5netcdf
export PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=300
pytest -vvv -s --cov tests/unit/*.py
- name: Test with pytest (Integration)
run: |
export XARRAY_BACKEND=h5netcdf
export XARRAY_ENGINE=h5netcdf
export PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=300
pytest -vvv -s --cov tests/integration/*.py
- name: Test with doctest
Expand Down
3 changes: 3 additions & 0 deletions examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.nc
slurm*.out
pymorize_report.log
43 changes: 22 additions & 21 deletions examples/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@
from pathlib import Path


def rm_file(fname):
try:
fname.unlink(fname)
print(f"Removed file: {fname}")
except Exception as e:
print(f"Error removing file {fname}: {e}")


def rm_dir(dirname):
try:
shutil.rmtree(dirname)
print(f"Removed directory: {dirname}")
except Exception as e:
print(f"Error removing directory {dirname}: {e}")


def cleanup():
current_dir = Path.cwd()

Expand All @@ -15,34 +31,19 @@ def cleanup():
and item.name.startswith("slurm")
and item.name.endswith("out")
):
try:
item.unlink()
print(f"Removed file: {item}")
except Exception as e:
print(f"Error removing file {item}: {e}")
rm_file(item)
if (
item.is_file()
and item.name.startswith("pymorize")
and item.name.endswith("json")
):
try:
item.unlink()
print(f"Removed file: {item}")
except Exception as e:
print(f"Error removing file {item}: {e}")
rm_file(item)
if item.is_file() and item.name.endswith("nc"):
try:
item.unlink()
print(f"Removed file: {item}")
except Exception as e:
print(f"Error removing file {item}: {e}")

rm_file(item)
if item.name == "pymorize_report.log":
rm_file(item)
elif item.is_dir() and item.name == "logs":
try:
shutil.rmtree(item)
print(f"Removed directory: {item}")
except Exception as e:
print(f"Error removing directory {item}: {e}")
rm_dir(item)
print("Cleanup completed.")


Expand Down
7 changes: 4 additions & 3 deletions examples/pymorize.slurm
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/bin/bash -l
#SBATCH --account=ab0246
#SBATCH --job-name=pymorize-controller # <<< This is the main job, it will launch subjobs if you have Dask enabled.
#SBATCH --account=ab0246 # <<< Adapt this to your computing account!
#SBATCH --partition=compute
#SBATCH --nodes=1
#SBATCH --time=00:30:00
# export PREFECT_SERVER_ALLOW_EPHEMERAL_MODE=False
#SBATCH --time=00:30:00 # <<< You may need more time, adapt as needed!
export PREFECT_SERVER_ALLOW_EPHEMERAL_MODE=True
export PREFECT_SERVER_API_HOST=0.0.0.0
conda activate pymorize
prefect server start &
Expand Down
7 changes: 5 additions & 2 deletions examples/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ pymorize:
# parallel: True
warn_on_no_rule: False
use_flox: True
cluster_mode: fixed
dask_cluster: "slurm"
dask_cluster_scaling_mode: fixed
fixed_jobs: 12
# minimum_jobs: 8
# maximum_jobs: 30
dimensionless_mapping_table: ../data/dimensionless_mappings.yaml
# You can add your own path to the dimensionless mapping table
# If nothing is specified here, it will use the built-in one.
# dimensionless_mapping_table: ../data/dimensionless_mappings.yaml
rules:
- name: paul_example_rule
description: "You can put some text here"
Expand Down
9 changes: 9 additions & 0 deletions src/pymorize/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
"""

import dask
from dask.distributed import LocalCluster
from dask_jobqueue import SLURMCluster

from .logging import logger

CLUSTER_MAPPINGS = {
"local": LocalCluster,
"slurm": SLURMCluster,
}
CLUSTER_SCALE_SUPPORT = {"local": False, "slurm": True}
CLUSTER_ADAPT_SUPPORT = {"local": False, "slurm": True}


def set_dashboard_link(cluster):
"""
Expand Down
86 changes: 58 additions & 28 deletions src/pymorize/cmorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import xarray as xr # noqa: F401
import yaml
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from everett.manager import generate_uppercase_key, get_runtime_config
from prefect import flow, task
from prefect.futures import wait
from rich.progress import track

from .cluster import set_dashboard_link
from .cluster import (CLUSTER_ADAPT_SUPPORT, CLUSTER_MAPPINGS,
CLUSTER_SCALE_SUPPORT, set_dashboard_link)
from .config import PymorizeConfig, PymorizeConfigManager
from .data_request import (DataRequest, DataRequestTable, DataRequestVariable,
IgnoreTableFiles)
Expand Down Expand Up @@ -88,17 +88,21 @@ def __init__(

################################################################################
# Post_Init:
if self._pymorize_cfg("parallel"):
if self._pymorize_cfg("parallel_backend") == "dask":
self._post_init_configure_dask()
self._post_init_create_dask_cluster()
if self._pymorize_cfg("enable_dask"):
logger.debug("Setting up dask configuration...")
self._post_init_configure_dask()
logger.debug("...done!")
logger.debug("Creating dask cluster...")
self._post_init_create_dask_cluster()
logger.debug("...done!")
self._post_init_create_pipelines()
self._post_init_create_rules()
self._post_init_read_bare_tables()
self._post_init_create_data_request()
self._post_init_populate_rules_with_tables()
self._post_init_read_dimensionless_unit_mappings()
self._post_init_data_request_variables()
logger.debug("...post-init done!")
################################################################################

def _post_init_configure_dask(self):
Expand All @@ -120,29 +124,42 @@ def _post_init_configure_dask(self):

def _post_init_create_dask_cluster(self):
# FIXME: In the future, we can support PBS, too.
logger.info("Setting up SLURMCluster...")
self._cluster = SLURMCluster()
logger.info("Setting up dask cluster...")
cluster_name = self._pymorize_cfg("dask_cluster")
ClusterClass = CLUSTER_MAPPINGS[cluster_name]
self._cluster = ClusterClass()
set_dashboard_link(self._cluster)
cluster_mode = self._pymorize_cfg.get("cluster_mode", "adapt")
if cluster_mode == "adapt":
min_jobs = self._pymorize_cfg.get("minimum_jobs", 1)
max_jobs = self._pymorize_cfg.get("maximum_jobs", 10)
self._cluster.adapt(minimum_jobs=min_jobs, maximum_jobs=max_jobs)
elif cluster_mode == "fixed":
jobs = self._pymorize_cfg.get("fixed_jobs", 5)
self._cluster.scale(jobs=jobs)
cluster_scaling_mode = self._pymorize_cfg.get(
"dask_cluster_scaling_mode", "adapt"
)
if cluster_scaling_mode == "adapt":
if CLUSTER_ADAPT_SUPPORT[cluster_name]:
min_jobs = self._pymorize_cfg.get(
"dask_cluster_scaling_minimum_jobs", 1
)
max_jobs = self._pymorize_cfg.get(
"dask_cluster_scaling_maximum_jobs", 10
)
self._cluster.adapt(minimum_jobs=min_jobs, maximum_jobs=max_jobs)
else:
logger.warning(f"{self._cluster} does not support adaptive scaling!")
elif cluster_scaling_mode == "fixed":
if CLUSTER_SCALE_SUPPORT[cluster_name]:
jobs = self._pymorize_cfg.get("dask_cluster_scaling_fixed_jobs", 5)
self._cluster.scale(jobs=jobs)
else:
logger.warning(f"{self._cluster} does not support fixed scaing")
else:
raise ValueError(
"You need to specify adapt or fixed for pymorize.cluster_mode"
"You need to specify adapt or fixed for pymorize.dask_cluster_scaling_mode"
)
# Wait for at least min_jobs to be available...
# FIXME: Client needs to be available here?
logger.info(f"SLURMCluster can be found at: {self._cluster=}")
# FIXME: Include the gateway option if possible
# FIXME: Does ``Client`` needs to be available here?
logger.info(f"Cluster can be found at: {self._cluster=}")
logger.info(f"Dashboard {self._cluster.dashboard_link}")
# NOTE(PG): In CI context, os.getlogin and nodename may not be available (???)

username = getpass.getuser()
nodename = getattr(os.uname(), "nodename", "UNKNOWN")
# FIXME: Include the gateway option if possible
logger.info(
"To see the dashboards run the following command in your computer's "
"terminal:\n"
Expand All @@ -152,7 +169,7 @@ def _post_init_create_dask_cluster(self):

dask_extras = 0
logger.info("Importing Dask Extras...")
if self._pymorize_cfg.get("use_flox", True):
if self._pymorize_cfg.get("enable_flox", True):
dask_extras += 1
logger.info("...flox...")
import flox # noqa: F401
Expand Down Expand Up @@ -337,7 +354,9 @@ def validate(self):
# self._check_rules_for_output_dir()
# FIXME(PS): Turn off this check, see GH #59 (https://tinyurl.com/3z7d8uuy)
# self._check_is_subperiod()
logger.debug("Starting validate....")
self._check_units()
logger.debug("...done!")

def _check_is_subperiod(self):
logger.info("checking frequency in netcdf file and in table...")
Expand Down Expand Up @@ -443,6 +462,7 @@ def from_dict(cls, data):
instance._post_init_create_data_request()
instance._post_init_data_request_variables()
instance._post_init_read_dimensionless_unit_mappings()
logger.debug("Object creation done!")
return instance

def add_rule(self, rule):
Expand Down Expand Up @@ -509,16 +529,23 @@ def check_rules_for_output_dir(self, output_dir):
logger.warning(filepath)

def process(self, parallel=None):
logger.debug("Process start!")
if parallel is None:
parallel = self._pymorize_cfg.get("parallel", True)
if parallel:
parallel_backend = self._pymorize_cfg.get("parallel_backend", "prefect")
return self.parallel_process(backend=parallel_backend)
logger.debug("Parallel processing...")
# FIXME(PG): This is mixed up, hard-coding to prefect for now...
workflow_backend = self._pymorize_cfg.get(
"pipeline_orchestrator", "prefect"
)
logger.debug(f"...with {workflow_backend}...")
return self.parallel_process(backend=workflow_backend)
else:
return self.serial_process()

def parallel_process(self, backend="prefect"):
if backend == "prefect":
logger.debug("About to submit _parallel_process_prefect()")
return self._parallel_process_prefect()
elif backend == "dask":
return self._parallel_process_dask()
Expand All @@ -529,6 +556,8 @@ def _parallel_process_prefect(self):
# prefect_logger = get_run_logger()
# logger = prefect_logger
# @flow(task_runner=DaskTaskRunner(address=self._cluster.scheduler_address))
logger.debug("Defining dynamically generated prefect workflow...")

@flow
def dynamic_flow():
rule_results = []
Expand All @@ -537,6 +566,9 @@ def dynamic_flow():
wait(rule_results)
return rule_results

logger.debug("...done!")

logger.debug("About to return dynamic_flow()...")
return dynamic_flow()

def _parallel_process_dask(self, external_client=None):
Expand Down Expand Up @@ -567,13 +599,11 @@ def _process_rule(self, rule):
# FIXME(PG): This might also be a place we need to consider copies...
rule.match_pipelines(self.pipelines)
data = None
# NOTE(PG): Send in a COPY of the rule, not the original rule
local_rule_copy = copy.deepcopy(rule)
if not len(rule.pipelines) > 0:
logger.error("No pipeline defined, something is wrong!")
for pipeline in rule.pipelines:
logger.info(f"Running {str(pipeline)}")
data = pipeline.run(data, local_rule_copy)
data = pipeline.run(data, rule)
return data

@task
Expand Down
Loading

0 comments on commit a6ac277

Please sign in to comment.