diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 048e2275..38e2c473 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: - uses: actions/checkout@v3 - uses: conda-incubator/setup-miniconda@v2 with: - activate-environment: yaib_updated + activate-environment: yaib environment-file: environment.yml auto-activate-base: false - name: Lint with flake8 @@ -33,11 +33,11 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # the GitHub editor is 127 chars wide - flake8 . --count --max-complexity=14 --max-line-length=127 --statistics + flake8 . --count --max-complexity=30 --max-line-length=127 --statistics # - name: Test with pytest # run: python -m pytest ./tests/recipes # If we want to test running the tool later on # - name: Setup package # run: pip install -e . # - name: Test command line tool - # run: python -m icu_benchmarks.run --help \ No newline at end of file + # run: python -m icu_benchmarks.run --help diff --git a/.gitignore b/.gitignore index 92bd0a44..a0ab243e 100644 --- a/.gitignore +++ b/.gitignore @@ -127,4 +127,5 @@ wandb/ .vscode/launch.json yaib_logs/ *.ckpt -*.csv \ No newline at end of file +*.csv +!demo_data/*/*/attrition.csv \ No newline at end of file diff --git a/README.md b/README.md index d001dfa4..5700f243 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,8 @@ [//]: # (TODO: add coverage once we have some tests ) -Yet another ICU benchmark (YAIB) provides a framework for doing clinical machine learning experiments on Intensive Care Unit ( -ICU) EHR data. +Yet another ICU benchmark (YAIB) provides a framework for doing clinical machine learning experiments on Intensive Care Unit +(ICU) EHR data. We support the following datasets out of the box: @@ -43,7 +43,7 @@ We provide five common tasks for clinical prediction by default: | 5 | Length of Stay (LoS) | Hourly (within 7D) | Regression | New tasks can be easily added. -For the purposes of getting started right away, we include the eICU and MIMIC-III demo datasets in our repository. +To get started right away, we include the eICU and MIMIC-III demo datasets in our repository. The following repositories may be relevant as well: @@ -51,27 +51,25 @@ The following repositories may be relevant as well: - [YAIB-models](https://github.com/rvandewater/YAIB-models): Pretrained models for YAIB. - [ReciPys](https://github.com/rvandewater/ReciPys): Preprocessing package for YAIB pipelines. -For all YAIB related repositories, please see: https://github.com/stars/rvandewater/lists/yaib. +For all YAIB-related repositories, please see: https://github.com/stars/rvandewater/lists/yaib. # 📄Paper -To reproduce the benchmarks in our paper, we refer to: the [ML reproducibility document](PAPER.md). +To reproduce the benchmarks in our paper, we refer to the [ML reproducibility document](PAPER.md). If you use this code in your research, please cite the following publication: ``` -@article{vandewaterYetAnotherICUBenchmark2023, - title = {Yet Another ICU Benchmark: A Flexible Multi-Center Framework for Clinical ML}, - shorttitle = {Yet Another ICU Benchmark}, - url = {http://arxiv.org/abs/2306.05109}, - language = {en}, - urldate = {2023-06-09}, - publisher = {arXiv}, - author = {Robin van de Water and Hendrik Schmidt and Paul Elbers and Patrick Thoral and Bert Arnrich and Patrick Rockenschaub}, - month = jun, - year = {2023}, - note = {arXiv:2306.05109 [cs]}, - keywords = {Computer Science - Machine Learning}, +@inproceedings{vandewaterYetAnotherICUBenchmark2024, + title = {Yet Another ICU Benchmark: A Flexible Multi-Center Framework for Clinical ML}, + shorttitle = {Yet Another ICU Benchmark}, + booktitle = {The Twelfth International Conference on Learning Representations}, + author = {van de Water, Robin and Schmidt, Hendrik Nils Aurel and Elbers, Paul and Thoral, Patrick and Arnrich, Bert and Rockenschaub, Patrick}, + year = {2024}, + month = oct, + urldate = {2024-02-19}, + langid = {english}, } + ``` This paper can also be found on arxiv [2306.05109](https://arxiv.org/abs/2306.05109) @@ -182,17 +180,16 @@ load existing cache files. ``` -icu-benchmarks train \ +icu-benchmarks \ -d demo_data/mortality24/mimic_demo \ -n mimic_demo \ -t BinaryClassification \ -tn Mortality24 \ -m LGBMClassifier \ -hp LGBMClassifier.min_child_samples=10 \ - --generate_cache + --generate_cache \ --load_cache \ --seed 2222 \ - -s 2222 \ -l ../yaib_logs/ \ --tune ``` @@ -224,13 +221,14 @@ wandb agent > Note: You will need to have a wandb account and be logged in to run the above commands. -## Evaluate +## Evaluate or Finetune -It is possible to evaluate a model trained on another dataset. In this case, the source dataset is the demo data from MIMIC and -the target is the eICU demo: +It is possible to evaluate a model trained on another dataset and no additional training is done. +In this case, the source dataset is the demo data from MIMIC and the target is the eICU demo: ``` -icu-benchmarks evaluate \ +icu-benchmarks \ + --eval \ -d demo_data/mortality24/eicu_demo \ -n eicu_demo \ -t BinaryClassification \ @@ -241,9 +239,11 @@ icu-benchmarks evaluate \ -s 2222 \ -l ../yaib_logs \ -sn mimic \ - --source-dir ../yaib_logs/mimic_demo/Mortality24/LGBMClassifier/2022-12-12T15-24-46/fold_0 + --source-dir ../yaib_logs/mimic_demo/Mortality24/LGBMClassifier/2022-12-12T15-24-46/repetition_0/fold_0 ``` +> A similar syntax is used for finetuning, where a model is loaded and then retrained. To run finetuning, replace `--eval` with `-ft`. + ## Models We provide several existing machine learning models that are commonly used for multivariate time-series data. @@ -275,6 +275,8 @@ We appreciate contributions to the project. Please read the [contribution guidel request. # Acknowledgements +This project has been developed partially under the funding of “Gemeinsamer Bundesausschuss (G-BA) Innovationsausschuss” in the framework of “CASSANDRA - Clinical ASSist AND aleRt Algorithms”. +(project number 01VSF20015). We would like to acknowledge the work of Alisher Turubayev, Anna Shopova, Fabian Lange, Mahmut Kamalak, Paul Mattes, and Victoria Ayvasky for adding Pytorch Lightning, Weights and Biases compatibility, and several optional imputation methods to a later version of the benchmark repository. We do not own any of the datasets used in this benchmark. This project uses heavily adapted components of the [HiRID benchmark](https://github.com/ratschlab/HIRID-ICU-Benchmark/). We thank the authors for providing this codebase and diff --git a/configs/experiments/LGBM_Mortality.gin b/configs/experiments/LGBM_Mortality.gin deleted file mode 100644 index 2e5abf3b..00000000 --- a/configs/experiments/LGBM_Mortality.gin +++ /dev/null @@ -1,7 +0,0 @@ -include "configs/tasks/BinaryClassification.gin" -include "configs/models/LGBMClassifier.gin" - -model/hyperparameter.max_depth = 7 -model/hyperparameter.num_leaves = 32 -model/hyperparameter.subsample = 1 -model/hyperparameter.colsample_bytree = 0.66 diff --git a/configs/experiments/LSTM_Mortality.gin b/configs/experiments/LSTM_Mortality.gin deleted file mode 100644 index 7226f37f..00000000 --- a/configs/experiments/LSTM_Mortality.gin +++ /dev/null @@ -1,9 +0,0 @@ -include "configs/tasks/BinaryClassification.gin" -include "configs/models/LSTM.gin" - -# Optimizer params -optimizer/hyperparameter.lr = 1e-4 - -# Encoder params -model/hyperparameter.hidden_dim = 128 -model/hyperparameter.layer_dim = 1 diff --git a/configs/prediction_models/BRFClassifier.gin b/configs/prediction_models/BRFClassifier.gin new file mode 100644 index 00000000..3682bb8c --- /dev/null +++ b/configs/prediction_models/BRFClassifier.gin @@ -0,0 +1,18 @@ +# Settings for ImbLearn Balanced Random Forest Classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @BRFClassifier + +model/hyperparameter.class_to_tune = @BRFClassifier +model/hyperparameter.n_estimators = [50, 100, 250, 500, 750,1000,1500] +model/hyperparameter.max_depth = [3, 5, 10, 15] +model/hyperparameter.min_samples_split = (2, 5, 10) +model/hyperparameter.min_samples_leaf = (1, 2, 4) +model/hyperparameter.max_features = ['sqrt', 'log2', 1.0] +model/hyperparameter.bootstrap = [True, False] +model/hyperparameter.class_weight = [None, 'balanced'] + + diff --git a/configs/prediction_models/CBClassifier.gin b/configs/prediction_models/CBClassifier.gin new file mode 100644 index 00000000..e9abbecd --- /dev/null +++ b/configs/prediction_models/CBClassifier.gin @@ -0,0 +1,15 @@ +# Settings for Catboost classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @CBClassifier + +model/hyperparameter.class_to_tune = @CBClassifier +model/hyperparameter.learning_rate = (1e-4, 0.5, "log") +model/hyperparameter.num_trees = [50, 100, 250, 500, 750,1000,1500] +model/hyperparameter.depth = [3, 5, 10, 15] +model/hyperparameter.scale_pos_weight = [1, 5, 10, 25, 50, 75, 99, 100, 1000] +model/hyperparameter.border_count = [5, 10, 20, 50, 100, 200] +model/hyperparameter.l2_leaf_reg = [1, 3, 5, 7, 9] \ No newline at end of file diff --git a/configs/prediction_models/GRU.gin b/configs/prediction_models/GRU.gin index d2a28a79..43cb0218 100644 --- a/configs/prediction_models/GRU.gin +++ b/configs/prediction_models/GRU.gin @@ -9,11 +9,11 @@ train_common.model = @GRUNet # Optimizer params optimizer/hyperparameter.class_to_tune = @Adam optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) +optimizer/hyperparameter.lr = (1e-6, 1e-4, "log") # Encoder params model/hyperparameter.class_to_tune = @GRUNet model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) -model/hyperparameter.layer_dim = (1, 3) +model/hyperparameter.hidden_dim = (32, 512, "log") +model/hyperparameter.layer_dim = (1, 10) diff --git a/configs/prediction_models/LGBMClassifier.gin b/configs/prediction_models/LGBMClassifier.gin index b7bfec9a..f29f40cc 100644 --- a/configs/prediction_models/LGBMClassifier.gin +++ b/configs/prediction_models/LGBMClassifier.gin @@ -11,6 +11,6 @@ model/hyperparameter.colsample_bytree = (0.33, 1.0) model/hyperparameter.max_depth = (3, 7) model/hyperparameter.min_child_samples = 1000 model/hyperparameter.n_estimators = 100000 -model/hyperparameter.num_leaves = (8, 128, "log-uniform", 2) +model/hyperparameter.num_leaves = (8, 128, "log", 2) model/hyperparameter.subsample = (0.33, 1.0) model/hyperparameter.subsample_freq = 1 diff --git a/configs/prediction_models/RFClassifier.gin b/configs/prediction_models/RFClassifier.gin index 72d03e66..61d627d6 100644 --- a/configs/prediction_models/RFClassifier.gin +++ b/configs/prediction_models/RFClassifier.gin @@ -8,11 +8,11 @@ train_common.model = @RFClassifier model/hyperparameter.class_to_tune = @RFClassifier model/hyperparameter.n_estimators = (10, 50, 100, 200, 500) -model/hyperparameter.max_depth = (None, 5, 10, 20) +model/hyperparameter.max_depth = (5, 10, 20) model/hyperparameter.min_samples_split = (2, 5, 10) model/hyperparameter.min_samples_leaf = (1, 2, 4) -model/hyperparameter.max_features = ('sqrt', 'log2', None) -model/hyperparameter.bootstrap = (True, False) -model/hyperparameter.class_weight = (None, 'balanced') +model/hyperparameter.max_features = ['sqrt', 'log2', None] +model/hyperparameter.bootstrap = [True, False] +model/hyperparameter.class_weight = [None, 'balanced'] diff --git a/configs/prediction_models/RUSBClassifier.gin b/configs/prediction_models/RUSBClassifier.gin new file mode 100644 index 00000000..e8f17722 --- /dev/null +++ b/configs/prediction_models/RUSBClassifier.gin @@ -0,0 +1,14 @@ +# Settings for ImbLearn Balanced Random Forest Classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @RUSBClassifier + +model/hyperparameter.class_to_tune = @RUSBClassifier +model/hyperparameter.n_estimators = (10, 50, 100, 200, 500) +model/hyperparameter.learning_rate = (0.005, 1, "log") +model/hyperparameter.sampling_strategy = "auto" + + diff --git a/configs/prediction_models/TCN.gin b/configs/prediction_models/TCN.gin index c6b314db..d1cb748a 100644 --- a/configs/prediction_models/TCN.gin +++ b/configs/prediction_models/TCN.gin @@ -9,12 +9,12 @@ train_common.model = @TemporalConvNet # Optimizer params optimizer/hyperparameter.class_to_tune = @Adam optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) +optimizer/hyperparameter.lr = (1e-6, 3e-4) # Encoder params model/hyperparameter.class_to_tune = @TemporalConvNet model/hyperparameter.num_classes = %NUM_CLASSES model/hyperparameter.max_seq_length = %HORIZON -model/hyperparameter.num_channels = (32, 256, "log-uniform", 2) -model/hyperparameter.kernel_size = (2, 32, "log-uniform", 2) +model/hyperparameter.num_channels = (32, 256, "log") +model/hyperparameter.kernel_size = (2, 128, "log") model/hyperparameter.dropout = (0.0, 0.4) diff --git a/configs/prediction_models/Transformer.gin b/configs/prediction_models/Transformer.gin index 2767fd37..69f31e51 100644 --- a/configs/prediction_models/Transformer.gin +++ b/configs/prediction_models/Transformer.gin @@ -8,17 +8,17 @@ train_common.model = @Transformer optimizer/hyperparameter.class_to_tune = @Adam optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) +optimizer/hyperparameter.lr = (1e-6, 1e-4) # Encoder params model/hyperparameter.class_to_tune = @Transformer -model/hyperparameter.ff_hidden_mult = 2 -model/hyperparameter.l1_reg = 0.0 +model/hyperparameter.ff_hidden_mult = (2,4,6,8) +model/hyperparameter.l1_reg = (0.0,1.0) model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.hidden = (32, 256, "log-uniform", 2) -model/hyperparameter.heads = (1, 8, "log-uniform", 2) +model/hyperparameter.hidden = (32, 512, "log") +model/hyperparameter.heads = (1, 8, "log") model/hyperparameter.depth = (1, 3) -model/hyperparameter.dropout = (0.0, 0.4) -model/hyperparameter.dropout_att = (0.0, 0.4) +model/hyperparameter.dropout = 0 # no improvement (0.0, 0.4) +model/hyperparameter.dropout_att = (0.0, 1.0) diff --git a/configs/prediction_models/XGBClassifier.gin b/configs/prediction_models/XGBClassifier.gin new file mode 100644 index 00000000..f1070672 --- /dev/null +++ b/configs/prediction_models/XGBClassifier.gin @@ -0,0 +1,17 @@ +# Settings for XGBoost classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @XGBClassifier + +model/hyperparameter.class_to_tune = @XGBClassifier +model/hyperparameter.learning_rate = (0.01, 0.1, "log") +model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000] +model/hyperparameter.max_depth = [3, 5, 10, 15] +model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20, 25, 30, 35, 40, 50, 75, 99, 100, 1000] +model/hyperparameter.min_child_weight = [1, 0.5] +model/hyperparameter.max_delta_step = [0, 1, 2, 3, 4, 5, 10] +model/hyperparameter.colsample_bytree = [0.1, 0.25, 0.5, 0.75, 1.0] +model/hyperparameter.eval_metric = "aucpr" \ No newline at end of file diff --git a/configs/prediction_models/common/DLCommon.gin b/configs/prediction_models/common/DLCommon.gin index c220e6ab..9d790775 100644 --- a/configs/prediction_models/common/DLCommon.gin +++ b/configs/prediction_models/common/DLCommon.gin @@ -3,7 +3,9 @@ # Imports to register the models import gin.torch.external_configurables import icu_benchmarks.models.wrappers -import icu_benchmarks.models.dl_models +import icu_benchmarks.models.dl_models.rnn +import icu_benchmarks.models.dl_models.transformer +import icu_benchmarks.models.dl_models.tcn import icu_benchmarks.models.utils # Do not generate features from dynamic data @@ -12,7 +14,7 @@ base_regression_preprocessor.generate_features = False # Train params train_common.optimizer = @Adam -train_common.epochs = 1000 +train_common.epochs = 50 train_common.batch_size = 64 train_common.patience = 10 train_common.min_delta = 1e-4 diff --git a/configs/prediction_models/common/DLTuning.gin b/configs/prediction_models/common/DLTuning.gin index b4d13e12..0d71c2f8 100644 --- a/configs/prediction_models/common/DLTuning.gin +++ b/configs/prediction_models/common/DLTuning.gin @@ -2,4 +2,4 @@ tune_hyperparameters.scopes = ["model", "optimizer"] tune_hyperparameters.n_initial_points = 5 tune_hyperparameters.n_calls = 30 -tune_hyperparameters.folds_to_tune_on = 2 \ No newline at end of file +tune_hyperparameters.folds_to_tune_on = 5 \ No newline at end of file diff --git a/configs/prediction_models/common/MLCommon.gin b/configs/prediction_models/common/MLCommon.gin index 460bceba..4d26b8c7 100644 --- a/configs/prediction_models/common/MLCommon.gin +++ b/configs/prediction_models/common/MLCommon.gin @@ -3,7 +3,11 @@ # Imports to register the models import gin.torch.external_configurables import icu_benchmarks.models.wrappers -import icu_benchmarks.models.ml_models +import icu_benchmarks.models.ml_models.sklearn +import icu_benchmarks.models.ml_models.lgbm +import icu_benchmarks.models.ml_models.xgboost +import icu_benchmarks.models.ml_models.imblearn +import icu_benchmarks.models.ml_models.catboost import icu_benchmarks.models.utils # Patience for early stopping diff --git a/configs/prediction_models/common/MLTuning.gin b/configs/prediction_models/common/MLTuning.gin index c582a02d..9df38c47 100644 --- a/configs/prediction_models/common/MLTuning.gin +++ b/configs/prediction_models/common/MLTuning.gin @@ -1,5 +1,5 @@ # Hyperparameter tuner settings for classical Machine Learning. tune_hyperparameters.scopes = ["model"] -tune_hyperparameters.n_initial_points = 10 -tune_hyperparameters.n_calls = 50 -tune_hyperparameters.folds_to_tune_on = 3 \ No newline at end of file +tune_hyperparameters.n_initial_points = 5 +tune_hyperparameters.n_calls = 30 +tune_hyperparameters.folds_to_tune_on = 5 \ No newline at end of file diff --git a/configs/tasks/BinaryClassification.gin b/configs/tasks/BinaryClassification.gin index 492a12eb..f86436a4 100644 --- a/configs/tasks/BinaryClassification.gin +++ b/configs/tasks/BinaryClassification.gin @@ -19,11 +19,10 @@ DLPredictionWrapper.loss = @cross_entropy # SELECTING PREPROCESSOR preprocess.preprocessor = @base_classification_preprocessor +preprocess.modality_mapping = %modality_mapping preprocess.vars = %vars preprocess.use_static = True # SELECTING DATASET -PredictionDataset.vars = %vars -PredictionDataset.ram_cache = True - +include "configs/tasks/common/Dataloader.gin" diff --git a/configs/tasks/DatasetImputation.gin b/configs/tasks/DatasetImputation.gin index ddbd56a2..55914adc 100644 --- a/configs/tasks/DatasetImputation.gin +++ b/configs/tasks/DatasetImputation.gin @@ -22,6 +22,6 @@ preprocess.file_names = { preprocess.preprocessor = @base_imputation_preprocessor preprocess.vars = %vars -ImputationDataset.vars = %vars -ImputationDataset.ram_cache = True + +include "configs/tasks/common/Dataloader.gin" diff --git a/configs/tasks/Regression.gin b/configs/tasks/Regression.gin index 5cf3f8d9..c2c54174 100644 --- a/configs/tasks/Regression.gin +++ b/configs/tasks/Regression.gin @@ -28,6 +28,5 @@ base_regression_preprocessor.outcome_min = 0 base_regression_preprocessor.outcome_max = 15 # SELECTING DATASET -PredictionDataset.vars = %vars -PredictionDataset.ram_cache = True +include "configs/tasks/common/Dataloader.gin" diff --git a/configs/tasks/common/Dataloader.gin b/configs/tasks/common/Dataloader.gin new file mode 100644 index 00000000..6bed1b7e --- /dev/null +++ b/configs/tasks/common/Dataloader.gin @@ -0,0 +1,8 @@ +# Prediction +PredictionPandasDataset.vars = %vars +PredictionPandasDataset.ram_cache = True +PredictionPolarsDataset.vars = %vars +PredictionPolarsDataset.ram_cache = True +# Imputation +ImputationPandasDataset.vars = %vars +ImputationPandasDataset.ram_cache = True \ No newline at end of file diff --git a/configs/tasks/common/PredictionTaskVariables.gin b/configs/tasks/common/PredictionTaskVariables.gin index 6e38638e..d5006041 100644 --- a/configs/tasks/common/PredictionTaskVariables.gin +++ b/configs/tasks/common/PredictionTaskVariables.gin @@ -15,4 +15,12 @@ vars = { "methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine", "wbc"], "STATIC": ["age", "sex", "height", "weight"], +} + +modality_mapping = { + "DYNAMIC": ["alb", "alp", "alt", "ast", "be", "bicar", "bili", "bili_dir", "bnd", "bun", "ca", "cai", "ck", "ckmb", "cl", + "crea", "crp", "dbp", "fgn", "fio2", "glu", "hgb", "hr", "inr_pt", "k", "lact", "lymph", "map", "mch", "mchc", "mcv", + "methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine", + "wbc"], + "STATIC": ["age", "sex", "height", "weight"], } \ No newline at end of file diff --git a/demo_data/kf/eicu_demo/attrition.csv b/demo_data/kidney_function/eicu_demo/attrition.csv similarity index 100% rename from demo_data/kf/eicu_demo/attrition.csv rename to demo_data/kidney_function/eicu_demo/attrition.csv diff --git a/demo_data/kf/eicu_demo/dyn.parquet b/demo_data/kidney_function/eicu_demo/dyn.parquet similarity index 100% rename from demo_data/kf/eicu_demo/dyn.parquet rename to demo_data/kidney_function/eicu_demo/dyn.parquet diff --git a/demo_data/kf/eicu_demo/outc.parquet b/demo_data/kidney_function/eicu_demo/outc.parquet similarity index 100% rename from demo_data/kf/eicu_demo/outc.parquet rename to demo_data/kidney_function/eicu_demo/outc.parquet diff --git a/demo_data/kf/eicu_demo/sta.parquet b/demo_data/kidney_function/eicu_demo/sta.parquet similarity index 100% rename from demo_data/kf/eicu_demo/sta.parquet rename to demo_data/kidney_function/eicu_demo/sta.parquet diff --git a/demo_data/kf/mimic_demo/attrition.csv b/demo_data/kidney_function/mimic_demo/attrition.csv similarity index 100% rename from demo_data/kf/mimic_demo/attrition.csv rename to demo_data/kidney_function/mimic_demo/attrition.csv diff --git a/demo_data/kf/mimic_demo/dyn.parquet b/demo_data/kidney_function/mimic_demo/dyn.parquet similarity index 100% rename from demo_data/kf/mimic_demo/dyn.parquet rename to demo_data/kidney_function/mimic_demo/dyn.parquet diff --git a/demo_data/kf/mimic_demo/outc.parquet b/demo_data/kidney_function/mimic_demo/outc.parquet similarity index 100% rename from demo_data/kf/mimic_demo/outc.parquet rename to demo_data/kidney_function/mimic_demo/outc.parquet diff --git a/demo_data/kf/mimic_demo/sta.parquet b/demo_data/kidney_function/mimic_demo/sta.parquet similarity index 100% rename from demo_data/kf/mimic_demo/sta.parquet rename to demo_data/kidney_function/mimic_demo/sta.parquet diff --git a/docs/adding_model/RNN.gin b/docs/adding_model/RNN.gin new file mode 100644 index 00000000..531aeff6 --- /dev/null +++ b/docs/adding_model/RNN.gin @@ -0,0 +1,18 @@ +# Settings for Recurrent Neural Network (RNN) models. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Train params +train_common.model = @RNNet + +# Optimizer params +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @RNNet +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) +model/hyperparameter.layer_dim = (1, 3) diff --git a/docs/adding_model/instructions.md b/docs/adding_model/instructions.md new file mode 100644 index 00000000..0896fedd --- /dev/null +++ b/docs/adding_model/instructions.md @@ -0,0 +1,190 @@ +# Adding new models to YAIB +## Example +We refer to the page [adding a new model](https://github.com/rvandewater/YAIB/wiki/Adding-a-new-model) for detailed instructions on adding new models. +We allow prediction models to be easily added and integrated into a Pytorch Lightning module. This +incorporates advanced logging and debugging capabilities, as well as +built-in parallelism. Our interface derives from the [`BaseModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html). + +Adding a model consists of three steps: +1. Add a model through the existing `MLPredictionWrapper` or `DLPredictionWrapper`. +2. Add a GIN config file to bind hyperparameters. +3. Execute YAIB using a simple command. + +This folder contains everything you need to add a model to YAIB. +Putting the `RNN.gin` file in `configs/prediction_models` and the `rnn.py` file into icu_benchmarks/models allows you to run the model fully. + +``` +icu-benchmarks train \ + -d demo_data/mortality24/mimic_demo \ # Insert cohort dataset here + -n mimic_demo \ + -t BinaryClassification \ # Insert task name here + -tn Mortality24 \ + --log-dir ../yaib_logs/ \ + -m RNN \ # Insert model here + -s 2222 \ + -l ../yaib_logs/ \ + --tune +``` +# Adding more models +## Regular ML +For standard Scikit-Learn type models (e.g., LGBM), one can +simply wrap `MLPredictionWrapper` the function with minimal code +overhead. Many ML (and some DL) models can be incorporated this way, requiring minimal code additions. See below. + +``` {#code:ml-model-definition frame="single" style="pycharm" caption="\\textit{Example ML model definition}" label="code:ml-model-definition" columns="fullflexible"} +@gin.configurable +class RFClassifier(MLWrapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = self.model_args() + + @gin.configurable(module="RFClassifier") + def model_args(self, *args, **kwargs): + return RandomForestClassifier(*args, **kwargs) +``` +## Adding DL models +It is relatively straightforward to add new Pytorch models to YAIB. We first provide a standard RNN-model which needs no extra components. Then, we show the implementation of the Temporal Fusion Transformer model. + +### Standard RNN-model +The definition of dl models can be done by creating a subclass from the +`DLPredictionWrapper`, inherits the standard methods needed for +training dl learning models. Pytorch Lightning significantly reduces the code +overhead. + + +``` {#code:dl-model-definition frame="single" style="pycharm" caption="\\textit{Example DL model definition}" label="code:dl-model-definition" columns="fullflexible"} +@gin.configurable +class RNNet(DLPredictionWrapper): + """Torch standard RNN model""" + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred +``` +### Adding a SOTA model: Temporal Fusion Transformer +There are two main questions when you want to add a more complex model: + +* _Do you want to manually define the model or use an existing library?_ This might require adapting the `DLPredictionWrapper`. +* _Does the model expect the data to be in a certain format?_ This might require adapting the `PredictionDataset`. + +By adapting, we mean creating a new subclass that inherits most functionality to avoid code duplication, is future-proof, and follows good coding practices. + +First, you can add modules to `models/layers.py` to use them for your model. +``` {#code:building blocks frame="single" style="pycharm" caption="\\textit{Example building block}" label="code: layers" columns="fullflexible"} +class StaticCovariateEncoder(nn.Module): + """ + Network to produce 4 context vectors to enrich static variables + Variable selection Network --> GRNs + """ + + def __init__(self, num_static_vars, hidden, dropout): + super().__init__() + self.vsn = VariableSelectionNetwork(hidden, dropout, num_static_vars) + self.context_grns = nn.ModuleList([GRN(hidden, hidden, dropout=dropout) for _ in range(4)]) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + variable_ctx, sparse_weights = self.vsn(x) + + # Context vectors: + # variable selection context + # enrichment context + # state_c context + # state_h context + cs, ce, ch, cc = [m(variable_ctx) for m in self.context_grns] + + return cs, ce, ch, cc +``` +Note that we can create modules out of modules as well. + +### Adapting the `DLPredictionWrapper` +The next step is to use the building blocks defined in layers.py or modules from an existing library to add to the model in `models/dl_models.py`. In this In this case, we use the Pytorch-forecasting library (https://github.com/jdb78/pytorch-forecasting): + +``` {#code:dl-model-definition frame="single" style="pycharm" caption="\\textit{Example DL model definition}" label="code:dl-model-definition" columns="fullflexible"} +class TFTpytorch(DLPredictionWrapper): + + supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, dataset, hidden, dropout, n_heads, dropout_att, lr, optimizer, num_classes, *args, **kwargs): + super().__init__(lr=lr, optimizer=optimizer, *args, **kwargs) + self.model = TemporalFusionTransformer.from_dataset( + dataset=dataset) + self.logit = nn.Linear(7, num_classes) + + + def forward(self, x): + out = self.model(x) + pred = self.logit(out["prediction"]) + return pred +``` + +### Adapting the `PredictionDataset` +Some models require an adjusted dataloader to facilitate, for example, explainability methods. In this case, changes need to be made to the `data/loader.py` file to ensure the data loader returns the data in the correct format. +This can be done by creating a class that inherits from PredictionDataset and editing the get_item method. +``` {#code:dataset frame="single" style="pycharm" caption="\\textit{Example custom dataset definition}" label="code: dataset" columns="fullflexible"} +@gin.configurable("PredictionDatasetTFT") +class PredictionDatasetTFT(PredictionDataset): + def __init__(self, *args, ram_cache: bool = True, **kwargs): + super().__init__(*args, ram_cache=True, **kwargs) + +def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. Used for TFT. + The data needs to be given to the model in the following order + [static categorical, static continuous,known categorical,known continuous, observed categorical, observed continuous,target,id] +``` +Then, you must check `models/wrapper.py`, particularly the step_fn method, to ensure the data is correctly transferred to the device. + +## Adding the model config GIN file +To define hyperparameters for each model in a standardized manner, we use GIN-config. We need to specify a GIN file to bind the parameters to train and optimize this model from a choice of hyperparameters. Note that we can use modifiers for the optimizer (e.g, Adam optimizer) and ranges that we can specify in rounded brackets "()". Square brackets, "[]", result in a random choice where the variable is uniformly sampled. +``` +# Hyperparameters for TFT model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Optimizer params +train_common.model = @TFT + +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @TFT +model/hyperparameter.encoder_length = 24 +model/hyperparameter.hidden = 256 +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.dropout = (0.0, 0.4) +model/hyperparameter.dropout_att = (0.0, 0.4) +model/hyperparameter.n_heads =4 +model/hyperparameter.example_length=25 +``` +## Training the model +After these steps, your model should be trainable with the following command: + +``` +icu-benchmarks train \ + -d demo_data/mortality24/mimic_demo \ # Insert cohort dataset here + -n mimic_demo \ + -t BinaryClassification \ # Insert task name here + -tn Mortality24 \ + --log-dir ../yaib_logs/ \ + -m TFT \ # Insert model here + -s 2222 \ + -l ../yaib_logs/ \ + --tune +``` diff --git a/docs/adding_model/rnn.py b/docs/adding_model/rnn.py new file mode 100644 index 00000000..d2215627 --- /dev/null +++ b/docs/adding_model/rnn.py @@ -0,0 +1,30 @@ +import gin +import torch.nn as nn +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +@gin.configurable +class RNNet(DLPredictionWrapper): + """Torch standard RNN model""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred diff --git a/docs/imputation_methods.md b/docs/imputation_methods.md index 2eead00b..26e73058 100644 --- a/docs/imputation_methods.md +++ b/docs/imputation_methods.md @@ -7,28 +7,29 @@ To add another imputation model, you have to create a class that inherits from ` from icu_benchmarks.models.wrappers import ImputationWrapper import gin + @gin.configurable("newmethod") class New_Method(ImputationWrapper): - - # adjust this accordingly - needs_training = False # if true, the method is trained iteratively (like a deep learning model) - needs_fit = True # if true, it receives the complete training data to perform a fit on - - def __init__(self, *args, model_arg1, model_arg2, **kwargs): - super().__init__(*args, **kwargs) - # define your new model here - self.model = ... - - # the following method has to be implemented for all methods - def forward(self, amputated_values, amputation_mask): - imputated_values = amputated_values - ... - return imputated_values - - # implement this, if needs_fit is true, otherwise you can leave it out. - # this method receives the complete input training data to perform a fit on. - def fit(self, train_data): - ... + # adjust this accordingly + # if true, the method is trained iteratively (like a deep learning model). + # If false it receives the complete training data to perform a fit on + requires_backprop = False + + def __init__(self, *args, model_arg1, model_arg2, **kwargs): + super().__init__(*args, **kwargs) + # define your new model here + self.model = ... + + # the following method has to be implemented for all methods + def forward(self, amputated_values, amputation_mask): + imputated_values = amputated_values + ... + return imputated_values + + # implement this, if needs_fit is true, otherwise you can leave it out. + # this method receives the complete input training data to perform a fit on. + def fit(self, train_data): + ... ``` You also need to create a gin configuration file in the `configs/imputation` directory, diff --git a/environment.yml b/environment.yml index f2db5efc..405d9b47 100644 --- a/environment.yml +++ b/environment.yml @@ -1,37 +1,11 @@ name: yaib channels: - - pytorch - - nvidia - conda-forge - - anaconda dependencies: - python=3.10 - - black=23.3.0 - - coverage=7.2.3 - - flake8=5.0.4 - - matplotlib=3.7.1 - - gin-config=0.5.0 - - ignite=0.4.11 - - pytorch=2.0.1 - - pytorch-cuda=11.8 - - lightgbm=3.3.5 - - numpy=1.24.3 - - pandas=2.0.0 - - pyarrow=11.0.0 - - pytest=7.3.1 - - scikit-learn=1.2.2 - - tensorboard=2.12.2 - - tqdm=4.64.1 - - pytorch-lightning=2.0.3 - - wandb=0.15.4 - - pip=23.1 - - einops=0.6.1 - - hydra-core=1.3 - - pip: - - recipies==0.1.2 - # Fixed version because of NumPy incompatibility and stale development status. - - scikit-optimize-fix==0.9.1 - - hydra-submitit-launcher==1.2.0 -# Note: versioning of Pytorch might be dependent on compatible CUDA version. -# Please check yourself if your Pytorch installation supports cuda (for gpu acceleration) + - pip>=24.0 + - flake8=7.1.0 +# - pip: +# - -r requirements.txt + diff --git a/experiments/benchmark_regression.yml b/experiments/benchmark_regression.yml index 8aa8d13e..1f9176f3 100644 --- a/experiments/benchmark_regression.yml +++ b/experiments/benchmark_regression.yml @@ -21,10 +21,10 @@ parameters: - ../data/los/hirid - ../data/los/eicu - ../data/los/aumc - - ../data/kf/miiv - - ../data/kf/hirid - - ../data/kf/eicu - - ../data/kf/aumc + - ../data/kidney_function/miiv + - ../data/kidney_function/hirid + - ../data/kidney_function/eicu + - ../data/kidney_function/aumc model: values: - ElasticNet diff --git a/experiments/demo_benchmark_regression.yml b/experiments/demo_benchmark_regression.yml index 3b678371..17d25d38 100644 --- a/experiments/demo_benchmark_regression.yml +++ b/experiments/demo_benchmark_regression.yml @@ -19,8 +19,8 @@ parameters: values: - demo_data/los/eicu_demo - demo_data/los/mimic_demo - - demo_data/kf/eicu_demo - - demo_data/kf/mimic_demo + - demo_data/kidney_function/eicu_demo + - demo_data/kidney_function/mimic_demo model: values: - ElasticNet diff --git a/experiments/experiment_eval_pooled.yml b/experiments/experiment_eval_pooled.yml new file mode 100644 index 00000000..a71da879 --- /dev/null +++ b/experiments/experiment_eval_pooled.yml @@ -0,0 +1,56 @@ +# This experiment evaluates the pooled models trained in a previous experiment. +command: + - ${env} + - ${program} + - --eval + - -d + - ../data/ + - -t + - BinaryClassification +# Manually set for regression tasks +# - Regression + - --log-dir + - ../yaib_logs + - --wandb-sweep + - -gc + - -lc + - --source-dir + - path/to/pooled_model +method: grid +name: yaib_pooled_eval +parameters: + data_dir: + values: + - ../data/mortality24/miiv + - ../data/mortality24/hirid + - ../data/mortality24/eicu + - ../data/mortality24/aumc + - ../data/aki/miiv + - ../data/aki/hirid + - ../data/aki/eicu + - ../data/aki/aumc + - ../data/sepsis/miiv + - ../data/sepsis/hirid + - ../data/sepsis/eicu + - ../data/sepsis/aumc + - ../data/los/miiv + - ../data/los/hirid + - ../data/los/eicu + - ../data/los/aumc + - ../data/kidney_function/miiv + - ../data/kidney_function/hirid + - ../data/kidney_function/eicu + - ../data/kidney_function/aumc + model: + values: + - GRU + - LSTM + - TCN + - Transformer + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/experiment_finetuning.yml b/experiments/experiment_finetuning.yml new file mode 100644 index 00000000..db3a5ceb --- /dev/null +++ b/experiments/experiment_finetuning.yml @@ -0,0 +1,50 @@ +# Finetuning setup for model +command: + - ${env} + - ${program} + - -ft + - 0 + - -d + - ../data/ + - -t + - BinaryClassification + - --log-dir + - ../yaib_logs_finetune + - --tune + - --wandb-sweep + - -gc + - -lc + - -sn + - eicu + - --source-dir + - path/to/model/to/finetune +method: grid +name: yaib_finetuning_benchmark +parameters: + fine_tune: + values: + - 100 + - 500 + - 1000 + - 2000 + - 4000 + - 6000 + - 8000 + - 10000 + - 12000 + data_dir: + values: + - ../data/mortality24/miiv + - ../data/mortality24/hirid + - ../data/mortality24/eicu + - ../data/mortality24/aumc + model: + values: + - GRU + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks diff --git a/experiments/experiment_full_training.yml b/experiments/experiment_full_training.yml new file mode 100644 index 00000000..309b128e --- /dev/null +++ b/experiments/experiment_full_training.yml @@ -0,0 +1,50 @@ +# This experiment trains a production-ready model with a full dataset (no cross-validation). +command: + - ${env} + - ${program} + - --full-train + - -d + - ../data/ + - -t + - BinaryClassification +# - Regression + - --log-dir + - ../yaib_logs + - --tune + - --wandb-sweep + - --tune +method: grid +name: yaib_full_benchmark +parameters: + data_dir: + values: + - ../data/mortality24/miiv + - ../data/mortality24/hirid + - ../data/mortality24/eicu + - ../data/mortality24/aumc + - ../data/aki/miiv + - ../data/aki/hirid + - ../data/aki/eicu + - ../data/aki/aumc + - ../data/sepsis/miiv + - ../data/sepsis/hirid + - ../data/sepsis/eicu + - ../data/sepsis/aumc + - ../data/los/miiv + - ../data/los/hirid + - ../data/los/eicu + - ../data/los/aumc + - ../data/kidney_function/miiv + - ../data/kidney_function/hirid + - ../data/kidney_function/eicu + - ../data/kidney_function/aumc + model: + values: + - GRU + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks diff --git a/experiments/experiment_pooled.yml b/experiments/experiment_pooled.yml new file mode 100644 index 00000000..f194dfba --- /dev/null +++ b/experiments/experiment_pooled.yml @@ -0,0 +1,60 @@ +# This experiment trains pooled models with pooled data. Note that you have to replace BinaryClassification with Regression for the appropriate dataset. +command: + - ${env} + - ${program} + - -d + - ../data/ + - -t + - BinaryClassification + # Manually set for regression tasks + # - Regression + - --log-dir + - ../yaib_logs_pooled + - --tune + - --wandb-sweep + - --hp-checkpoint + - path/to/checkpoint + - -gc + - -lc +method: grid +name: yaib_pooled_benchmark +parameters: + data_dir: + values: + - ../data/mortality24/aumc_hirid_eicu_10000 + - ../data/mortality24/hirid_eicu_miiv_10000 + - ../data/mortality24/aumc_eicu_miiv_10000 + - ../data/mortality24/aumc_hirid_miiv_10000 + - ../data/mortality24/aumc_hirid_eicu_miiv_10000 + - ../data/aki/aumc_hirid_eicu_10000 + - ../data/aki/hirid_eicu_miiv_10000 + - ../data/aki/aumc_eicu_miiv_10000 + - ../data/aki/aumc_hirid_miiv_10000 + - ../data/aki/aumc_hirid_eicu_miiv_10000 + - ../data/sepsis/aumc_hirid_eicu_10000 + - ../data/sepsis/hirid_eicu_miiv_10000 + - ../data/sepsis/aumc_eicu_miiv_10000 + - ../data/sepsis/aumc_hirid_miiv_10000 + - ../data/sepsis/aumc_hirid_eicu_miiv_10000 + - ../data/kidney_function/aumc_hirid_eicu_10000 + - ../data/kidney_function/hirid_eicu_miiv_10000 + - ../data/kidney_function/aumc_eicu_miiv_10000 + - ../data/kidney_function/aumc_hirid_miiv_10000 + - ../data/kidney_function/aumc_hirid_eicu_miiv_10000 + - ../data/los/aumc_hirid_eicu_10000 + - ../data/los/hirid_eicu_miiv_10000 + - ../data/los/aumc_eicu_miiv_10000 + - ../data/los/aumc_hirid_miiv_10000 + - ../data/los/aumc_hirid_eicu_miiv_10000 + + model: + values: + # - GRU + - Transformer + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks diff --git a/experiments/experiment_small_set_training.yml b/experiments/experiment_small_set_training.yml new file mode 100644 index 00000000..f8aadb73 --- /dev/null +++ b/experiments/experiment_small_set_training.yml @@ -0,0 +1,46 @@ +# This experiment trains models with progressively more samples to see how the performance changes +command: + - ${env} + - ${program} + - --samples + - 0 + - -d + - ../data/ + - -t + - BinaryClassification + - --log-dir + - ../yaib_logs_small_set_training + - --tune + - --wandb-sweep + - --source-dir + - /dhc/home/robin.vandewater/projects/transfer_learning/gru_mortality/hirid +method: grid +name: yaib_samples_benchmark +parameters: + samples: + values: + - 100 + - 500 + - 1000 + - 2000 + - 4000 + - 6000 + - 8000 + - 10000 + - 12000 + data_dir: + values: + - ../../data/mortality24/miiv + - ../../data/mortality24/hirid + - ../../data/mortality24/eicu + - ../../data/mortality24/aumc + model: + values: + - GRU + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks diff --git a/icu_benchmarks/contants.py b/icu_benchmarks/constants.py similarity index 100% rename from icu_benchmarks/contants.py rename to icu_benchmarks/constants.py diff --git a/icu_benchmarks/cross_validation.py b/icu_benchmarks/cross_validation.py index 89a98864..e1563f9e 100644 --- a/icu_benchmarks/cross_validation.py +++ b/icu_benchmarks/cross_validation.py @@ -11,7 +11,7 @@ from icu_benchmarks.models.train import train_common from icu_benchmarks.models.utils import JsonResultLoggingEncoder from icu_benchmarks.run_utils import log_full_line -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode @gin.configurable @@ -19,6 +19,8 @@ def execute_repeated_cv( data_dir: Path, log_dir: Path, seed: int, + eval_only: bool = False, + train_size: int = None, load_weights: bool = False, source_dir: Path = None, cv_repetitions: int = 5, @@ -35,14 +37,19 @@ def execute_repeated_cv( cpu: bool = False, verbose: bool = False, wandb: bool = False, + complete_train: bool = False, ) -> float: """Preprocesses data and trains a model for each fold. Args: + complete_train: Use the full data for training instead of held out test splits. + wandb: Use wandb for logging. data_dir: Path to the data directory. log_dir: Path to the log directory. seed: Random seed. + eval_only: Whether to only evaluate the model. + train_size: Fixed size of train split (including validation data). load_weights: Whether to load weights from source_dir. source_dir: Path to the source directory. cv_folds: Number of folds for cross validation. @@ -66,10 +73,21 @@ def execute_repeated_cv( if not cv_folds_to_train: cv_folds_to_train = cv_folds agg_loss = 0 - seed_everything(seed, reproducible) + if complete_train: + logging.info("Will train full model without cross validation.") + cv_repetitions_to_train = 1 + cv_folds_to_train = 1 + + else: + logging.info(f"Starting nested CV with {cv_repetitions_to_train} repetitions of {cv_folds_to_train} folds.") + # Train model for each repetition (a manner of splitting the folds) for repetition in range(cv_repetitions_to_train): + # Train model for each fold configuration (i.e, one fold is test fold and the rest are train/val folds) for fold_index in range(cv_folds_to_train): + repetition_fold_dir = log_dir / f"repetition_{repetition}" / f"fold_{fold_index}" + repetition_fold_dir.mkdir(parents=True, exist_ok=True) + start_time = datetime.now() data = preprocess_data( data_dir, @@ -79,19 +97,19 @@ def execute_repeated_cv( generate_cache=generate_cache, cv_repetitions=cv_repetitions, repetition_index=repetition, + train_size=train_size, cv_folds=cv_folds, fold_index=fold_index, pretrained_imputation_model=pretrained_imputation_model, runmode=mode, + complete_train=complete_train, ) - - repetition_fold_dir = log_dir / f"repetition_{repetition}" / f"fold_{fold_index}" - repetition_fold_dir.mkdir(parents=True, exist_ok=True) preprocess_time = datetime.now() - start_time start_time = datetime.now() agg_loss += train_common( data, log_dir=repetition_fold_dir, + eval_only=eval_only, load_weights=load_weights, source_dir=source_dir, reproducible=reproducible, @@ -100,11 +118,12 @@ def execute_repeated_cv( cpu=cpu, verbose=verbose, use_wandb=wandb, + train_only=complete_train, ) train_time = datetime.now() - start_time log_full_line( - f"FINISHED FOLD {fold_index}| PREPROCESSING DURATION {preprocess_time}| TRAINING DURATION {train_time}", + f"FINISHED FOLD {fold_index}| PREPROCESSING DURATION {preprocess_time}| PROCEDURE DURATION {train_time}", level=logging.INFO, ) durations = {"preprocessing_duration": preprocess_time, "train_duration": train_time} @@ -114,7 +133,10 @@ def execute_repeated_cv( if wandb: wandb_log({"Iteration": repetition * cv_folds_to_train + fold_index}) if repetition * cv_folds_to_train + fold_index > 1: - aggregate_results(log_dir) + try: + aggregate_results(log_dir) + except Exception as e: + logging.error(f"Failed to aggregate results: {e}") log_full_line(f"FINISHED CV REPETITION {repetition}", level=logging.INFO, char="=", num_newlines=3) return agg_loss / (cv_repetitions_to_train * cv_folds_to_train) diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index 9b831aa7..b227dbf2 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -1,3 +1,4 @@ +import warnings from typing import List from pandas import DataFrame import gin @@ -6,13 +7,183 @@ from torch.utils.data import Dataset import logging from typing import Dict, Tuple - +import polars as pl from icu_benchmarks.imputation.amputations import ampute_data from .constants import DataSegment as Segment from .constants import DataSplit as Split -class CommonDataset(Dataset): +@gin.configurable("CommonPolarsDataset") +class CommonPolarsDataset(Dataset): + def __init__( + self, + data: dict, + split: str = Split.train, + vars: Dict[str, str] = gin.REQUIRED, + grouping_segment: str = Segment.outcome, + mps: bool = False, + name: str = "", + *args, + **kwargs, + ): + # super().__init__(*args, **kwargs) + self.split = split + self.vars = vars + self.grouping_df = data[split][grouping_segment] # .set_index(self.vars["GROUP"]) + # logging.info(f"data split: {data[split]}") + # self.features_df = ( + # data[split][Segment.features].set_index(self.vars["GROUP"]).drop(labels=self.vars["SEQUENCE"], axis=1) + # ) + # Get the row indicators for the data to be able to match predicted labels + if "SEQUENCE" in self.vars and self.vars["SEQUENCE"] in data[split][Segment.features].columns: + # We have a time series dataset + self.row_indicators = data[split][Segment.features][self.vars["GROUP"], self.vars["SEQUENCE"]] + self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours()) + self.features_df = data[split][Segment.features] + self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) + self.features_df = self.features_df.drop(self.vars["SEQUENCE"]) + else: + # We have a static dataset + logging.info("Using static dataset") + self.row_indicators = data[split][Segment.features][self.vars["GROUP"]] + self.features_df = data[split][Segment.features] + # calculate basic info for the data + self.num_stays = self.grouping_df[self.vars["GROUP"]].unique().shape[0] + self.maxlen = self.features_df.group_by([self.vars["GROUP"]]).len().max().item(0, 1) + self.mps = mps + self.name = name + + def ram_cache(self, cache: bool = True): + self._cached_dataset = None + if cache: + logging.info(f"Caching {self.split} dataset in ram.") + self._cached_dataset = [self[i] for i in range(len(self))] + + def __len__(self) -> int: + """Returns number of stays in the data. + + Returns: + number of stays in the data + """ + return self.num_stays + + def get_feature_names(self) -> List[str]: + return self.features_df.columns + + def to_tensor(self) -> List[Tensor]: + values = [] + for entry in self: + for i, value in enumerate(entry): + if len(values) <= i: + values.append([]) + values[i].append(value.unsqueeze(0)) + return [cat(value, dim=0) for value in values] + + +@gin.configurable("PredictionPolarsDataset") +class PredictionPolarsDataset(CommonPolarsDataset): + """Subclass of common dataset for prediction tasks. + + Args: + ram_cache (bool, optional): Whether the complete dataset should be stored in ram. Defaults to True. + """ + + def __init__(self, *args, ram_cache: bool = True, **kwargs): + super().__init__(*args, **kwargs) + self.outcome_df = self.grouping_df + self.ram_cache(ram_cache) + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. Used for deep learning implementations. + + Args: + idx: A specific row index to sample. + + Returns: + A sample from the data, consisting of data, labels and padding mask. + """ + if self._cached_dataset is not None: + return self._cached_dataset[idx] + + pad_value = 0.0 + # stay_id = self.outcome_df.index.unique()[idx] # [self.vars["GROUP"]] + stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] # [self.vars["GROUP"]] + + # slice to make sure to always return a DF + # window = self.features_df.loc[stay_id:stay_id].to_numpy() + # labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float) + window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id).to_numpy() + labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(float) + + if len(labels) == 1: + # only one label per stay, align with window + labels = np.concatenate([np.empty(window.shape[0] - 1) * np.nan, labels], axis=0) + + length_diff = self.maxlen - window.shape[0] + pad_mask = np.ones(window.shape[0]) + + # Padding the array to fulfill size requirement + if length_diff > 0: + # window shorter than the longest window in dataset, pad to same length + window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0) + labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0) + pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0) + + not_labeled = np.argwhere(np.isnan(labels)) + if len(not_labeled) > 0: + labels[not_labeled] = -1 + pad_mask[not_labeled] = 0 + + pad_mask = pad_mask.astype(bool) + labels = labels.astype(np.float32) + data = window.astype(np.float32) + + return from_numpy(data), from_numpy(labels), from_numpy(pad_mask) + + def get_balance(self) -> list: + """Return the weight balance for the split of interest. + + Returns: + Weights for each label. + """ + counts = self.outcome_df[self.vars["LABEL"]].value_counts(parallel=True).get_columns()[1] + counts = counts.to_numpy() + weights = list((1 / counts) * np.sum(counts) / counts.shape[0]) + return weights + + def get_data_and_labels(self) -> Tuple[np.array, np.array, np.array]: + """Function to return all the data and labels aligned at once. + + We use this function for the ML methods which don't require an iterator. + + Returns: + A Tuple containing data points and label for the split. + """ + labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) + rep = self.features_df + + if len(labels) == self.num_stays: + # order of groups could be random, we make sure not to change it + # rep = rep.groupby(level=self.vars["GROUP"], sort=False).last() + rep = rep.group_by(self.vars["GROUP"]).last() + else: + # Adding segment count for each stay id and timestep. + rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter")) + rep = rep.to_numpy().astype(float) + logging.debug(f"rep shape: {rep.shape}") + logging.debug(f"labels shape: {labels.shape}") + return rep, labels, self.row_indicators.to_numpy() + + def to_tensor(self) -> Tuple[Tensor, Tensor, Tensor]: + data, labels, row_indicators = self.get_data_and_labels() + if self.mps: + return from_numpy(data).to(float32), from_numpy(labels).to(float32) + else: + return from_numpy(data), from_numpy(labels), row_indicators + + +@gin.configurable("CommonPandasDataset") +class CommonPandasDataset(Dataset): """Common dataset: subclass of Torch Dataset that represents the data to learn on. Args: data: Dict of the different splits of the data. split: Either 'train','val' or 'test'. vars: Contains the names of @@ -26,10 +197,14 @@ def __init__( split: str = Split.train, vars: Dict[str, str] = gin.REQUIRED, grouping_segment: str = Segment.outcome, + mps: bool = False, + name: str = "", ): + warnings.warn("CommonPandasDataset is deprecated. Use CommonPolarsDataset instead.", DeprecationWarning, stacklevel=2) self.split = split self.vars = vars self.grouping_df = data[split][grouping_segment].set_index(self.vars["GROUP"]) + # logging.info(f"data split: {data[split]}") self.features_df = ( data[split][Segment.features].set_index(self.vars["GROUP"]).drop(labels=self.vars["SEQUENCE"], axis=1) ) @@ -37,11 +212,13 @@ def __init__( # calculate basic info for the data self.num_stays = self.grouping_df.index.unique().shape[0] self.maxlen = self.features_df.groupby([self.vars["GROUP"]]).size().max() + self.mps = mps + self.name = name def ram_cache(self, cache: bool = True): self._cached_dataset = None if cache: - logging.info("Caching dataset in ram.") + logging.info(f"Caching {self.split} dataset in ram.") self._cached_dataset = [self[i] for i in range(len(self))] def __len__(self) -> int: @@ -52,10 +229,10 @@ def __len__(self) -> int: """ return self.num_stays - def get_feature_names(self): + def get_feature_names(self) -> List[str]: return self.features_df.columns - def to_tensor(self): + def to_tensor(self) -> List[Tensor]: values = [] for entry in self: for i, value in enumerate(entry): @@ -65,8 +242,8 @@ def to_tensor(self): return [cat(value, dim=0) for value in values] -@gin.configurable("PredictionDataset") -class PredictionDataset(CommonDataset): +@gin.configurable("PredictionPandasDataset") +class PredictionPandasDataset(CommonPandasDataset): """Subclass of common dataset for prediction tasks. Args: @@ -102,7 +279,6 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: labels = np.concatenate([np.empty(window.shape[0] - 1) * np.nan, labels], axis=0) length_diff = self.maxlen - window.shape[0] - pad_mask = np.ones(window.shape[0]) # Padding the array to fulfill size requirement @@ -130,6 +306,7 @@ def get_balance(self) -> list: Weights for each label. """ counts = self.outcome_df[self.vars["LABEL"]].value_counts() + # weights = list((1 / counts) * np.sum(counts) / counts.shape[0]) return list((1 / counts) * np.sum(counts) / counts.shape[0]) def get_data_and_labels(self) -> Tuple[np.array, np.array]: @@ -151,11 +328,14 @@ def get_data_and_labels(self) -> Tuple[np.array, np.array]: def to_tensor(self): data, labels = self.get_data_and_labels() - return from_numpy(data), from_numpy(labels) + if self.mps: + return from_numpy(data).to(float32), from_numpy(labels).to(float32) + else: + return from_numpy(data), from_numpy(labels) -@gin.configurable("ImputationDataset") -class ImputationDataset(CommonDataset): +@gin.configurable("ImputationPandasDataset") +class ImputationPandasDataset(CommonPandasDataset): """Subclass of Common Dataset that contains data for imputation models.""" def __init__( diff --git a/icu_benchmarks/data/pooling.py b/icu_benchmarks/data/pooling.py new file mode 100644 index 00000000..e8ee8ba6 --- /dev/null +++ b/icu_benchmarks/data/pooling.py @@ -0,0 +1,202 @@ +from pathlib import Path +import logging +import pandas as pd +from sklearn.model_selection import train_test_split +from .constants import DataSegment as Segment, VarType as Var +from icu_benchmarks.constants import RunMode +import pyarrow.parquet as pq + + +class PooledDataset: + hirid_eicu_miiv = ["hirid", "eicu", "miiv"] + aumc_hirid_eicu = ["aumc", "hirid", "eicu"] + aumc_eicu_miiv = ["aumc", "eicu", "miiv"] + aumc_hirid_miiv = ["aumc", "hirid", "miiv"] + aumc_hirid_eicu_miiv = ["aumc", "hirid", "eicu", "miiv"] + + +class PooledData: + def __init__( + self, + data_dir, + vars, + datasets, + file_names, + shuffle=False, + stratify=None, + runmode=RunMode.classification, + save_test=True, + ): + """ + Generate pooled data from existing datasets. + Args: + data_dir: Where to read the data from + vars: Variables dictionary + datasets: Which datasets to pool + file_names: Which files to read from + shuffle: Whether to shuffle data + stratify: Stratify data + runmode: Which task runmode + save_test: Save left over test data to test on without leakage + """ + self.data_dir = data_dir + self.vars = vars + self.datasets = datasets + self.file_names = file_names + self.shuffle = shuffle + self.stratify = stratify + self.runmode = runmode + self.save_test = save_test + + def generate( + self, + datasets, + samples=10000, + seed=42, + ): + """ + Generate pooled data from existing datasets. + Args: + datasets: Which datasets to pool + samples: Amount of samples to pool + seed: Random seed + """ + data = {} + for folder in self.data_dir.iterdir(): + if folder.is_dir(): + if folder.name in datasets: + data[folder.name] = { + f: pq.read_table(folder / self.file_names[f]).to_pandas(self_destruct=True) for f in self.file_names + } + data = self._pool_datasets( + datasets=data, + samples=samples, + vars=vars, + shuffle=self.shuffle, + stratify=self.stratify, + seed=seed, + runmode=self.runmode, + data_dir=self.data_dir, + save_test=self.save_test, + ) + self._save_pooled_data(self.data_dir, data, datasets, self.file_names, samples=samples) + + def _save_pooled_data(self, data_dir, data, datasets, file_names, samples=10000): + """ + Save pooled data to disk. + Args: + data_dir: Directory to save the data + data: Data to save + datasets: Which datasets were pooled + file_names: The file names to save to + samples: Amount of samples to save + """ + save_folder = "_".join(datasets) + save_folder += f"_{samples}" + save_dir = data_dir / save_folder + if not save_dir.exists(): + save_dir.mkdir() + for key, value in data.items(): + value.to_parquet(save_dir / Path(file_names[key])) + logging.info(f"Saved pooled data at {save_dir}") + + def _pool_datasets( + self, + datasets=None, + samples=10000, + vars=None, + seed=42, + shuffle=True, + runmode=RunMode.classification, + data_dir=Path("data"), + save_test=True, + ): + """ + Pool datasets into a single dataset. + Args: + datasets: list of datasets to pool + samples: Amount of samples + vars: The variables dictionary + seed: Random seed + shuffle: Shuffle samples + runmode: Runmode + data_dir: Where to save the data + save_test: If true, save test data to test on without leakage + Returns: + pooled dataset + """ + if datasets is None: + datasets = {} + if vars is None: + vars = [] + if len(datasets) == 0: + raise ValueError("No datasets supplied.") + pooled_data = {Segment.static: [], Segment.dynamic: [], Segment.outcome: []} + id = vars[Var.group] + int_id = 0 + for key, value in datasets.items(): + int_id += 1 + # Preventing id clashing + repeated_digit = str(int_id) * 4 + outcome = value[Segment.outcome] + static = value[Segment.static] + dynamic = value[Segment.dynamic] + # Get unique stay IDs from outcome segment + stays = pd.Series(outcome[id].unique()) + + if runmode is RunMode.classification: + # If we have more outcomes than stays, check max label value per stay id + labels = outcome.groupby(id).max()[vars[Var.label]].reset_index(drop=True) + # if pd.Series(outcome[id].unique()) is outcome[id]): + selected_stays = train_test_split( + stays, stratify=labels, shuffle=shuffle, random_state=seed, train_size=samples + ) + else: + selected_stays = train_test_split(stays, shuffle=shuffle, random_state=seed, train_size=samples) + # Select only stays that are in the selected_stays + # Save test sets to test on without leakage + if save_test: + select = selected_stays[1] + outcome, static, dynamic = self._select_stays( + outcome=outcome, static=static, dynamic=dynamic, select=select, repeated_digit=repeated_digit + ) + save_folder = key + save_folder += f"_test_{len(select)}" + save_dir = data_dir / save_folder + if not save_dir.exists(): + save_dir.mkdir() + outcome.to_parquet(save_dir / Path("outc.parquet")) + static.to_parquet(save_dir / Path("sta.parquet")) + dynamic.to_parquet(save_dir / Path("dyn.parquet")) + logging.info(f"Saved train data at {save_dir}") + selected_stays = selected_stays[0] + outcome, static, dynamic = self._select_stays( + outcome=outcome, static=static, dynamic=dynamic, select=selected_stays, repeated_digit=repeated_digit + ) + # Adding to pooled data + pooled_data[Segment.static].append(static) + pooled_data[Segment.dynamic].append(dynamic) + pooled_data[Segment.outcome].append(outcome) + # Add each datatype together + for key, value in pooled_data.items(): + pooled_data[key] = pd.concat(value, ignore_index=True) + return pooled_data + + def _select_stays(self, outcome, static, dynamic, select, repeated_digit=1): + """Selects stays for outcome, static, dynamic dataframes. + + Args: + outcome: Outcome dataframe + static: Static dataframe + dynamic: Dynamic dataframe + select: Stay IDs to select + repeated_digit: Digit to repeat for ID clashing + """ + outcome = outcome.loc[outcome[id].isin(select)] + static = static.loc[static[id].isin(select)] + dynamic = dynamic.loc[dynamic[id].isin(select)] + # Preventing id clashing + outcome[id] = outcome[id].map(lambda x: int(str(x) + repeated_digit)) + static[id] = static[id].map(lambda x: int(str(x) + repeated_digit)) + dynamic[id] = dynamic[id].map(lambda x: int(str(x) + repeated_digit)) + return outcome, static, dynamic diff --git a/icu_benchmarks/data/preprocessor.py b/icu_benchmarks/data/preprocessor.py index c924980a..9f6103dd 100644 --- a/icu_benchmarks/data/preprocessor.py +++ b/icu_benchmarks/data/preprocessor.py @@ -1,11 +1,25 @@ +import copy +import pickle + import torch import logging import gin import pandas as pd +import polars as pl from recipys.recipe import Recipe from recipys.selector import all_numeric_predictors, all_outcomes, has_type, all_of -from recipys.step import StepScale, StepImputeFill, StepSklearn, StepHistorical, Accumulator, StepImputeModel +from recipys.step import ( + StepScale, + StepImputeFastForwardFill, + StepImputeFastZeroFill, + StepImputeFill, + StepSklearn, + StepHistorical, + Accumulator, + StepImputeModel, +) + from sklearn.impute import SimpleImputer, MissingIndicator from sklearn.preprocessing import LabelEncoder, FunctionTransformer, MinMaxScaler @@ -15,9 +29,9 @@ import abc -class Preprocessor: +class Preprocessor(abc.ABC): @abc.abstractmethod - def apply(self, data, vars): + def apply(self, data, vars, save_cache=False, load_cache=None, vars_to_exclude=None): return data @abc.abstractmethod @@ -31,13 +45,24 @@ def set_imputation_model(self, imputation_model): @gin.configurable("base_classification_preprocessor") -class DefaultClassificationPreprocessor(Preprocessor): - def __init__(self, generate_features: bool = True, scaling: bool = True, use_static_features: bool = True): +class PolarsClassificationPreprocessor(Preprocessor): + def __init__( + self, + generate_features: bool = False, + scaling: bool = True, + use_static_features: bool = True, + save_cache=None, + load_cache=None, + vars_to_exclude=None, + ): """ Args: generate_features: Generate features for dynamic data. scaling: Scaling of dynamic and static data. use_static_features: Use static features. + save_cache: Save recipe cache from this path. + load_cache: Load recipe cache from this path. + vars_to_exclude: Variables to exclude from missing indicator/ feature generation. Returns: Preprocessed data. """ @@ -45,6 +70,241 @@ def __init__(self, generate_features: bool = True, scaling: bool = True, use_sta self.scaling = scaling self.use_static_features = use_static_features self.imputation_model = None + self.save_cache = save_cache + self.load_cache = load_cache + self.vars_to_exclude = vars_to_exclude + + def apply(self, data, vars) -> dict[dict[pl.DataFrame]]: + """ + Args: + data: Train, validation and test data dictionary. Further divided in static, dynamic, and outcome. + vars: Variables for static, dynamic, outcome. + Returns: + Preprocessed data. + """ + # Check if dynamic features are present + if ( + self.use_static_features + and all(Segment.static in value for value in data.values()) + and len(vars[Segment.static]) > 0 + ): + logging.info("Preprocessing static features.") + data = self._process_static(data, vars) + else: + self.use_static_features = False + + if all(Segment.dynamic in value for value in data.values()): + logging.info("Preprocessing dynamic features.") + logging.info(data.keys()) + data = self._process_dynamic(data, vars) + if self.use_static_features: + # Join static and dynamic data. + data[Split.train][Segment.dynamic] = data[Split.train][Segment.dynamic].join( + data[Split.train][Segment.static], on=vars["GROUP"] + ) + data[Split.val][Segment.dynamic] = data[Split.val][Segment.dynamic].join( + data[Split.val][Segment.static], on=vars["GROUP"] + ) + data[Split.test][Segment.dynamic] = data[Split.test][Segment.dynamic].join( + data[Split.test][Segment.static], on=vars["GROUP"] + ) + + # Remove static features from splits + data[Split.train][Segment.features] = data[Split.train].pop(Segment.static) + data[Split.val][Segment.features] = data[Split.val].pop(Segment.static) + data[Split.test][Segment.features] = data[Split.test].pop(Segment.static) + + # Create feature splits + data[Split.train][Segment.features] = data[Split.train].pop(Segment.dynamic) + data[Split.val][Segment.features] = data[Split.val].pop(Segment.dynamic) + data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic) + elif self.use_static_features: + data[Split.train][Segment.features] = data[Split.train].pop(Segment.static) + data[Split.val][Segment.features] = data[Split.val].pop(Segment.static) + data[Split.test][Segment.features] = data[Split.test].pop(Segment.static) + else: + raise Exception(f"No recognized data segments data to preprocess. Available: {data.keys()}") + logging.debug("Data head") + logging.debug(data[Split.train][Segment.features].head()) + logging.debug(data[Split.train][Segment.outcome]) + for split in [Split.train, Split.val, Split.test]: + if vars["SEQUENCE"] in data[split][Segment.outcome] and len(data[split][Segment.features]) != len( + data[split][Segment.outcome] + ): + raise Exception( + f"Data and outcome length mismatch in {split} split: " + f"features: {len(data[split][Segment.features])}, outcome: {len(data[split][Segment.outcome])}" + ) + data[Split.train][Segment.features] = data[Split.train][Segment.features].unique() + data[Split.val][Segment.features] = data[Split.val][Segment.features].unique() + data[Split.test][Segment.features] = data[Split.test][Segment.features].unique() + + logging.info(f"Generate features: {self.generate_features}") + return data + + def _process_static(self, data, vars): + sta_rec = Recipe(data[Split.train][Segment.static], [], vars[Segment.static]) + sta_rec.add_step(StepSklearn(MissingIndicator(features="all"), sel=all_of(vars[Segment.static]), in_place=False)) + if self.scaling: + sta_rec.add_step(StepScale()) + sta_rec.add_step(StepImputeFill(sel=all_numeric_predictors(), strategy="zero")) + # sta_rec.add_step(StepImputeFastZeroFill(sel=all_numeric_predictors())) + # if len(data[Split.train][Segment.static].select_dtypes(include=["object"]).columns) > 0: + types = ["String", "Object", "Categorical"] + sel = has_type(types) + if len(sel(sta_rec.data)) > 0: + # if len(data[Split.train][Segment.static].select(cs.by_dtype(types)).columns) > 0: + sta_rec.add_step(StepSklearn(SimpleImputer(missing_values=None, strategy="most_frequent"), sel=has_type(types))) + sta_rec.add_step(StepSklearn(LabelEncoder(), sel=has_type(types), columnwise=True)) + + data = apply_recipe_to_splits(sta_rec, data, Segment.static, self.save_cache, self.load_cache) + + return data + + def _model_impute(self, data, group=None): + dataset = ImputationPredictionDataset(data, group, self.imputation_model.trained_columns) + input_data = torch.cat([data_point.unsqueeze(0) for data_point in dataset], dim=0) + self.imputation_model.eval() + with torch.no_grad(): + logging.info(f"Imputing with {self.imputation_model.__class__.__name__}.") + imputation = self.imputation_model.predict(input_data) + logging.info("Imputation done.") + assert imputation.isnan().sum() == 0 + data = data.copy() + data.loc[:, self.imputation_model.trained_columns] = imputation.flatten(end_dim=1).to("cpu") + if group is not None: + data.drop(columns=group, inplace=True) + return data + + def _process_dynamic(self, data, vars): + dyn_rec = Recipe(data[Split.train][Segment.dynamic], [], vars[Segment.dynamic], vars["GROUP"], vars["SEQUENCE"]) + if self.scaling: + dyn_rec.add_step(StepScale()) + if self.imputation_model is not None: + dyn_rec.add_step(StepImputeModel(model=self.model_impute, sel=all_of(vars[Segment.dynamic]))) + if self.vars_to_exclude is not None: + # Exclude vars_to_exclude from missing indicator/ feature generation + vars_to_apply = list(set(vars[Segment.dynamic]) - set(self.vars_to_exclude)) + else: + vars_to_apply = vars[Segment.dynamic] + dyn_rec.add_step(StepSklearn(MissingIndicator(features="all"), sel=all_of(vars_to_apply), in_place=False)) + # dyn_rec.add_step(StepImputeFastForwardFill()) + dyn_rec.add_step(StepImputeFill(strategy="forward")) + # dyn_rec.add_step(StepImputeFastZeroFill()) + dyn_rec.add_step(StepImputeFill(strategy="zero")) + if self.generate_features: + dyn_rec = self._dynamic_feature_generation(dyn_rec, all_of(vars_to_apply)) + data = apply_recipe_to_splits(dyn_rec, data, Segment.dynamic, self.save_cache, self.load_cache) + return data + + def _dynamic_feature_generation(self, data, dynamic_vars): + logging.debug("Adding dynamic feature generation.") + data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MIN, suffix="min_hist")) + data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MAX, suffix="max_hist")) + data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.COUNT, suffix="count_hist")) + data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MEAN, suffix="mean_hist")) + return data + + def to_cache_string(self): + return ( + super().to_cache_string() + + f"_classification_{self.generate_features}_{self.scaling}_{self.imputation_model.__class__.__name__}" + ) + + +@gin.configurable("base_regression_preprocessor") +class PolarsRegressionPreprocessor(PolarsClassificationPreprocessor): + # Override base classification preprocessor + def __init__( + self, + generate_features: bool = False, + scaling: bool = True, + use_static_features: bool = True, + outcome_max=None, + outcome_min=None, + save_cache=None, + load_cache=None, + ): + """ + Args: + generate_features: Generate features for dynamic data. + scaling: Scaling of dynamic and static data. + use_static_features: Use static features. + max_range: Maximum value in outcome. + min_range: Minimum value in outcome. + save_cache: Save recipe cache. + load_cache: Load recipe cache. + Returns: + Preprocessed data. + """ + super().__init__(generate_features, scaling, use_static_features, save_cache, load_cache) + self.outcome_max = outcome_max + self.outcome_min = outcome_min + + def apply(self, data, vars): + """ + Args: + data: Train, validation and test data dictionary. Further divided in static, dynamic, and outcome. + vars: Variables for static, dynamic, outcome. + Returns: + Preprocessed data. + """ + for split in [Split.train, Split.val, Split.test]: + data = self._process_outcome(data, vars, split) + + data = super().apply(data, vars) + return data + + def _process_outcome(self, data, vars, split): + logging.debug(f"Processing {split} outcome values.") + outcome_rec = Recipe(data[split][Segment.outcome], vars["LABEL"], [], vars["GROUP"]) + # If the range is predefined, use predefined transformation function + if self.outcome_max is not None and self.outcome_min is not None: + if self.outcome_max == self.outcome_min: + logging.warning("outcome_max equals outcome_min. Skipping outcome scaling.") + else: + outcome_rec.add_step( + StepSklearn( + sklearn_transformer=FunctionTransformer( + func=lambda x: ((x - self.outcome_min) / (self.outcome_max - self.outcome_min)) + ), + sel=all_outcomes(), + ) + ) + else: + # If the range is not predefined, use MinMaxScaler + outcome_rec.add_step(StepSklearn(MinMaxScaler(), sel=all_outcomes())) + outcome_rec.prep() + data[split][Segment.outcome] = outcome_rec.bake() + return data + + +@gin.configurable("pandas_classification_preprocessor") +class PandasClassificationPreprocessor(Preprocessor): + def __init__( + self, + generate_features: bool = True, + scaling: bool = True, + use_static_features: bool = True, + save_cache=None, + load_cache=None, + ): + """ + Args: + generate_features: Generate features for dynamic data. + scaling: Scaling of dynamic and static data. + use_static_features: Use static features. + save_cache: Save recipe cache from this path. + load_cache: Load recipe cache from this path. + Returns: + Preprocessed data. + """ + self.generate_features = generate_features + self.scaling = scaling + self.use_static_features = use_static_features + self.imputation_model = None + self.save_cache = save_cache + self.load_cache = load_cache def apply(self, data, vars) -> dict[dict[pd.DataFrame]]: """ @@ -55,6 +315,7 @@ def apply(self, data, vars) -> dict[dict[pd.DataFrame]]: Preprocessed data. """ logging.info("Preprocessing dynamic features.") + data = self._process_dynamic(data, vars) if self.use_static_features: logging.info("Preprocessing static features.") @@ -85,6 +346,11 @@ def apply(self, data, vars) -> dict[dict[pd.DataFrame]]: data[Split.train][Segment.features] = data[Split.train].pop(Segment.dynamic) data[Split.val][Segment.features] = data[Split.val].pop(Segment.dynamic) data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic) + + logging.debug("Data head") + logging.debug(data[Split.train][Segment.features].head()) + logging.debug(data[Split.train][Segment.outcome].head()) + logging.info(f"Generate features: {self.generate_features}") return data def _process_static(self, data, vars): @@ -92,11 +358,12 @@ def _process_static(self, data, vars): if self.scaling: sta_rec.add_step(StepScale()) - sta_rec.add_step(StepImputeFill(sel=all_numeric_predictors(), value=0)) - sta_rec.add_step(StepSklearn(SimpleImputer(missing_values=None, strategy="most_frequent"), sel=has_type("object"))) - sta_rec.add_step(StepSklearn(LabelEncoder(), sel=has_type("object"), columnwise=True)) + sta_rec.add_step(StepImputeFastZeroFill(sel=all_numeric_predictors())) + if len(data[Split.train][Segment.static].select_dtypes(include=["object"]).columns) > 0: + sta_rec.add_step(StepSklearn(SimpleImputer(missing_values=None, strategy="most_frequent"), sel=has_type("object"))) + sta_rec.add_step(StepSklearn(LabelEncoder(), sel=has_type("object"), columnwise=True)) - data = apply_recipe_to_splits(sta_rec, data, Segment.static) + data = apply_recipe_to_splits(sta_rec, data, Segment.static, self.save_cache, self.load_cache) return data @@ -122,11 +389,11 @@ def _process_dynamic(self, data, vars): if self.imputation_model is not None: dyn_rec.add_step(StepImputeModel(model=self.model_impute, sel=all_of(vars[Segment.dynamic]))) dyn_rec.add_step(StepSklearn(MissingIndicator(), sel=all_of(vars[Segment.dynamic]), in_place=False)) - dyn_rec.add_step(StepImputeFill(method="ffill")) - dyn_rec.add_step(StepImputeFill(value=0)) + dyn_rec.add_step(StepImputeFastForwardFill()) + dyn_rec.add_step(StepImputeFastZeroFill()) if self.generate_features: dyn_rec = self._dynamic_feature_generation(dyn_rec, all_of(vars[Segment.dynamic])) - data = apply_recipe_to_splits(dyn_rec, data, Segment.dynamic) + data = apply_recipe_to_splits(dyn_rec, data, Segment.dynamic, self.save_cache, self.load_cache) return data def _dynamic_feature_generation(self, data, dynamic_vars): @@ -144,8 +411,8 @@ def to_cache_string(self): ) -@gin.configurable("base_regression_preprocessor") -class DefaultRegressionPreprocessor(DefaultClassificationPreprocessor): +@gin.configurable("pandas_regression_preprocessor") +class PandasRegressionPreprocessor(PandasClassificationPreprocessor): # Override base classification preprocessor def __init__( self, @@ -154,6 +421,8 @@ def __init__( use_static_features: bool = True, outcome_max=None, outcome_min=None, + save_cache=None, + load_cache=None, ): """ Args: @@ -162,10 +431,12 @@ def __init__( use_static_features: Use static features. max_range: Maximum value in outcome. min_range: Minimum value in outcome. + save_cache: Save recipe cache. + load_cache: Load recipe cache. Returns: Preprocessed data. """ - super().__init__(generate_features, scaling, use_static_features) + super().__init__(generate_features, scaling, use_static_features, save_cache, load_cache) self.outcome_max = outcome_max self.outcome_min = outcome_min @@ -205,7 +476,7 @@ def _process_outcome(self, data, vars, split): @gin.configurable("base_imputation_preprocessor") -class DefaultImputationPreprocessor(Preprocessor): +class PandasImputationPreprocessor(Preprocessor): def __init__( self, scaling: bool = True, @@ -236,7 +507,7 @@ def apply(self, data, vars): dyn_rec = Recipe(data[Split.train][Segment.dynamic], [], vars[Segment.dynamic], vars["GROUP"], vars["SEQUENCE"]) if self.scaling: dyn_rec.add_step(StepScale()) - data = apply_recipe_to_splits(dyn_rec, data, Segment.dynamic) + data = apply_recipe_to_splits(dyn_rec, data, Segment.dynamic, self.save_cache, self.load_cache) data[Split.train][Segment.features] = ( data[Split.train].pop(Segment.dynamic).loc[:, vars[Segment.dynamic] + [vars["GROUP"], vars["SEQUENCE"]]] @@ -262,10 +533,15 @@ def _process_dynamic_data(self, data, vars): @staticmethod -def apply_recipe_to_splits(recipe: Recipe, data: dict[dict[pd.DataFrame]], type: str) -> dict[dict[pd.DataFrame]]: +def apply_recipe_to_splits( + recipe: Recipe, data: dict[dict[pd.DataFrame]], type: str, save_cache=None, load_cache=None +) -> dict[dict[pd.DataFrame]]: """Fits and transforms the training features, then transforms the validation and test features with the recipe. + Works with both Polars and Pandas versions of recipys. Args: + load_cache: Load recipe from cache, for e.g. transfer learning. + save_cache: Save recipe to cache, for e.g. transfer learning. recipe: Object containing info about the features and steps. data: Dict containing 'train', 'val', and 'test' and types of features per split. type: Whether to apply recipe to dynamic features, static features or outcomes. @@ -273,7 +549,42 @@ def apply_recipe_to_splits(recipe: Recipe, data: dict[dict[pd.DataFrame]], type: Returns: Transformed features divided into 'train', 'val', and 'test'. """ - data[Split.train][type] = recipe.prep() + + if isinstance(load_cache, str): + # Load existing recipe + recipe = restore_recipe(load_cache) + data[Split.train][type] = recipe.bake(data[Split.train][type]) + elif isinstance(save_cache, str): + # Save prepped recipe + data[Split.train][type] = recipe.prep() + cache_recipe(recipe, save_cache) + else: + # No saving or loading of existing cache + data[Split.train][type] = recipe.prep() + data[Split.val][type] = recipe.bake(data[Split.val][type]) data[Split.test][type] = recipe.bake(data[Split.test][type]) return data + + +def cache_recipe(recipe: Recipe, cache_file: str) -> None: + """Cache recipe to make it available for e.g. transfer learning.""" + recipe_cache = copy.deepcopy(recipe) + recipe_cache.cache() + if not (cache_file / "..").exists(): + (cache_file / "..").mkdir() + cache_file.touch() + with open(cache_file, "wb") as f: + pickle.dump(recipe_cache, f, pickle.HIGHEST_PROTOCOL) + logging.info(f"Cached recipe in {cache_file}.") + + +def restore_recipe(cache_file: str) -> Recipe: + """Restore recipe from cache to use for e.g. transfer learning.""" + if cache_file.exists(): + with open(cache_file, "rb") as f: + logging.info(f"Loading cached recipe from {cache_file}.") + recipe = pickle.load(f) + return recipe + else: + raise FileNotFoundError(f"Cache file {cache_file} not found.") diff --git a/icu_benchmarks/data/split_process_data.py b/icu_benchmarks/data/split_process_data.py index 08db9d75..a47c9729 100644 --- a/icu_benchmarks/data/split_process_data.py +++ b/icu_benchmarks/data/split_process_data.py @@ -1,16 +1,19 @@ +import copy import logging +import os + import gin import json import hashlib import pandas as pd -import pyarrow.parquet as pq +import polars as pl from pathlib import Path import pickle - -from sklearn.model_selection import StratifiedKFold, KFold - -from icu_benchmarks.data.preprocessor import Preprocessor, DefaultClassificationPreprocessor -from icu_benchmarks.contants import RunMode +from timeit import default_timer as timer +from sklearn.model_selection import StratifiedKFold, KFold, StratifiedShuffleSplit, ShuffleSplit +from icu_benchmarks.data.preprocessor import Preprocessor, PandasClassificationPreprocessor, PolarsClassificationPreprocessor +from icu_benchmarks.constants import RunMode +from icu_benchmarks.run_utils import check_required_keys from .constants import DataSplit as Split, DataSegment as Segment, VarType as Var @@ -18,23 +21,33 @@ def preprocess_data( data_dir: Path, file_names: dict[str] = gin.REQUIRED, - preprocessor: Preprocessor = DefaultClassificationPreprocessor, + preprocessor: Preprocessor = PolarsClassificationPreprocessor, use_static: bool = True, vars: dict[str] = gin.REQUIRED, + modality_mapping: dict[str] = {}, + selected_modalities: list[str] = "all", seed: int = 42, debug: bool = False, cv_repetitions: int = 5, repetition_index: int = 0, cv_folds: int = 5, + train_size: int = None, load_cache: bool = False, generate_cache: bool = False, fold_index: int = 0, pretrained_imputation_model: str = None, + complete_train: bool = False, runmode: RunMode = RunMode.classification, -) -> dict[dict[pd.DataFrame]]: + label: str = None, + required_var_types=["GROUP", "SEQUENCE", "LABEL"], + required_segments=[Segment.static, Segment.dynamic, Segment.outcome], +) -> dict[dict[pl.DataFrame]] or dict[dict[pd.DataFrame]]: """Perform loading, splitting, imputing and normalising of task data. Args: + use_static: Whether to use static features (for DL models). + complete_train: Whether to use all data for training/validation. + runmode: Run mode. Can be one of the values of RunMode preprocessor: Define the preprocessor. data_dir: Path to the directory holding the data. file_names: Contains the parquet file names in data_dir. @@ -44,6 +57,7 @@ def preprocess_data( cv_repetitions: Number of times to repeat cross validation. repetition_index: Index of the repetition to return. cv_folds: Number of folds to use for cross validation. + train_size: Fixed size of train split (including validation data). load_cache: Use cached preprocessed data if true. generate_cache: Generate cached preprocessed data if true. fold_index: Index of the fold to return. @@ -55,21 +69,45 @@ def preprocess_data( """ cache_dir = data_dir / "cache" - + check_required_keys(vars, required_var_types) + check_required_keys(file_names, required_segments) if not use_static: file_names.pop(Segment.static) vars.pop(Segment.static) - + if isinstance(vars[Var.label], list) and len(vars[Var.label]) > 1: + if label is not None: + vars[Var.label] = [label] + else: + logging.debug(f"Multiple labels found and no value provided. Using first label: {vars[Var.label]}") + vars[Var.label] = vars[Var.label][0] + logging.info(f"Using label: {vars[Var.label]}") + if not vars[Var.label]: + raise ValueError("No label selected after filtering.") dumped_file_names = json.dumps(file_names, sort_keys=True) dumped_vars = json.dumps(vars, sort_keys=True) + cache_filename = f"s_{seed}_r_{repetition_index}_f_{fold_index}_t_{train_size}_d_{debug}" + logging.log(logging.INFO, f"Using preprocessor: {preprocessor.__name__}") - preprocessor = preprocessor(use_static_features=use_static) - if isinstance(preprocessor, DefaultClassificationPreprocessor): + vars_to_exclude = ( + modality_mapping.get("cat_clinical_notes") + modality_mapping.get("cat_med_embeddings_map") + if ( + modality_mapping.get("cat_clinical_notes") is not None + and modality_mapping.get("cat_med_embeddings_map") is not None + ) + else None + ) + + preprocessor = preprocessor( + use_static_features=use_static, + save_cache=data_dir / "preproc" / (cache_filename + "_recipe"), + vars_to_exclude=vars_to_exclude, + ) + if isinstance(preprocessor, PandasClassificationPreprocessor): preprocessor.set_imputation_model(pretrained_imputation_model) - hash_config = f"{preprocessor.to_cache_string()}{dumped_file_names}{dumped_vars}{debug}".encode("utf-8") - cache_filename = f"s_{seed}_r_{repetition_index}_f_{fold_index}_d_{debug}_{hashlib.md5(hash_config).hexdigest()}" + hash_config = hashlib.md5(f"{preprocessor.to_cache_string()}{dumped_file_names}{dumped_vars}".encode("utf-8")) + cache_filename += f"_{hash_config.hexdigest()}" cache_file = cache_dir / cache_filename if load_cache: @@ -82,16 +120,66 @@ def preprocess_data( # Read parquet files into pandas dataframes and remove the parquet file from memory logging.info(f"Loading data from directory {data_dir.absolute()}") - data = {f: pq.read_table(data_dir / file_names[f]).to_pandas(self_destruct=True) for f in file_names.keys()} + data = { + f: pl.read_parquet(data_dir / file_names[f]) for f in file_names.keys() if os.path.exists(data_dir / file_names[f]) + } + logging.info(f"Loaded data: {list(data.keys())}") + data = check_sanitize_data(data, vars) + + if not (Segment.dynamic in data.keys()): + logging.warning("No dynamic data found, using only static data.") + + logging.debug(f"Modality mapping: {modality_mapping}") + if len(modality_mapping) > 0: + # Optional modality selection + if selected_modalities not in [None, "all", ["all"]]: + data, vars = modality_selection(data, modality_mapping, selected_modalities, vars) + else: + logging.info("Selecting all modalities.") # Generate the splits logging.info("Generating splits.") - data = make_single_split( - data, vars, cv_repetitions, repetition_index, cv_folds, fold_index, seed=seed, debug=debug, runmode=runmode - ) + # complete_train = True + if not complete_train: + data = make_single_split( + data, + vars, + cv_repetitions, + repetition_index, + cv_folds, + fold_index, + train_size=train_size, + seed=seed, + debug=debug, + runmode=runmode, + ) + else: + # If full train is set, we use all data for training/validation + data = make_train_val(data, vars, train_size=None, seed=seed, debug=debug, runmode=runmode) # Apply preprocessing + + start = timer() data = preprocessor.apply(data, vars) + end = timer() + logging.info(f"Preprocessing took {end - start:.2f} seconds.") + logging.info(f"Checking for NaNs and nulls in {data.keys()}.") + for dict in data.values(): + for key, val in dict.items(): + logging.debug(f"Data type: {key}") + logging.debug("Is NaN:") + sel = dict[key].select(pl.selectors.numeric().is_nan().max()) + logging.debug(sel.select(col.name for col in sel if col.item(0))) + # logging.info(dict[key].select(pl.all().has_nulls()).sum_horizontal()) + logging.debug("Has nulls:") + sel = dict[key].select(pl.all().has_nulls()) + logging.debug(sel.select(col.name for col in sel if col.item(0))) + # dict[key] = val[:, [not (s.null_count() > 0) for s in val]] + dict[key] = val.fill_null(strategy="zero") + dict[key] = val.fill_nan(0) + logging.debug("Dropping columns with nulls") + sel = dict[key].select(pl.all().has_nulls()) + logging.debug(sel.select(col.name for col in sel if col.item(0))) # Generate cache if generate_cache: @@ -104,6 +192,154 @@ def preprocess_data( return data +def check_sanitize_data(data, vars): + """Check for duplicates in the loaded data and remove them.""" + group = vars[Var.group] if Var.group in vars.keys() else None + sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None + keep = "last" + if Segment.static in data.keys(): + old_len = len(data[Segment.static]) + data[Segment.static] = data[Segment.static].unique(subset=group, keep=keep, maintain_order=True) + logging.warning(f"Removed {old_len - len(data[Segment.static])} duplicates from static data.") + if Segment.dynamic in data.keys(): + old_len = len(data[Segment.dynamic]) + data[Segment.dynamic] = data[Segment.dynamic].unique(subset=[group, sequence], keep=keep, maintain_order=True) + logging.warning(f"Removed {old_len - len(data[Segment.dynamic])} duplicates from dynamic data.") + if Segment.outcome in data.keys(): + old_len = len(data[Segment.outcome]) + if sequence in data[Segment.outcome].columns: + # We have a dynamic outcome with group and sequence + data[Segment.outcome] = data[Segment.outcome].unique(subset=[group, sequence], keep=keep, maintain_order=True) + else: + data[Segment.outcome] = data[Segment.outcome].unique(subset=[group], keep=keep, maintain_order=True) + logging.warning(f"Removed {old_len - len(data[Segment.outcome])} duplicates from outcome data.") + return data + + +def modality_selection( + data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars +) -> dict[pl.DataFrame]: + logging.info(f"Selected modalities: {selected_modalities}") + selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()] + if not any(col in modality_mapping.keys() for col in selected_modalities): + raise ValueError("None of the selected modalities found in modality mapping.") + if selected_columns == []: + logging.info("No columns selected. Using all columns.") + return data, vars + selected_columns = sum(selected_columns, []) + selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) + old_columns = [] + # Update vars dict + for key, value in vars.items(): + if key not in [Var.group, Var.label, Var.sequence]: + old_columns.extend(value) + vars[key] = [col for col in value if col in selected_columns] + # -3 because of standard columns + logging.info(f"Selected columns: {len(selected_columns) - 3}, old columns: {len(old_columns)}") + logging.debug(f"Difference: {set(old_columns) - set(selected_columns)}") + # Update data dict + for key in data.keys(): + sel_col = [col for col in data[key].columns if col in selected_columns] + data[key] = data[key].select(sel_col) + logging.debug(f"Selected columns in {key}: {len(data[key].columns)}") + return data, vars + + +def make_train_val( + data: dict[pd.DataFrame], + vars: dict[str], + train_size=0.8, + seed: int = 42, + debug: bool = False, + runmode: RunMode = RunMode.classification, + polars: bool = True, +) -> dict[dict[pl.DataFrame]]: + """Randomly split the data into training and validation sets for fitting a full model. + + Args: + data: dictionary containing data divided int OUTCOME, STATIC, and DYNAMIC. + vars: Contains the names of columns in the data. + train_size: Fixed size of train split (including validation data). + seed: Random seed. + debug: Load less data if true. + Returns: + Input data divided into 'train', 'val', and 'test'. + """ + # ID variable + id = vars[Var.group] + + if debug: + # Only use 1% of the data + logging.info("Using only 1% of the data for debugging. Note that this might lead to errors for small datasets.") + if polars: + data[Segment.outcome] = data[Segment.outcome].sample(fraction=0.01, seed=seed) + else: + data[Segment.outcome] = data[Segment.outcome].sample(frac=0.01, random_state=seed) + + # Get stay IDs from outcome segment + stays = _get_stays(data, id, polars) + + # If there are labels, and the task is classification, use stratified k-fold + if Var.label in vars and runmode is RunMode.classification: + # Get labels from outcome data (takes the highest value (or True) in case seq2seq classification) + labels = _get_labels(data, id, vars, polars) + train_val = StratifiedShuffleSplit(train_size=train_size, random_state=seed, n_splits=1) + train, val = list(train_val.split(stays, labels))[0] + + else: + # If there are no labels, use random split + train_val = ShuffleSplit(train_size=train_size, random_state=seed) + train, val = list(train_val.split(stays))[0] + + if polars: + split = { + Split.train: stays[train].cast(pl.datatypes.Int64).to_frame(), + Split.val: stays[val].cast(pl.datatypes.Int64).to_frame(), + } + else: + split = {Split.train: stays.iloc[train], Split.val: stays.iloc[val]} + + data_split = {} + + for fold in split.keys(): # Loop through splits (train / val / test) + # Loop through segments (DYNAMIC / STATIC / OUTCOME) + # set sort to true to make sure that IDs are reordered after scrambling earlier + if polars: + data_split[fold] = { + data_type: split[fold] + .join(data[data_type].with_columns(pl.col(id).cast(pl.datatypes.Int64)), on=id, how="left") + .sort(by=id) + for data_type in data.keys() + } + else: + data_split[fold] = { + data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() + } + + # Maintain compatibility with test split + data_split[Split.test] = copy.deepcopy(data_split[Split.val]) + return data_split + + +def _get_stays(data, id, polars): + return ( + pl.Series(name=id, values=data[Segment.outcome][id].unique()) + if polars + else pd.Series(data[Segment.outcome][id].unique(), name=id) + ) + + +def _get_labels(data, id, vars, polars): + # Get labels from outcome data (takes the highest value (or True) in case seq2seq classification) + if polars: + return data[Segment.outcome].group_by(id).max()[vars[Var.label]] + else: + return data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) + + +# Use these helper functions in both make_train_val and make_single_split + + def make_single_split( data: dict[pd.DataFrame], vars: dict[str], @@ -111,10 +347,12 @@ def make_single_split( repetition_index: int, cv_folds: int, fold_index: int, + train_size: int = None, seed: int = 42, debug: bool = False, runmode: RunMode = RunMode.classification, -) -> dict[dict[pd.DataFrame]]: + polars: bool = True, +) -> dict[dict[pl.DataFrame]]: """Randomly split the data into training, validation, and test set. Args: @@ -125,6 +363,7 @@ def make_single_split( repetition_index: Index of the repetition to return. cv_folds: Number of folds for cross validation. fold_index: Index of the fold to return. + train_size: Fixed size of train split (including validation data). seed: Random seed. debug: Load less data if true. @@ -134,51 +373,93 @@ def make_single_split( # ID variable id = vars[Var.group] - # Get stay IDs from outcome segment - stays = pd.Series(data[Segment.outcome][id].unique(), name=id) - if debug: # Only use 1% of the data - stays = stays.sample(frac=0.01, random_state=seed) + logging.info("Using only 1% of the data for debugging. Note that this might lead to errors for small datasets.") + if polars: + data[Segment.outcome] = data[Segment.outcome].sample(fraction=0.01, seed=seed) + else: + data[Segment.outcome] = data[Segment.outcome].sample(frac=0.01, random_state=seed) + # Get stay IDs from outcome segment + if polars: + stays = pl.Series(name=id, values=data[Segment.outcome][id].unique()) + else: + stays = pd.Series(data[Segment.outcome][id].unique(), name=id) # If there are labels, and the task is classification, use stratified k-fold if Var.label in vars and runmode is RunMode.classification: # Get labels from outcome data (takes the highest value (or True) in case seq2seq classification) - labels = data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) - if labels.value_counts().min() < cv_folds: - raise Exception( - f"The smallest amount of samples in a class is: {labels.value_counts().min()}, " - f"but {cv_folds} folds are requested. Reduce the number of folds or use more data." - ) - outer_cv = StratifiedKFold(cv_repetitions, shuffle=True, random_state=seed) + if polars: + labels = data[Segment.outcome].group_by(id).max()[vars[Var.label]] + if labels.value_counts().min().item(0, 1) < cv_folds: + raise Exception( + f"The smallest amount of samples in a class is: {labels.value_counts().min()}, " + f"but {cv_folds} folds are requested. Reduce the number of folds or use more data." + ) + else: + labels = data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) + if labels.value_counts().min() < cv_folds: + raise Exception( + f"The smallest amount of samples in a class is: {labels.value_counts().min()}, " + f"but {cv_folds} folds are requested. Reduce the number of folds or use more data." + ) + + if train_size: + outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size) + else: + outer_cv = StratifiedKFold(cv_repetitions, shuffle=True, random_state=seed) inner_cv = StratifiedKFold(cv_folds, shuffle=True, random_state=seed) dev, test = list(outer_cv.split(stays, labels))[repetition_index] - dev_stays = stays.iloc[dev] - train, val = list(inner_cv.split(dev_stays, labels.iloc[dev]))[fold_index] + if polars: + dev_stays = stays[dev] + train, val = list(inner_cv.split(dev_stays, labels[dev]))[fold_index] + else: + dev_stays = stays.iloc[dev] + train, val = list(inner_cv.split(dev_stays, labels.iloc[dev]))[fold_index] else: # If there are no labels, or the task is regression, use regular k-fold. - outer_cv = KFold(cv_repetitions, shuffle=True, random_state=seed) + if train_size: + outer_cv = ShuffleSplit(cv_repetitions, train_size=train_size) + else: + outer_cv = KFold(cv_repetitions, shuffle=True, random_state=seed) inner_cv = KFold(cv_folds, shuffle=True, random_state=seed) dev, test = list(outer_cv.split(stays))[repetition_index] - dev_stays = stays.iloc[dev] + if polars: + dev_stays = stays[dev] + else: + dev_stays = stays.iloc[dev] train, val = list(inner_cv.split(dev_stays))[fold_index] - - split = { - Split.train: dev_stays.iloc[train], - Split.val: dev_stays.iloc[val], - Split.test: stays.iloc[test], - } + if polars: + split = { + Split.train: dev_stays[train].cast(pl.datatypes.Int64).to_frame(), + Split.val: dev_stays[val].cast(pl.datatypes.Int64).to_frame(), + Split.test: stays[test].cast(pl.datatypes.Int64).to_frame(), + } + else: + split = { + Split.train: dev_stays.iloc[train], + Split.val: dev_stays.iloc[val], + Split.test: stays.iloc[test], + } data_split = {} for fold in split.keys(): # Loop through splits (train / val / test) # Loop through segments (DYNAMIC / STATIC / OUTCOME) # set sort to true to make sure that IDs are reordered after scrambling earlier - data_split[fold] = { - data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() - } - + if polars: + data_split[fold] = { + data_type: split[fold] + .join(data[data_type].with_columns(pl.col(id).cast(pl.datatypes.Int64)), on=id, how="left") + .sort(by=id) + for data_type in data.keys() + } + else: + data_split[fold] = { + data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() + } + logging.info(f"Data split: {data_split}") return data_split diff --git a/icu_benchmarks/imputation/baselines.py b/icu_benchmarks/imputation/baselines.py index 27ddf307..75b23fc5 100644 --- a/icu_benchmarks/imputation/baselines.py +++ b/icu_benchmarks/imputation/baselines.py @@ -15,8 +15,7 @@ class KNNImputation(ImputationWrapper): """Imputation using Scikit-Learn K-Nearest Neighbour.""" - needs_training = False - needs_fit = True + requires_backprop = False def __init__(self, *args, n_neighbors=2, **kwargs) -> None: super().__init__(*args, n_neighbors=n_neighbors, **kwargs) @@ -38,8 +37,7 @@ def forward(self, amputated_values, amputation_mask): class MICEImputation(ImputationWrapper): """Imputation using Scikit-Learn MICE.""" - needs_training = False - needs_fit = True + requires_backprop = False def __init__(self, *args, max_iter=100, verbose=2, imputation_order="random", random_state=0, **kwargs) -> None: super().__init__( @@ -69,8 +67,7 @@ def forward(self, amputated_values, amputation_mask): class MeanImputation(ImputationWrapper): """Mean imputation using Scikit-Learn SimpleImputer.""" - needs_training = False - needs_fit = True + requires_backprop = False def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -90,8 +87,9 @@ def forward(self, amputated_values, amputation_mask): @gin.configurable("Median") class MedianImputation(ImputationWrapper): - needs_training = False - needs_fit = True + """Median imputation using Scikit-Learn SimpleImputer.""" + + requires_backprop = False def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -113,8 +111,7 @@ def forward(self, amputated_values, amputation_mask): class ZeroImputation(ImputationWrapper): """Zero imputation using Scikit-Learn SimpleImputer.""" - needs_training = False - needs_fit = True + requires_backprop = False def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -136,8 +133,7 @@ def forward(self, amputated_values, amputation_mask): class MostFrequentImputation(ImputationWrapper): """Most frequent imputation using Scikit-Learn SimpleImputer.""" - needs_training = False - needs_fit = True + requires_backprop = False def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -159,8 +155,7 @@ def wrap_hyperimpute_model(methodName: str, configName: str) -> Type: class HyperImputeImputation(ImputationWrapper): """Imputation using HyperImpute package.""" - needs_training = False - needs_fit = True + requires_backprop = False def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -198,8 +193,7 @@ def forward(self, amputated_values, amputation_mask): class BRITSImputation(ImputationWrapper): """Bidirectional Recurrent Imputation for Time Series (BRITS) imputation using PyPots package.""" - needs_training = False - needs_fit = True + requires_backprop = False def __init__(self, *args, input_size, epochs=1, rnn_hidden_size=64, batch_size=256, **kwargs) -> None: super().__init__( @@ -234,8 +228,7 @@ def forward(self, amputated_values, amputation_mask): class SAITSImputation(ImputationWrapper): """Self-Attention based Imputation for Time Series (SAITS) imputation using PyPots package.""" - needs_training = False - needs_fit = True + requires_backprop = False def __init__(self, *args, input_size, epochs, n_layers, d_model, d_inner, n_head, d_k, d_v, dropout, **kwargs) -> None: super().__init__( @@ -284,8 +277,8 @@ def forward(self, amputated_values, amputation_mask): class AttentionImputation(ImputationWrapper): """Attention based Imputation (Transformer) imputation using PyPots package.""" - needs_training = False - needs_fit = True + # Handled within the library + requires_backprop = False def __init__(self, *args, input_size, epochs, n_layers, d_model, d_inner, n_head, d_k, d_v, dropout, **kwargs) -> None: super().__init__( diff --git a/icu_benchmarks/imputation/diffusion.py b/icu_benchmarks/imputation/diffusion.py index 2f270bd3..c841d703 100644 --- a/icu_benchmarks/imputation/diffusion.py +++ b/icu_benchmarks/imputation/diffusion.py @@ -20,8 +20,7 @@ class SimpleDiffusionModel(ImputationWrapper): """Simple Diffusion Model for Imputation. See https://arxiv.org/abs/2006.11239 for more details.""" - needs_training = True - needs_fit = False + requires_backprop = True input_size = [] diff --git a/icu_benchmarks/imputation/diffwave.py b/icu_benchmarks/imputation/diffwave.py index 0945c0e6..147eeb37 100644 --- a/icu_benchmarks/imputation/diffwave.py +++ b/icu_benchmarks/imputation/diffwave.py @@ -301,7 +301,7 @@ def forward(self, input_data): cond = self.cond_conv(cond) h += cond - out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) + out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) res = self.res_conv(out) assert x.shape == res.shape diff --git a/icu_benchmarks/imputation/mlp.py b/icu_benchmarks/imputation/mlp.py index 68d98e23..e6062814 100644 --- a/icu_benchmarks/imputation/mlp.py +++ b/icu_benchmarks/imputation/mlp.py @@ -8,8 +8,7 @@ class MLPImputation(ImputationWrapper): """Imputation model based on a Multi-Layer Perceptron (MLP).""" - needs_training = True - needs_fit = False + requires_backprop = True def __init__(self, *args, input_size, num_hidden_layers=3, hidden_layer_size=10, **kwargs) -> None: super().__init__( diff --git a/icu_benchmarks/imputation/np.py b/icu_benchmarks/imputation/np.py index 6a974fe5..6f0d766d 100644 --- a/icu_benchmarks/imputation/np.py +++ b/icu_benchmarks/imputation/np.py @@ -13,8 +13,7 @@ class NPImputation(ImputationWrapper): """Imputation using Neural Processes. Implementation adapted from https://github.com/EmilienDupont/neural-processes/. Provides imputation wrapper for NeuralProcess class.""" - needs_training = True - needs_fit = False + requires_backprop = True def __init__( self, diff --git a/icu_benchmarks/imputation/rnn.py b/icu_benchmarks/imputation/rnn.py index 9f6a7213..1d514c32 100644 --- a/icu_benchmarks/imputation/rnn.py +++ b/icu_benchmarks/imputation/rnn.py @@ -10,8 +10,7 @@ class RNNImputation(ImputationWrapper): """Imputation model with Gated Recurrent Units (GRU) or Long-Short Term Memory Network (LSTM). Defaults to GRU.""" - needs_training = True - needs_fit = False + requires_backprop = True def __init__(self, *args, input_size, hidden_size=64, state_init="zero", cell="gru", **kwargs) -> None: super().__init__(*args, input_size=input_size, hidden_size=hidden_size, state_init=state_init, cell=cell, **kwargs) @@ -69,8 +68,7 @@ class BRNNImputation(ImputationWrapper): """Imputation model with Bidirectional Gated Recurrent Units (GRU) or Long-Short Term Memory Network (LSTM). Defaults to GRU.""" - needs_training = True - needs_fit = False + requires_backprop = True def __init__(self, *args, input_size, hidden_size=64, state_init="zero", dropout=0.0, cell="gru", **kwargs) -> None: super().__init__( diff --git a/icu_benchmarks/imputation/simple_diffusion.py b/icu_benchmarks/imputation/simple_diffusion.py index 3590a5a0..f2c23d10 100644 --- a/icu_benchmarks/imputation/simple_diffusion.py +++ b/icu_benchmarks/imputation/simple_diffusion.py @@ -11,8 +11,7 @@ class SimpleDiffusionModel(ImputationWrapper): """Imputation model based on a Simple Diffusion Model. Adapted from https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL.""" - needs_training = True - needs_fit = False + requires_backprop = True input_size = [] diff --git a/icu_benchmarks/models/constants.py b/icu_benchmarks/models/constants.py index 70ba5b65..43843db8 100644 --- a/icu_benchmarks/models/constants.py +++ b/icu_benchmarks/models/constants.py @@ -1,5 +1,5 @@ -from ignite.contrib.metrics import AveragePrecision, ROC_AUC, PrecisionRecallCurve, RocCurve -from ignite.metrics import Accuracy, RootMeanSquaredError # , ConfusionMatrix +from ignite.contrib.metrics import AveragePrecision, ROC_AUC, RocCurve, PrecisionRecallCurve +from ignite.metrics import Accuracy, RootMeanSquaredError from sklearn.calibration import calibration_curve from sklearn.metrics import ( average_precision_score, @@ -9,36 +9,43 @@ mean_absolute_error, precision_recall_curve, roc_curve, - # confusion_matrix, r2_score, mean_squared_error, - # f1_score, +) +from torchmetrics.classification import ( + AUROC, + AveragePrecision as TorchMetricsAveragePrecision, + PrecisionRecallCurve as TorchMetricsPrecisionRecallCurve, + CalibrationError, + F1Score, ) from enum import Enum - -from icu_benchmarks.models.metrics import CalibrationCurve, BalancedAccuracy, MAE, JSD +from icu_benchmarks.models.custom_metrics import ( + CalibrationCurve, + BalancedAccuracy, + MAE, + JSD, + BinaryFairnessWrapper, + confusion_matrix, +) -# TODO: revise transformation for metrics in wrappers.py in order to handle metrics that can not handle a mix of binary and -# continuous targets class MLMetrics: BINARY_CLASSIFICATION = { "AUC": roc_auc_score, "Calibration_Curve": calibration_curve, - # "Confusion_Matrix": confusion_matrix, - # "F1": f1_score, "PR": average_precision_score, "PR_Curve": precision_recall_curve, "RO_Curve": roc_curve, + "Confusion_Matrix": confusion_matrix, } MULTICLASS_CLASSIFICATION = { "Accuracy": accuracy_score, "AUC": roc_auc_score, "Balanced_Accuracy": balanced_accuracy_score, - # "Confusion_Matrix": confusion_matrix, - # "F1": f1_score, - "PR": average_precision_score, + # "PR": average_precision_score, + "Confusion_Matrix": confusion_matrix, } REGRESSION = { @@ -53,12 +60,20 @@ class DLMetrics: BINARY_CLASSIFICATION = { "AUC": ROC_AUC, "Calibration_Curve": CalibrationCurve, - # "Confusion_Matrix": ConfusionMatrix(num_classes=2), "PR": AveragePrecision, "PR_Curve": PrecisionRecallCurve, "RO_Curve": RocCurve, } + BINARY_CLASSIFICATION_TORCHMETRICS = { + "AUC": AUROC(task="binary"), + "PR": TorchMetricsAveragePrecision(task="binary"), + "PrecisionRecallCurve": TorchMetricsPrecisionRecallCurve(task="binary"), + "Calibration_Error": CalibrationError(task="binary", n_bins=10), + "F1": F1Score(task="binary", num_classes=2), + "Binary_Fairness": BinaryFairnessWrapper(num_groups=2, task="demographic_parity", group_name="sex"), + } + MULTICLASS_CLASSIFICATION = { "Accuracy": Accuracy, "BalancedAccuracy": BalancedAccuracy, diff --git a/icu_benchmarks/models/custom_metrics.py b/icu_benchmarks/models/custom_metrics.py new file mode 100644 index 00000000..eb0a5d23 --- /dev/null +++ b/icu_benchmarks/models/custom_metrics.py @@ -0,0 +1,145 @@ +import torch +from typing import Callable +import numpy as np +from ignite.metrics import EpochMetric +from numpy import ndarray +from sklearn.metrics import balanced_accuracy_score, mean_absolute_error, confusion_matrix as sk_confusion_matrix +from sklearn.calibration import calibration_curve +from scipy.spatial.distance import jensenshannon +from torchmetrics.classification import BinaryFairness + +"""" +This file contains custom metrics that can be added to YAIB. +""" + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class BalancedAccuracy(EpochMetric): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: + super(BalancedAccuracy, self).__init__( + self.balanced_accuracy_compute, output_transform=output_transform, check_compute_fn=check_compute_fn + ) + + def balanced_accuracy_compute(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: + y_true = y_targets.numpy() + y_pred = np.argmax(y_preds.numpy(), axis=-1) + return balanced_accuracy_score(y_true, y_pred) + + +class CalibrationCurve(EpochMetric): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: + super(CalibrationCurve, self).__init__( + self.ece_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn + ) + + def ece_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, n_bins=10) -> float: + y_true = y_targets.numpy() + y_pred = y_preds.numpy() + return calibration_curve(y_true, y_pred, n_bins=n_bins) + + +class MAE(EpochMetric): + def __init__( + self, + output_transform: Callable = lambda x: x, + check_compute_fn: bool = False, + invert_transform: Callable = lambda x: x, + ) -> None: + super(MAE, self).__init__( + lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform), + output_transform=output_transform, + check_compute_fn=check_compute_fn, + ) + + def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: + y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] + y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] + return mean_absolute_error(y_true, y_pred) + + +class JSD(EpochMetric): + def __init__( + self, + output_transform: Callable = lambda x: x, + check_compute_fn: bool = False, + ) -> None: + super(JSD, self).__init__( + lambda x, y: JSD_fn(x, y), + output_transform=output_transform, + check_compute_fn=check_compute_fn, + ) + + def JSD_fn(y_preds: torch.Tensor, y_targets: torch.Tensor): + return jensenshannon(abs(y_preds).flatten(), abs(y_targets).flatten()) ** 2 + + +class TorchMetricsWrapper: + metric = None + + def __init__(self, metric) -> None: + self.metric = metric + + def update(self, output_tuple) -> None: + self.metric.update(output_tuple[0], output_tuple[1]) + + def compute(self) -> None: + return self.metric.compute() + + def reset(self) -> None: + return self.metric.reset() + + +class BinaryFairnessWrapper(BinaryFairness): + """ + This class is a wrapper for the BinaryFairness metric from TorchMetrics. + """ + + group_name = None + + def __init__(self, group_name="sex", *args, **kwargs) -> None: + self.group_name = group_name + super().__init__(*args, **kwargs) + + def update(self, preds, target, data, feature_names) -> None: + """ " Standard metric update function""" + groups = data[:, :, feature_names.index(self.group_name)] + group_per_id = groups[:, 0] + return super().update(preds=preds.cpu(), target=target.cpu(), groups=group_per_id.long().cpu()) + + def feature_helper(self, trainer, step_prefix): + """Helper function to get the feature names from the trainer""" + if step_prefix == "train": + feature_names = trainer.train_dataloader.dataset.features + elif step_prefix == "val": + feature_names = trainer.train_dataloader.dataset.features + else: + feature_names = trainer.test_dataloaders.dataset.features + return feature_names + + +def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize=False) -> torch.tensor: + y_pred = np.rint(y_pred).astype(int) + confusion = sk_confusion_matrix(y_true, y_pred) + if normalize: + confusion = confusion / confusion.sum() + confusion_dict = {} + for i in range(confusion.shape[0]): + for j in range(confusion.shape[1]): + confusion_dict[f"class_{i}_pred_{j}"] = confusion[i][j] + return confusion_dict diff --git a/icu_benchmarks/models/dl_models.py b/icu_benchmarks/models/dl_models.py deleted file mode 100644 index 0fb1b0d2..00000000 --- a/icu_benchmarks/models/dl_models.py +++ /dev/null @@ -1,282 +0,0 @@ -import gin -from numbers import Integral -import numpy as np -import torch.nn as nn -from icu_benchmarks.contants import RunMode -from icu_benchmarks.models.layers import TransformerBlock, LocalBlock, TemporalBlock, PositionalEncoding -from icu_benchmarks.models.wrappers import DLPredictionWrapper - - -@gin.configurable -class RNNet(DLPredictionWrapper): - """Torch standard RNN model""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): - super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs - ) - self.hidden_dim = hidden_dim - self.layer_dim = layer_dim - self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) - self.logit = nn.Linear(hidden_dim, num_classes) - - def init_hidden(self, x): - h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) - return h0 - - def forward(self, x): - h0 = self.init_hidden(x) - out, hn = self.rnn(x, h0) - pred = self.logit(out) - return pred - - -@gin.configurable -class LSTMNet(DLPredictionWrapper): - """Torch standard LSTM model.""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): - super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs - ) - self.hidden_dim = hidden_dim - self.layer_dim = layer_dim - self.rnn = nn.LSTM(input_size[2], hidden_dim, layer_dim, batch_first=True) - self.logit = nn.Linear(hidden_dim, num_classes) - - def init_hidden(self, x): - h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) - c0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) - return [t for t in (h0, c0)] - - def forward(self, x): - h0, c0 = self.init_hidden(x) - out, h = self.rnn(x, (h0, c0)) - pred = self.logit(out) - return pred - - -@gin.configurable -class GRUNet(DLPredictionWrapper): - """Torch standard GRU model.""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): - super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs - ) - self.hidden_dim = hidden_dim - self.layer_dim = layer_dim - self.rnn = nn.GRU(input_size[2], hidden_dim, layer_dim, batch_first=True) - self.logit = nn.Linear(hidden_dim, num_classes) - - def init_hidden(self, x): - h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) - return h0 - - def forward(self, x): - h0 = self.init_hidden(x) - out, hn = self.rnn(x, h0) - pred = self.logit(out) - - return pred - - -@gin.configurable -class Transformer(DLPredictionWrapper): - """Transformer model as defined by the HiRID-Benchmark (https://github.com/ratschlab/HIRID-ICU-Benchmark).""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__( - self, - input_size, - hidden, - heads, - ff_hidden_mult, - depth, - num_classes, - *args, - dropout=0.0, - l1_reg=0, - pos_encoding=True, - dropout_att=0.0, - **kwargs, - ): - super().__init__( - input_size=input_size, - hidden=hidden, - heads=heads, - ff_hidden_mult=ff_hidden_mult, - depth=depth, - num_classes=num_classes, - *args, - dropout=dropout, - l1_reg=l1_reg, - pos_encoding=pos_encoding, - dropout_att=dropout_att, - **kwargs, - ) - hidden = hidden if hidden % 2 == 0 else hidden + 1 # Make sure hidden is even - self.input_embedding = nn.Linear(input_size[2], hidden) # This acts as a time-distributed layer by defaults - if pos_encoding: - self.pos_encoder = PositionalEncoding(hidden) - else: - self.pos_encoder = None - - tblocks = [] - for i in range(depth): - tblocks.append( - TransformerBlock( - emb=hidden, - hidden=hidden, - heads=heads, - mask=True, - ff_hidden_mult=ff_hidden_mult, - dropout=dropout, - dropout_att=dropout_att, - ) - ) - - self.tblocks = nn.Sequential(*tblocks) - self.logit = nn.Linear(hidden, num_classes) - self.l1_reg = l1_reg - - def forward(self, x): - x = self.input_embedding(x) - if self.pos_encoder is not None: - x = self.pos_encoder(x) - x = self.tblocks(x) - pred = self.logit(x) - - return pred - - -@gin.configurable -class LocalTransformer(DLPredictionWrapper): - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__( - self, - input_size, - hidden, - heads, - ff_hidden_mult, - depth, - num_classes, - *args, - dropout=0.0, - l1_reg=0, - pos_encoding=True, - local_context=1, - dropout_att=0.0, - **kwargs, - ): - super().__init__( - input_size=input_size, - hidden=hidden, - heads=heads, - ff_hidden_mult=ff_hidden_mult, - depth=depth, - num_classes=num_classes, - *args, - dropout=dropout, - l1_reg=l1_reg, - pos_encoding=pos_encoding, - local_context=local_context, - dropout_att=dropout_att, - **kwargs, - ) - - hidden = hidden if hidden % 2 == 0 else hidden + 1 # Make sure hidden is even - self.input_embedding = nn.Linear(input_size[2], hidden) # This acts as a time-distributed layer by defaults - if pos_encoding: - self.pos_encoder = PositionalEncoding(hidden) - else: - self.pos_encoder = None - - tblocks = [] - for i in range(depth): - tblocks.append( - LocalBlock( - emb=hidden, - hidden=hidden, - heads=heads, - mask=True, - ff_hidden_mult=ff_hidden_mult, - local_context=local_context, - dropout=dropout, - dropout_att=dropout_att, - ) - ) - - self.tblocks = nn.Sequential(*tblocks) - self.logit = nn.Linear(hidden, num_classes) - self.l1_reg = l1_reg - - def forward(self, x): - x = self.input_embedding(x) - if self.pos_encoder is not None: - x = self.pos_encoder(x) - x = self.tblocks(x) - pred = self.logit(x) - - return pred - - -@gin.configurable -class TemporalConvNet(DLPredictionWrapper): - """Temporal Convolutional Network. Adapted from TCN original paper https://github.com/locuslab/TCN""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__(self, input_size, num_channels, num_classes, *args, max_seq_length=0, kernel_size=2, dropout=0.0, **kwargs): - super().__init__( - input_size=input_size, - num_channels=num_channels, - num_classes=num_classes, - *args, - max_seq_length=max_seq_length, - kernel_size=kernel_size, - dropout=dropout, - **kwargs, - ) - layers = [] - - # We compute automatically the depth based on the desired seq_length. - if isinstance(num_channels, Integral) and max_seq_length: - num_channels = [num_channels] * int(np.ceil(np.log(max_seq_length / 2) / np.log(kernel_size))) - elif isinstance(num_channels, Integral) and not max_seq_length: - raise Exception("a maximum sequence length needs to be provided if num_channels is int") - - num_levels = len(num_channels) - for i in range(num_levels): - dilation_size = 2**i - in_channels = input_size[2] if i == 0 else num_channels[i - 1] - out_channels = num_channels[i] - layers += [ - TemporalBlock( - in_channels, - out_channels, - kernel_size, - stride=1, - dilation=dilation_size, - padding=(kernel_size - 1) * dilation_size, - dropout=dropout, - ) - ] - - self.network = nn.Sequential(*layers) - self.logit = nn.Linear(num_channels[-1], num_classes) - - def forward(self, x): - x = x.permute(0, 2, 1) # Permute to channel first - o = self.network(x) - o = o.permute(0, 2, 1) # Permute to channel last - pred = self.logit(o) - return pred diff --git a/icu_benchmarks/models/dl_models/__init__.py b/icu_benchmarks/models/dl_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/icu_benchmarks/models/layers.py b/icu_benchmarks/models/dl_models/layers.py similarity index 100% rename from icu_benchmarks/models/layers.py rename to icu_benchmarks/models/dl_models/layers.py diff --git a/icu_benchmarks/models/dl_models/rnn.py b/icu_benchmarks/models/dl_models/rnn.py new file mode 100644 index 00000000..4f0c65bc --- /dev/null +++ b/icu_benchmarks/models/dl_models/rnn.py @@ -0,0 +1,84 @@ +import gin +from torch import nn as nn + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +@gin.configurable +class RNNet(DLPredictionWrapper): + """Torch standard RNN model""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + *args, input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred + + +@gin.configurable +class LSTMNet(DLPredictionWrapper): + """Torch standard LSTM model.""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + *args, input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.LSTM(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + c0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return [t for t in (h0, c0)] + + def forward(self, x): + h0, c0 = self.init_hidden(x) + out, h = self.rnn(x, (h0, c0)) + pred = self.logit(out) + return pred + + +@gin.configurable +class GRUNet(DLPredictionWrapper): + """Torch standard GRU model.""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + *args, input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.GRU(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred diff --git a/icu_benchmarks/models/dl_models/tcn.py b/icu_benchmarks/models/dl_models/tcn.py new file mode 100644 index 00000000..8be71fea --- /dev/null +++ b/icu_benchmarks/models/dl_models/tcn.py @@ -0,0 +1,62 @@ +from numbers import Integral + +import gin +import numpy as np +from torch import nn as nn + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.dl_models.layers import TemporalBlock +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +@gin.configurable +class TemporalConvNet(DLPredictionWrapper): + """Temporal Convolutional Network. Adapted from TCN original paper https://github.com/locuslab/TCN""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, num_channels, num_classes, *args, max_seq_length=0, kernel_size=2, dropout=0.0, **kwargs): + super().__init__( + *args, + input_size=input_size, + num_channels=num_channels, + num_classes=num_classes, + max_seq_length=max_seq_length, + kernel_size=kernel_size, + dropout=dropout, + **kwargs, + ) + layers = [] + + # We compute automatically the depth based on the desired seq_length. + if isinstance(num_channels, Integral) and max_seq_length: + num_channels = [num_channels] * int(np.ceil(np.log(max_seq_length / 2) / np.log(kernel_size))) + elif isinstance(num_channels, Integral) and not max_seq_length: + raise Exception("a maximum sequence length needs to be provided if num_channels is int") + + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2**i + in_channels = input_size[2] if i == 0 else num_channels[i - 1] + out_channels = num_channels[i] + layers += [ + TemporalBlock( + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=dilation_size, + padding=(kernel_size - 1) * dilation_size, + dropout=dropout, + ) + ] + + self.network = nn.Sequential(*layers) + self.logit = nn.Linear(num_channels[-1], num_classes) + + def forward(self, x): + x = x.permute(0, 2, 1) # Permute to channel first + o = self.network(x) + o = o.permute(0, 2, 1) # Permute to channel last + pred = self.logit(o) + return pred diff --git a/icu_benchmarks/models/dl_models/transformer.py b/icu_benchmarks/models/dl_models/transformer.py new file mode 100644 index 00000000..ed7d4a2c --- /dev/null +++ b/icu_benchmarks/models/dl_models/transformer.py @@ -0,0 +1,81 @@ +import gin +from torch import nn as nn + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.dl_models.layers import PositionalEncoding, TransformerBlock, LocalBlock +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +class BaseTransformer(DLPredictionWrapper): + _supported_run_modes = [RunMode.classification, RunMode.regression] + """Refactored Transformer model as defined by the HiRID-Benchmark (https://github.com/ratschlab/HIRID-ICU-Benchmark).""" + + def __init__( + self, + block_class, + input_size, + hidden, + heads, + ff_hidden_mult, + depth, + num_classes, + dropout=0.0, + l1_reg=0, + pos_encoding=True, + dropout_att=0.0, + local_context=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + if local_context is not None and self._get_name() == "Transformer": + raise ValueError("Local context is only supported for LocalTransformer") + hidden = hidden if hidden % 2 == 0 else hidden + 1 # Make sure hidden is even + self.input_embedding = nn.Linear(input_size[2], hidden) # This acts as a time-distributed layer by defaults + if pos_encoding: + self.pos_encoder = PositionalEncoding(hidden) + else: + self.pos_encoder = None + + t_blocks = [] + for _ in range(depth): + t_blocks.append( + block_class( + emb=hidden, + hidden=hidden, + heads=heads, + mask=True, + ff_hidden_mult=ff_hidden_mult, + dropout=dropout, + dropout_att=dropout_att, + **({"local_context": local_context} if local_context is not None else {}), + ) + ) + + self.t_blocks = nn.Sequential(*t_blocks) + self.logit = nn.Linear(hidden, num_classes) + self.l1_reg = l1_reg + + def forward(self, x): + x = self.input_embedding(x) + if self.pos_encoder is not None: + x = self.pos_encoder(x) + x = self.t_blocks(x) + pred = self.logit(x) + return pred + + +@gin.configurable +class Transformer(BaseTransformer): + """Transformer model.""" + + def __init__(self, *kwargs, **args): + super().__init__(TransformerBlock, *kwargs, **args) + + +@gin.configurable +class LocalTransformer(BaseTransformer): + """Transformer model with local context.""" + + def __init__(self, *kwargs, **args): + super().__init__(LocalBlock, *kwargs, **args) diff --git a/icu_benchmarks/models/metrics.py b/icu_benchmarks/models/metrics.py deleted file mode 100644 index 280a0653..00000000 --- a/icu_benchmarks/models/metrics.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -from typing import Callable -import numpy as np -from ignite.metrics import EpochMetric -from sklearn.metrics import balanced_accuracy_score, mean_absolute_error -from sklearn.calibration import calibration_curve -from scipy.spatial.distance import jensenshannon - -"""" -This file contains metrics that are not available in ignite.metrics. Specifically, it adds transformation capabilities to some -metrics. -""" - - -def accuracy(output, target, topk=(1,)): - """Computes the accuracy over the k top predictions for the specified values of k""" - with torch.no_grad(): - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -def balanced_accuracy_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: - y_true = y_targets.numpy() - y_pred = np.argmax(y_preds.numpy(), axis=-1) - return balanced_accuracy_score(y_true, y_pred) - - -def ece_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: - y_true = y_targets.numpy() - y_pred = y_preds.numpy() - return calibration_curve(y_true, y_pred, n_bins=10) - - -def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: - y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] - y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] - return mean_absolute_error(y_true, y_pred) - - -def JSD_fn(y_preds: torch.Tensor, y_targets: torch.Tensor): - return jensenshannon(abs(y_preds).flatten(), abs(y_targets).flatten()) ** 2 - - -class BalancedAccuracy(EpochMetric): - def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: - super(BalancedAccuracy, self).__init__( - balanced_accuracy_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn - ) - - -class CalibrationCurve(EpochMetric): - def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: - super(CalibrationCurve, self).__init__( - ece_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn - ) - - -class MAE(EpochMetric): - def __init__( - self, - output_transform: Callable = lambda x: x, - check_compute_fn: bool = False, - invert_transform: Callable = lambda x: x, - ) -> None: - super(MAE, self).__init__( - lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform), - output_transform=output_transform, - check_compute_fn=check_compute_fn, - ) - - -class JSD(EpochMetric): - def __init__( - self, - output_transform: Callable = lambda x: x, - check_compute_fn: bool = False, - ) -> None: - super(JSD, self).__init__( - lambda x, y: JSD_fn(x, y), - output_transform=output_transform, - check_compute_fn=check_compute_fn, - ) diff --git a/icu_benchmarks/models/ml_models/__init__.py b/icu_benchmarks/models/ml_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/icu_benchmarks/models/ml_models/catboost.py b/icu_benchmarks/models/ml_models/catboost.py new file mode 100644 index 00000000..4bbebea1 --- /dev/null +++ b/icu_benchmarks/models/ml_models/catboost.py @@ -0,0 +1,26 @@ +import gin +import catboost as cb +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import MLWrapper + + +@gin.configurable +class CBClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, task_type="CPU", *args, **kwargs): + model_kwargs = {"task_type": task_type, **kwargs} + self.model = self.set_model_args(cb.CatBoostClassifier, *args, **model_kwargs) + super().__init__(*args, **kwargs) + + def predict(self, features): + """ + Predicts class probabilities for the given features. + + Args: + features: Input features for prediction. + + Returns: + numpy.ndarray: Predicted probabilities for each class. + """ + return self.model.predict_proba(features) diff --git a/icu_benchmarks/models/ml_models/imblearn.py b/icu_benchmarks/models/ml_models/imblearn.py new file mode 100644 index 00000000..d1db0703 --- /dev/null +++ b/icu_benchmarks/models/ml_models/imblearn.py @@ -0,0 +1,22 @@ +from imblearn.ensemble import BalancedRandomForestClassifier, RUSBoostClassifier +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import MLWrapper +import gin + + +@gin.configurable +class BRFClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(BalancedRandomForestClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable +class RUSBClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) diff --git a/icu_benchmarks/models/ml_models/lgbm.py b/icu_benchmarks/models/ml_models/lgbm.py new file mode 100644 index 00000000..c2207555 --- /dev/null +++ b/icu_benchmarks/models/ml_models/lgbm.py @@ -0,0 +1,57 @@ +import gin +import lightgbm as lgbm +import numpy as np +import wandb +from wandb.integration.lightgbm import wandb_callback as wandb_lgbm + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import MLWrapper + + +class LGBMWrapper(MLWrapper): + def fit_model(self, train_data, train_labels, val_data, val_labels): + """Fitting function for LGBM models.""" + self.model.set_params(random_state=np.random.get_state()[1][0]) + callbacks = [lgbm.early_stopping(self.hparams.patience, verbose=True), lgbm.log_evaluation(period=-1)] + + if wandb.run is not None: + callbacks.append(wandb_lgbm()) + + self.model = self.model.fit( + train_data, + train_labels, + eval_set=(val_data, val_labels), + callbacks=callbacks, + ) + val_loss = list(self.model.best_score_["valid_0"].values())[0] + return val_loss + + +@gin.configurable +class LGBMClassifier(LGBMWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) + + def predict(self, features): + """ + Predicts class probabilities for the given features. + + Args: + features: Input features for prediction. + + Returns: + numpy.ndarray: Predicted probabilities for each class. + """ + return self.model.predict_proba(features) + + +@gin.configurable +class LGBMRegressor(LGBMWrapper): + _supported_run_modes = [RunMode.regression] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(lgbm.LGBMRegressor, *args, **kwargs) + super().__init__(*args, **kwargs) diff --git a/icu_benchmarks/models/ml_models.py b/icu_benchmarks/models/ml_models/sklearn.py similarity index 64% rename from icu_benchmarks/models/ml_models.py rename to icu_benchmarks/models/ml_models/sklearn.py index e06d3fe7..1fe7a87b 100644 --- a/icu_benchmarks/models/ml_models.py +++ b/icu_benchmarks/models/ml_models/sklearn.py @@ -1,49 +1,9 @@ import gin -import lightgbm -from sklearn import linear_model -from sklearn import ensemble -from sklearn import neural_network -from sklearn import svm +from sklearn import linear_model, ensemble, svm, neural_network +from icu_benchmarks.constants import RunMode from icu_benchmarks.models.wrappers import MLWrapper -from icu_benchmarks.contants import RunMode - - -class LGBMWrapper(MLWrapper): - def fit_model(self, train_data, train_labels, val_data, val_labels): - """Fitting function for LGBM models.""" - self.model.fit( - train_data, - train_labels, - eval_set=(val_data, val_labels), - verbose=True, - callbacks=[ - lightgbm.early_stopping(self.hparams.patience, verbose=False), - lightgbm.log_evaluation(period=-1, show_stdv=False), - ], - ) - val_loss = list(self.model.best_score_["valid_0"].values())[0] - return val_loss -@gin.configurable -class LGBMClassifier(LGBMWrapper): - _supported_run_modes = [RunMode.classification] - - def __init__(self, *args, **kwargs): - self.model = self.set_model_args(lightgbm.LGBMClassifier, *args, **kwargs) - super().__init__(*args, **kwargs) - - -@gin.configurable -class LGBMRegressor(LGBMWrapper): - _supported_run_modes = [RunMode.regression] - - def __init__(self, *args, **kwargs): - self.model = self.set_model_args(lightgbm.LGBMRegressor, *args, **kwargs) - super().__init__(*args, **kwargs) - - -# Scikit-learn models @gin.configurable class LogisticRegression(MLWrapper): __supported_run_modes = [RunMode.classification] diff --git a/icu_benchmarks/models/ml_models/xgboost.py b/icu_benchmarks/models/ml_models/xgboost.py new file mode 100644 index 00000000..5ca738ac --- /dev/null +++ b/icu_benchmarks/models/ml_models/xgboost.py @@ -0,0 +1,74 @@ +import inspect +import logging +from statistics import mean + +import gin +import shap +import wandb +import xgboost as xgb +from xgboost.callback import EarlyStopping +from wandb.integration.xgboost import wandb_callback as wandb_xgb + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import MLWrapper + + +# Uncomment if needed in the future +# from optuna.integration import XGBoostPruningCallback + + +@gin.configurable +class XGBClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + _explain_values = False + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(xgb.XGBClassifier, *args, **kwargs, device="cpu") + super().__init__(*args, **kwargs) + + def predict(self, features): + """ + Predicts class probabilities for the given features. + + Args: + features: Input features for prediction. + + Returns: + numpy.ndarray: Predicted probabilities for each class. + """ + return self.model.predict_proba(features) + + def fit_model(self, train_data, train_labels, val_data, val_labels): + """Fit the model to the training data (default SKlearn syntax)""" + callbacks = [EarlyStopping(self.hparams.patience)] + + if wandb.run is not None: + callbacks.append(wandb_xgb()) + logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") + logging.info(train_labels) + self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=False) + if self._explain_values: + self.explainer = shap.TreeExplainer(self.model) + self.train_shap_values = self.explainer(train_data) + # shap.summary_plot(shap_values, X_test, feature_names=features) + # logging.info(self.model.get_booster().get_score(importance_type='weight')) + # self.log_dict(self.model.get_booster().get_score(importance_type='weight')) + # Return the first metric we use for validation + eval_score = mean(next(iter(self.model.evals_result_["validation_0"].values()))) + return eval_score # , callbacks=callbacks) + + def set_model_args(self, model, *args, **kwargs): + """XGBoost signature does not include the hyperparams so we need to pass them manually.""" + signature = inspect.signature(model.__init__).parameters + valid_params = signature.keys() + + # Filter out invalid arguments + valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + + logging.debug(f"Creating model with: {valid_kwargs}.") + return model(**valid_kwargs) + + def get_feature_importance(self): + if not hasattr(self.model, "feature_importances_"): + raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") + return self.model.feature_importances_ diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index 1c2a91cd..7bc8e45a 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -1,17 +1,19 @@ import os import gin +import numpy as np import torch import logging -import pandas as pd +import polars as pl +from joblib import load from torch.optim import Adam from torch.utils.data import DataLoader -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar, LearningRateMonitor from pathlib import Path -from icu_benchmarks.data.loader import PredictionDataset, ImputationDataset +from icu_benchmarks.data.loader import PredictionPandasDataset, ImputationPandasDataset, PredictionPolarsDataset from icu_benchmarks.models.utils import save_config_file, JSONMetricsLogger -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode from icu_benchmarks.data.constants import DataSplit as Split cpu_core_count = len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else os.cpu_count() @@ -25,8 +27,9 @@ def assure_minimum_length(dataset): @gin.configurable("train_common") def train_common( - data: dict[str, pd.DataFrame], + data: dict[str, pl.DataFrame], log_dir: Path, + eval_only: bool = False, load_weights: bool = False, source_dir: Path = None, reproducible: bool = True, @@ -35,22 +38,28 @@ def train_common( weight: str = None, optimizer: type = Adam, precision=32, - batch_size=64, - epochs=1000, + batch_size=1, + epochs=100, patience=20, min_delta=1e-5, test_on: str = Split.test, + dataset_names=None, use_wandb: bool = False, cpu: bool = False, verbose=False, ram_cache=False, - num_workers: int = min(cpu_core_count, torch.cuda.device_count() * 4 * int(torch.cuda.is_available()), 32), + pl_model=True, + train_only=False, + num_workers: int = min(cpu_core_count, torch.cuda.device_count() * 8 * int(torch.cuda.is_available()), 32), + polars=True, + persistent_workers=None, ): """Common wrapper to train all benchmarked models. Args: data: Dict containing data to be trained on. log_dir: Path to directory where model output should be saved. + eval_only: If set to true, skip training and only evaluate the model. load_weights: If set to true, skip training and load weights from source_dir instead. source_dir: If set to load weights, path to directory containing trained weights. reproducible: If set to true, set torch to run reproducibly. @@ -61,98 +70,110 @@ def train_common( precision: Pytorch precision to be used for training. Can be 16 or 32. batch_size: Batch size to be used for training. epochs: Number of epochs to train for. - patience: Number of epochs to wait before early stopping. + patience: Number of epochs to wait for improvement before early stopping. min_delta: Minimum change in loss to be considered an improvement. test_on: If set to "test", evaluate the model on the test set. If set to "val", evaluate on the validation set. use_wandb: If set to true, log to wandb. cpu: If set to true, run on cpu. verbose: Enable detailed logging. ram_cache: Whether to cache the data in RAM. + pl_model: Loading a pytorch lightning model. num_workers: Number of workers to use for data loading. """ logging.info(f"Training model: {model.__name__}.") - dataset_class = ImputationDataset if mode == RunMode.imputation else PredictionDataset - + # todo: add support for polars versions of datasets + dataset_classes = { + RunMode.imputation: ImputationPandasDataset, + RunMode.classification: PredictionPolarsDataset if polars else PredictionPandasDataset, + RunMode.regression: PredictionPolarsDataset if polars else PredictionPandasDataset, + } + dataset_class = dataset_classes[mode] + + logging.info(f"Using dataset class: {dataset_class.__name__}.") logging.info(f"Logging to directory: {log_dir}.") save_config_file(log_dir) # We save the operative config before and also after training - - train_dataset = dataset_class(data, split=Split.train, ram_cache=ram_cache) - val_dataset = dataset_class(data, split=Split.val, ram_cache=ram_cache) + train_dataset = dataset_class(data, split=Split.train, ram_cache=ram_cache, name=dataset_names["train"]) + val_dataset = dataset_class(data, split=Split.val, ram_cache=ram_cache, name=dataset_names["val"]) train_dataset, val_dataset = assure_minimum_length(train_dataset), assure_minimum_length(val_dataset) batch_size = min(batch_size, len(train_dataset), len(val_dataset)) - logging.debug(f"Training on {len(train_dataset)} samples and validating on {len(val_dataset)} samples.") + if not eval_only: + logging.info( + f"Training on {train_dataset.name} with {len(train_dataset)} samples and validating on {val_dataset.name} with" + f" {len(val_dataset)} samples." + ) logging.info(f"Using {num_workers} workers for data loading.") - train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, - pin_memory=True, drop_last=True, + persistent_workers=persistent_workers, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, - pin_memory=True, drop_last=True, + persistent_workers=persistent_workers, ) data_shape = next(iter(train_loader))[0].shape - model = model(optimizer=optimizer, input_size=data_shape, epochs=epochs, run_mode=mode) - model.set_weight(weight, train_dataset) if load_weights: - if source_dir.exists(): - # if not model.needs_training: - checkpoint = torch.load(source_dir / "model.ckpt") - # else: - model = model.load_state_dict(checkpoint["state_dict"]) - else: - raise Exception(f"No weights to load at path : {source_dir}") + model = load_model(model, source_dir, pl_model=pl_model, cpu=cpu) + else: + model = model(optimizer=optimizer, input_size=data_shape, epochs=epochs, run_mode=mode, cpu=cpu) + model.set_weight(weight, train_dataset) model.set_trained_columns(train_dataset.get_feature_names()) - loggers = [TensorBoardLogger(log_dir), JSONMetricsLogger(log_dir)] - + if use_wandb: + loggers.append(WandbLogger(save_dir=log_dir)) callbacks = [ - EarlyStopping(monitor="val/loss", min_delta=min_delta, patience=patience, strict=False), + EarlyStopping(monitor="val/loss", min_delta=min_delta, patience=patience, strict=False, verbose=verbose), ModelCheckpoint(log_dir, filename="model", save_top_k=1, save_last=True), + LearningRateMonitor(logging_interval="step"), ] if verbose: callbacks.append(TQDMProgressBar(refresh_rate=min(100, len(train_loader) // 2))) if precision == 16 or "16-mixed": torch.set_float32_matmul_precision("medium") + trainer = Trainer( - max_epochs=epochs if model.needs_training else 1, + max_epochs=epochs if model.requires_backprop else 1, + min_epochs=1, # We need at least one epoch to get results. callbacks=callbacks, precision=precision, accelerator="auto" if not cpu else "cpu", devices=max(torch.cuda.device_count(), 1), - deterministic=reproducible, + deterministic="warn" if reproducible else False, benchmark=not reproducible, enable_progress_bar=verbose, logger=loggers, - num_sanity_val_steps=0, + num_sanity_val_steps=2, # Helps catch errors in the validation loop before training begins. + log_every_n_steps=5, ) - - if model.needs_fit: - logging.info("Fitting model to data.") - model.fit(train_dataset, val_dataset) - model.save_model(log_dir, "last") - logging.info("Fitting complete.") - - if model.needs_training: - logging.info("Training model.") - trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) - logging.info("Training complete.") - - test_dataset = dataset_class(data, split=test_on) + if not eval_only: + if model.requires_backprop: + logging.info("Training DL model.") + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + logging.info("Training complete.") + else: + logging.info("Training ML model.") + model.fit(train_dataset, val_dataset) + model.save_model(log_dir, "last") + logging.info("Training complete.") + if train_only: + logging.info("Finished training full model.") + save_config_file(log_dir) + return 0 + test_dataset = dataset_class(data, split=test_on, name=dataset_names["test"], ram_cache=ram_cache) test_dataset = assure_minimum_length(test_dataset) + logging.info(f"Testing on {test_dataset.name} with {len(test_dataset)} samples.") test_loader = ( DataLoader( test_dataset, @@ -161,12 +182,63 @@ def train_common( num_workers=num_workers, pin_memory=True, drop_last=True, + persistent_workers=persistent_workers, ) - if model.needs_training + if model.requires_backprop else DataLoader([test_dataset.to_tensor()], batch_size=1) ) model.set_weight("balanced", train_dataset) test_loss = trainer.test(model, dataloaders=test_loader, verbose=verbose)[0]["test/loss"] + persist_shap_data(trainer, log_dir) save_config_file(log_dir) return test_loss + + +def persist_shap_data(trainer: Trainer, log_dir: Path): + """ + Persist shap values to disk. + Args: + trainer: Pytorch lightning trainer object + log_dir: Log directory + """ + try: + if trainer.lightning_module.test_shap_values is not None: + shap_values = trainer.lightning_module.test_shap_values + shaps_test = pl.DataFrame(schema=trainer.lightning_module.trained_columns, data=np.transpose(shap_values.values)) + with (log_dir / "shap_values_test.parquet").open("wb") as f: + shaps_test.write_parquet(f) + logging.info(f"Saved shap values to {log_dir / 'test_shap_values.parquet'}") + if trainer.lightning_module.train_shap_values is not None: + shap_values = trainer.lightning_module.train_shap_values + shaps_train = pl.DataFrame(schema=trainer.lightning_module.trained_columns, data=np.transpose(shap_values.values)) + with (log_dir / "shap_values_train.parquet").open("wb") as f: + shaps_train.write_parquet(f) + + except Exception as e: + logging.error(f"Failed to save shap values: {e}") + + +def load_model(model, source_dir, pl_model=True): + if source_dir.exists(): + if model.requires_backprop: + if (source_dir / "model.ckpt").exists(): + model_path = source_dir / "model.ckpt" + elif (source_dir / "model-v1.ckpt").exists(): + model_path = source_dir / "model-v1.ckpt" + elif (source_dir / "last.ckpt").exists(): + model_path = source_dir / "last.ckpt" + else: + return Exception(f"No weights to load at path : {source_dir}") + if pl_model: + model = model.load_from_checkpoint(model_path) + else: + checkpoint = torch.load(model_path) + model.load_from_checkpoint(checkpoint) + else: + model_path = source_dir / "model.joblib" + model = load(model_path) + else: + raise Exception(f"No weights to load at path : {source_dir}") + logging.info(f"Loaded {type(model)} model from {model_path}") + return model diff --git a/icu_benchmarks/models/utils.py b/icu_benchmarks/models/utils.py index 53bfcae2..fc5b5506 100644 --- a/icu_benchmarks/models/utils.py +++ b/icu_benchmarks/models/utils.py @@ -11,6 +11,7 @@ from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities import rank_zero_only +from sklearn.metrics import average_precision_score from torch.nn import Module from torch.optim import Optimizer, Adam, SGD, RAdam from typing import Optional, Union @@ -23,7 +24,7 @@ def save_config_file(log_dir): f.write(gin.operative_config_str()) -def create_optimizer(name: str, model: Module, lr: float, momentum: float) -> Optimizer: +def create_optimizer(name: str, model: Module, lr: float, momentum: float = 0) -> Optimizer: """creates the specified optimizer with the given parameters Args: @@ -188,3 +189,96 @@ def version(self): @rank_zero_only def log_hyperparams(self, params): pass + + +class scorer_wrapper: + """ + Wrapper that flattens the binary classification input such that we can use a broader range of sklearn metrics. + """ + + def __init__(self, scorer=average_precision_score): + self.scorer = scorer + + def __call__(self, y_true, y_pred): + if len(np.unique(y_true)) <= 2 and y_pred.ndim > 1: + y_pred_argmax = np.argmax(y_pred, axis=1) + return self.scorer(y_true, y_pred_argmax) + else: + return self.scorer(y_true, y_pred) + + def __name__(self): + return "scorer_wrapper" + + +# Source: https://github.com/ratschlab/tls +@gin.configurable("get_smoothed_labels") +def get_smoothed_labels( + label, event, smoothing_fn=gin.REQUIRED, h_true=gin.REQUIRED, h_min=gin.REQUIRED, h_max=gin.REQUIRED, delta_h=12, gamma=0.1 +): + diffs = np.concatenate([np.zeros(1), event[1:] - event[:-1]], axis=-1) + pos_event_change_full = np.where((diffs == 1) & (event == 1))[0] + + multihorizon = isinstance(h_true, list) + if multihorizon: + label_for_event = label[0] + h_for_event = h_true[0] + else: + label_for_event = label + h_for_event = h_true + diffs_label = np.concatenate([np.zeros(1), label_for_event[1:] - label_for_event[:-1]], axis=-1) + + # Event that occurred after the end of the stay for M3B. + # In that case event are equal to the number of hours after the end of stay when the event occured. + pos_event_change_delayed = np.where((diffs >= 1) & (event > 1))[0] + if len(pos_event_change_delayed) > 0: + delays = event[pos_event_change_delayed] - 1 + pos_event_change_delayed += delays.astype(int) + pos_event_change_full = np.sort(np.concatenate([pos_event_change_full, pos_event_change_delayed])) + + last_know_label = label_for_event[np.where(label_for_event != -1)][-1] + last_know_idx = np.where(label_for_event == last_know_label)[0][-1] + + # Need to handle the case where the ts was truncatenated at 2016 for HiB + if ((last_know_label == 1) and (len(pos_event_change_full) == 0)) or ( + (last_know_label == 1) and (last_know_idx >= pos_event_change_full[-1]) + ): + last_know_event = 0 + if len(pos_event_change_full) > 0: + last_know_event = pos_event_change_full[-1] + + last_known_stable = 0 + known_stable = np.where(label_for_event == 0)[0] + if len(known_stable) > 0: + last_known_stable = known_stable[-1] + + pos_change = np.where((diffs_label >= 1) & (label_for_event == 1))[0] + last_pos_change = pos_change[np.where(pos_change > max(last_know_event, last_known_stable))][0] + pos_event_change_full = np.concatenate([pos_event_change_full, [last_pos_change + h_for_event]]) + + # No event case + if len(pos_event_change_full) == 0: + pos_event_change_full = np.array([np.inf]) + + time_array = np.arange(len(label)) + dist = pos_event_change_full.reshape(-1, 1) - time_array + dte = np.where(dist > 0, dist, np.inf).min(axis=0) + if multihorizon: + smoothed_labels = [] + for k in range(label.shape[-1]): + smoothed_labels.append( + np.array( + list( + map( + lambda x: smoothing_fn( + x, h_true=h_true[k], h_min=h_min[k], h_max=h_max[k], delta_h=delta_h, gamma=gamma + ), + dte, + ) + ) + ) + ) + return np.stack(smoothed_labels, axis=-1) + else: + return np.array( + list(map(lambda x: smoothing_fn(x, h_true=h_true, h_min=h_min, h_max=h_max, delta_h=delta_h, gamma=gamma), dte)) + ) diff --git a/icu_benchmarks/models/wrappers.py b/icu_benchmarks/models/wrappers.py index 5fc94a12..310f9fe6 100644 --- a/icu_benchmarks/models/wrappers.py +++ b/icu_benchmarks/models/wrappers.py @@ -1,43 +1,53 @@ import logging from abc import ABC -from typing import Dict, Any -from typing import List, Optional, Union +from typing import Dict, Any, List, Optional, Union +from pathlib import Path +import torchmetrics +from sklearn.metrics import log_loss, mean_squared_error, average_precision_score, roc_auc_score -import sklearn.metrics -from sklearn.metrics import log_loss +import torch from torch.nn import MSELoss, CrossEntropyLoss -from torch.nn.modules.loss import _Loss -from torch.optim import Optimizer +import torch.nn as nn +from torch import Tensor, FloatTensor +from torch.optim import Optimizer, Adam + import inspect import gin import numpy as np -import torch from ignite.exceptions import NotComputableError from icu_benchmarks.models.constants import ImputationInit +from icu_benchmarks.models.custom_metrics import confusion_matrix from icu_benchmarks.models.utils import create_optimizer, create_scheduler from joblib import dump from pytorch_lightning import LightningModule from icu_benchmarks.models.constants import MLMetrics, DLMetrics -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode -gin.config.external_configurable(torch.nn.functional.nll_loss, module="torch.nn.functional") -gin.config.external_configurable(torch.nn.functional.cross_entropy, module="torch.nn.functional") -gin.config.external_configurable(torch.nn.functional.mse_loss, module="torch.nn.functional") +gin.config.external_configurable(nn.functional.nll_loss, module="torch.nn.functional") +gin.config.external_configurable(nn.functional.cross_entropy, module="torch.nn.functional") +gin.config.external_configurable(nn.functional.mse_loss, module="torch.nn.functional") -gin.config.external_configurable(sklearn.metrics.mean_squared_error, module="sklearn.metrics") -gin.config.external_configurable(sklearn.metrics.log_loss, module="sklearn.metrics") +gin.config.external_configurable(mean_squared_error, module="sklearn.metrics") +gin.config.external_configurable(log_loss, module="sklearn.metrics") +gin.config.external_configurable(average_precision_score, module="sklearn.metrics") +gin.config.external_configurable(roc_auc_score, module="sklearn.metrics") +# gin.config.external_configurable(scorer_wrapper, module="icu_benchmarks.models.utils") @gin.configurable("BaseModule") class BaseModule(LightningModule): - needs_training = False - needs_fit = False - + # DL type models, requires backpropagation + requires_backprop = False + # Loss function weight initialization type weight = None + # Metrics to be logged metrics = {} trained_columns = None + # Type of run mode run_mode = None + debug = False + explain_features = False def forward(self, *args, **kwargs): raise NotImplementedError() @@ -54,8 +64,14 @@ def set_metrics(self, *args, **kwargs): def set_trained_columns(self, columns: List[str]): self.trained_columns = columns - def set_weight(self, weight, *args, **kwargs): - pass + def set_weight(self, weight, dataset): + """Set the weight for the loss function.""" + + if isinstance(weight, list): + weight = FloatTensor(weight).to(self.device) + elif weight == "balanced": + weight = FloatTensor(dataset.get_balance()).to(self.device) + self.loss_weights = weight def training_step(self, batch, batch_idx): return self.step_fn(batch, "train") @@ -91,15 +107,14 @@ def check_supported_runmode(self, runmode: RunMode): @gin.configurable("DLWrapper") class DLWrapper(BaseModule, ABC): - needs_training = True - needs_fit = False + requires_backprop = True _metrics_warning_printed = set() _supported_run_modes = [RunMode.classification, RunMode.regression, RunMode.imputation] def __init__( self, loss=CrossEntropyLoss(), - optimizer=torch.optim.Adam, + optimizer=Adam, run_mode: RunMode = RunMode.classification, input_shape=None, lr: float = 0.002, @@ -108,18 +123,27 @@ def __init__( lr_factor: float = 0.99, lr_steps: Optional[List[int]] = None, epochs: int = 100, - input_size: torch.Tensor = None, + input_size: Tensor = None, initialization_method: str = "normal", **kwargs, ): - """Interface for Deep Learning models.""" + """General interface for Deep Learning (DL) models.""" super().__init__() self.save_hyperparameters(ignore=["loss", "optimizer"]) self.loss = loss self.optimizer = optimizer - self.scaler = None self.check_supported_runmode(run_mode) self.run_mode = run_mode + self.input_shape = input_shape + self.lr = lr + self.momentum = momentum + self.lr_scheduler = lr_scheduler + self.lr_factor = lr_factor + self.lr_steps = lr_steps + self.epochs = epochs + self.input_size = input_size + self.initialization_method = initialization_method + self.scaler = None def on_fit_start(self): self.metrics = { @@ -131,11 +155,23 @@ def on_fit_start(self): } return super().on_fit_start() + def on_train_start(self): + self.metrics = { + step_name: { + metric_name: (metric() if isinstance(metric, type) else metric) + for metric_name, metric in self.set_metrics().items() + } + for step_name in ["train", "val", "test"] + } + return super().on_train_start() + def finalize_step(self, step_prefix=""): try: self.log_dict( { - f"{step_prefix}/{name}": metric.compute() + f"{step_prefix}/{name}": ( + np.float32(metric.compute()) if isinstance(metric.compute(), np.float64) else metric.compute() + ) for name, metric in self.metrics[step_prefix].items() if "_Curve" not in name }, @@ -150,8 +186,13 @@ def finalize_step(self, step_prefix=""): pass def configure_optimizers(self): + """Configure optimizers and learning rate schedulers.""" + if isinstance(self.optimizer, str): - optimizer = create_optimizer(self.optimizer, self, self.hparams.lr, self.hparams.momentum) + optimizer = create_optimizer(self.optimizer, self.lr, self.hparams.momentum) + elif isinstance(self.optimizer, Optimizer): + # Already set + optimizer = self.optimizer else: optimizer = self.optimizer(self.parameters()) @@ -160,7 +201,9 @@ def configure_optimizers(self): scheduler = create_scheduler( self.hparams.lr_scheduler, optimizer, self.hparams.lr_factor, self.hparams.lr_steps, self.hparams.epochs ) - return {"optimizer": optimizer, "lr_scheduler": scheduler} + optimizers = {"optimizer": optimizer, "lr_scheduler": scheduler} + logging.info(f"Using: {optimizers}") + return optimizers def on_test_epoch_start(self) -> None: self.metrics = { @@ -184,14 +227,39 @@ class DLPredictionWrapper(DLWrapper): _supported_run_modes = [RunMode.classification, RunMode.regression] - def set_weight(self, weight, dataset): - """Set the weight for the loss function.""" - - if isinstance(weight, list): - weight = torch.FloatTensor(weight).to(self.device) - elif weight == "balanced": - weight = torch.FloatTensor(dataset.get_balance()).to(self.device) - self.loss_weights = weight + def __init__( + self, + loss=CrossEntropyLoss(), + optimizer=torch.optim.Adam, + run_mode: RunMode = RunMode.classification, + input_shape=None, + lr: float = 0.002, + momentum: float = 0.9, + lr_scheduler: Optional[str] = None, + lr_factor: float = 0.99, + lr_steps: Optional[List[int]] = None, + epochs: int = 100, + input_size: Tensor = None, + initialization_method: str = "normal", + **kwargs, + ): + super().__init__( + loss=loss, + optimizer=optimizer, + run_mode=run_mode, + input_shape=input_shape, + lr=lr, + momentum=momentum, + lr_scheduler=lr_scheduler, + lr_factor=lr_factor, + lr_steps=lr_steps, + epochs=epochs, + input_size=input_size, + initialization_method=initialization_method, + kwargs=kwargs, + ) + self.output_transform = None + self.loss_weights = None def set_metrics(self, *args): """Set the evaluation metrics for the prediction model.""" @@ -224,10 +292,19 @@ def softmax_multi_output_transform(output): metrics = DLMetrics.REGRESSION else: raise ValueError(f"Run mode {self.run_mode} not supported.") + for key, value in metrics.items(): + # Torchmetrics metrics are not moved to the device by default + if isinstance(value, torchmetrics.Metric): + value.to(self.device) return metrics def step_fn(self, element, step_prefix=""): - """Perform a step in the training loop.""" + """Perform a step in the DL prediction model training loop. + + Args: + element (object): + step_prefix (str): Step type, by default: test, train, val. + """ if len(element) == 2: data, labels = element[0], element[1].to(self.device) @@ -248,24 +325,36 @@ def step_fn(self, element, step_prefix=""): else: raise Exception("Loader should return either (data, label) or (data, label, mask)") out = self(data) + + # If aux_loss is present, it is returned as a tuple if len(out) == 2 and isinstance(out, tuple): out, aux_loss = out else: aux_loss = 0 + # Get prediction and target prediction = torch.masked_select(out, mask.unsqueeze(-1)).reshape(-1, out.shape[-1]).to(self.device) target = torch.masked_select(labels, mask).to(self.device) + if prediction.shape[-1] > 1 and self.run_mode == RunMode.classification: # Classification task loss = self.loss(prediction, target.long(), weight=self.loss_weights.to(self.device)) + aux_loss - # torch.long because NLL + # Returns torch.long because negative log likelihood loss elif self.run_mode == RunMode.regression: # Regression task loss = self.loss(prediction[:, 0], target.float()) + aux_loss else: - raise ValueError(f"Run mode {self.run_mode} not supported.") + raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.") transformed_output = self.output_transform((prediction, target)) - for metric in self.metrics[step_prefix].values(): - metric.update(transformed_output) + + for key, value in self.metrics[step_prefix].items(): + if isinstance(value, torchmetrics.Metric): + if key == "Binary_Fairness": + feature_names = key.feature_helper(self.trainer) + value.update(transformed_output[0], transformed_output[1], data, feature_names) + else: + value.update(transformed_output[0], transformed_output[1]) + else: + value.update(transformed_output) self.log(f"{step_prefix}/loss", loss, on_step=False, on_epoch=True, sync_dist=True) return loss @@ -274,11 +363,10 @@ def step_fn(self, element, step_prefix=""): class MLWrapper(BaseModule, ABC): """Interface for prediction with traditional Scikit-learn-like Machine Learning models.""" - needs_training = False - needs_fit = True + requires_backprop = False _supported_run_modes = [RunMode.classification, RunMode.regression] - def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patience=10, **kwargs): + def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patience=10, mps=False, **kwargs): super().__init__() self.save_hyperparameters() self.scaler = None @@ -286,14 +374,15 @@ def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patien self.run_mode = run_mode self.loss = loss self.patience = patience + self.mps = mps + self.loss_weight = None def set_metrics(self, labels): if self.run_mode == RunMode.classification: # Binary classification if len(np.unique(labels)) == 2: # if isinstance(self.model, lightgbm.basic.Booster): - self.output_transform = lambda x: x - # self.output_transform = lambda x: x[:, 1] + self.output_transform = lambda x: x[:, 1] self.label_transform = lambda x: x self.metrics = MLMetrics.BINARY_CLASSIFICATION @@ -316,19 +405,20 @@ def set_metrics(self, labels): def fit(self, train_dataset, val_dataset): """Fit the model to the training data.""" - train_rep, train_label = train_dataset.get_data_and_labels() - val_rep, val_label = val_dataset.get_data_and_labels() + train_rep, train_label, row_indicators = train_dataset.get_data_and_labels() + val_rep, val_label, row_indicators = val_dataset.get_data_and_labels() self.set_metrics(train_label) - # if "class_weight" in self.model.get_params().keys(): # Set class weights - # self.model.set_params(class_weight=self.weight) + if "class_weight" in self.model.get_params().keys(): # Set class weights + self.model.set_params(class_weight=self.weight) val_loss = self.fit_model(train_rep, train_label, val_rep, val_label) train_pred = self.predict(train_rep) logging.debug(f"Model:{self.model}") + self.log("train/loss", self.loss(train_label, train_pred), sync_dist=True) logging.debug(f"Train loss: {self.loss(train_label, train_pred)}") self.log("val/loss", val_loss, sync_dist=True) @@ -342,7 +432,7 @@ def fit_model(self, train_data, train_labels, val_data, val_labels): return val_loss def validation_step(self, val_dataset, _): - val_rep, val_label = val_dataset.get_data_and_labels() + val_rep, val_label, row_indicators = val_dataset.get_data_and_labels() val_rep, val_label = torch.from_numpy(val_rep).to(self.device), torch.from_numpy(val_label).to(self.device) self.set_metrics(val_label) @@ -353,34 +443,64 @@ def validation_step(self, val_dataset, _): self.log_metrics(val_label, val_pred, "val") def test_step(self, dataset, _): - test_rep, test_label = dataset - test_rep, test_label = test_rep.squeeze().cpu().numpy(), test_label.squeeze().cpu().numpy() + test_rep, test_label, pred_indicators = dataset + test_rep, test_label, pred_indicators = ( + test_rep.squeeze().cpu().numpy(), + test_label.squeeze().cpu().numpy(), + pred_indicators.squeeze().cpu().numpy(), + ) self.set_metrics(test_label) test_pred = self.predict(test_rep) - - self.log("test/loss", self.loss(test_label, test_pred), sync_dist=True) + if self.debug: + self._save_model_outputs(pred_indicators, test_pred, test_label) + if self.explain_features: + self.explain_model(test_rep, test_label) + if self.mps: + self.log("test/loss", np.float32(self.loss(test_label, test_pred)), sync_dist=True) + self.log_metrics(np.float32(test_label), np.float32(test_pred), "test") + else: + self.log("test/loss", self.loss(test_label, test_pred), sync_dist=True) + self.log_metrics(test_label, test_pred, "test") logging.debug(f"Test loss: {self.loss(test_label, test_pred)}") - self.log_metrics(test_label, test_pred, "test") def predict(self, features): if self.run_mode == RunMode.regression: return self.model.predict(features) - else: - return self.model.predict(features) + else: # Classification: return probabilities + return self.model.predict_proba(features) def log_metrics(self, label, pred, metric_type): """Log metrics to the PL logs.""" - + if "Confusion_Matrix" in self.metrics: + self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True) self.log_dict( { - f"{metric_type}/{name}": metric(self.label_transform(label), self.output_transform(pred)) + f"{metric_type}/{name}": (metric(self.label_transform(label), self.output_transform(pred))) + # For every metric for name, metric in self.metrics.items() # Filter out metrics that return a tuple (e.g. precision_recall_curve) if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) + and name != "Confusion_Matrix" }, sync_dist=True, ) + def _explain_model(self, test_rep, test_label): + if self.explainer is not None: + self.test_shap_values = self.explainer(test_rep) + else: + logging.warning("No explainer or explain_features values set.") + + def _save_model_outputs(self, pred_indicators, test_pred, test_label): + if len(pred_indicators.shape) > 1 and len(test_pred.shape) > 1 and pred_indicators.shape[1] == test_pred.shape[1]: + pred_indicators = np.hstack((pred_indicators, test_label.reshape(-1, 1))) + pred_indicators = np.hstack((pred_indicators, test_pred)) + # Save as: id, time (hours), ground truth, prediction 0, prediction 1 + np.savetxt(Path(self.logger.save_dir) / "pred_indicators.csv", pred_indicators, delimiter=",") + logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") + else: + logging.warning("Could not save row indicators.") + def configure_optimizers(self): return None @@ -405,6 +525,7 @@ def set_model_args(self, model, *args, **kwargs): # Get passed keyword arguments arguments = locals()["kwargs"] # Get valid hyperparameters + logging.debug(f"Possible hps: {possible_hps}") hyperparams = {key: value for key, value in arguments.items() if key in possible_hps} logging.debug(f"Creating model with: {hyperparams}.") return model(**hyperparams) @@ -414,27 +535,40 @@ def set_model_args(self, model, *args, **kwargs): class ImputationWrapper(DLWrapper): """Interface for imputation models.""" - needs_training = True - needs_fit = False + requires_backprop = True _supported_run_modes = [RunMode.imputation] def __init__( self, - loss: _Loss = MSELoss(), + loss: nn.modules.loss._Loss = MSELoss(), optimizer: Union[str, Optimizer] = "adam", - runmode: RunMode = RunMode.imputation, + run_mode: RunMode = RunMode.imputation, lr: float = 0.002, momentum: float = 0.9, lr_scheduler: Optional[str] = None, lr_factor: float = 0.99, lr_steps: Optional[List[int]] = None, - input_size: torch.Tensor = None, + input_size: Tensor = None, initialization_method: ImputationInit = ImputationInit.NORMAL, + epochs=100, **kwargs: str, ) -> None: - super().__init__() - self.check_supported_runmode(runmode) - self.run_mode = runmode + super().__init__( + loss=loss, + optimizer=optimizer, + run_mode=run_mode, + lr=lr, + momentum=momentum, + lr_scheduler=lr_scheduler, + lr_factor=lr_factor, + lr_steps=lr_steps, + epochs=epochs, + input_size=input_size, + initialization_method=initialization_method, + kwargs=kwargs, + ) + self.check_supported_runmode(run_mode) + self.run_mode = run_mode self.save_hyperparameters(ignore=["loss", "optimizer"]) self.loss = loss self.optimizer = optimizer @@ -447,20 +581,20 @@ def init_func(m): classname = m.__class__.__name__ if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): if init_type == ImputationInit.NORMAL: - torch.nn.init.normal_(m.weight.data, 0.0, gain) + nn.init.normal_(m.weight.data, 0.0, gain) elif init_type == ImputationInit.XAVIER: - torch.nn.init.xavier_normal_(m.weight.data, gain=gain) + nn.init.xavier_normal_(m.weight.data, gain=gain) elif init_type == ImputationInit.KAIMING: - torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out") + nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out") elif init_type == ImputationInit.ORTHOGONAL: - torch.nn.init.orthogonal_(m.weight.data, gain=gain) + nn.init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError(f"Initialization method {init_type} is not implemented") if hasattr(m, "bias") and m.bias is not None: - torch.nn.init.constant_(m.bias.data, 0.0) + nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm2d") != -1: - torch.nn.init.normal_(m.weight.data, 1.0, gain) - torch.nn.init.constant_(m.bias.data, 0.0) + nn.init.normal_(m.weight.data, 1.0, gain) + nn.init.constant_(m.bias.data, 0.0) self.apply(init_func) diff --git a/icu_benchmarks/run.py b/icu_benchmarks/run.py index 0f8b7859..1c5479a0 100644 --- a/icu_benchmarks/run.py +++ b/icu_benchmarks/run.py @@ -1,16 +1,12 @@ # -*- coding: utf-8 -*- from datetime import datetime - import gin import logging import sys from pathlib import Path -import importlib.util - import torch.cuda - -from icu_benchmarks.wandb_utils import update_wandb_config, apply_wandb_sweep, set_wandb_run_name -from icu_benchmarks.tuning.hyperparameters import choose_and_bind_hyperparameters +from icu_benchmarks.wandb_utils import update_wandb_config, apply_wandb_sweep, set_wandb_experiment_name +from icu_benchmarks.tuning.hyperparameters import choose_and_bind_hyperparameters_optuna from scripts.plotting.utils import plot_aggregated_results from icu_benchmarks.cross_validation import execute_repeated_cv from icu_benchmarks.run_utils import ( @@ -20,8 +16,11 @@ log_full_line, load_pretrained_imputation_model, setup_logging, + import_preprocessor, + name_datasets, + get_config_files, ) -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode @gin.configurable("Run") @@ -33,61 +32,71 @@ def get_mode(mode: gin.REQUIRED): def main(my_args=tuple(sys.argv[1:])): args, _ = build_parser().parse_known_args(my_args) - # Set arguments for wandb sweep if args.wandb_sweep: args = apply_wandb_sweep(args) - + set_wandb_experiment_name(args, "run") # Initialize loggers log_format = "%(asctime)s - %(levelname)s - %(name)s : %(message)s" date_format = "%Y-%m-%d %H:%M:%S" verbose = args.verbose setup_logging(date_format, log_format, verbose) - - # Load weights if in evaluation mode - load_weights = args.command == "evaluate" - data_dir = Path(args.data_dir) - # Get arguments + data_dir = Path(args.data_dir) name = args.name task = args.task model = args.model reproducible = args.reproducible + evaluate = args.eval + experiment = args.experiment + source_dir = args.source_dir + modalities = args.modalities + if modalities: + logging.debug(f"Binding modalities: {modalities}") + gin.bind_parameter("preprocess.selected_modalities", modalities) + if args.label: + logging.debug(f"Binding label: {args.label}") + gin.bind_parameter("preprocess.label", args.label) + tasks, models = get_config_files(Path("configs")) + if task not in tasks or model not in models: + raise ValueError( + f"Invalid task or model. Task: {task} {'not ' if task not in tasks else ''} found. " + f"Model: {model} {'not ' if model not in models else ''}found." + ) + # Load task config + gin.parse_config_file(f"configs/tasks/{task}.gin") + mode = get_mode() # Set experiment name if name is None: name = data_dir.name logging.info(f"Running experiment {name}.") - - # Load task config - gin.parse_config_file(f"configs/tasks/{task}.gin") - - mode = get_mode() - - if args.wandb_sweep: - run_name = f"{mode}_{model}_{name}" - set_wandb_run_name(run_name) - logging.info(f"Task mode: {mode}.") - experiment = args.experiment - pretrained_imputation_model = load_pretrained_imputation_model(args.pretrained_imputation) + # Set train size to fine tune size if fine tune is set, else use custom train size + train_size = args.fine_tune if args.fine_tune is not None else args.samples if args.samples is not None else None + # Whether to load weights from a previous run + load_weights = evaluate or args.fine_tune is not None + pretrained_imputation_model = load_pretrained_imputation_model(args.pretrained_imputation) # Log imputation model to wandb update_wandb_config( { - "pretrained_imputation_model": pretrained_imputation_model.__class__.__name__ - if pretrained_imputation_model is not None - else "None" + "pretrained_imputation_model": ( + pretrained_imputation_model.__class__.__name__ if pretrained_imputation_model is not None else "None" + ) } ) - source_dir = None + log_dir_name = args.log_dir / name log_dir = ( (log_dir_name / experiment) if experiment else (log_dir_name / (args.task_name if args.task_name is not None else args.task) / model) ) + log_full_line(f"Logging to {log_dir}.", logging.INFO) + + # Check cuda availability if torch.cuda.is_available(): for name in range(0, torch.cuda.device_count()): log_full_line(f"Available GPU {name}: {torch.cuda.get_device_name(name)}", level=logging.INFO) @@ -96,29 +105,32 @@ def main(my_args=tuple(sys.argv[1:])): "No GPUs available: please check your device and Torch,Cuda installation if unintended.", level=logging.WARNING ) - log_full_line(f"Logging to {log_dir}.", logging.INFO) - if args.preprocessor: - # Import custom supplied preprocessor - log_full_line(f"Importing custom preprocessor from {args.preprocessor}.", logging.INFO) - try: - spec = importlib.util.spec_from_file_location("CustomPreprocessor", args.preprocessor) - module = importlib.util.module_from_spec(spec) - sys.modules["preprocessor"] = module - spec.loader.exec_module(module) - gin.bind_parameter("preprocess.preprocessor", module.CustomPreprocessor) - except Exception as e: - logging.error(f"Could not import custom preprocessor from {args.preprocessor}: {e}") + import_preprocessor(args.preprocessor) + # Load pretrained model in evaluate mode or when finetuning if load_weights: - # Evaluate - log_dir /= f"from_{args.source_name}" + if args.source_dir is None: + raise ValueError("Please specify a source directory when evaluating or fine-tuning.") + log_dir /= f"_from_{args.source_name}" + name_datasets(args.source_name, args.source_name, args.name) + if args.fine_tune: + log_dir /= f"fine_tune_{args.fine_tune}" + name_datasets(args.name, args.name, args.name) run_dir = create_run_dir(log_dir) source_dir = args.source_dir + logging.info(f"Will load weights from {source_dir} and bind train gin-config. Note: this might override your config.") gin.parse_config_file(source_dir / "train_config.gin") + elif args.samples and args.source_dir is not None: # Train model with limited samples and bind existing config + logging.info("Binding train gin-config. Note: this might override your config.") + gin.parse_config_file(args.source_dir / "train_config.gin") + log_dir /= f"samples_{args.fine_tune}" + name_datasets(args.name, args.name, args.name) + run_dir = create_run_dir(log_dir) else: - # Train - checkpoint = log_dir / args.checkpoint if args.checkpoint else None + # Normal train and evaluate + name_datasets(args.name, args.name, args.name) + hp_checkpoint = log_dir / args.hp_checkpoint if args.hp_checkpoint else None model_path = ( Path("configs") / ("imputation_models" if mode == RunMode.imputation else "prediction_models") / f"{model}.gin" ) @@ -130,26 +142,36 @@ def main(my_args=tuple(sys.argv[1:])): gin.parse_config_files_and_bindings(gin_config_files, args.hyperparams, finalize_config=False) log_full_line(f"Data directory: {data_dir.resolve()}", level=logging.INFO) run_dir = create_run_dir(log_dir) - choose_and_bind_hyperparameters( - args.tune, - data_dir, - run_dir, - args.seed, + choose_and_bind_hyperparameters_optuna( + do_tune=args.tune, + data_dir=data_dir, + log_dir=run_dir, + seed=args.seed, run_mode=mode, - checkpoint=checkpoint, + checkpoint=hp_checkpoint, debug=args.debug, generate_cache=args.generate_cache, load_cache=args.load_cache, - verbose=args.verbose, + verbose=verbose, + wandb=args.wandb_sweep, ) log_full_line(f"Logging to {run_dir.resolve()}", level=logging.INFO) - log_full_line("STARTING TRAINING", level=logging.INFO, char="=", num_newlines=3) + if evaluate: + mode_string = "STARTING EVALUATION" + elif args.fine_tune: + mode_string = "STARTING FINE TUNING" + else: + mode_string = "STARTING TRAINING" + log_full_line(mode_string, level=logging.INFO, char="=", num_newlines=3) + start_time = datetime.now() execute_repeated_cv( data_dir, run_dir, args.seed, + eval_only=evaluate, + train_size=train_size, load_weights=load_weights, source_dir=source_dir, reproducible=reproducible, @@ -161,12 +183,17 @@ def main(my_args=tuple(sys.argv[1:])): pretrained_imputation_model=pretrained_imputation_model, cpu=args.cpu, wandb=args.wandb_sweep, + complete_train=args.complete_train, ) log_full_line("FINISHED TRAINING", level=logging.INFO, char="=", num_newlines=3) execution_time = datetime.now() - start_time log_full_line(f"DURATION: {execution_time}", level=logging.INFO, char="") - aggregate_results(run_dir, execution_time) + try: + aggregate_results(run_dir, execution_time) + except Exception as e: + logging.error(f"Failed to aggregate results: {e}") + logging.debug("Error details:", exc_info=True) if args.plot: plot_aggregated_results(run_dir, "aggregated_test_metrics.json") diff --git a/icu_benchmarks/run_utils.py b/icu_benchmarks/run_utils.py index 179ab1f5..97b50f8c 100644 --- a/icu_benchmarks/run_utils.py +++ b/icu_benchmarks/run_utils.py @@ -1,9 +1,12 @@ +import importlib +import sys import warnings from math import sqrt +import gin import torch import json -from argparse import ArgumentParser, BooleanOptionalAction +from argparse import ArgumentParser, BooleanOptionalAction as BOA from datetime import datetime, timedelta import logging from pathlib import Path @@ -12,6 +15,7 @@ from statistics import mean, pstdev from icu_benchmarks.models.utils import JsonResultLoggingEncoder from icu_benchmarks.wandb_utils import wandb_log +import polars as pl def build_parser() -> ArgumentParser: @@ -20,76 +24,42 @@ def build_parser() -> ArgumentParser: Returns: The configured ArgumentParser. """ - parser = ArgumentParser(description="Benchmark lib for processing and evaluation of deep learning models on ICU data") - - parent_parser = ArgumentParser(add_help=False) - subparsers = parser.add_subparsers(title="Commands", dest="command", required=True) - - # ARGUMENTS FOR ALL COMMANDS - general_args = parent_parser.add_argument_group("General arguments") - general_args.add_argument("-d", "--data-dir", required=True, type=Path, help="Path to the parquet data directory.") - general_args.add_argument("-t", "--task", default="BinaryClassification", required=True, help="Name of the task gin.") - general_args.add_argument("-n", "--name", required=False, help="Name of the (target) dataset.") - general_args.add_argument("-tn", "--task-name", required=False, help="Name of the task, used for naming experiments.") - general_args.add_argument("-m", "--model", default="LGBMClassifier", required=False, help="Name of the model gin.") - general_args.add_argument("-e", "--experiment", required=False, help="Name of the experiment gin.") - general_args.add_argument( - "-l", "--log-dir", required=False, default=Path("../yaib_logs/"), type=Path, help="Log directory with model weights." + parser = ArgumentParser(description="Framework for benchmarking ML/DL models on ICU data") + + parser.add_argument("-d", "--data-dir", required=True, type=Path, help="Path to the parquet data directory.") + parser.add_argument("-t", "--task", default="BinaryClassification", required=True, help="Name of the task gin.") + parser.add_argument("-n", "--name", help="Name of the (target) dataset.") + parser.add_argument("-tn", "--task-name", help="Name of the task, used for naming experiments.") + parser.add_argument("-m", "--model", default="LGBMClassifier", help="Name of the model gin.") + parser.add_argument("-e", "--experiment", help="Name of the experiment gin.") + parser.add_argument("-l", "--log-dir", default=Path("../yaib_logs/"), type=Path, help="Log directory for model weights.") + parser.add_argument("-s", "--seed", default=1234, type=int, help="Random seed for processing, tuning and training.") + parser.add_argument("-v", "--verbose", default=False, action=BOA, help="Set to log verbosly. Disable for clean logs.") + parser.add_argument("--cpu", default=False, action=BOA, help="Set to use CPU.") + parser.add_argument("-db", "--debug", default=False, action=BOA, help="Set to load less data.") + parser.add_argument("--reproducible", default=True, action=BOA, help="Make torch reproducible.") + parser.add_argument("-lc", "--load_cache", default=False, action=BOA, help="Set to load generated data cache.") + parser.add_argument("-gc", "--generate_cache", default=False, action=BOA, help="Set to generate data cache.") + parser.add_argument("-p", "--preprocessor", type=Path, help="Load custom preprocessor from file.") + parser.add_argument("-pl", "--plot", action=BOA, help="Generate common plots.") + parser.add_argument("-wd", "--wandb-sweep", action="store_true", help="Activates wandb hyper parameter sweep.") + parser.add_argument("-imp", "--pretrained-imputation", type=str, help="Path to pretrained imputation model.") + parser.add_argument("-hp", "--hyperparams", nargs="+", help="Hyperparameters for model.") + parser.add_argument("--tune", default=False, action=BOA, help="Find best hyperparameters.") + parser.add_argument("--hp-checkpoint", type=Path, help="Use previous hyperparameter checkpoint.") + parser.add_argument("--eval", default=False, action=BOA, help="Only evaluate model, skip training.") + parser.add_argument("--complete-train", default=False, action=BOA, help="Use all data to train model, skip testing.") + parser.add_argument("-ft", "--fine-tune", default=None, type=int, help="Finetune model with amount of train data.") + parser.add_argument("-sn", "--source-name", type=Path, help="Name of the source dataset.") + parser.add_argument("--source-dir", type=Path, help="Directory containing gin and model weights.") + parser.add_argument("-sa", "--samples", type=int, default=None, help="Number of samples to use for evaluation.") + parser.add_argument( + "-mo", + "--modalities", + nargs="+", + help="Optional modality selection to use. Specify multiple modalities separated by spaces.", ) - general_args.add_argument( - "-s", "--seed", required=False, default=1234, type=int, help="Random seed for processing, tuning and training." - ) - general_args.add_argument( - "-v", - "--verbose", - default=False, - required=False, - action=BooleanOptionalAction, - help="Whether to use verbose logging. Disable for clean logs.", - ) - general_args.add_argument("--cpu", default=False, required=False, action=BooleanOptionalAction, help="Set to use CPU.") - general_args.add_argument( - "-db", "--debug", required=False, default=False, action=BooleanOptionalAction, help="Set to load less data." - ) - general_args.add_argument( - "-lc", - "--load_cache", - required=False, - default=False, - action=BooleanOptionalAction, - help="Set to load generated data cache.", - ) - general_args.add_argument( - "-gc", - "--generate_cache", - required=False, - default=False, - action=BooleanOptionalAction, - help="Set to generate data cache.", - ) - general_args.add_argument("-p", "--preprocessor", required=False, type=Path, help="Load custom preprocessor from file.") - general_args.add_argument("-pl", "--plot", required=False, action=BooleanOptionalAction, help="Generate common plots.") - general_args.add_argument( - "-wd", "--wandb-sweep", required=False, action="store_true", help="Activates wandb hyper parameter sweep." - ) - general_args.add_argument( - "-imp", "--pretrained-imputation", required=False, type=str, help="Path to pretrained imputation model." - ) - - # MODEL TRAINING ARGUMENTS - prep_and_train = subparsers.add_parser("train", help="Preprocess features and train model.", parents=[parent_parser]) - prep_and_train.add_argument( - "--reproducible", required=False, default=True, action=BooleanOptionalAction, help="Make torch reproducible." - ) - prep_and_train.add_argument("-hp", "--hyperparams", required=False, nargs="+", help="Hyperparameters for model.") - prep_and_train.add_argument("--tune", default=False, action=BooleanOptionalAction, help="Find best hyperparameters.") - prep_and_train.add_argument("--checkpoint", required=False, type=Path, help="Use previous checkpoint.") - - # EVALUATION PARSER - evaluate = subparsers.add_parser("evaluate", help="Evaluate trained model on data.", parents=[parent_parser]) - evaluate.add_argument("-sn", "--source-name", required=True, type=Path, help="Name of the source dataset.") - evaluate.add_argument("--source-dir", required=True, type=Path, help="Directory containing gin and model weights.") - + parser.add_argument("--label", type=str, help="Label to use for evaluation in case of multiple labels.", default=None) return parser @@ -107,12 +77,27 @@ def create_run_dir(log_dir: Path, randomly_searched_params: str = None) -> Path: Path to the created run log directory. """ log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + while log_dir_run.exists(): + log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) log_dir_run.mkdir(parents=True) if randomly_searched_params: (log_dir_run / randomly_searched_params).touch() return log_dir_run +def import_preprocessor(preprocessor_path: str): + # Import custom supplied preprocessor + log_full_line(f"Importing custom preprocessor from {preprocessor_path}.", logging.INFO) + try: + spec = importlib.util.spec_from_file_location("CustomPreprocessor", preprocessor_path) + module = importlib.util.module_from_spec(spec) + sys.modules["preprocessor"] = module + spec.loader.exec_module(module) + gin.bind_parameter("preprocess.preprocessor", module.CustomPreprocessor) + except Exception as e: + logging.error(f"Could not import custom preprocessor from {preprocessor_path}: {e}") + + def aggregate_results(log_dir: Path, execution_time: timedelta = None): """Aggregates results from all folds and writes to JSON file. @@ -121,6 +106,7 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): execution_time: Overall execution time. """ aggregated = {} + shap_values_test = [] for repetition in log_dir.iterdir(): if repetition.is_dir(): aggregated[repetition.name] = {} @@ -139,7 +125,18 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): with open(fold_iter / "durations.json", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) - + if (fold_iter / "test_shap_values.parquet").is_file(): + shap_values_test.append(pl.read_parquet(fold_iter / "test_shap_values.parquet")) + + if shap_values_test: + shap_values = pl.concat(shap_values_test) + shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + + try: + shap_values = pl.concat(shap_values_test) + shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + except Exception as e: + logging.error(f"Error aggregating or writing SHAP values: {e}") # Aggregate results per metric list_scores = {} for repetition, folds in aggregated.items(): @@ -179,6 +176,11 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): wandb_log(json.loads(json.dumps(accumulated_metrics, cls=JsonResultLoggingEncoder))) +def name_datasets(train="default", val="default", test="default"): + """Names the datasets for logging (optional).""" + gin.bind_parameter("train_common.dataset_names", {"train": train, "val": val, "test": test}) + + def log_full_line(msg: str, level: int = logging.INFO, char: str = "-", num_newlines: int = 0): """Logs a full line of a given character with a message centered. @@ -252,3 +254,46 @@ def setup_logging(date_format, log_format, verbose): for logger in loggers: logging.getLogger(logger).setLevel(logging.DEBUG) warnings.filterwarnings("default") + + +def get_config_files(config_dir: Path): + """ + Get all task and model config files in the specified directory. + Args: + config_dir: Name of the directory containing the config gin files. + + Returns: + tasks: List of task names + models: List of model names + """ + try: + tasks = list((config_dir / "tasks").glob("*")) + models = list((config_dir / "prediction_models").glob("*")) + tasks = [task.stem for task in tasks if task.is_file()] + models = [model.stem for model in models if model.is_file()] + except Exception as e: + logging.error(f"Error retrieving config files: {e}") + return [], [] + if "common" in tasks: + tasks.remove("common") + if "common" in models: + models.remove("common") + logging.info(f"Found tasks: {tasks}") + logging.info(f"Found models: {models}") + return tasks, models + + +def check_required_keys(vars, required_keys): + """ + Checks if all required keys are present in the vars dictionary. + + Args: + vars (dict): The dictionary to check. + required_keys (list): The list of required keys. + + Raises: + KeyError: If any required key is missing. + """ + missing_keys = [key for key in required_keys if key not in vars] + if missing_keys: + raise KeyError(f"Missing required keys in vars: {', '.join(missing_keys)}") diff --git a/icu_benchmarks/tuning/gin_utils.py b/icu_benchmarks/tuning/gin_utils.py index c88d3a0f..0f9765a9 100644 --- a/icu_benchmarks/tuning/gin_utils.py +++ b/icu_benchmarks/tuning/gin_utils.py @@ -32,15 +32,28 @@ def get_gin_hyperparameters(class_to_tune: str = gin.REQUIRED, **hyperparams: di return hyperparams_to_tune -def bind_gin_params(hyperparams_names: list[str], hyperparams_values: list): - """Binds hyperparameters to gin config and logs them. +# def bind_gin_params(hyperparams_names: list[str], hyperparams_values: list): +# """Binds hyperparameters to gin config and logs them. +# +# Args: +# hyperparams_names: List of hyperparameter names. +# hyperparams_values: List of hyperparameter values. +# """ +# logging.info("Binding Hyperparameters:") +# for param, value in zip(hyperparams_names, hyperparams_values): +# gin.bind_parameter(param, value) +# logging.info(f"{param} = {value}") +# wandb_log({param: value}) + + +def bind_gin_params(hyperparams: dict[str, any]): + """Binds hyperparameter dict to gin config and logs them. Args: - hyperparams_names: List of hyperparameter names. - hyperparams_values: List of hyperparameter values. + hyperparams: Dictionary of hyperparameters. """ logging.info("Binding Hyperparameters:") - for param, value in zip(hyperparams_names, hyperparams_values): + for param, value in hyperparams.items(): gin.bind_parameter(param, value) logging.info(f"{param} = {value}") wandb_log({param: value}) diff --git a/icu_benchmarks/tuning/hyperparameters.py b/icu_benchmarks/tuning/hyperparameters.py index 212a7753..c879bd62 100644 --- a/icu_benchmarks/tuning/hyperparameters.py +++ b/icu_benchmarks/tuning/hyperparameters.py @@ -2,24 +2,27 @@ import gin import logging from logging import NOTSET +import matplotlib.pyplot as plt import numpy as np from pathlib import Path from skopt import gp_minimize import tempfile - +import optuna +from optuna.integration.wandb import WeightsAndBiasesCallback from icu_benchmarks.models.utils import JsonResultLoggingEncoder, log_table_row, Align from icu_benchmarks.cross_validation import execute_repeated_cv from icu_benchmarks.run_utils import log_full_line from icu_benchmarks.tuning.gin_utils import get_gin_hyperparameters, bind_gin_params -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode from icu_benchmarks.wandb_utils import wandb_log +from optuna.visualization import plot_param_importances, plot_optimization_history TUNE = 25 logging.addLevelName(25, "TUNE") -@gin.configurable("tune_hyperparameters") -def choose_and_bind_hyperparameters( +@gin.configurable("tune_hyperparameters_deprecated") +def choose_and_bind_hyperparameters_scikit_optimize( do_tune: bool, data_dir: Path, log_dir: Path, @@ -60,6 +63,9 @@ def choose_and_bind_hyperparameters( Raises: ValueError: If checkpoint is not None and the checkpoint does not exist. """ + logging.warning( + "This function is deprecated and will be removed in the future. " "Use choose_and_bind_hyperparameters_optuna instead." + ) hyperparams = {} if len(scopes) == 0 or folds_to_tune_on is None: @@ -169,6 +175,216 @@ def tune_step_callback(res): bind_gin_params(hyperparams_names, res.x) +@gin.configurable("tune_hyperparameters") +def choose_and_bind_hyperparameters_optuna( + do_tune: bool, + data_dir: Path, + log_dir: Path, + seed: int, + run_mode: RunMode = RunMode.classification, + checkpoint: str = None, + scopes: list[str] = [], + n_initial_points: int = 3, + n_calls: int = 20, + sampler=optuna.samplers.GPSampler, + folds_to_tune_on: int = None, + checkpoint_file: str = "hyperparameter_tuning_logs.db", + generate_cache: bool = False, + load_cache: bool = False, + debug: bool = False, + verbose: bool = False, + wandb: bool = False, + plot: bool = True, +): + """Choose hyperparameters to tune and bind them to gin. Uses Optuna for hyperparameter optimization. + + Args: + plot: Whether to plot hyperparameter importances. + sampler: The sampler to use for hyperparameter optimization. + wandb: Whether we use wandb or not. + load_cache: Load cached data if available. + generate_cache: Generate cache data. + do_tune: Whether to tune hyperparameters or not. + data_dir: Path to the data directory. + log_dir: Path to the log directory. + seed: Random seed. + run_mode: The run mode of the experiment. + checkpoint: Name of the checkpoint run to load previously explored hyperparameters from. + scopes: List of gin scopes to search for hyperparameters to tune. + n_initial_points: Number of initial points to explore. + n_calls: Number of iterations to optimize the hyperparameters. + folds_to_tune_on: Number of folds to tune on. + checkpoint_file: Name of the checkpoint file. + debug: Whether to load less data. + verbose: Set to true to increase log output. + + Raises: + ValueError: If checkpoint is not None and the checkpoint does not exist. + """ + hyperparams = {} + + if len(scopes) == 0 or folds_to_tune_on is None: + logging.warning("No scopes and/or folds to tune on, skipping tuning.") + return + + # Collect hyperparameters. + hyperparams_bounds, hyperparams_names = collect_bound_hyperparameters(hyperparams, scopes) + + if do_tune and not hyperparams_bounds: + logging.info("No hyperparameters to tune, skipping tuning.") + return + + # Function that trains the model with the given hyperparameters. + + header = ["ITERATION"] + hyperparams_names + ["LOSS AT ITERATION"] + + # Optuna objective function + def objective(trial, hyperparams_bounds, hyperparams_names): + # Optuna objective function + hyperparams = {} + logging.info(f"Bounds: {hyperparams_bounds}, Names: {hyperparams_names}") + for name, value in zip(hyperparams_names, hyperparams_bounds): + if isinstance(value, tuple): + + def suggest_int_param(trial, name, value): + return trial.suggest_int(name, value[0], value[1], log=value[2] == "log" if len(value) == 3 else False) + + def suggest_float_param(trial, name, value): + return trial.suggest_float(name, value[0], value[1], log=value[2] == "log" if len(value) == 3 else False) + + def suggest_categorical_param(trial, name, value): + return trial.suggest_categorical(name, value) + + # Then in the objective function: + if isinstance(value[0], int) and isinstance(value[1], int): + hyperparams[name] = suggest_int_param(trial, name, value) + elif isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)): + hyperparams[name] = suggest_float_param(trial, name, value) + else: + hyperparams[name] = suggest_categorical_param(trial, name, value) + else: + hyperparams[name] = trial.suggest_categorical(name, value) + return bind_params_and_train(hyperparams) + + def tune_step_callback(study: optuna.study.Study, trial: optuna.trial.FrozenTrial): + table_cells = [str(len(study.trials)), *list(study.trials[-1].params.values()), study.trials[-1].value] + highlight = study.trials[-1] == study.best_trial # highlight if best so far + log_table_row(header, TUNE) + log_table_row(table_cells, TUNE, align=Align.RIGHT, header=header, highlight=highlight) + wandb_log({"HP-optimization-iteration": len(study.trials)}) + + if do_tune: + log_full_line("STARTING TUNING", level=TUNE, char="=") + logging.log( + TUNE, + f"Applying {sampler} from {n_initial_points} points in {n_calls} " f"iterations on {folds_to_tune_on} folds.", + ) + log_table_row(header, TUNE) + else: + logging.log(TUNE, "Hyperparameter tuning disabled") + if checkpoint: + study = optuna.load_study(study_name="tuning", storage="sqlite:///" + str(checkpoint)) + configuration = study.best_params + # We have loaded a checkpoint, use the best hyperparameters. + logging.info("Training with the best hyperparameters from loaded checkpoint:") + bind_gin_params(configuration) + return + else: + logging.log( + TUNE, "Choosing hyperparameters randomly from bounds using hp tuning as no earlier checkpoint " "supplied." + ) + n_initial_points = 1 + n_calls = 1 + + def bind_params_and_train(hyperparams): + with tempfile.TemporaryDirectory(dir=log_dir) as temp_dir: + bind_gin_params(hyperparams) + if not do_tune: + return 0 + score = execute_repeated_cv( + data_dir, + Path(temp_dir), + seed, + mode=run_mode, + cv_repetitions_to_train=1, + cv_folds_to_train=folds_to_tune_on, + generate_cache=generate_cache, + load_cache=load_cache, + test_on="val", + debug=debug, + verbose=verbose, + wandb=wandb, + ) + logging.info(f"Score: {score}") + return score + + if isinstance(sampler, optuna.samplers.GPSampler): + sampler = sampler(seed=seed, n_startup_trials=n_initial_points, deterministic_objective=True) + else: + sampler = sampler(seed=seed) + pruner = optuna.pruners.HyperbandPruner() + # Optuna study + # Attempt checkpoint loading + if checkpoint and checkpoint.exists(): + logging.warning(f"Hyperparameter checkpoint {checkpoint} does not exist.") + # logging.info("Attempting to find latest checkpoint file.") + # checkpoint_path = find_checkpoint(log_dir.parent, checkpoint_file) + # Check if we found a checkpoint file + logging.info(f"Loading checkpoint at {checkpoint}") + study = optuna.load_study(study_name="tuning", storage="sqlite:///" + str(checkpoint), sampler=sampler, pruner=pruner) + n_calls = n_calls - len(study.trials) + else: + if checkpoint: + logging.warning("Checkpoint path given as flag but not found, starting from scratch.") + study = optuna.create_study( + sampler=sampler, + storage="sqlite:///" + str(log_dir / checkpoint_file), + study_name="tuning", + pruner=pruner, + load_if_exists=True, + ) + + callbacks = [tune_step_callback] + if wandb: + wandb_kwargs = { + "config": {"sampler": sampler}, + "allow_val_change": True, + } + wandbc = WeightsAndBiasesCallback(metric_name="loss", wandb_kwargs=wandb_kwargs) + callbacks.append(wandbc) + + logging.info(f"Starting or resuming Optuna study with {n_calls} trails and callbacks: {callbacks}.") + if n_calls > 0: + study.optimize( + lambda trail: objective(trail, hyperparams_bounds, hyperparams_names), + n_trials=n_calls, + callbacks=callbacks, + gc_after_trial=True, + ) + else: + logging.info("No more hyperparameter tuning iterations left, skipping tuning.") + logging.info("Training with these hyperparameters:") + bind_gin_params(study.best_params) + return + logging.disable(level=NOTSET) + + if do_tune: + log_full_line("FINISHED TUNING", level=TUNE, char="=", num_newlines=4) + + logging.info("Training with these hyperparameters:") + bind_gin_params(study.best_params) + + if plot: + try: + logging.info("Plotting hyperparameter importances.") + plot_param_importances(study) + plt.savefig(log_dir / "param_importances.png") + plot_optimization_history(study) + plt.savefig(log_dir / "optimization_history.png") + except Exception as e: + logging.error(f"Failed to plot hyperparameter importances: {e}") + + def collect_bound_hyperparameters(hyperparams, scopes): for scope in scopes: with gin.config_scope(scope): diff --git a/icu_benchmarks/wandb_utils.py b/icu_benchmarks/wandb_utils.py index be8bac7d..2ea06b57 100644 --- a/icu_benchmarks/wandb_utils.py +++ b/icu_benchmarks/wandb_utils.py @@ -1,5 +1,7 @@ from argparse import Namespace import logging +from pathlib import Path + import wandb @@ -16,7 +18,7 @@ def update_wandb_config(config: dict) -> None: """ logging.debug(f"Updating Wandb config: {config}") if wandb_running(): - wandb.config.update(config) + wandb.config.update(config, allow_val_change=True) def apply_wandb_sweep(args: Namespace) -> Namespace: @@ -28,7 +30,7 @@ def apply_wandb_sweep(args: Namespace) -> Namespace: Returns: Namespace: arguments with sweep configuration applied (some are applied via hyperparams) """ - wandb.init() + wandb.init(allow_val_change=True, dir=args.log_dir) sweep_config = wandb.config args.__dict__.update(sweep_config) if args.hyperparams is None: @@ -49,13 +51,28 @@ def wandb_log(log_dict): wandb.log(log_dict) -def set_wandb_run_name(run_name): +def set_wandb_experiment_name(args, mode): """stores the run name in wandb config Args: - run_name (str): name of the current run + args (Namespace): parsed arguments + mode (RunMode): run mode """ + if args.name is None: + data_dir = Path(args.data_dir) + args.name = data_dir.name + run_name = f"{mode}_{args.model}_{args.name}" + if args.modalities: + run_name += f"_mods_{args.modalities}" + if args.fine_tune: + run_name += f"_source_{args.source_name}_fine-tune_{args.fine_tune}_samples" + elif args.eval: + run_name += f"_source_{args.source_name}" + elif args.samples: + run_name += f"_train_size_{args.samples}_samples" + elif args.complete_train: + run_name += "_complete_training" + if wandb_running(): wandb.config.update({"run-name": run_name}) wandb.run.name = run_name - wandb.run.save() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..e48f4d06 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,35 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +black==24.3.0 +coverage==7.2.3 +flake8>=7.0.0 +matplotlib==3.7.1 +gin-config==0.5.0 +pytorch-ignite==0.5.0.post2 +# Note: versioning of Pytorch might be dependent on compatible CUDA version. +# Please check yourself if your Pytorch installation supports cuda (for gpu acceleration) +torch==2.4 +lightning==2.4.0 +torchmetrics==1.0.3 +lightgbm==4.4.0 +xgboost==2.1.0 +imbalanced-learn==0.12.3 +catboost==1.2.5 +numpy==1.24.3 +pandas==2.2.2 +polars==1.9.0 +pyarrow==14.0.1 +pytest==7.3.1 +scikit-learn==1.5.0 +tensorboard==2.12.2 +tqdm==4.66.3 +einops==0.6.1 +hydra-core==1.3 +optuna==4.0.0 +optuna-integration==4.0.0 +wandb==0.17.3 +recipies==1.0 +#Fixed version because of NumPy incompatibility and stale development status. +scikit-optimize-fix==0.9.1 +hydra-submitit-launcher==1.2.0 +pytest-runner==6.0.1 + diff --git a/setup.py b/setup.py index c4edbb99..384562ea 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ def parse_environment_yml(): keywords="benchmark mimic-iii eicu hirid clinical-ml machine-learning benchmark time-series mimic-iv patient-monitoring " "amsterdamumcdb clinical-data ehr icu ricu pyicu", name="yaib", - packages=find_packages(include=["icu_benchmarks"]), + packages=find_packages(), setup_requires=setup_requirements, test_suite="tests", tests_require=[],