Skip to content

Commit

Permalink
ADD: WandB option (#25)
Browse files Browse the repository at this point in the history
* fix ray init and add cleanup to slurm jobs

* add wandb and define dashboard port in slurm

* fix thread error

* group wandb by experiment_name

* remove comment
  • Loading branch information
adamovanja authored Aug 13, 2024
1 parent 2874c70 commit 7c6061e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 27 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,27 @@ python q2_ritme/eval_best_trial_overall.py --model_path "experiments/models"
````

## Model training on HPC with slurm:
Edit file `launch_slurm_syn_cpu.sh` and then run
Edit file `launch_slurm_cpu.sh` and then run
````
sbatch launch_slurm_syn_cpu.sh
sbatch launch_slurm_cpu.sh
````
If you (or your collaborators) plan to launch multiple jobs on the same infrastructure you should set the variable `JOB_NB` in `launch_slurm_cpu.sh` accordingly. This variable makes sure that the assigned ports don't overlap (see [here](https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html#slurm-networking-caveats)).

If you are using SLURM and get the following error returned: "RuntimeError: can't start new thread"
it is probably caused by your hardware. Try decreasing the CPUs allocated to the job and/or decrease the variable `max_concurrent_trials` in `tune_models.py`.
If you are using SLURM and your error message contains this: "The process is killed by SIGKILL by OOM killer due to high memory usage", you should increase the assigned memory per CPU (`--mem-per-cpu`).

## Model tracking
In the config file you can choose to track your trials with MLflow (tracking_uri=="mlruns") or with WandB (tracking_uri=="wandb"). In case of using WandB you need to store your `WANDB_API_KEY` & `WANDB_ENTITY` as a environment variable in `.env`. Make sure to ignore this file in version control (add to `.gitignore`)!

The `WANDB_ENTITY` is the project name you would like to store the results in. For more information on this parameter see the official webpage from WandB initialization [here](https://docs.wandb.ai/ref/python/init).

Also if you are running WandB from a HPC, you might need to set the proxy URLs to your respective URLs by exporting these variables:
```
export HTTPS_PROXY=http://proxy.example.com:8080
export HTTP_PROXY=http://proxy.example.com:8080
````
## Code test coverage
To run test coverage with Code Gutters in VScode run:
````
Expand Down
5 changes: 5 additions & 0 deletions ci/recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ requirements:
- lightning
- mlflow
- numpy
- optuna
- packaging
- pandas
- pip
- pytorch
- python-dotenv
- py-xgboost
# todo: update ray to newest once Q2 has migrated to Python 3.10
# note: currently ray is in v2.8.1 restricted by Q2
Expand All @@ -48,6 +50,9 @@ requirements:
- c-lasso
# grpcio pinned due to incompatibility with ray caused by c-lasso
- grpcio==1.51.1
# to enable insights in ray dashboard
- py-spy
- wandb


test:
Expand Down
10 changes: 10 additions & 0 deletions launch_slurm_cpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ CONFIG="q2_ritme/run_config.json"
# -> only these values are allowed: 1, 2, 3 - since below ports are
# -> otherwise taken or not allowed
JOB_NB=1

# if your number of threads are limited increase as needed
ulimit -u 60000
ulimit -n 524288
# ! USER END __________

# __doc_head_address_start__
Expand Down Expand Up @@ -55,6 +59,7 @@ ray_client_server_port=$((1 + JOB_NB * 10000))
redis_shard_ports=$((6602 + JOB_NB * 100))
min_worker_port=$((2 + JOB_NB * 10000))
max_worker_port=$((9999 + JOB_NB * 10000))
dashboard_port=$((8265 + JOB_NB))

ip_head=$head_node_ip:$port
export ip_head
Expand All @@ -70,6 +75,7 @@ srun --nodes=1 --ntasks=1 -w "$head_node" \
--redis-shard-ports=$redis_shard_ports \
--min-worker-port=$min_worker_port \
--max-worker-port=$max_worker_port \
--dashboard-port=$dashboard_port \
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_TASK:-0}" --block &
# __doc_head_ray_end__

Expand All @@ -90,6 +96,10 @@ for ((i = 1; i <= worker_num; i++)); do
done
# __doc_worker_ray_end__

# Output the dashboard URL
dashboard_url="http://${head_node_ip}:${dashboard_port}"
export RAY_DASHBOARD_URL="$dashboard_url"
echo "Ray Dashboard URL: $RAY_DASHBOARD_URL"

# __doc_script_start__
python -u q2_ritme/run_n_eval_tune.py --config $CONFIG
Expand Down
2 changes: 1 addition & 1 deletion q2_ritme/run_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"nn_corn",
"rf"
],
"mlflow_tracking_uri": "mlruns",
"models_to_evaluate_separately": [
"xgb",
"nn_reg",
Expand All @@ -26,5 +25,6 @@
"seed_model": 12,
"target": "age_months",
"test_mode": false,
"tracking_uri": "wandb",
"train_size": 0.8
}
13 changes: 10 additions & 3 deletions q2_ritme/run_n_eval_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ def run_n_eval_tune(config_path):
f"This experiment tag already exists: {config['experiment_tag']}."
"Please use another one."
)

path_mlflow = os.path.join("experiments", config["mlflow_tracking_uri"])
if config["tracking_uri"] == "mlruns":
path_tracker = os.path.join("experiments", config["tracking_uri"])
elif config["tracking_uri"] == "wandb":
path_tracker = "wandb"
else:
raise ValueError(
f"Invalid tracking_uri: {config['tracking_uri']}. Must be "
f"'mlruns' or 'wandb'."
)
path_exp = os.path.join(base_path, config["experiment_tag"])

# ! Load and split data
Expand All @@ -67,7 +74,7 @@ def run_n_eval_tune(config_path):
config["seed_model"],
tax,
tree_phylo,
path_mlflow,
path_tracker,
path_exp,
# number of trials to run per model type * grid_search parameters in
# @_static_searchspace
Expand Down
61 changes: 40 additions & 21 deletions q2_ritme/tune_models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import random

import dotenv
import numpy as np
import pandas as pd
import skbio
import torch
from ray import air, init, shutdown, tune
from ray import air, init, tune
from ray.air.integrations.mlflow import MLflowLoggerCallback
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.tune.schedulers import AsyncHyperBandScheduler, HyperBandScheduler

from q2_ritme.model_space import static_searchspace as ss
Expand Down Expand Up @@ -34,7 +36,7 @@ def get_slurm_resource(resource_name, default_value=0):


def run_trials(
mlflow_tracking_uri, # MLflow with MLflowLoggerCallback
tracking_uri,
exp_name,
trainable,
search_space,
Expand All @@ -52,9 +54,8 @@ def run_trials(
scheduler_max_t=100,
resources=None,
):
# todo: this 10 is imposed by my HPC system - should be made flexible (20
# could also be possible)
max_concurrent_trials = 10
# since each trial starts it own threads - this should not be set to highly
max_concurrent_trials = min(num_trials, 5)
if resources is None:
# if not a slurm process: default values are used
all_cpus_avail = get_slurm_resource("SLURM_CPUS_PER_TASK", 1)
Expand All @@ -75,19 +76,17 @@ def run_trials(
# - xgb, nn_reg, nn_class, nn_corn: parallel processing supported with GPU
# support

if not os.path.exists(mlflow_tracking_uri):
os.makedirs(mlflow_tracking_uri)

# set seed for search algorithms/schedulers
random.seed(seed_model)
np.random.seed(seed_model)
torch.manual_seed(seed_model)

# Initialize Ray with the runtime environment
shutdown()
# shutdown() #can't be used when launching on HPC with externally started
# ray instance
# todo: configure dashboard here - see "ray dashboard set up" online once
# todo: ray (Q2>Py) is updated
context = init(include_dashboard=False, ignore_reinit_error=True)
context = init(address="auto", include_dashboard=False, ignore_reinit_error=True)
print(context.dashboard_url)
# note: both schedulers might decide to run more trials than allocated
if not fully_reproducible:
Expand All @@ -109,6 +108,36 @@ def run_trials(

storage_path = os.path.abspath(path2exp)
experiment_tag = os.path.basename(path2exp)
# define callbacks
if tracking_uri.endswith("mlruns"):
if not os.path.exists(tracking_uri):
os.makedirs(tracking_uri)
callbacks = [
MLflowLoggerCallback(
tracking_uri=tracking_uri,
experiment_name=exp_name,
# below would be double saving: local_dir as artifact here
# save_artifact=True,
tags={"experiment_tag": experiment_tag},
)
]
elif tracking_uri == "wandb":
# load wandb API key from .env file
dotenv.load_dotenv()
api_key = os.getenv("WANDB_API_KEY")
entity = os.getenv("WANDB_ENTITY")
if api_key is None:
raise ValueError("No WANDB_API_KEY found in .env file.")
if entity is None:
raise ValueError("No WANDB_ENTITY found in .env file.")
callbacks = [
WandbLoggerCallback(
api_key=api_key,
entity=entity,
project=experiment_tag,
tags={experiment_tag},
)
]
analysis = tune.Tuner(
# trainable with input parameters passed and set resources
tune.with_resources(
Expand Down Expand Up @@ -136,17 +165,7 @@ def run_trials(
checkpoint_score_order="min",
num_to_keep=3,
),
# ! callback: executing specific tasks (e.g. logging) at specific
# points in training - used in MLflow browser interface
callbacks=[
MLflowLoggerCallback(
tracking_uri=mlflow_tracking_uri,
experiment_name=exp_name,
# below would be double saving: local_dir as artifact here
# save_artifact=True,
tags={"experiment_tag": experiment_tag},
),
],
callbacks=callbacks,
),
# hyperparameter space: passes config used in trainables
param_space=search_space,
Expand Down

0 comments on commit 7c6061e

Please sign in to comment.