diff --git a/README.md b/README.md index 0fbe0ec..5cb765e 100644 --- a/README.md +++ b/README.md @@ -15,34 +15,68 @@ conda activate ritme make dev ``` -## Model training locally -To train models with a defined configuration in `q2_ritme/config.json` run: +## Model training +The model configuration is defined in `q2_ritme/run_config.json`. If you want to parallelise the training of different model types, we recommend training each model in a separate experiment. If you decide to run several model types in one experiment, be aware that the model types are trained sequentially. So, this will take longer to finish. + +Once you have trained some models, you can check the progress of the trained models in the tracking software you selected (see section #model-tracking). + +To define a suitable model configuration, please find the description of each variable in `q2_ritme/run_config.json` here: + +| Parameter | Description | +|-----------|-------------| +| experiment_tag | Name of the experiment. | +| host_id | Column name for unique host_id in the metadata. | +| target | Column name of the target variable in the metadata. | +| ls_model_types | List of model types to explore sequentially - options include "linreg", "trac", "xgb", "nn_reg", "nn_class", "nn_corn" and "rf". | +| models_to_evaluate_separately | List of models to evaluate separately during iterative learning - only possible for "xgb", "nn_reg", "nn_class" and "nn_corn". | +| num_trials | Total number of trials to try per model type: the larger this value the more space of the complete search space can be searched. | +| max_cuncurrent_trials | Maximal number of concurrent trials to run. | +| path_to_ft | Path to the feature table file. | +| path_to_md | Path to the metadata file. | +| path_to_phylo | Path to the phylogenetic tree file. | +| path_to_tax | Path to the taxonomy file. | +| seed_data | Seed for data-related random operations. | +| seed_model | Seed for model-related random operations. | +| test_mode | Boolean flag to indicate if running in test mode. | +| tracking_uri | Which platform to use for experiment tracking either "wandb" for WandB or "mlruns" for MLflow. See #model-tracking for set-up instructions. | +| train_size | Fraction of data to use for training (e.g., 0.8 for 80% train, 20% test split). | + +### Local training +To locally train models with a defined configuration in `q2_ritme/config.json` run: ```` -python q2_ritme/run_n_eval_tune.py --config q2_ritme/run_config.json +./launch_local.sh q2_ritme/config.json ```` -Once you have trained some models, you can check the progress of the trained models by launching `mlflow ui --backend-store-uri experiments/mlruns`. - -To evaluate the best trial (trial < experiment) of all launched experiments, run: +To evaluate the best trial (trial < experiment) of all launched experiments locally, run: ```` python q2_ritme/eval_best_trial_overall.py --model_path "experiments/models" ```` -## Model training on HPC with slurm: -Edit file `launch_slurm_cpu.sh` and then run +### Training with slurm on HPC +To train a model with slurm on 1 node, edit the file `launch_slurm_cpu.sh` and then run ```` sbatch launch_slurm_cpu.sh ```` -If you (or your collaborators) plan to launch multiple jobs on the same infrastructure you should set the variable `JOB_NB` in `launch_slurm_cpu.sh` accordingly. This variable makes sure that the assigned ports don't overlap (see [here](https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html#slurm-networking-caveats)). -If you are using SLURM and get the following error returned: "RuntimeError: can't start new thread" -it is probably caused by your hardware. Try decreasing the CPUs allocated to the job and/or decrease the variable `max_concurrent_trials` in `tune_models.py`. -If you are using SLURM and your error message contains this: "The process is killed by SIGKILL by OOM killer due to high memory usage", you should increase the assigned memory per CPU (`--mem-per-cpu`). +To train a model with slurm on multiple nodes or to enable running of multiple ray instances on the same HPC, you can use: `sbatch launch_slurm_cpu_multi_node.sh`. If you (or your collaborators) plan to launch multiple jobs on the same infrastructure you should set the variable `JOB_NB` in `launch_slurm_cpu_multi_node.sh` accordingly. This variable makes sure that the assigned ports don't overlap (see [here](https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html#slurm-networking-caveats)). Currently, the script allows for 3 parallel ray slurm jobs to be executed. +**Note:** training a model with slurm on multiple nodes can be very specific to your infrastructure. So you might need to adjust this bash script to your set-up. + +#### Some common slurm errors: +If you are using SLURM and ... +* ... get the following error returned: "RuntimeError: can't start new thread" it is probably caused by thread limits of the cluster. You can try increasing the number of threads allowed `ulimit -u` in `launch_slurm_cpu.sh` and/or decrease the variable `max_concurrent_trials` in `q2_ritme/config.json`. In case neither helps, it might be worth contacting the cluster administrators. + +* ... your error message contains this: "The process is killed by SIGKILL by OOM killer due to high memory usage", you should increase the assigned memory per CPU (`--mem-per-cpu`) in `launch_slurm_cpu.sh`. ## Model tracking -In the config file you can choose to track your trials with MLflow (tracking_uri=="mlruns") or with WandB (tracking_uri=="wandb"). In case of using WandB you need to store your `WANDB_API_KEY` & `WANDB_ENTITY` as a environment variable in `.env`. Make sure to ignore this file in version control (add to `.gitignore`)! +In the config file you can choose to track your trials with MLflow (tracking_uri=="mlruns") or with WandB (tracking_uri=="wandb"). + +### MLflow +In case of using MLflow you can view your models with `mlflow ui --backend-store-uri experiments/mlruns`. For more information check out the [official MLflow documentation](https://mlflow.org/docs/latest/index.html). + +### WandB +In case of using WandB you need to store your `WANDB_API_KEY` & `WANDB_ENTITY` as a environment variable in `.env`. Make sure to ignore this file in version control (add to `.gitignore`)! -The `WANDB_ENTITY` is the project name you would like to store the results in. For more information on this parameter see the official webpage from WandB initialization [here](https://docs.wandb.ai/ref/python/init). +The `WANDB_ENTITY` is the project name you would like to store the results in. For more information on this parameter see the official webpage for initializing WandB [here](https://docs.wandb.ai/ref/python/init). Also if you are running WandB from a HPC, you might need to set the proxy URLs to your respective URLs by exporting these variables: ``` @@ -50,13 +84,14 @@ export HTTPS_PROXY=http://proxy.example.com:8080 export HTTP_PROXY=http://proxy.example.com:8080 ```` -## Code test coverage +## Developers topics - to be removed prior to publication +### Code test coverage To run test coverage with Code Gutters in VScode run: ```` pytest --cov=q2_ritme q2_ritme/tests/ --cov-report=xml:coverage.xml ```` -## Call graphs +### Call graphs To create a call graph for all functions in the package, run the following commands: ```` pip install pyan3==1.1.1 @@ -65,6 +100,6 @@ pyan3 q2_ritme/**/*.py --uses --no-defines --colored --grouped --annotated --svg ```` (Note: different other options to create call graphs were tested such as code2flow and snakeviz. However, these although properly maintained didn't directly output call graphs such as pyan3 did.) -## Background -### Why ray tune? +### Background +#### Why ray tune? "By using tuning libraries such as Ray Tune we can try out combinations of hyperparameters. Using sophisticated search strategies, these parameters can be selected so that they are likely to lead to good results (avoiding an expensive exhaustive search). Also, trials that do not perform well can be preemptively stopped to reduce waste of computing resources. Lastly, Ray Tune also takes care of training these runs in parallel, greatly increasing search speed." [source](https://docs.ray.io/en/latest/tune/examples/tune-xgboost.html#tune-xgboost-ref) diff --git a/launch_local.sh b/launch_local.sh new file mode 100644 index 0000000..59ffe96 --- /dev/null +++ b/launch_local.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + exit 1 +fi + +CONFIG=$1 +PORT=6379 + +if ! nc -z localhost $PORT; then + echo "Starting Ray on port $PORT" + ray start --head --port=$PORT --dashboard-port=0 +else + echo "Ray is already running on port $PORT" +fi + +OUTNAME=$(echo "$CONFIG" | sed -n 's/.*\/\([^/]*\)\.json/\1/p') +python q2_ritme/run_n_eval_tune.py --config "$CONFIG" > x_"$OUTNAME"_out.txt 2>&1 diff --git a/launch_slurm_cpu.sh b/launch_slurm_cpu.sh index da1d1d5..1c71589 100644 --- a/launch_slurm_cpu.sh +++ b/launch_slurm_cpu.sh @@ -3,8 +3,8 @@ #SBATCH --job-name="run_config" #SBATCH -A partition_name #SBATCH --nodes=1 -#SBATCH --cpus-per-task=20 -#SBATCH --time=23:59:59 +#SBATCH --cpus-per-task=100 +#SBATCH --time=119:59:59 #SBATCH --mem-per-cpu=4096 #SBATCH --output="%x_out.txt" #SBATCH --open-mode=append @@ -17,91 +17,12 @@ echo "SLURM_GPUS_PER_TASK: $SLURM_GPUS_PER_TASK" # ! USER SETTINGS HERE # -> config file to use CONFIG="q2_ritme/run_config.json" -# -> count of this concurrent job launched on same infrastructure -# -> only these values are allowed: 1, 2, 3 - since below ports are -# -> otherwise taken or not allowed -JOB_NB=1 # if your number of threads are limited increase as needed ulimit -u 60000 ulimit -n 524288 # ! USER END __________ -# __doc_head_address_start__ -# script was edited from: -# https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html - -# Getting the node names -nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") -nodes_array=($nodes) - -head_node=${nodes_array[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) - -# if we detect a space character in the head node IP, we'll -# convert it to an ipv4 address. This step is optional. -if [[ "$head_node_ip" == *" "* ]]; then -IFS=' ' read -ra ADDR <<<"$head_node_ip" -if [[ ${#ADDR[0]} -gt 16 ]]; then - head_node_ip=${ADDR[1]} -else - head_node_ip=${ADDR[0]} -fi -echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" -fi -# __doc_head_address_end__ - -# __doc_head_ray_start__ -port=$((6378 + JOB_NB)) -node_manager_port=$((6600 + JOB_NB * 100)) -object_manager_port=$((6601 + JOB_NB * 100)) -ray_client_server_port=$((1 + JOB_NB * 10000)) -redis_shard_ports=$((6602 + JOB_NB * 100)) -min_worker_port=$((2 + JOB_NB * 10000)) -max_worker_port=$((9999 + JOB_NB * 10000)) -dashboard_port=$((8265 + JOB_NB)) - -ip_head=$head_node_ip:$port -export ip_head -echo "IP Head: $ip_head" - -echo "Starting HEAD at $head_node" -srun --nodes=1 --ntasks=1 -w "$head_node" \ - ray start --head --node-ip-address="$head_node_ip" \ - --port=$port \ - --node-manager-port=$node_manager_port \ - --object-manager-port=$object_manager_port \ - --ray-client-server-port=$ray_client_server_port \ - --redis-shard-ports=$redis_shard_ports \ - --min-worker-port=$min_worker_port \ - --max-worker-port=$max_worker_port \ - --dashboard-port=$dashboard_port \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_TASK:-0}" --block & -# __doc_head_ray_end__ - -# __doc_worker_ray_start__ -# optional, though may be useful in certain versions of Ray < 1.0. -sleep 10 - -# number of nodes other than the head node -worker_num=$((SLURM_JOB_NUM_NODES - 1)) - -for ((i = 1; i <= worker_num; i++)); do - node_i=${nodes_array[$i]} - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" \ - ray start --address "$ip_head" \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_TASK:-0}" --block & - sleep 5 -done -# __doc_worker_ray_end__ - -# Output the dashboard URL -dashboard_url="http://${head_node_ip}:${dashboard_port}" -export RAY_DASHBOARD_URL="$dashboard_url" -echo "Ray Dashboard URL: $RAY_DASHBOARD_URL" - -# __doc_script_start__ python -u q2_ritme/run_n_eval_tune.py --config $CONFIG sstat -j $SLURM_JOB_ID diff --git a/launch_slurm_cpu_multi_node.sh b/launch_slurm_cpu_multi_node.sh new file mode 100644 index 0000000..3dd43fe --- /dev/null +++ b/launch_slurm_cpu_multi_node.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +#SBATCH --job-name="run_config" +#SBATCH -A partition_name +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=100 +#SBATCH --time=23:59:59 +#SBATCH --mem-per-cpu=4096 +#SBATCH --output="%x_out.txt" +#SBATCH --open-mode=append + +set -x + +echo "SLURM_CPUS_PER_TASK: $SLURM_CPUS_PER_TASK" +echo "SLURM_GPUS_PER_TASK: $SLURM_GPUS_PER_TASK" + +# ! USER SETTINGS HERE +# -> config file to use +CONFIG="q2_ritme/run_config.json" +# -> count of this concurrent job launched on same infrastructure +# -> only these values are allowed: 1, 2, 3 - since below ports are +# -> otherwise taken or not allowed +JOB_NB=2 + +# if your number of threads are limited increase as needed +ulimit -u 60000 +ulimit -n 524288 +# ! USER END __________ + +# __doc_head_address_start__ +# script was edited from: +# https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html + +# Getting the node names +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) + +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +# if we detect a space character in the head node IP, we'll +# convert it to an ipv4 address. This step is optional. +if [[ "$head_node_ip" == *" "* ]]; then +IFS=' ' read -ra ADDR <<<"$head_node_ip" +if [[ ${#ADDR[0]} -gt 16 ]]; then + head_node_ip=${ADDR[1]} +else + head_node_ip=${ADDR[0]} +fi +echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" +fi +# __doc_head_address_end__ + +# __doc_head_ray_start__ +port=$((6378 + JOB_NB)) +node_manager_port=$((6600 + JOB_NB * 100)) +object_manager_port=$((6601 + JOB_NB * 100)) +ray_client_server_port=$((1 + JOB_NB * 10000)) +redis_shard_ports=$((6602 + JOB_NB * 100)) +min_worker_port=$((2 + JOB_NB * 10000)) +max_worker_port=$((9999 + JOB_NB * 10000)) +dashboard_port=$((8265 + JOB_NB)) + +ip_head=$head_node_ip:$port +export ip_head +echo "IP Head: $ip_head" + +echo "Starting HEAD at $head_node" +srun --nodes=1 --ntasks=1 -w "$head_node" \ + ray start --head --node-ip-address="$head_node_ip" \ + --port=$port \ + --node-manager-port=$node_manager_port \ + --object-manager-port=$object_manager_port \ + --ray-client-server-port=$ray_client_server_port \ + --redis-shard-ports=$redis_shard_ports \ + --min-worker-port=$min_worker_port \ + --max-worker-port=$max_worker_port \ + --dashboard-port=$dashboard_port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_TASK:-0}" --block & +# __doc_head_ray_end__ + +# __doc_worker_ray_start__ +# optional, though may be useful in certain versions of Ray < 1.0. +sleep 10 + +# number of nodes other than the head node +worker_num=$((SLURM_JOB_NUM_NODES - 1)) + +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" \ + ray start --address "$ip_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_TASK:-0}" --block & + sleep 5 +done +# __doc_worker_ray_end__ + +# Output the dashboard URL +dashboard_url="http://${head_node_ip}:${dashboard_port}" +export RAY_DASHBOARD_URL="$dashboard_url" +echo "Ray Dashboard URL: $RAY_DASHBOARD_URL" + +# __doc_script_start__ +python -u q2_ritme/run_n_eval_tune.py --config $CONFIG +sstat -j $SLURM_JOB_ID + +# get elapsed time of job +echo "TIME COUNTER:" +sacct -j $SLURM_JOB_ID --format=elapsed --allocations diff --git a/launch_slurm_cpu_own_ss.sh b/launch_slurm_cpu_own_ss.sh new file mode 100644 index 0000000..23521a1 --- /dev/null +++ b/launch_slurm_cpu_own_ss.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +#SBATCH --job-name="r_optuna_own_ss_rf" +#SBATCH -A es_bokulich +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=100 +#SBATCH --time=119:59:59 +#SBATCH --mem-per-cpu=4096 +#SBATCH --output="%x_out.txt" +#SBATCH --open-mode=append + +set -x + +echo "SLURM_CPUS_PER_TASK: $SLURM_CPUS_PER_TASK" +echo "SLURM_GPUS_PER_TASK: $SLURM_GPUS_PER_TASK" + +# ! USER SETTINGS HERE +# -> config file to use +CONFIG="q2_ritme/r_optuna_own_ss_rf.json" + +# if your number of threads are limited increase as needed +ulimit -u 60000 +ulimit -n 524288 +# ! USER END __________ + +python -u q2_ritme/run_n_eval_tune.py --config $CONFIG +sstat -j $SLURM_JOB_ID + +# get elapsed time of job +echo "TIME COUNTER:" +sacct -j $SLURM_JOB_ID --format=elapsed --allocations diff --git a/q2_ritme/evaluate_models.py b/q2_ritme/evaluate_models.py index d5b7740..90d802a 100644 --- a/q2_ritme/evaluate_models.py +++ b/q2_ritme/evaluate_models.py @@ -142,10 +142,7 @@ def select(self, data, split): # assign self.train_selected_fts to be able to run select on test set later train_selected = select_microbial_features( data, - self.data_config["data_selection"], - self.data_config["data_selection_i"], - self.data_config["data_selection_q"], - self.data_config["data_selection_t"], + self.data_config, ft_prefix, ) self.train_selected_fts = train_selected.columns diff --git a/q2_ritme/feature_space/_process_train.py b/q2_ritme/feature_space/_process_train.py index 2b73e81..2a30cd2 100644 --- a/q2_ritme/feature_space/_process_train.py +++ b/q2_ritme/feature_space/_process_train.py @@ -4,7 +4,6 @@ _find_most_nonzero_feature_idx, transform_microbial_features, ) -from q2_ritme.feature_space.utils import _update_config from q2_ritme.process_data import split_data_by_host @@ -23,16 +22,9 @@ def process_train(config, train_val, target, host_id, tax, seed_data): print(f"Number of features after aggregation: {len(ft_agg.columns)}") # SELECT - # adjust data_selection config dependencies by main method selected - # ! parameters are switched to metrics if they are changed during process_train! - config = _update_config(config) - ft_selected = select_microbial_features( ft_agg, - config["data_selection"], - config["data_selection_i"], - config["data_selection_q"], - config["data_selection_t"], + config, feat_prefix, ) print(f"Number of features after selection: {len(ft_selected.columns)}") diff --git a/q2_ritme/feature_space/select_features.py b/q2_ritme/feature_space/select_features.py index 7c2a7d6..f30a525 100644 --- a/q2_ritme/feature_space/select_features.py +++ b/q2_ritme/feature_space/select_features.py @@ -147,20 +147,30 @@ def find_features_to_group_by_variance_threshold( ) -def select_microbial_features(feat, method, i, q, t, ft_prefix): - if method is None: - # return original feature table - return feat.copy() - # if i larger than max index of feat.columns raise warning - if i is not None and i > len(feat.columns): +def _reset_i_too_large(i: int, feat: pd.DataFrame): + if i > len(feat.columns): warnings.warn( f"Selected i={i} is larger than number of " f"features. So it is set to the max. possible value: {len(feat.columns)}." ) - i = len(feat.columns) + return len(feat.columns) + return i + + +def select_microbial_features(feat, config, ft_prefix): + method = config["data_selection"] + + if method is None: + # return original feature table + return feat.copy() group_name = "_low_abun" if method.startswith("abundance") else "_low_var" + if method.endswith("_ith") or method.endswith("_topi"): + # ! HOTFIX: should be done more elegantly such that Optuna knows about + # ! this + i = _reset_i_too_large(config["data_selection_i"], feat) + if method == "abundance_ith": group_ft_ls = find_features_to_group_by_abundance_ith(feat, i) elif method == "variance_ith": @@ -170,13 +180,21 @@ def select_microbial_features(feat, method, i, q, t, ft_prefix): elif method == "variance_topi": group_ft_ls = find_features_to_group_by_variance_topi(feat, i) elif method == "abundance_quantile": - group_ft_ls = find_features_to_group_by_abundance_quantile(feat, q) + group_ft_ls = find_features_to_group_by_abundance_quantile( + feat, config["data_selection_q"] + ) elif method == "variance_quantile": - group_ft_ls = find_features_to_group_by_variance_quantile(feat, q) + group_ft_ls = find_features_to_group_by_variance_quantile( + feat, config["data_selection_q"] + ) elif method == "abundance_threshold": - group_ft_ls = find_features_to_group_by_abundance_threshold(feat, t) + group_ft_ls = find_features_to_group_by_abundance_threshold( + feat, config["data_selection_t"] + ) elif method == "variance_threshold": - group_ft_ls = find_features_to_group_by_variance_threshold(feat, t) + group_ft_ls = find_features_to_group_by_variance_threshold( + feat, config["data_selection_t"] + ) else: raise ValueError(f"Unknown method: {method}.") diff --git a/q2_ritme/feature_space/utils.py b/q2_ritme/feature_space/utils.py index 903dbf3..69e766e 100644 --- a/q2_ritme/feature_space/utils.py +++ b/q2_ritme/feature_space/utils.py @@ -18,27 +18,3 @@ def _biom_to_df(biom_tab: biom.Table) -> pd.DataFrame: index=biom_tab.ids(axis="sample"), columns=biom_tab.ids(axis="observation"), ) - - -def _update_config(config): - """Adjust data_selection config dependencies by main method selected""" - data_selection = config.get("data_selection") - - if data_selection is None: - config["data_selection_i"] = None - config["data_selection_q"] = None - config["data_selection_t"] = None - elif data_selection.endswith("_ith") or data_selection.endswith("_topi"): - config["data_selection_i"] = config["dsi_option"] - config["data_selection_q"] = None - config["data_selection_t"] = None - elif data_selection.endswith("_quantile"): - config["data_selection_i"] = None - config["data_selection_q"] = config["dsq_option"] - config["data_selection_t"] = None - elif data_selection.endswith("_threshold"): - config["data_selection_i"] = None - config["data_selection_q"] = None - config["data_selection_t"] = config["dst_option"] - - return config diff --git a/q2_ritme/model_space/static_searchspace.py b/q2_ritme/model_space/static_searchspace.py index b462a6b..db9b912 100644 --- a/q2_ritme/model_space/static_searchspace.py +++ b/q2_ritme/model_space/static_searchspace.py @@ -1,141 +1,141 @@ -from ray import tune +from typing import Any, Dict, Optional -def get_data_eng_space(tax, test_mode=False): +def _get_dependent_data_eng_space(trial, data_selection: str) -> None: + if data_selection.endswith("_ith") or data_selection.endswith("_topi"): + trial.suggest_int("data_selection_i", 1, 20) + elif data_selection.endswith("_quantile"): + trial.suggest_float("data_selection_q", 0.5, 0.9, step=0.1) + elif data_selection.endswith("_threshold"): + trial.suggest_float("data_selection_t", 0.00001, 0.01, log=True) + + +def get_data_eng_space(trial, tax, test_mode: bool = False) -> None: if test_mode: # note: test mode can be adjusted to whatever one wants to test - return { - "data_aggregation": None, - "data_selection": tune.grid_search( - [ - None, - "abundance_ith", - # "variance_ith", - # "abundance_topi", - # "variance_topi", - # "abundance_quantile", - # "variance_quantile", - # "abundance_threshold", - # "variance_threshold", - ] - ), - "dsi_option": tune.choice([1, 5]), - "dsq_option": tune.choice([0.5, 0.75]), - "dst_option": tune.choice([0.001, 0.0001]), - "data_transform": None, - } - return { - # grid search specified here checks all options: so new nb_trials= - # num_trials * nb of options w gridsearch * nb of model types - "data_aggregation": tune.grid_search( - [None, "tax_class", "tax_order", "tax_family", "tax_genus"] + data_selection = trial.suggest_categorical( + "data_selection", ["abundance_ith", "variance_threshold"] ) - if not tax.empty - else None, - "data_selection": tune.grid_search( - [ - None, - "abundance_ith", - "variance_ith", - "abundance_topi", - "variance_topi", - "abundance_quantile", - "variance_quantile", - "abundance_threshold", - "variance_threshold", - ] - ), - # todo: adjust the i, q and t ranges to more sophisticated quantities - "dsi_option": tune.choice([1, 3, 5, 10]), - "dsq_option": tune.choice([0.5, 0.75, 0.9, 0.95]), - "dst_option": tune.choice( - [ - 0.01, - 0.005, - 0.001, - 0.0005, - 0.0001, - 0.00005, - 0.00001, - ] - ), - "data_transform": tune.grid_search([None, "clr", "ilr", "alr", "pa"]), - } + if data_selection is not None: + _get_dependent_data_eng_space(trial, data_selection) + trial.suggest_categorical("data_aggregation", [None]) + trial.suggest_categorical("data_transform", [None]) + return None -def get_linreg_space(tax, test_mode=False): - data_eng_space = get_data_eng_space(tax, test_mode) - return dict( - model="linreg", - **data_eng_space, - **{ - "fit_intercept": True, - # alpha controls overall regularization strength, alpha = 0 is equivalent to - # an ordinary least square. large alpha -> more regularization - "alpha": tune.loguniform(1e-5, 1.0), - # balance between L1 and L2 reg., when =0 -> L2, when =1 -> L1 - "l1_ratio": tune.uniform(0, 1), - }, + # feature aggregation + data_aggregation_options = ( + [None] + if tax.empty + else [None, "tax_class", "tax_order", "tax_family", "tax_genus"] ) + trial.suggest_categorical("data_aggregation", data_aggregation_options) + # feature selection + data_selection_options = [ + None, + "abundance_ith", + "variance_ith", + "abundance_topi", + "variance_topi", + "abundance_quantile", + "variance_quantile", + "abundance_threshold", + "variance_threshold", + ] + data_selection = trial.suggest_categorical("data_selection", data_selection_options) + if data_selection is not None: + _get_dependent_data_eng_space(trial, data_selection) -def get_rf_space(tax, test_mode=False): - data_eng_space = get_data_eng_space(tax, test_mode) - return dict( - model="rf", - **data_eng_space, - **{ - "n_estimators": tune.randint(50, 300), - "max_depth": tune.randint(2, 32), - "min_samples_split": tune.choice([0.001, 0.01, 0.1]), - "min_samples_leaf": tune.choice([0.0001, 0.001]), - "max_features": tune.choice([None, "sqrt", "log2", 0.1, 0.2, 0.5, 0.8]), - "min_impurity_decrease": tune.choice([0.0001, 0.001, 0.01]), - "bootstrap": tune.choice([True, False]), - }, - ) + # feature transform + trial.suggest_categorical("data_transform", [None, "clr", "ilr", "alr", "pa"]) + return None -def get_nn_space(tax, model_name, test_mode=False): - data_eng_space = get_data_eng_space(tax, test_mode) - max_layers = 12 - nn_space = { - # Sample random uniformly between [1,9] rounding to multiples of 3 - "n_hidden_layers": tune.qrandint(1, max_layers, 3), - "learning_rate": tune.loguniform(1e-5, 1e-1), - "batch_size": tune.choice([32, 64, 128]), - "epochs": tune.choice([10, 50, 100, 200]), - } - # first and last layer are fixed by shape of features and target - for i in range(0, max_layers): - # todo: increase! - nn_space[f"n_units_hl{i}"] = tune.randint(3, 64) - return dict(model=model_name, **data_eng_space, **nn_space) - - -def get_xgb_space(tax, test_mode=False): - data_eng_space = get_data_eng_space(tax, test_mode) - return dict( - model="xgb", - **data_eng_space, - **{ - "objective": "reg:squarederror", - # value between 2 and 6 is often a good starting point - "max_depth": tune.randint(3, 10), - # depends on sample size: 0 - 2% # todo make depends on sample number - "min_child_weight": tune.randint(0, 4), - "subsample": tune.choice([0.7, 0.8, 0.9, 1.0]), - "eta": tune.choice([0.01, 0.05, 0.1, 0.2, 0.3]), - # "n_estimators": tune.choice([5, 10, 20, 50]) - # todo add: nb gradient boosted trees - }, +def get_linreg_space(trial, tax, test_mode: bool = False) -> Dict[str, str]: + get_data_eng_space(trial, tax, test_mode) + + # alpha controls overall regularization strength, alpha = 0 is equivalent to + # an ordinary least square. large alpha -> more regularization + trial.suggest_float("alpha", 0, 1) + # balance between L1 and L2 reg., when =0 -> L2, when =1 -> L1 + trial.suggest_float("l1_ratio", 0, 1) + + return {"model": "linreg"} + + +def get_rf_space(trial, tax, test_mode: bool = False) -> Dict[str, str]: + get_data_eng_space(trial, tax, test_mode) + + # number of trees in forest: the more the higher computational costs + trial.suggest_int("n_estimators", 50, 300) + + # max depths of the tree: the higher the higher probab of overfitting + trial.suggest_int("max_depth", 5, 50) + # min number of samples requires to split internal node: small + # values higher probab of overfitting + trial.suggest_float("min_samples_split", 0.01, 0.1, step=0.01) + + # min # samples requires at leaf node: small values higher probab + # of overfitting + trial.suggest_categorical("min_samples_leaf", [0.005, 0.01, 0.05, 0.1]) + + # max # features to consider when looking for best split: small can + # reduce overfitting + trial.suggest_categorical("max_features", [None, "sqrt", "log2", 0.1, 0.2, 0.5]) + + # node split occurs if impurity is >= to this value: large values + # prevent overfitting + trial.suggest_float("min_impurity_decrease", 0.01, 0.1, step=0.01) + + trial.suggest_categorical("bootstrap", [True, False]) + + return {"model": "rf"} + + +def get_nn_space( + trial, tax, model_name: str, test_mode: bool = False +) -> Dict[str, str]: + get_data_eng_space(trial, tax, test_mode) + max_layers = 30 + # Sample random uniformly between [1,max_layers] rounding to multiples of 5 + trial.suggest_int("n_hidden_layers", 1, max_layers, step=5) + trial.suggest_categorical( + "learning_rate", [0.01, 0.005, 0.001, 0.0005, 0.0001, 0.00005, 0.00001] ) + trial.suggest_categorical("batch_size", [32, 64, 128, 256]) + trial.suggest_categorical("epochs", [10, 50, 100, 200]) + + # first and last layer are fixed by shape of features and target + for i in range(max_layers): + trial.suggest_categorical(f"n_units_hl{i}", [32, 64, 128, 256, 512]) + + return {"model": model_name} + +def get_xgb_space(trial, tax, test_mode: bool = False) -> Dict[str, Any]: + get_data_eng_space(trial, tax, test_mode) -def get_trac_space(tax, test_mode=False): + # value between 2 and 6 is often a good starting point + trial.suggest_int("max_depth", 2, 10) + + # depends on sample size: 0 - 2% # todo make depends on sample number + trial.suggest_int("min_child_weight", 0, 4) + trial.suggest_categorical("subsample", [0.7, 0.8, 0.9, 1.0]) + trial.suggest_categorical("eta", [0.01, 0.05, 0.1, 0.2, 0.3]) + trial.suggest_int("num_parallel_tree", 1, 3, step=1) + + return {"model": "xgb"} + + +def get_trac_space(trial, tax, test_mode: bool = False) -> Dict[str, Any]: # no feature_transformation to be used for trac # data_aggregate=taxonomy not an option because tax tree does not match with # regards to feature IDs here + + # with loguniform: sampled values are more densely concentrated + # towards the lower end of the range + trial.suggest_float("lambda", 1e-3, 1.0, log=True) data_eng_space_trac = { "data_aggregation": None, "data_selection": None, @@ -144,24 +144,22 @@ def get_trac_space(tax, test_mode=False): "data_selection_t": None, "data_transform": None, } - return dict( - model="trac", - **data_eng_space_trac, - **{ - # with loguniform: sampled values are more densely concentrated - # towards the lower end of the range - "lambda": tune.loguniform(1e-3, 1.0) - }, - ) + return {"model": "trac", **data_eng_space_trac} -def get_search_space(tax, test_mode=False): - return { - "xgb": get_xgb_space(tax, test_mode), - "nn_reg": get_nn_space(tax, "nn_reg", test_mode), - "nn_class": get_nn_space(tax, "nn_class", test_mode), - "nn_corn": get_nn_space(tax, "nn_corn", test_mode), - "linreg": get_linreg_space(tax, test_mode), - "rf": get_rf_space(tax, test_mode), - "trac": get_trac_space(tax, test_mode), - } +def get_search_space( + trial, model_type: str, tax, test_mode: bool = False +) -> Optional[Dict[str, Any]]: + """Creates the search space""" + if model_type in ["xgb", "linreg", "rf", "trac"]: + space_functions = { + "xgb": get_xgb_space, + "linreg": get_linreg_space, + "rf": get_rf_space, + "trac": get_trac_space, + } + return space_functions[model_type](trial, tax, test_mode) + elif model_type in ["nn_reg", "nn_class", "nn_corn"]: + return get_nn_space(trial, tax, model_type, test_mode) + else: + raise ValueError(f"Model type {model_type} not supported.") diff --git a/q2_ritme/model_space/static_trainables.py b/q2_ritme/model_space/static_trainables.py index df06e7f..391d385 100644 --- a/q2_ritme/model_space/static_trainables.py +++ b/q2_ritme/model_space/static_trainables.py @@ -149,7 +149,7 @@ def train_linreg( linreg = ElasticNet( alpha=config["alpha"], l1_ratio=config["l1_ratio"], - fit_intercept=config["fit_intercept"], + fit_intercept=True, ) linreg.fit(X_train, y_train) diff --git a/q2_ritme/run_config.json b/q2_ritme/run_config.json index c6ab563..b984ce5 100644 --- a/q2_ritme/run_config.json +++ b/q2_ritme/run_config.json @@ -10,13 +10,14 @@ "nn_corn", "rf" ], + "max_cuncurrent_trials": 5, "models_to_evaluate_separately": [ "xgb", "nn_reg", "nn_class", "nn_corn" ], - "num_trials": 10, + "num_trials": 500, "path_to_ft": "experiments/data/all_otu_table_filt.qza", "path_to_md": "experiments/data/metadata_proc_v20240323_r0_r3_le_2yrs.tsv", "path_to_phylo": "experiments/data/silva-138-99-rooted-tree.qza", diff --git a/q2_ritme/run_n_eval_tune.py b/q2_ritme/run_n_eval_tune.py index e69d78e..51c54aa 100644 --- a/q2_ritme/run_n_eval_tune.py +++ b/q2_ritme/run_n_eval_tune.py @@ -79,6 +79,7 @@ def run_n_eval_tune(config_path): # number of trials to run per model type * grid_search parameters in # @_static_searchspace config["num_trials"], + config["max_cuncurrent_trials"], model_types=config["ls_model_types"], fully_reproducible=False, test_mode=config["test_mode"], diff --git a/q2_ritme/tests/test_feature_space.py b/q2_ritme/tests/test_feature_space.py index 7f915fe..bd60aea 100644 --- a/q2_ritme/tests/test_feature_space.py +++ b/q2_ritme/tests/test_feature_space.py @@ -43,7 +43,7 @@ presence_absence, transform_microbial_features, ) -from q2_ritme.feature_space.utils import _biom_to_df, _df_to_biom, _update_config +from q2_ritme.feature_space.utils import _biom_to_df, _df_to_biom class TestUtils(TestPluginBase): @@ -71,76 +71,6 @@ def test_df_to_biom(self): obs_biom_table = _df_to_biom(self.true_df) assert obs_biom_table == self.true_biom_table - @parameterized.expand( - ["abundance_ith", "variance_ith", "abundance_topi", "variance_topi"] - ) - def test_update_config_i(self, method): - config = { - "data_selection": method, - "dsi_option": 1, - "dsq_option": 0.5, - "dst_option": 0.1, - } - expected_config = { - **config, - "data_selection_i": 1, - "data_selection_q": None, - "data_selection_t": None, - } - obs_config = _update_config(config) - self.assertDictEqual(expected_config, obs_config) - - @parameterized.expand(["abundance_quantile", "variance_quantile"]) - def test_update_config_q(self, method): - config = { - "data_selection": method, - "dsi_option": 1, - "dsq_option": 0.5, - "dst_option": 0.1, - } - expected_config = { - **config, - "data_selection_i": None, - "data_selection_q": 0.5, - "data_selection_t": None, - } - obs_config = _update_config(config) - self.assertDictEqual(expected_config, obs_config) - - @parameterized.expand(["abundance_threshold", "variance_threshold"]) - def test_update_config_t(self, method): - config = { - "data_selection": method, - "dsi_option": 1, - "dsq_option": 0.5, - "dst_option": 0.1, - } - expected_config = { - **config, - "data_selection_i": None, - "data_selection_q": None, - "data_selection_t": 0.1, - } - obs_config = _update_config(config) - self.assertDictEqual(expected_config, obs_config) - - def test_update_config_none(self): - method = None - config = { - "data_selection": method, - "data_selection_i": 1, - "data_selection_q": 0.5, - "data_selection_t": 0.1, - } - expected_config = { - "data_selection": method, - "data_selection_i": None, - "data_selection_q": None, - "data_selection_t": None, - } - obs_config = _update_config(config) - self.assertDictEqual(expected_config, obs_config) - class TestTransformMicrobialFeatures(TestPluginBase): package = "q2_ritme.tests" @@ -459,25 +389,27 @@ def test_find_features_to_group_variance_threshold(self, t, expected_features): self.assertEqual(features_to_group, expected_features) def test_select_microbial_features_method_none(self): - obs_ft = select_microbial_features(self.ft, None, None, None, None, "F") + config = {"data_selection": None} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(self.ft, obs_ft) def test_select_microbial_features_none_grouped(self): with self.assertWarnsRegex(Warning, r".* Returning original feature table."): - obs_ft = select_microbial_features( - self.ft, "abundance_ith", 4, None, None, "F" - ) + config = {"data_selection": "abundance_ith", "data_selection_i": 4} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(self.ft, obs_ft) def test_select_microbial_features_i_too_large(self): with self.assertWarnsRegex( Warning, r"Selected i=1000 is larger than number of features.*" ): - select_microbial_features(self.ft, "abundance_ith", 1000, None, None, "F") + config = {"data_selection": "abundance_ith", "data_selection_i": 1000} + select_microbial_features(self.ft, config, "F") def test_select_microbial_features_unknown_method(self): with self.assertRaisesRegex(ValueError, r"Unknown method: FancyMethod."): - select_microbial_features(self.ft, "FancyMethod", 1, None, None, "F") + config = {"data_selection": "FancyMethod"} + select_microbial_features(self.ft, config, "F") def test_select_microbial_features_abundance_ith(self): # expected @@ -486,7 +418,8 @@ def test_select_microbial_features_abundance_ith(self): exp_ft.drop(columns=["F1", "F2"], inplace=True) # observed - obs_ft = select_microbial_features(self.ft, "abundance_ith", 2, None, None, "F") + config = {"data_selection": "abundance_ith", "data_selection_i": 2} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -499,7 +432,8 @@ def test_select_microbial_features_variance_ith(self): exp_ft.drop(columns=ls_group, inplace=True) # observed - obs_ft = select_microbial_features(self.ft, "variance_ith", 1, None, None, "F") + config = {"data_selection": "variance_ith", "data_selection_i": 1} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -511,9 +445,8 @@ def test_select_microbial_features_abundance_topi(self): exp_ft.drop(columns=ls_group, inplace=True) # observed - obs_ft = select_microbial_features( - self.ft, "abundance_topi", 2, None, None, "F" - ) + config = {"data_selection": "abundance_topi", "data_selection_i": 2} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -526,7 +459,8 @@ def test_select_microbial_features_variance_topi(self): exp_ft.drop(columns=ls_group, inplace=True) # observed - obs_ft = select_microbial_features(self.ft, "variance_topi", 1, None, None, "F") + config = {"data_selection": "variance_topi", "data_selection_i": 1} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -538,9 +472,8 @@ def test_select_microbial_features_abundance_quantile(self): exp_ft.drop(columns=ls_group, inplace=True) # observed - obs_ft = select_microbial_features( - self.ft, "abundance_quantile", None, 0.5, None, "F" - ) + config = {"data_selection": "abundance_quantile", "data_selection_q": 0.5} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -553,9 +486,8 @@ def test_select_microbial_features_variance_quantile(self): exp_ft.drop(columns=ls_group, inplace=True) # observed - obs_ft = select_microbial_features( - self.ft, "variance_quantile", None, 0.5, None, "F" - ) + config = {"data_selection": "variance_quantile", "data_selection_q": 0.5} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -567,9 +499,8 @@ def test_select_microbial_features_abundance_threshold(self): exp_ft.drop(columns=ls_group, inplace=True) # observed - obs_ft = select_microbial_features( - self.ft, "abundance_threshold", None, None, 10, "F" - ) + config = {"data_selection": "abundance_threshold", "data_selection_t": 10} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -581,9 +512,8 @@ def test_select_microbial_features_abundance_threshold_all_grouped(self): exp_ft.drop(columns=ls_group, inplace=True) # observed - obs_ft = select_microbial_features( - self.ft, "abundance_threshold", None, None, 100, "F" - ) + config = {"data_selection": "abundance_threshold", "data_selection_t": 100} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -596,9 +526,8 @@ def test_select_microbial_features_variance_threshold(self): exp_ft.drop(columns=ls_group, inplace=True) # observed - obs_ft = select_microbial_features( - self.ft, "variance_threshold", None, None, 3, "F" - ) + config = {"data_selection": "variance_threshold", "data_selection_t": 3} + obs_ft = select_microbial_features(self.ft, config, "F") assert_frame_equal(exp_ft, obs_ft) @@ -612,6 +541,9 @@ def setUp(self): "data_transform": None, "data_aggregation": None, "data_selection": None, + "data_selection_i": None, + "data_selection_q": None, + "data_selection_t": None, } self.train_val = pd.DataFrame( { @@ -669,9 +601,7 @@ def test_process_train_no_feature_engineering( # Assert self._assert_called_with_df(mock_aggregate_features, ft, None, self.tax) - self._assert_called_with_df( - mock_select_features, ft, None, None, None, None, "F" - ) + self._assert_called_with_df(mock_select_features, ft, self.config, "F") self._assert_called_with_df(mock_transform_features, ft, None) self._assert_called_with_df( mock_split_data_by_host, diff --git a/q2_ritme/tests/test_static_searchspace.py b/q2_ritme/tests/test_static_searchspace.py index a28602e..c084f21 100644 --- a/q2_ritme/tests/test_static_searchspace.py +++ b/q2_ritme/tests/test_static_searchspace.py @@ -1,10 +1,27 @@ import pandas as pd +from parameterized import parameterized from qiime2.plugin.testing import TestPluginBase -from ray import tune from q2_ritme.model_space import static_searchspace as ss +class MockTrial: + def __init__(self): + self.params = {} + + def suggest_categorical(self, name, categories): + self.params[name] = categories[0] if categories else None + return self.params[name] + + def suggest_int(self, name, low, high, step=1): + self.params[name] = low + return low + + def suggest_float(self, name, low, high, step=None, log=False): + self.params[name] = low + return low + + class TestStaticSearchSpace(TestPluginBase): package = "q2_ritme.tests" @@ -12,114 +29,119 @@ def setUp(self): super().setUp() self.tax = pd.DataFrame() - def test_get_search_space(self): - search_space = ss.get_search_space(self.tax) - - self.assertIsInstance(search_space, dict) - self.assertIn("xgb", search_space) - self.assertIn("nn_reg", search_space) - self.assertIn("nn_class", search_space) - self.assertIn("nn_corn", search_space) - self.assertIn("linreg", search_space) - self.assertIn("rf", search_space) - - def test_get_data_eng_space_w_tax(self): - tax = pd.DataFrame({"Taxon": ["Bacteria", "Firmicutes", "Clostridia"]}) - data_eng_space = ss.get_data_eng_space(tax) - - self.assertIsInstance(data_eng_space, dict) - self.assertEqual( - data_eng_space["data_aggregation"], - tune.grid_search( - [None, "tax_class", "tax_order", "tax_family", "tax_genus"] - ), - ) - self.assertEqual( - data_eng_space["data_selection"], - tune.grid_search( - [ - None, - "abundance_ith", - "variance_ith", - "abundance_topi", - "variance_topi", - "abundance_quantile", - "variance_quantile", - "abundance_threshold", - "variance_threshold", - ] - ), - ) - self.assertEqual( - data_eng_space["dsi_option"].categories, - [1, 3, 5, 10], - ) - self.assertEqual( - data_eng_space["data_transform"], - tune.grid_search([None, "clr", "ilr", "alr", "pa"]), - ) - - def test_get_data_eng_space_empty_tax(self): - data_eng_space = ss.get_data_eng_space(self.tax) - self.assertEqual(data_eng_space["data_aggregation"], None) - - # todo: add this test once test mode is more clearly defined - # def test_get_data_eng_space_test_mode(self): - # data_eng_space = ss.get_data_eng_space(self.tax, True) - # for key in ["data_aggregation", "data_selection", "data_transform"]: - # self.assertEqual(data_eng_space[key], None) + @parameterized.expand( + [ + ("abundance_ith", "i"), + ("variance_quantile", "q"), + ("abundance_threshold", "t"), + ] + ) + def test_get_dependent_data_eng_space(self, data_selection, expected_suffix): + trial = MockTrial() + ss._get_dependent_data_eng_space(trial, data_selection) + + hyperparam = f"data_selection_{expected_suffix}" + self.assertIn(hyperparam, trial.params) + self.assertIsNotNone(trial.params[hyperparam]) + + def test_get_data_eng_space_test_mode(self): + trial = MockTrial() + ss.get_data_eng_space(trial, self.tax, test_mode=True) + expected_params = {"data_selection", "data_aggregation", "data_transform"} + self.assertTrue(expected_params.issubset(trial.params.keys())) + self.assertIn(trial.params["data_selection"], [None, "abundance_ith"]) + self.assertEqual(trial.params["data_aggregation"], None) + self.assertEqual(trial.params["data_transform"], None) + + def test_get_data_eng_space(self): + trial = MockTrial() + ss.get_data_eng_space(trial, self.tax) + expected_params = {"data_selection", "data_aggregation", "data_transform"} + self.assertTrue(expected_params.issubset(trial.params.keys())) def test_get_linreg_space(self): - linreg_space = ss.get_linreg_space(self.tax) - + trial = MockTrial() + linreg_space = ss.get_linreg_space(trial, self.tax) self.assertIsInstance(linreg_space, dict) self.assertEqual(linreg_space["model"], "linreg") - self.assertIn("data_transform", linreg_space) - self.assertIn("fit_intercept", linreg_space) - self.assertIn("alpha", linreg_space) - self.assertIn("l1_ratio", linreg_space) - - def test_get_trac_space(self): - trac_space = ss.get_trac_space(self.tax) - - self.assertIsInstance(trac_space, dict) - self.assertEqual(trac_space["model"], "trac") - self.assertEqual(trac_space["data_transform"], None) - self.assertIn("lambda", trac_space) + expected_params = {"alpha", "l1_ratio"} + self.assertTrue(expected_params.issubset(trial.params.keys())) def test_get_rf_space(self): - rf_space = ss.get_rf_space(self.tax) - + trial = MockTrial() + rf_space = ss.get_rf_space(trial, self.tax) self.assertIsInstance(rf_space, dict) self.assertEqual(rf_space["model"], "rf") - self.assertIn("data_transform", rf_space) - self.assertIn("n_estimators", rf_space) - self.assertIn("max_depth", rf_space) - self.assertIn("min_samples_split", rf_space) - self.assertIn("min_samples_leaf", rf_space) - self.assertIn("max_features", rf_space) - self.assertIn("min_impurity_decrease", rf_space) - self.assertIn("bootstrap", rf_space) + expected_params = { + "n_estimators", + "max_depth", + "min_samples_split", + "min_samples_leaf", + "max_features", + "min_impurity_decrease", + "bootstrap", + } + self.assertTrue(expected_params.issubset(trial.params.keys())) + + @parameterized.expand( + [ + ("nn_reg",), + ("nn_class",), + ("nn_corn",), + ] + ) + def test_get_nn_space(self, model_type): + trial = MockTrial() + nn_space = ss.get_nn_space(trial, self.tax, model_type) - def test_get_xgb_space(self): - xgb_space = ss.get_xgb_space(self.tax) + self.assertIsInstance(nn_space, dict) + self.assertEqual(nn_space["model"], model_type) + + expected_params = {"n_hidden_layers", "learning_rate", "batch_size", "epochs"} + self.assertTrue(expected_params.issubset(trial.params.keys())) + + self.assertTrue(any(f"n_units_hl{i}" in trial.params for i in range(30))) + def test_get_xgb_space(self): + trial = MockTrial() + xgb_space = ss.get_xgb_space(trial, self.tax) self.assertIsInstance(xgb_space, dict) self.assertEqual(xgb_space["model"], "xgb") - self.assertIn("data_transform", xgb_space) - self.assertIn("objective", xgb_space) - self.assertIn("max_depth", xgb_space) - self.assertIn("min_child_weight", xgb_space) - self.assertIn("subsample", xgb_space) - self.assertIn("eta", xgb_space) + expected_params = { + "max_depth", + "min_child_weight", + "subsample", + "eta", + "num_parallel_tree", + } + self.assertTrue(expected_params.issubset(trial.params.keys())) - def test_get_nn_space(self): - nn_space = ss.get_nn_space(self.tax, "nn_reg") + def test_get_trac_space(self): + trial = MockTrial() + trac_space = ss.get_trac_space(trial, self.tax) + self.assertIsInstance(trac_space, dict) + self.assertEqual(trac_space["model"], "trac") + self.assertIn("lambda", trial.params) + + @parameterized.expand( + [ + ("xgb",), + ("nn_reg",), + ("nn_class",), + ("nn_corn",), + ("linreg",), + ("rf",), + ("trac",), + ] + ) + def test_get_search_space(self, model_type): + trial = MockTrial() + search_space = ss.get_search_space(trial, model_type, self.tax) + self.assertIsInstance(search_space, dict) + self.assertEqual(search_space["model"], model_type) - self.assertIsInstance(nn_space, dict) - self.assertEqual(nn_space["model"], "nn_reg") - self.assertIn("data_transform", nn_space) - self.assertIn("n_hidden_layers", nn_space) - self.assertIn("learning_rate", nn_space) - self.assertIn("batch_size", nn_space) - self.assertIn("epochs", nn_space) + def test_get_search_space_model_not_supported(self): + model_type = "FakeModel" + trial = MockTrial() + with self.assertRaisesRegex(ValueError, "Model type FakeModel not supported."): + ss.get_search_space(trial, model_type, self.tax) diff --git a/q2_ritme/tune_models.py b/q2_ritme/tune_models.py index 89cc2a1..9e6357e 100644 --- a/q2_ritme/tune_models.py +++ b/q2_ritme/tune_models.py @@ -1,15 +1,18 @@ import os import random +from functools import partial import dotenv import numpy as np import pandas as pd +import ray import skbio import torch from ray import air, init, tune from ray.air.integrations.mlflow import MLflowLoggerCallback from ray.air.integrations.wandb import WandbLoggerCallback from ray.tune.schedulers import AsyncHyperBandScheduler, HyperBandScheduler +from ray.tune.search.optuna import OptunaSearch from q2_ritme.model_space import static_searchspace as ss from q2_ritme.model_space import static_trainables as st @@ -39,7 +42,7 @@ def run_trials( tracking_uri, exp_name, trainable, - search_space, + test_mode, train_val, target, host_id, @@ -49,13 +52,14 @@ def run_trials( tree_phylo, path2exp, num_trials, + max_concurrent_trials, fully_reproducible=False, # if True hyperband instead of ASHA scheduler is used - scheduler_grace_period=5, + scheduler_grace_period=10, scheduler_max_t=100, resources=None, ): # since each trial starts it own threads - this should not be set to highly - max_concurrent_trials = min(num_trials, 5) + max_concurrent_trials = min(num_trials, max_concurrent_trials) if resources is None: # if not a slurm process: default values are used all_cpus_avail = get_slurm_resource("SLURM_CPUS_PER_TASK", 1) @@ -86,8 +90,21 @@ def run_trials( # ray instance # todo: configure dashboard here - see "ray dashboard set up" online once # todo: ray (Q2>Py) is updated - context = init(address="auto", include_dashboard=False, ignore_reinit_error=True) - print(context.dashboard_url) + context = init( + address="local", + include_dashboard=False, + ignore_reinit_error=True, + # logging_level=logging.DEBUG, + # log_to_driver=True, + ) + print(f"Ray cluster resources: {ray.cluster_resources()}") + print(f"Dashboard URL at: {context.dashboard_url}") + + # define metric and mode to optimize + metric = "rmse_val" + mode = "min" + + # define schedulers: # note: both schedulers might decide to run more trials than allocated if not fully_reproducible: # AsyncHyperBand enables aggressive early stopping of bad trials. @@ -101,11 +118,21 @@ def run_trials( max_t=scheduler_max_t, ) else: - # ! slower BUT + # ! HyperBandScheduler slower BUT # ! improves the reproducibility of experiments by ensuring that all trials # ! are evaluated in the same order. scheduler = HyperBandScheduler(max_t=scheduler_max_t) + # define search algorithm with search space + # partial function needed to pass additional parameters + define_search_space = partial( + ss.get_search_space, model_type=exp_name, tax=tax, test_mode=test_mode + ) + + search_algo = OptunaSearch( + space=define_search_space, seed=seed_model, metric=metric, mode=mode + ) + storage_path = os.path.abspath(path2exp) experiment_tag = os.path.basename(path2exp) # define callbacks @@ -161,26 +188,23 @@ def run_trials( # ! checkpoint: to store best model - is retrieved in # evaluate_models.py checkpoint_config=air.CheckpointConfig( - checkpoint_score_attribute="rmse_val", - checkpoint_score_order="min", + checkpoint_score_attribute=metric, + checkpoint_score_order=mode, num_to_keep=3, ), callbacks=callbacks, ), - # hyperparameter space: passes config used in trainables - param_space=search_space, tune_config=tune.TuneConfig( - metric="rmse_val", - mode="min", + metric=metric, + mode=mode, # define the scheduler scheduler=scheduler, # number of trials to run - schedulers might decide to run more trials num_samples=num_trials, - # # todo: remove below or change search_alg - # max_concurrent_trials=max_concurrent_trials, - # ! set seed - # todo: set advanced search algo -> here default random - search_alg=tune.search.BasicVariantGenerator(), + # set max concurrent trials to launch + max_concurrent_trials=max_concurrent_trials, + # define search algorithm + search_alg=search_algo, ), ) # ResultGrid output @@ -206,6 +230,7 @@ def run_all_trials( mlflow_uri: str, path_exp: str, num_trials: int, + max_concurrent_trials: int, model_types: list = [ "xgb", "nn_reg", @@ -219,7 +244,6 @@ def run_all_trials( test_mode: bool = False, ) -> dict: results_all = {} - model_search_space = ss.get_search_space(tax, test_mode) # if tax + phylogeny empty we can't run trac if (tax.empty or tree_phylo.children == []) and "trac" in model_types: @@ -237,7 +261,7 @@ def run_all_trials( mlflow_uri, model, model_trainables[model], - model_search_space[model], + test_mode, train_val, target, host_id, @@ -247,6 +271,7 @@ def run_all_trials( tree_phylo, path_exp, num_trials, + max_concurrent_trials, fully_reproducible=fully_reproducible, ) results_all[model] = result