diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c857c4b..048e2275 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,4 +40,4 @@ jobs: # - name: Setup package # run: pip install -e . # - name: Test command line tool - # run: python -m icu_benchmarks.run --help + # run: python -m icu_benchmarks.run --help \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7b1be7d7..92bd0a44 100644 --- a/.gitignore +++ b/.gitignore @@ -113,7 +113,18 @@ lsf* icu_benchmarks/legacy_MH MH_DEBUG +.vscode +lightning_logs + +# Mac file system +.DS_Store + +# wandb logs +wandb/ # Cached data **/cache/ .DS_Store .vscode/launch.json +yaib_logs/ +*.ckpt +*.csv \ No newline at end of file diff --git a/CONTRIBUTING.MD b/CONTRIBUTING.MD new file mode 100644 index 00000000..34bd1212 --- /dev/null +++ b/CONTRIBUTING.MD @@ -0,0 +1,101 @@ +# Contributing + +When contributing to this repository, please first discuss the change you wish to make via issue, +email, or any other method with the owners of this repository before making a change. + +Please note we have a code of conduct, please follow it in all your interactions with the project. + +## Pull Request Process + +1. Ensure any install or build dependencies are removed before the end of the layer when doing a + build. +2. Update the README.md with details of changes to the interface, this includes new environment + variables, exposed ports, useful file locations and container parameters. +3. Increase the version numbers in any examples files and the README.md to the new version that this + Pull Request would represent. The versioning scheme we use is [SemVer](http://semver.org/). +4. You may merge the Pull Request in once you have the sign-off of two other developers, or if you + do not have permission to do that, you may request the second reviewer to merge it for you. + +## Code of Conduct + +### Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of experience, +nationality, personal appearance, race, religion, or sexual identity and +orientation. + +### Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +### Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +### Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +[//]: # (### Enforcement) + +[//]: # () +[//]: # (Instances of abusive, harassing, or otherwise unacceptable behavior may be) + +[//]: # (reported by contacting the project team at [INSERT EMAIL ADDRESS]. All) + +[//]: # (complaints will be reviewed and investigated and will result in a response that) + +[//]: # (is deemed necessary and appropriate to the circumstances. The project team is) + +[//]: # (obligated to maintain confidentiality with regard to the reporter of an incident.) + +[//]: # (Further details of specific enforcement policies may be posted separately.) + +[//]: # () +[//]: # (Project maintainers who do not follow or enforce the Code of Conduct in good) + +[//]: # (faith may face temporary or permanent repercussions as determined by other) + +[//]: # (members of the project's leadership.) + +### Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ \ No newline at end of file diff --git a/LICENSE b/LICENSE index 97c71cd4..533fac62 100644 --- a/LICENSE +++ b/LICENSE @@ -2,7 +2,7 @@ MIT License -Copyright (c) 2022, Robin van de Water, Hendrik Schmidt, Patrick Rockenschaub +Copyright (c) 2023, Robin van de Water, Hendrik Schmidt, Patrick Rockenschaub Copyright (c) 2021, ETH Zurich, Biomedical Informatics Group; ratschlab Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/PAPER.md b/PAPER.md new file mode 100644 index 00000000..cc88ed50 --- /dev/null +++ b/PAPER.md @@ -0,0 +1,108 @@ +> 📋 This file follows the template for releasing ML research code +> from [papers with code](https://github.com/paperswithcode/releasing-research-code) + +![YAIB](docs/figures/yaib_logo.png) + +# Yet Another ICU Benchmark: _A Flexible Multi-Center Framework for Clinical ML_ + +This repository is the official implementation of [placeholder](https://arxiv.org/abs/2030.12345). +See a graphical overview of our framework below: +![yaib_flow](docs/figures/yaib_flow_combined.svg) +We propose Yet Another ICU Benchmark. It was designed to address the issues of reproduciblity and provide a unified interface to develop +clinical prediction models for the ICU. An experiment in YAIB consists of four steps: +1) Defining clinical concepts from the raw data. +2) Extracting the patient cohort and specifying the prediction task. +3) Preprocessing and feature generation. +4) Training and evaluation of the ML model. + +## 📋 Requirements + +YAIB can be installed using conda or pip. Below you will find the three CLI commands to install YAIB using conda. +The + +The first command will install an environment based on Python 3.10 (currently). +This should work on x86 hardware. + +``` +conda env update -f environment.yml +``` + +We then activate the environment and install a package called `icu-benchmarks`, after which YAIB should be operational. + +``` +conda activate yaib +pip install -e . +``` + +To get the datasets for this paper, please see the [YAIB-cohorts repository](https://github.com/rvandewater/YAIB-cohorts) and +the [page on the YAIB wiki](https://github.com/rvandewater/YAIB/wiki/Generating-Cohorts). You +will need to get access to the ICU datasets that you want to run by following a credentialing procedure. + +## Training + +The easiest method to train the models in the paper is to run these commands from the directory root: + +```train +wandb sweep --verbose experiments/benchmark_classification.yml +wandb sweep --verbose experiments/benchmark_regression.yml +``` + +This will create two hyperparameter sweeps for WandB for the classification and regression tasks. +This configuration will train all the models in the paper. You can then run the following command to train the models: + +```train +wandb agent +``` + +> Tip: You can choose to run each of the configurations on a SLURM cluster instance by `wandb agent --count 1 ` + +### Quickstart + +If you do not yet have access to the ICU datasets, you can run the following command to train models for the included demo +(MIMIC-III and eICU) task +cohorts: + +```train +wandb sweep --verbose experiments/demo_benchmark_classification.yml +wandb sweep --verbose experiments/demo_benchmark_regression.yml +``` + +Use the command above to create a sweep and run this sweep. + +## Evaluation + +Evaluation will happen automatically after running this command. Additionally, YAIB will generate extensive log files and +model files. The logging location is specified within the `.yml` files. We recommend using the `wandb` web-interface to inspect +the results (see your personal WandB project. + +## Pre-trained Models + +You can download pretrained models here: [YAIB-models GitHub repository](https://github.com/rvandewater/YAIB-models). +YAIB has built-in functionality to evaluate these models. See the below command for an example: + +``` +icu-benchmarks evaluate \ + -d demo_data/mortality24/eicu_demo \ + -n eicu_demo \ + -t BinaryClassification \ + -tn Mortality24 \ + -m LGBMClassifier \ + --generate_cache \ + --load_cache \ + -s 2222 \ + -l ../yaib_logs \ + -sn mimic \ + --source-dir ../yaib_logs/mimic_demo/Mortality24/LGBMClassifier/2022-12-12T15-24-46/fold_0 +``` + +## 📊Results + +The current latest results are shown below. Note that there have been major changes between the classification and regression +task experiments. However, results should be comparable overall. Updated results will be posted in the near future. +![Results](docs/figures/results_yaib.png) + +## Contributing + +This source code is released under the MIT license, included [here](LICENSE). We do not own any of the datasets used or +included in this repository. The demo datasets have been released under +an [Open Data Commons Open Database License (ODbL)](https://opendatacommons.org/licenses/odbl/1-0/). diff --git a/README.md b/README.md index 41e6542d..22ad48be 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ![YAIB logo](docs/figures/yaib_logo.png) -# Yet Another ICU Benchmark +# 🧪 Yet Another ICU Benchmark [![CI](https://github.com/rvandewater/YAIB/actions/workflows/ci.yml/badge.svg?branch=development)](https://github.com/rvandewater/YAIB/actions/workflows/ci.yml) [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) @@ -9,45 +9,59 @@ [//]: # (TODO: add coverage once we have some tests ) -Yet another ICU benchmark (YAIB) provides a framework for doing clinical machine learning experiments on (ICU) EHR data. +Yet another ICU benchmark (YAIB) provides a framework for doing clinical machine learning experiments on Intensive Care Unit ( +ICU) EHR data. + We support the following datasets out of the box: -| Dataset | [MIMIC-III](https://physionet.org/content/mimiciii/) / [IV](https://physionet.org/content/mimiciv/) | [eICU-CRD](https://physionet.org/content/eicu-crd/) | [HiRID](https://physionet.org/content/hirid/1.1.1/) | [AUMCdb](https://doi.org/10.17026/dans-22u-f8vd) | +| **Dataset** | [MIMIC-III](https://physionet.org/content/mimiciii/) / [IV](https://physionet.org/content/mimiciv/) | [eICU-CRD](https://physionet.org/content/eicu-crd/) | [HiRID](https://physionet.org/content/hirid/1.1.1/) | [AUMCdb](https://doi.org/10.17026/dans-22u-f8vd) | |-------------------------|-----------------------------------------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------|--------------------------------------------------| -| Admissions | 40k / 50k | 200k | 33k | 23k | -| Frequency (time-series) | 1 hour | 5 minutes | 2 / 5 minutes | up to 1 minute | -| Origin | USA | USA | Switzerland | Netherlands | +| **Admissions** | 40k / 73k | 200k | 33k | 23k | +| **Version** | v1.4 / v2.2 | v2.0 | v1.1.1 | v1.0.2 | | | | +| **Frequency** (time-series) | 1 hour | 5 minutes | 2 / 5 minutes | up to 1 minute | +| **Originally published** | 2015 / 2020 | 2017 | 2020 | 2019 | | | | | +| **Origin** | USA | USA | Switzerland | Netherlands | -The benchmark is designed for operating on preprocessed parquet files. We refer to the PyICU (in development) -or [ricu package](https://github.com/eth-mds/ricu) for generating these parquet files for particular cohorts and endpoints. +New datasets can also be added. We are currently working on a package to make this process as smooth as possible. +The benchmark is designed for operating on preprocessed parquet files. + -We provide several common tasks for clinical prediction: +We provide five common tasks for clinical prediction by default: -| No | Task Theme | Temporality | Type | +| No | Task | Frequency | Type | |-----|---------------------------|--------------------|-------------------------------------| -| 1 | ICU Mortality | Hourly (after 24H) | Sequential Classification | -| 2 | Acute Kidney Injury (AKI) | Hourly (within 6H) | Sequence to Sequence Classification | -| 3 | Sepsis | Hourly (within 6H) | Sequence to Sequence Classification | - -[//]: # (| 4 | Circulatory Failure | 5 Minutes | Sequence to Sequence Classification |) +| 1 | ICU Mortality | Once per Stay (after 24H) | Binary Classification | +| 2 | Acute Kidney Injury (AKI) | Hourly (within 6H) | Binary Classification | +| 3 | Sepsis | Hourly (within 6H) | Binary Classification | +| 4 | Kidney Function(KF) | Once per stay | Regression | +| 5 | Length of Stay (LoS) | Hourly (within 7D) | Regression | -[//]: # (| 5 | Length of Stay (LoS) | Hourly | Sequence to Sequence Regression | ) +New tasks can be easily added. +For the purposes of getting started right away, we include the eICU and MIMIC-III demo datasets in our repository. +The following repositories may be relevant as well: -Please refer to [cohort definitions]() for further information. +- [YAIB-cohorts](https://github.com/rvandewater/YAIB-cohorts): Cohort generation for YAIB. +- [YAIB-models](https://github.com/rvandewater/YAIB-models): Pretrained models for YAIB. +- [ReciPys](https://github.com/rvandewater/ReciPys): Preprocessing package for YAIB pipelines. -## Paper + For all YAIB related repositories, please see: https://github.com/stars/rvandewater/lists/yaib. +# 📄Paper +To reproduce the benchmarks in our paper, we refer to: the [ML reproducibility document](PAPER.md). If you use this code in your research, please cite the following publication: ``` ``` - This paper can also be found on arxiv: TBD -# Installation +# 💿Installation + + -YAIB currently requires an installation of Conda. Below you will find the three CLI commands to install YAIB. + +YAIB can be installed using conda or pip. Below you will find the three CLI commands to install YAIB using conda. The The first command will install an environment based on Python 3.10 (currently). @@ -65,44 +79,94 @@ conda activate yaib pip install -e . ``` -# Usage +If you want to install the icu-benchmarks package with pip, execute the command below: -## Getting the Datasets +``` +pip install torch numpy && pip install -e . +``` -HiRID, eICU, and MIMIC IV can be accessed through [PhysioNet](https://physionet.org/). A guide to this process can be -found [here](https://eicu-crd.mit.edu/gettingstarted/access/). -AUMCdb can be accessed through a separate access [procedure](https://github.com/AmsterdamUMC/AmsterdamUMCdb). We do not have -involvement in the access procedure and can not answer to any requests for data access. +If you are on a Mac with Metal Performance Shader, install the package with the following command: -## Data Conversion +``` +pip install torch numpy && pip install -e .[mps] +``` -Since the datasets were created independently of each other, they do not share the same data structure or data identifiers. In -order to make them interoperable, use the preprocessing utilities -provided by the [ricu package](https://github.com/eth-mds/ricu). -Ricu pre-defines a large number of clinical concepts and how to load them from a given dataset, providing a common interface to -the data, that is used in this -benchmark. +# 👩‍💻Usage -### Extracting cohorts +Please refer to [our wiki](https://github.com/rvandewater/YAIB/wiki) for detailed information on how to use YAIB. -TODO +## Quickstart 🚀 (demo data) -# Data +In the folder `demo_data` we provide processed publicly available demo datasets from eICU and MIMIC with the necessary labels +for `Mortality at 24h`,`Sepsis`, `Akute Kidney Injury`, `Kidney Function`, and `Length of Stay`. -YAIB expects data generated by [pyicu](https://github.com/prockenschaub/pyicu), a -rewritten [ricu](https://github.com/prockenschaub/ricu-package) for Python. +If you do not yet have access to the ICU datasets, you can run the following command to train models for the included demo +cohorts: -## Demo data +``` +wandb sweep --verbose experiments/demo_benchmark_classification.yml +wandb sweep --verbose experiments/demo_benchmark_regression.yml +``` -In the folder `demo_data` we provide processed publicly available demo datasets from eICU and MIMIC with the necessary lables -for `Akute Kidney Injury`, `Mortality at 24h` and `Sepsis`. +```train +wandb agent +``` -# Use with CLI Commands +> Tip: You can choose to run each of the configurations on a SLURM cluster instance by `wandb agent --count 1 ` -## Preprocess and Train +> Note: You will need to have a wandb account and be logged in to run the above commands. -The following command will run training and evaluation on the MIMIC demo dataset for (Binary) Mortality prediction at 24h with the -LGBMClassifier. Child samples are reduced due to the small amount of training data. We load available cache and, if available, load +## Getting the datasets + +HiRID, eICU, and MIMIC IV can be accessed through [PhysioNet](https://physionet.org/). A guide to this process can be +found [here](https://eicu-crd.mit.edu/gettingstarted/access/). +AUMCdb can be accessed through a separate access [procedure](https://github.com/AmsterdamUMC/AmsterdamUMCdb). We do not have +involvement in the access procedure and can not answer to any requests for data access. + +## Cohort creation + +Since the datasets were created independently of each other, they do not share the same data structure or data identifiers. In +order to make them interoperable, use the preprocessing utilities +provided by the [ricu package](https://github.com/eth-mds/ricu). +Ricu pre-defines a large number of clinical concepts and how to load them from a given dataset, providing a common interface to +the data, that is used in this +benchmark. Please refer to our [cohort definition](https://github.com/rvandewater/YAIB-cohorts) code for generating the cohorts +using our python interface for ricu. +After this, you can run the benchmark once you have gained access to the datasets. + +## Data + +Users can supply their own datasets in specific format. + +Adding a new dataset type can be easily done by providing it in a `.gin` +task definition file, see. Note, however, that any datasets formatted in the default way do not require any changes to be +used by YAIB. By +default, we have chosen to work with the Apache +parquet file format, which is a modern, +open-source column-oriented format that does not require a lot of +storage due to efficient data compression. We separate the data into +three separate files: `DYNAMIC`, `STATIC`, and `OUTCOME`; this is +defined for dynamic variables (that change during the stay), constant +parameters, and the prediction task label respectively. Our [cohort +definition code](https://github.com/rvandewater/YAIB-cohorts) produces +the files exactly in this format. Furthermore, we see the concept of +`roles` with the definition of the `vars` dictionary. These roles are +assigned as defined in [ReciPys](https://github.com/rvandewater/ReciPys), the preprocessing package developed +alongside YAIB. +The `GROUP` variable defines which internal dataset variable should be +used to "group by" for, e.g., aggregating patient vital signs. The +`SEQUENCE` variable defines the sequential dimension of the dataset (in +the common case, time). The other keys in this dictionary define the +feature columns and outcome variables to be used for prediction. + +# 👟 Running YAIB + +## Preprocessing and Training + +The following command will run training and evaluation on the MIMIC demo dataset for (Binary) mortality prediction at 24h with +the +LGBMClassifier. Child samples are reduced due to the small amount of training data. We load available cache and, if available, +load existing cache files. ``` @@ -113,8 +177,9 @@ icu-benchmarks train \ -tn Mortality24 \ -m LGBMClassifier \ -hp LGBMClassifier.min_child_samples=10 \ - -gc \ - -lc \ + --generate_cache + --load_cache \ + --seed 2222 \ -s 2222 \ -l ../yaib_logs/ \ --tune @@ -128,73 +193,24 @@ icu-benchmarks train \ > For Windows based systems, the next line character (\\) needs to be replaced by (^) (Command Prompt) or (`) (Powershell) > respectively. -### Hyperparameter Tuning -To understand how a parameter can be automatically tuned via bayesian optimization, let's look at the following example -configuration: +Alternatively, the easiest method to train all the models in the paper is to run these commands from the directory root: -``` -... -# Optimizer params -optimizer/hyperparameter.class_to_tune = @Adam -optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) - -# Encoder params -model/hyperparameter.class_to_tune = @LSTMNet -model/hyperparameter.input_dim = %EMB -model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.hidden_dim = (32, 256) -model/hyperparameter.layer_dim = (1, 3) - -tune_hyperparameters.scopes = ["model", "optimizer"] # defines the scopes that the random search runs in -tune_hyperparameters.n_initial_points = 5 # defines random points to initilaize gaussian process -tune_hyperparameters.n_calls = 30 # numbe rof iterations to find best set of hyperparameters -tune_hyperparameters.folds_to_tune_on = 2 # number of folds to use to evaluate set of hyperparameters +```train +wandb sweep --verbose experiments/benchmark_classification.yml +wandb sweep --verbose experiments/benchmark_regression.yml ``` -In this example, we have the two scopes `model` and `optimizer`, the scopes take care of adding the parameters only to the -pertinent classes. -For each scope a `class_to_tune` needs to be set to the class it represents, in this case `LSTMNet` and `Adam` -respectively. -We can add whichever parameter we want to the classes following this syntax: +This will create two hyperparameter sweeps for WandB for the classification and regression tasks. +This configuration will train all the models in the paper. You can then run the following command to train the models: -``` -tune_hyperparameters.scopes = ["", ...] -/hyperparameter.class_to_tune = @ -/hyperparameter. = ['list', 'of', 'possible', 'values'] +```train +wandb agent ``` -If we run `experiments` and want to overwrite the model configuration, this can be done easily: +> Tip: You can choose to run each of the configurations on a SLURM cluster instance by `wandb agent --count 1 ` -``` -include "configs/tasks/Mortality_At24Hours.gin" -include "configs/models/LSTM.gin" - -optimizer/hyperparameter.lr = 1e-4 - -model/hyperparameter.hidden_dim = [100, 200] -``` - -This configuration for example overwrites the `lr` parameter of `Adam` with a concrete value, -while it only specifies a different search space for `hidden_dim` of `LSTMNet` to run the random search on. - -The same holds true for the command line. Setting the following flag would achieve the same result (make sure to only have -spaces between parameters): - -``` --hp optimizer/hyperparameter.lr=1e-4 model/hyperparameter.hidden_dim='[100,200]' -``` - -There is an implicit hierarchy, independent of where the parameters are added (`model.gin`, `experiment.gin` or CLI `-hp`): - -``` -LSTM.hidden_dim = 8 # always takes precedence -model/hyperparameter.hidden_dim = 6 # second most important -model/hyperparameter.hidden_dim = (4, 6) # only evaluated if the others aren't found in gin configs and CLI -``` - -The hierarchy CLI `-hp` > `experiment.gin` > `model.gin` is only important for bindings on the same "level" from above. +> Note: You will need to have a wandb account and be logged in to run the above commands. ## Evaluate @@ -208,85 +224,70 @@ icu-benchmarks evaluate \ -t BinaryClassification \ -tn Mortality24 \ -m LGBMClassifier \ - -c \ + --generate_cache \ + --load_cache \ -s 2222 \ -l ../yaib_logs \ -sn mimic \ --source-dir ../yaib_logs/mimic_demo/Mortality24/LGBMClassifier/2022-12-12T15-24-46/fold_0 ``` -## Metrics +[//]: # (## Metrics) + +[//]: # () -Several metrics are defined for this benchmark: +[//]: # (Several metrics are defined for this benchmark:) -- Binary Classification: Because our tasks are all highly imbalanced, we use both ROC and PR Area Under the Curve - using [sklearn.metrics.roc_auc_score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html) - and [sklearn.metrics.average_precision_score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) -- Regression : The Mean Absolute Error (MAE) is used - with [sklearn.metrics.mean_absolute_error](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html) +[//]: # () + +[//]: # (- Binary Classification: Because our tasks are all highly imbalanced, we use both ROC and PR Area Under the Curve) + +[//]: # ( using [sklearn.metrics.roc_auc_score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html)) + +[//]: # ( and [sklearn.metrics.average_precision_score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score)) + +[//]: # (- Regression : The Mean Absolute Error (MAE) is used) + +[//]: # ( with [sklearn.metrics.mean_absolute_error](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html)) ## Models We provide several existing machine learning models that are commonly used for multivariate time-series data. -`pytorch` is used for the deep learning models, `lightgbm` for the boosted tree approaches, and `sklearn` for the logistic -regression model and metrics. -The benchmark provides the following built-in models: +`pytorch` is used for the deep learning models, `lightgbm` for the boosted tree approaches, and `sklearn` for other classical +machine learning models. +The benchmark provides (among others) the following built-in models: - [Logistic Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html?highlight=logistic+regression): Standard regression approach. +- [Elastic Net](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html): Linear regression with + combined L1 and L2 priors as regularizer. - [LightGBM](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf): Efficient gradient boosting trees. - [Long Short-term Memory (LSTM)](https://ieeexplore.ieee.org/document/818041): The most commonly used type of Recurrent Neural Networks for long sequences. -- [Gated Recurrent Unit (GRU)](https://arxiv.org/abs/1406.1078) : A extension to LSTM which showed improvement over them in the - context of polyphonic music modeling and speech signal modeling ([paper](https://arxiv.org/abs/1412.3555)). +- [Gated Recurrent Unit (GRU)](https://arxiv.org/abs/1406.1078) : A extension to LSTM which showed + improvements ([paper](https://arxiv.org/abs/1412.3555)). - [Temporal Convolutional Networks (TCN)](https://arxiv.org/pdf/1803.01271 ): 1D convolution approach to sequence data. By using dilated convolution to extend the receptive field of the network it has shown great performance on long-term dependencies. - [Transformers](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf): The most common Attention based approach. -# Development - -YAIB is in active development. The following sections could be relevant for adding new code to our repository - -## Libraries - -The following libraries are important to the operation of YAIB: - -- [Pandas](https://github.com/pandas-dev/pandas): Popular data structure framework. -- [ReciPys](https://github.com/rvandewater/recipys): A modular preprocessing package for Pandas dataframes. -- [Pytorch](https://pytorch.org/): An open source machine learning framework for deep learning applications. -- [Pytorch Ignite](https://github.com/pytorch/ignite): Library for training and evaluating neural networks in Pytorch. -- [Cuda Toolkit](https://developer.nvidia.com/cuda-toolkit): GPU acceleration used for deep learning models. -- [Scikit-learn](https://github.com/scikit-learn/scikit-learn): Machine learning library. -- [LightGBM](https://github.com/microsoft/LightGBM): Gradient boosting framework. -- [GIN](https://github.com/google/gin-config): Provides a lightweight configuration framework for Python. +# 🛠️ Development -## Run Tests - -``` -python -m pytest ./tests/recipes -coverage run -m pytest ./tests/recipes -# then use either of the following -coverage report -coverage html -``` - -## Autoformat and lint - -For development purposes, we use the `Black` package to autoformat our code and a `Flake8` Linting/CI check: - -``` -black . -l 127 -flake8 . --count --max-complexity=14 --max-line-length=127 --statistics -``` +To adapt YAIB to your own use case, you can use +the [development information](https://github.com/rvandewater/YAIB/wiki/Contribution-and-development) page as a reference. +We appreciate contributions to the project. Please read the [contribution guidelines](CONTRIBUTING.md) before submitting a pull +request. # Acknowledgements -We do not own any of the datasets used in this benchmark. This project uses adapted components of -the [HiRID benchmark](https://github.com/ratschlab/HIRID-ICU-Benchmark/). We thank the authors for providing this codebase. +We do not own any of the datasets used in this benchmark. This project uses heavily adapted components of +the [HiRID benchmark](https://github.com/ratschlab/HIRID-ICU-Benchmark/). We thank the authors for providing this codebase and +encourage further development to benefit the scientific community. The demo datasets have been released under +an [Open Data Commons Open Database License (ODbL)](https://opendatacommons.org/licenses/odbl/1-0/). # License -This source code is released under the MIT license, included [here](LICENSE). +This source code is released under the MIT license, included [here](LICENSE). We do not own any of the datasets used or +included in this repository. diff --git a/configs/imputation_models/Attention.gin b/configs/imputation_models/Attention.gin new file mode 100644 index 00000000..ba638b81 --- /dev/null +++ b/configs/imputation_models/Attention.gin @@ -0,0 +1,19 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @Attention + +# Model params +Attention.n_layers = 4 +Attention.d_model = 256 +Attention.d_inner = 128 +Attention.n_head = 4 +Attention.d_k = 32 +Attention.d_v = 32 +Attention.dropout = 0.0 +Attention.epochs = 100 diff --git a/configs/imputation_models/BRITS.gin b/configs/imputation_models/BRITS.gin new file mode 100644 index 00000000..11ca7512 --- /dev/null +++ b/configs/imputation_models/BRITS.gin @@ -0,0 +1,14 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @BRITS +train_common.epochs = 100 + +# Model params +BRITS.rnn_hidden_size = 64 +BRITS.batch_size = 256 diff --git a/configs/imputation_models/BRNN.gin b/configs/imputation_models/BRNN.gin new file mode 100644 index 00000000..4994629f --- /dev/null +++ b/configs/imputation_models/BRNN.gin @@ -0,0 +1,26 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.rnn +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @BRNN +train_common.epochs = 100 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 + +train_common.optimizer = @Adam +ImputationWrapper.lr_scheduler = "" + +# Optimizer params +Adam.lr = 0.001 +Adam.weight_decay = 1e-6 + +# Model params +BRNN.cell = 'lstm' +BRNN.hidden_size = 64 +BRNN.state_init = 'zero' +BRNN.dropout = 0.3 \ No newline at end of file diff --git a/configs/imputation_models/CSDI.gin b/configs/imputation_models/CSDI.gin new file mode 100644 index 00000000..e7881d07 --- /dev/null +++ b/configs/imputation_models/CSDI.gin @@ -0,0 +1,37 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.data.preprocess +import icu_benchmarks.imputation.csdi + +# Train params +train_common.model = @CSDI + +# here you can set some training parameters +train_common.epochs = 1000 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 +train_common.use_wandb = True + +ImputationWrapper.optimizer = @Adam +ImputationWrapper.lr_scheduler = "cosine" + +# Optimizer params +Adam.lr = 0.001 +Adam.weight_decay = 1e-6 + +CSDI.time_step_embedding_size = 64 +CSDI.feature_embedding_size = 64 +CSDI.unconditional = False +CSDI.target_strategy = "hist" +CSDI.num_diffusion_steps = 50 +CSDI.diffusion_step_embedding_dim = 128 +CSDI.n_attention_heads = 8 +CSDI.num_residual_layers = 8 +CSDI.noise_schedule = "quad" +CSDI.beta_start = 0.0001 +CSDI.beta_end = 0.5 +CSDI.conv_channels = 64 +CSDI.n_samples = 15 diff --git a/configs/imputation_models/DiffWave.gin b/configs/imputation_models/DiffWave.gin new file mode 100644 index 00000000..94de2952 --- /dev/null +++ b/configs/imputation_models/DiffWave.gin @@ -0,0 +1,38 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.utils +import icu_benchmarks.data.preprocess +import icu_benchmarks.imputation.diffwave + +# Train params +train_common.model = @DiffWave + +# here you can set some training parameters +train_common.epochs = 2 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 +train_common.use_wandb = True + +ImputationWrapper.optimizer = @Adam +ImputationWrapper.lr_scheduler = "cosine" + +# Optimizer params +Adam.lr = 0.001 +Adam.weight_decay = 1e-6 + +# Model params +DiffWave.in_channels = 6 +DiffWave.res_channels = 256 +DiffWave.out_channels = 6 +DiffWave.skip_channels = 256 +DiffWave.num_res_layers = 36 +DiffWave.dilation_cycle = 12 +DiffWave.diffusion_step_embed_dim_in = 128 +DiffWave.diffusion_step_embed_dim_mid = 512 +DiffWave.diffusion_step_embed_dim_out = 512 + +# Probably also needed +DiffWave.diffusion_time_steps = 1000 +DiffWave.beta_0 = 1e-4 +DiffWave.beta_T = 2e-2 diff --git a/configs/imputation_models/Diffusion.gin b/configs/imputation_models/Diffusion.gin new file mode 100644 index 00000000..28a40a66 --- /dev/null +++ b/configs/imputation_models/Diffusion.gin @@ -0,0 +1,30 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.data.preprocess +import icu_benchmarks.imputation.diffusion + +# Train params +train_common.model = @Diffusion + +# here you can set some training parameters +train_common.epochs = 5 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 +train_common.use_wandb = False + +ImputationWrapper.optimizer = @Adam +ImputationWrapper.lr_scheduler = "cosine" + +# Optimizer params +Adam.lr = 3e-4 +Adam.weight_decay = 1e-6 + +# Model params +Diffusion.n_onedirectional_conv = 3 +Diffusion.T = 300 +Diffusion.min_noise = 0.0001 +Diffusion.max_noise = 0.02 +Diffusion.noise_scheduler = 'linear' \ No newline at end of file diff --git a/configs/imputation_models/GAIN.gin b/configs/imputation_models/GAIN.gin new file mode 100644 index 00000000..40352d50 --- /dev/null +++ b/configs/imputation_models/GAIN.gin @@ -0,0 +1,4 @@ +import icu_benchmarks.imputation.baselines + +# Train params +train_common.model = @GAIN diff --git a/configs/imputation_models/ICE.gin b/configs/imputation_models/ICE.gin new file mode 100644 index 00000000..02b3637d --- /dev/null +++ b/configs/imputation_models/ICE.gin @@ -0,0 +1,4 @@ +import icu_benchmarks.imputation.baselines + +# Train params +train_common.model = @ICE diff --git a/configs/imputation_models/KNN.gin b/configs/imputation_models/KNN.gin new file mode 100644 index 00000000..4d5811d9 --- /dev/null +++ b/configs/imputation_models/KNN.gin @@ -0,0 +1,11 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @KNN + +KNN.n_neighbors = 10 diff --git a/configs/imputation_models/MICE.gin b/configs/imputation_models/MICE.gin new file mode 100644 index 00000000..efc97036 --- /dev/null +++ b/configs/imputation_models/MICE.gin @@ -0,0 +1,14 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @MICE + +MICE.max_iter = 100 +MICE.verbose = 2 +MICE.imputation_order = 'random' +MICE.random_state = 0 \ No newline at end of file diff --git a/configs/imputation_models/MLP.gin b/configs/imputation_models/MLP.gin new file mode 100644 index 00000000..43c52e49 --- /dev/null +++ b/configs/imputation_models/MLP.gin @@ -0,0 +1,25 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.mlp +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @MLP +train_common.epochs = 2 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 + +train_common.optimizer = @Adam +ImputationWrapper.lr_scheduler = "cosine" + +# Optimizer params +Adam.lr = 1e-2 +Adam.weight_decay = 1e-6 + +# Encoder params + +MLP.num_hidden_layers = 5 +MLP.hidden_layer_size = 15 diff --git a/configs/imputation_models/Mean.gin b/configs/imputation_models/Mean.gin new file mode 100644 index 00000000..82cb01f6 --- /dev/null +++ b/configs/imputation_models/Mean.gin @@ -0,0 +1,10 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @Mean +ImputationDataset.ram_cache = False diff --git a/configs/imputation_models/Median.gin b/configs/imputation_models/Median.gin new file mode 100644 index 00000000..814816e3 --- /dev/null +++ b/configs/imputation_models/Median.gin @@ -0,0 +1,9 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @Median diff --git a/configs/imputation_models/Miracle.gin b/configs/imputation_models/Miracle.gin new file mode 100644 index 00000000..c9001223 --- /dev/null +++ b/configs/imputation_models/Miracle.gin @@ -0,0 +1,4 @@ +import icu_benchmarks.imputation.baselines + +# Train params +train_common.model = @Miracle diff --git a/configs/imputation_models/MissForest.gin b/configs/imputation_models/MissForest.gin new file mode 100644 index 00000000..6a6d68d0 --- /dev/null +++ b/configs/imputation_models/MissForest.gin @@ -0,0 +1,4 @@ +import icu_benchmarks.imputation.baselines + +# Train params +train_common.model = @MissForest diff --git a/configs/imputation_models/Miwae.gin b/configs/imputation_models/Miwae.gin new file mode 100644 index 00000000..9deb9dc1 --- /dev/null +++ b/configs/imputation_models/Miwae.gin @@ -0,0 +1,4 @@ +import icu_benchmarks.imputation.baselines + +# Train params +train_common.model = @Miwae diff --git a/configs/imputation_models/MostFrequent.gin b/configs/imputation_models/MostFrequent.gin new file mode 100644 index 00000000..9ed94fc0 --- /dev/null +++ b/configs/imputation_models/MostFrequent.gin @@ -0,0 +1,9 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @MostFrequent diff --git a/configs/imputation_models/NP.gin b/configs/imputation_models/NP.gin new file mode 100644 index 00000000..57568e49 --- /dev/null +++ b/configs/imputation_models/NP.gin @@ -0,0 +1,34 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.np +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @NP + +ImputationWrapper.optimizer = @Adam +ImputationWrapper.lr_scheduler = "cosine" + +Adam.lr = 0.1 +Adam.weight_decay = 1e-6 + +# Model params +# Deterministic encoder params +NP.encoder_layers = 6 +NP.encoder_h_dim = 72 + +# Decoder params +NP.decoder_layers = 3 +NP.decoder_h_dim = 72 + +# Additional params +NP.r_dim = 12 # Dimension of output representation r +NP.z_dim = 12 # Dimension of latent variable z + +# Sampling params +NP.train_sample_times = 100 +NP.val_sample_times = 500 +NP.test_sample_times = 500 +NP.predict_sample_times = 100 \ No newline at end of file diff --git a/configs/imputation_models/RNN.gin b/configs/imputation_models/RNN.gin new file mode 100644 index 00000000..5d2731ef --- /dev/null +++ b/configs/imputation_models/RNN.gin @@ -0,0 +1,25 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.rnn +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @RNN +train_common.epochs = 100 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 + +train_common.optimizer = @Adam +ImputationWrapper.lr_scheduler = "cosine" + +# Optimizer params +Adam.lr = 1e-2 +Adam.weight_decay = 1e-6 + +# Model params +RNN.cell = 'gru' +RNN.hidden_size = 64 +RNN.state_init = 'zero' \ No newline at end of file diff --git a/configs/imputation_models/SAITS.gin b/configs/imputation_models/SAITS.gin new file mode 100644 index 00000000..1b734576 --- /dev/null +++ b/configs/imputation_models/SAITS.gin @@ -0,0 +1,21 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @SAITS + + + +# Model params +SAITS.n_layers = 4 +SAITS.d_model = 128 +SAITS.d_inner = 128 +SAITS.n_head = 4 +SAITS.d_k = 32 +SAITS.d_v = 32 +SAITS.dropout = 0.1 +SAITS.epochs = 100 diff --git a/configs/imputation_models/SSSDS4.gin b/configs/imputation_models/SSSDS4.gin new file mode 100644 index 00000000..ffe2cc7b --- /dev/null +++ b/configs/imputation_models/SSSDS4.gin @@ -0,0 +1,38 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.utils +import icu_benchmarks.data.preprocess +import icu_benchmarks.imputation.sssds4 + +# Train params +train_common.model = @SSSDS4 + +# here you can set some training parameters +train_common.epochs = 1000 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 +train_common.use_wandb = True + +ImputationWrapper.optimizer = @Adam +ImputationWrapper.lr_scheduler = "" + +# Optimizer params +Adam.lr = 0.1 +Adam.weight_decay = 1e-6 + +# Model params +SSSDS4.res_channels = 64 +SSSDS4.skip_channels = 256 +SSSDS4.num_res_layers = 36 +SSSDS4.diffusion_step_embed_dim_in = 256 +SSSDS4.diffusion_step_embed_dim_mid = 128 +SSSDS4.diffusion_step_embed_dim_out = 256 +SSSDS4.s4_lmax = 100 +SSSDS4.s4_d_state = 64 +SSSDS4.s4_dropout = 0.3 +SSSDS4.s4_bidirectional = False +SSSDS4.s4_layernorm = False +SSSDS4.diffusion_time_steps = 2000 +SSSDS4.beta_0 = 1e-4 +SSSDS4.beta_T = 2e-2 diff --git a/configs/imputation_models/SSSDSA.gin b/configs/imputation_models/SSSDSA.gin new file mode 100644 index 00000000..ea481931 --- /dev/null +++ b/configs/imputation_models/SSSDSA.gin @@ -0,0 +1,47 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.utils +import icu_benchmarks.data.preprocess +import icu_benchmarks.imputation.sssdsa + +# Train params +train_common.model = @SSSDSA + +# here you can set some training parameters +train_common.epochs = 2 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 +train_common.use_wandb = True + +ImputationWrapper.optimizer = @Adam +ImputationWrapper.lr_scheduler = "cosine" + +# Optimizer params +Adam.lr = 0.001 +Adam.weight_decay = 1e-6 + +# Model params +SSSDSA.d_model = 64 +SSSDSA.n_layers = 6 +SSSDSA.pool = [2, 2] +SSSDSA.expand = 2 +SSSDSA.ff = 2 +SSSDSA.glu = True +SSSDSA.unet = True +SSSDSA.dropout = 0.0 +SSSDSA.in_channels = 6 +SSSDSA.out_channels = 6 +SSSDSA.diffusion_step_embed_dim_in = 128 +SSSDSA.diffusion_step_embed_dim_mid = 512 +SSSDSA.diffusion_step_embed_dim_out = 512 +SSSDSA.label_embed_dim = 128 +SSSDSA.label_embed_classes = 71 +SSSDSA.bidirectional = True +SSSDSA.s4_lmax = 1 +SSSDSA.s4_d_state = 64 +SSSDSA.s4_dropout = 0.0 +SSSDSA.s4_bidirectional = True +SSSDSA.diffusion_time_steps = 1000 +SSSDSA.beta_0 = 1e-4 +SSSDSA.beta_T = 2e-2 \ No newline at end of file diff --git a/configs/imputation_models/Simple_Diffusion.gin b/configs/imputation_models/Simple_Diffusion.gin new file mode 100644 index 00000000..1571bc60 --- /dev/null +++ b/configs/imputation_models/Simple_Diffusion.gin @@ -0,0 +1,23 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.data.preprocess +import icu_benchmarks.imputation.simple_diffusion + +# Train params +train_common.model = @Simple_Diffusion + +# here you can set some training parameters +train_common.epochs = 10 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 +train_common.use_wandb = True + +ImputationWrapper.optimizer = @Adam +ImputationWrapper.lr_scheduler = "cosine" + +# Optimizer params +Adam.lr = 3e-4 +Adam.weight_decay = 1e-6 diff --git a/configs/imputation_models/Sinkhorn.gin b/configs/imputation_models/Sinkhorn.gin new file mode 100644 index 00000000..d4cf8faa --- /dev/null +++ b/configs/imputation_models/Sinkhorn.gin @@ -0,0 +1,4 @@ +import icu_benchmarks.imputation.baselines + +# Train params +train_common.model = @Sinkhorn diff --git a/configs/imputation_models/SoftImpute.gin b/configs/imputation_models/SoftImpute.gin new file mode 100644 index 00000000..d6e5b04f --- /dev/null +++ b/configs/imputation_models/SoftImpute.gin @@ -0,0 +1,4 @@ +import icu_benchmarks.imputation.baselines + +# Train params +train_common.model = @SoftImpute diff --git a/configs/imputation_models/Zero.gin b/configs/imputation_models/Zero.gin new file mode 100644 index 00000000..d7bffee4 --- /dev/null +++ b/configs/imputation_models/Zero.gin @@ -0,0 +1,9 @@ +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.imputation.baselines +import icu_benchmarks.data.preprocess + +# Train params +train_common.model = @Zero diff --git a/configs/models/GRU.gin b/configs/models/GRU.gin deleted file mode 100644 index 6de637e4..00000000 --- a/configs/models/GRU.gin +++ /dev/null @@ -1,33 +0,0 @@ -import gin.torch.external_configurables -import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders -import icu_benchmarks.models.utils - -default_preprocessor.generate_features = False - -# Train params -train_common.model = @DLWrapper() - -DLWrapper.encoder = @GRUNet() -DLWrapper.optimizer_fn = @Adam - -DLWrapper.train.epochs = 1000 -DLWrapper.train.batch_size = 64 -DLWrapper.train.patience = 10 -DLWrapper.train.min_delta = 1e-4 - -# Optimizer params -optimizer/hyperparameter.class_to_tune = @Adam -optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) - -# Encoder params -model/hyperparameter.class_to_tune = @GRUNet -model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) -model/hyperparameter.layer_dim = (1, 3) - -tune_hyperparameters.scopes = ["model", "optimizer"] -tune_hyperparameters.n_initial_points = 5 -tune_hyperparameters.n_calls = 30 -tune_hyperparameters.folds_to_tune_on = 2 diff --git a/configs/models/LGBMClassifier.gin b/configs/models/LGBMClassifier.gin deleted file mode 100644 index e50783ad..00000000 --- a/configs/models/LGBMClassifier.gin +++ /dev/null @@ -1,26 +0,0 @@ -import gin.torch.external_configurables -import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders -import icu_benchmarks.models.utils - -default_preprocessor.generate_features = True - -# Train params -train_common.model = @MLWrapper() - -MLWrapper.model = @LGBMClassifier() -MLWrapper.train.patience = 10 - -model/hyperparameter.class_to_tune = @LGBMClassifier -model/hyperparameter.colsample_bytree = (0.33, 1.0) -model/hyperparameter.max_depth = (3, 7) -model/hyperparameter.min_child_samples = 1000 -model/hyperparameter.n_estimators = 100000 -model/hyperparameter.num_leaves = (8, 128, "log-uniform", 2) -model/hyperparameter.subsample = (0.33, 1.0) -model/hyperparameter.subsample_freq = 1 - -tune_hyperparameters.scopes = ["model"] -tune_hyperparameters.n_initial_points = 10 -tune_hyperparameters.n_calls = 50 -tune_hyperparameters.folds_to_tune_on = 3 diff --git a/configs/models/LGBMRegressor.gin b/configs/models/LGBMRegressor.gin deleted file mode 100644 index c1c119ee..00000000 --- a/configs/models/LGBMRegressor.gin +++ /dev/null @@ -1,26 +0,0 @@ -import gin.torch.external_configurables -import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders -import icu_benchmarks.models.utils - -default_preprocessor.generate_features = True - -# Train params -train_common.model = @MLWrapper() - -MLWrapper.model = @LGBMRegressor() -MLWrapper.train.patience = 10 - -model/hyperparameter.class_to_tune = @LGBMRegressor -model/hyperparameter.colsample_bytree = (0.33, 1.0) -model/hyperparameter.max_depth = (3, 7) -model/hyperparameter.min_child_samples = 1000 -model/hyperparameter.n_estimators = 100000 -model/hyperparameter.num_leaves = (8, 128, "log-uniform", 2) -model/hyperparameter.subsample = (0.33, 1.0) -model/hyperparameter.subsample_freq = 1 - -tune_hyperparameters.scopes = ["model"] -tune_hyperparameters.n_initial_points = 10 -tune_hyperparameters.n_calls = 250 -tune_hyperparameters.folds_to_tune_on = 3 diff --git a/configs/models/LSTM.gin b/configs/models/LSTM.gin deleted file mode 100644 index 7a12de7a..00000000 --- a/configs/models/LSTM.gin +++ /dev/null @@ -1,31 +0,0 @@ -import gin.torch.external_configurables -import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders - -default_preprocessor.generate_features = False - -# Train params -train_common.model = @DLWrapper() - -DLWrapper.encoder = @LSTMNet() -DLWrapper.optimizer_fn = @Adam -DLWrapper.train.epochs = 1000 -DLWrapper.train.batch_size = 64 # FIXME bug with tensor dimensions on Apple MPS leads to BS = input length constraint -DLWrapper.train.patience = 10 -DLWrapper.train.min_delta = 1e-4 - -# Optimizer params -optimizer/hyperparameter.class_to_tune = @Adam -optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) - -# Encoder params -model/hyperparameter.class_to_tune = @LSTMNet -model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) -model/hyperparameter.layer_dim = (1, 3) - -tune_hyperparameters.scopes = ["model", "optimizer"] -tune_hyperparameters.n_initial_points = 5 -tune_hyperparameters.n_calls = 30 -tune_hyperparameters.folds_to_tune_on = 2 \ No newline at end of file diff --git a/configs/models/LogisticRegression.gin b/configs/models/LogisticRegression.gin deleted file mode 100644 index 7961d7e5..00000000 --- a/configs/models/LogisticRegression.gin +++ /dev/null @@ -1,25 +0,0 @@ -import gin.torch.external_configurables -import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders -import icu_benchmarks.models.utils - -default_preprocessor.generate_features = True - -# Train params -train_common.model = @MLWrapper() - -MLWrapper.model = @LogisticRegression() -MLWrapper.train.patience = 10 - -model/hyperparameter.class_to_tune = @LogisticRegression -model/hyperparameter.solver = "saga" -model/hyperparameter.n_jobs = 8 -model/hyperparameter.max_iter = 100000 -model/hyperparameter.C = (1e-3, 1e1, "log-uniform") -model/hyperparameter.penalty = ["l1", "l2", "elasticnet"] -model/hyperparameter.l1_ratio = (0.0, 1.0) - -tune_hyperparameters.scopes = ["model"] -tune_hyperparameters.n_initial_points = 10 -tune_hyperparameters.n_calls = 50 -tune_hyperparameters.folds_to_tune_on = 3 diff --git a/configs/models/TCN.gin b/configs/models/TCN.gin deleted file mode 100644 index 84ee30c8..00000000 --- a/configs/models/TCN.gin +++ /dev/null @@ -1,35 +0,0 @@ -import gin.torch.external_configurables -import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders -import icu_benchmarks.models.utils - -default_preprocessor.generate_features = False - -# Train params -train_common.model = @DLWrapper() - -DLWrapper.encoder = @TemporalConvNet() -DLWrapper.optimizer_fn = @Adam - -DLWrapper.train.epochs = 1000 -DLWrapper.train.batch_size = 64 -DLWrapper.train.patience = 10 -DLWrapper.train.min_delta = 1e-4 - -# Optimizer params -optimizer/hyperparameter.class_to_tune = @Adam -optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) - -# Encoder params -model/hyperparameter.class_to_tune = @TemporalConvNet -model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.max_seq_length = %HORIZON -model/hyperparameter.num_channels = (32, 256, "log-uniform", 2) -model/hyperparameter.kernel_size = (2, 32, "log-uniform", 2) -model/hyperparameter.dropout = (0.0, 0.4) - -tune_hyperparameters.scopes = ["model", "optimizer"] -tune_hyperparameters.n_initial_points = 5 -tune_hyperparameters.n_calls = 30 -tune_hyperparameters.folds_to_tune_on = 2 \ No newline at end of file diff --git a/configs/models/Transformer.gin b/configs/models/Transformer.gin deleted file mode 100644 index d04c4daf..00000000 --- a/configs/models/Transformer.gin +++ /dev/null @@ -1,39 +0,0 @@ -import gin.torch.external_configurables -import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders -import icu_benchmarks.models.utils -import icu_benchmarks.data.loader - -default_preprocessor.generate_features = False - -# Train params -train_common.model = @DLWrapper() - -DLWrapper.encoder = @Transformer() -DLWrapper.optimizer_fn = @Adam - -DLWrapper.train.epochs = 1000 -DLWrapper.train.batch_size = 64 -DLWrapper.train.patience = 10 -DLWrapper.train.min_delta = 1e-4 - -# Optimizer params -optimizer/hyperparameter.class_to_tune = @Adam -optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) - -# Encoder params -model/hyperparameter.class_to_tune = @Transformer -model/hyperparameter.ff_hidden_mult = 2 -model/hyperparameter.l1_reg = 0.0 -model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.hidden = (32, 256, "log-uniform", 2) -model/hyperparameter.heads = (1, 8, "log-uniform", 2) -model/hyperparameter.depth = (1, 3) -model/hyperparameter.dropout = (0.0, 0.4) -model/hyperparameter.dropout_att = (0.0, 0.4) - -tune_hyperparameters.scopes = ["model", "optimizer"] -tune_hyperparameters.n_initial_points = 5 -tune_hyperparameters.n_calls = 30 -tune_hyperparameters.folds_to_tune_on = 2 diff --git a/configs/prediction_models/ElasticNet.gin b/configs/prediction_models/ElasticNet.gin new file mode 100644 index 00000000..9de43cac --- /dev/null +++ b/configs/prediction_models/ElasticNet.gin @@ -0,0 +1,15 @@ +# Settings for ElasticNet model. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @ElasticNet + +model/hyperparameter.class_to_tune = @ElasticNet +#model/hyperparameter.solver = "saga" +model/hyperparameter.n_jobs = 8 +model/hyperparameter.max_iter = 10000 +model/hyperparameter.alpha = (1e-2, 1e1, "log-uniform") +model/hyperparameter.tol = (1e-5, 1e-1, "log-uniform") +model/hyperparameter.l1_ratio = (0.0, 1.0) diff --git a/configs/prediction_models/GRU.gin b/configs/prediction_models/GRU.gin new file mode 100644 index 00000000..d2a28a79 --- /dev/null +++ b/configs/prediction_models/GRU.gin @@ -0,0 +1,19 @@ +# Settings for Gated Recurrent Unit (GRU) model. + +#Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Train params +train_common.model = @GRUNet + +# Optimizer params +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @GRUNet +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) +model/hyperparameter.layer_dim = (1, 3) + diff --git a/configs/prediction_models/LGBMClassifier.gin b/configs/prediction_models/LGBMClassifier.gin new file mode 100644 index 00000000..b7bfec9a --- /dev/null +++ b/configs/prediction_models/LGBMClassifier.gin @@ -0,0 +1,16 @@ +# Settings for Light Gradient Boosting Machine (LGBM) classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @LGBMClassifier + +model/hyperparameter.class_to_tune = @LGBMClassifier +model/hyperparameter.colsample_bytree = (0.33, 1.0) +model/hyperparameter.max_depth = (3, 7) +model/hyperparameter.min_child_samples = 1000 +model/hyperparameter.n_estimators = 100000 +model/hyperparameter.num_leaves = (8, 128, "log-uniform", 2) +model/hyperparameter.subsample = (0.33, 1.0) +model/hyperparameter.subsample_freq = 1 diff --git a/configs/prediction_models/LGBMRegressor.gin b/configs/prediction_models/LGBMRegressor.gin new file mode 100644 index 00000000..d677fa48 --- /dev/null +++ b/configs/prediction_models/LGBMRegressor.gin @@ -0,0 +1,17 @@ +# Settings for Light Gradient Boosting Machine (LGBM) regressor. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @LGBMRegressor + +model/hyperparameter.class_to_tune = @LGBMRegressor +model/hyperparameter.colsample_bytree = (0.33, 1.0) +model/hyperparameter.max_depth = (3, 7) +model/hyperparameter.min_child_samples = 1000 +model/hyperparameter.n_estimators = 100000 +model/hyperparameter.num_leaves = (8, 128, "log-uniform", 2) +model/hyperparameter.subsample = (0.33, 1.0) +model/hyperparameter.subsample_freq = 1 +model/hyperparameter.eval_metric = "logloss" diff --git a/configs/prediction_models/LSTM.gin b/configs/prediction_models/LSTM.gin new file mode 100644 index 00000000..b6390841 --- /dev/null +++ b/configs/prediction_models/LSTM.gin @@ -0,0 +1,20 @@ +# Settings for Long Short-Term Memory (LSTM) model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Train params +train_common.model = @LSTMNet + +# Optimizer params +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @LSTMNet +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) +model/hyperparameter.layer_dim = (1, 3) + + diff --git a/configs/models/LocalTransformer.gin b/configs/prediction_models/LocalTransformer.gin similarity index 52% rename from configs/models/LocalTransformer.gin rename to configs/prediction_models/LocalTransformer.gin index 39d84eed..69ae1fe5 100644 --- a/configs/models/LocalTransformer.gin +++ b/configs/prediction_models/LocalTransformer.gin @@ -1,19 +1,11 @@ -import gin.torch.external_configurables -import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders -import icu_benchmarks.models.utils +# Settings for Local Transformer Model. -default_preprocessor.generate_features = False +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" # Train params -train_common.model = @DLWrapper() -DLWrapper.encoder = @LocalTransformer() -DLWrapper.optimizer_fn = @Adam -DLWrapper.train.epochs = 1000 -DLWrapper.train.batch_size = 64 -DLWrapper.train.patience = 10 -DLWrapper.train.min_delta = 1e-4 +train_common.model = @LocalTransformer # Optimizer params optimizer/hyperparameter.class_to_tune = @Adam @@ -31,8 +23,3 @@ model/hyperparameter.heads = (1, 8, "log-uniform", 2) model/hyperparameter.depth = (1, 3) model/hyperparameter.dropout = (0.0, 0.4) model/hyperparameter.dropout_att = (0.0, 0.4) - -tune_hyperparameters.scopes = ["model", "optimizer"] -tune_hyperparameters.n_initial_points = 5 -tune_hyperparameters.n_calls = 30 -tune_hyperparameters.folds_to_tune_on = 2 \ No newline at end of file diff --git a/configs/prediction_models/LogisticRegression.gin b/configs/prediction_models/LogisticRegression.gin new file mode 100644 index 00000000..86ff89db --- /dev/null +++ b/configs/prediction_models/LogisticRegression.gin @@ -0,0 +1,18 @@ +# Settings for Logistic Regression model. + + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @LogisticRegression +MLWrapper.patience = 10 + +model/hyperparameter.class_to_tune = @LogisticRegression +model/hyperparameter.solver = "saga" +model/hyperparameter.n_jobs = 8 +model/hyperparameter.max_iter = 100000 +model/hyperparameter.C = (1e-3, 1e1, "log-uniform") +model/hyperparameter.penalty = ["l1", "l2", "elasticnet"] +model/hyperparameter.l1_ratio = (0.0, 1.0) + diff --git a/configs/prediction_models/RFClassifier.gin b/configs/prediction_models/RFClassifier.gin new file mode 100644 index 00000000..72d03e66 --- /dev/null +++ b/configs/prediction_models/RFClassifier.gin @@ -0,0 +1,18 @@ +# Settings for Random Forest Classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @RFClassifier + +model/hyperparameter.class_to_tune = @RFClassifier +model/hyperparameter.n_estimators = (10, 50, 100, 200, 500) +model/hyperparameter.max_depth = (None, 5, 10, 20) +model/hyperparameter.min_samples_split = (2, 5, 10) +model/hyperparameter.min_samples_leaf = (1, 2, 4) +model/hyperparameter.max_features = ('sqrt', 'log2', None) +model/hyperparameter.bootstrap = (True, False) +model/hyperparameter.class_weight = (None, 'balanced') + + diff --git a/configs/prediction_models/RNN.gin b/configs/prediction_models/RNN.gin new file mode 100644 index 00000000..531aeff6 --- /dev/null +++ b/configs/prediction_models/RNN.gin @@ -0,0 +1,18 @@ +# Settings for Recurrent Neural Network (RNN) models. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Train params +train_common.model = @RNNet + +# Optimizer params +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @RNNet +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) +model/hyperparameter.layer_dim = (1, 3) diff --git a/configs/prediction_models/TCN.gin b/configs/prediction_models/TCN.gin new file mode 100644 index 00000000..c6b314db --- /dev/null +++ b/configs/prediction_models/TCN.gin @@ -0,0 +1,20 @@ +# Settings for Temporal Convolutional Network (TCN) model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Train params +train_common.model = @TemporalConvNet + +# Optimizer params +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @TemporalConvNet +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.max_seq_length = %HORIZON +model/hyperparameter.num_channels = (32, 256, "log-uniform", 2) +model/hyperparameter.kernel_size = (2, 32, "log-uniform", 2) +model/hyperparameter.dropout = (0.0, 0.4) diff --git a/configs/prediction_models/Transformer.gin b/configs/prediction_models/Transformer.gin new file mode 100644 index 00000000..2767fd37 --- /dev/null +++ b/configs/prediction_models/Transformer.gin @@ -0,0 +1,24 @@ +# Settings for Transformer model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Optimizer params +train_common.model = @Transformer + +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @Transformer +model/hyperparameter.ff_hidden_mult = 2 +model/hyperparameter.l1_reg = 0.0 +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.hidden = (32, 256, "log-uniform", 2) +model/hyperparameter.heads = (1, 8, "log-uniform", 2) +model/hyperparameter.depth = (1, 3) +model/hyperparameter.dropout = (0.0, 0.4) +model/hyperparameter.dropout_att = (0.0, 0.4) + + diff --git a/configs/prediction_models/common/DLCommon.gin b/configs/prediction_models/common/DLCommon.gin new file mode 100644 index 00000000..c220e6ab --- /dev/null +++ b/configs/prediction_models/common/DLCommon.gin @@ -0,0 +1,21 @@ +# Common settings for DL models + +# Imports to register the models +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils + +# Do not generate features from dynamic data +base_classification_preprocessor.generate_features = False +base_regression_preprocessor.generate_features = False + +# Train params +train_common.optimizer = @Adam +train_common.epochs = 1000 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 + +# Hyperparameter tuning settings +include "configs/prediction_models/common/DLTuning.gin" \ No newline at end of file diff --git a/configs/prediction_models/common/DLTuning.gin b/configs/prediction_models/common/DLTuning.gin new file mode 100644 index 00000000..b4d13e12 --- /dev/null +++ b/configs/prediction_models/common/DLTuning.gin @@ -0,0 +1,5 @@ +# Hyperparameter tuner settings for Deep Learning. +tune_hyperparameters.scopes = ["model", "optimizer"] +tune_hyperparameters.n_initial_points = 5 +tune_hyperparameters.n_calls = 30 +tune_hyperparameters.folds_to_tune_on = 2 \ No newline at end of file diff --git a/configs/prediction_models/common/MLCommon.gin b/configs/prediction_models/common/MLCommon.gin new file mode 100644 index 00000000..460bceba --- /dev/null +++ b/configs/prediction_models/common/MLCommon.gin @@ -0,0 +1,17 @@ +# Common settings for ML models + +# Imports to register the models +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.ml_models +import icu_benchmarks.models.utils + +# Patience for early stopping +MLWrapper.patience = 10 + +# Generate features from dynamic data +base_classification_preprocessor.generate_features = True +base_regression_preprocessor.generate_features = True + +# Hyperparameter tuning settings +include "configs/prediction_models/common/MLTuning.gin" \ No newline at end of file diff --git a/configs/prediction_models/common/MLTuning.gin b/configs/prediction_models/common/MLTuning.gin new file mode 100644 index 00000000..c582a02d --- /dev/null +++ b/configs/prediction_models/common/MLTuning.gin @@ -0,0 +1,5 @@ +# Hyperparameter tuner settings for classical Machine Learning. +tune_hyperparameters.scopes = ["model"] +tune_hyperparameters.n_initial_points = 10 +tune_hyperparameters.n_calls = 50 +tune_hyperparameters.folds_to_tune_on = 3 \ No newline at end of file diff --git a/configs/tasks/BinaryClassification.gin b/configs/tasks/BinaryClassification.gin index 65752c15..fe1e33f3 100644 --- a/configs/tasks/BinaryClassification.gin +++ b/configs/tasks/BinaryClassification.gin @@ -1,48 +1,32 @@ -import icu_benchmarks.data.preprocess +import icu_benchmarks.data.split_process_data import icu_benchmarks.data.loader import icu_benchmarks.models.wrappers -import icu_benchmarks.models.encoders +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.ml_models -# CLASSIFICATION -NUM_CLASSES = 2 -HORIZON = 24 +# DATASET CONFIGURATION +include "configs/tasks/common/PredictionTaskVariables.gin" + +# CROSS-VALIDATION +include "configs/tasks/common/CrossValidation.gin" +# MODE SETTINGS +Run.mode = "Classification" +NUM_CLASSES = 2 # Binary classification +HORIZON = 24 train_common.weight = "balanced" +train_common.ram_cache = True # DEEP LEARNING -DLWrapper.loss = @cross_entropy - -# DATASET AND PREPROCESSING -preprocess.file_names = { - "DYNAMIC": "dyn.parquet", - "OUTCOME": "outc.parquet", - "STATIC": "sta.parquet", -} - -vars = { - "GROUP": "stay_id", - "LABEL": "label", - "SEQUENCE": "time", - "DYNAMIC": ["alb", "alp", "alt", "ast", "be", "bicar", "bili", - "bili_dir", "bnd", "bun", "ca", "cai", "ck", "ckmb", "cl", "crea", - "crp", "dbp", "fgn", "fio2", "glu", "hgb", "hr", "inr_pt", "k", "lact", - "lymph", "map", "mch", "mchc", "mcv", "methb", "mg", "na", "neut", - "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", - "temp", "tnt", "urine", "wbc"], - "STATIC": ["age", "sex", "height", "weight"], -} -static_features = True +DLPredictionWrapper.loss = @cross_entropy +# SELECTING PREPROCESSOR +preprocess.preprocessor = @base_classification_preprocessor preprocess.vars = %vars -preprocess.use_static = %static_features - -default_preprocessor.use_static_features = %static_features -default_preprocessor.vars = %vars - -Dataset.vars = %vars +preprocess.use_static = True -# CROSS VALIDATION +# SELECTING DATASET +PredictionDataset.vars = %vars +PredictionDataset.ram_cache = True -execute_repeated_cv.cv_repetitions = 5 -execute_repeated_cv.cv_folds = 5 diff --git a/configs/tasks/DatasetImputation.gin b/configs/tasks/DatasetImputation.gin new file mode 100644 index 00000000..452649c6 --- /dev/null +++ b/configs/tasks/DatasetImputation.gin @@ -0,0 +1,29 @@ +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.ml_models +import icu_benchmarks.models.dl_models +import icu_benchmarks.data.split_process_data + +# CROSS-VALIDATION +include "configs/tasks/common/CrossValidation.gin" + +Run.mode = "Imputation" + +# DATASET STRUCTURE +vars = { + "GROUP": "stay_id", + "SEQUENCE": "time", + "DYNAMIC": ["hr","map","sbp", "dbp", "resp", "o2sat"], + "STATIC": ["age", "sex", "height", "weight"], +} + +preprocess.file_names = { + "DYNAMIC": "dyn.parquet", + "STATIC": "sta.parquet", +} + +preprocess.preprocessor = @base_imputation_preprocessor + +preprocess.vars = %vars +ImputationDataset.vars = %vars +ImputationDataset.ram_cache = True + diff --git a/configs/tasks/Regression.gin b/configs/tasks/Regression.gin new file mode 100644 index 00000000..76516ad9 --- /dev/null +++ b/configs/tasks/Regression.gin @@ -0,0 +1,36 @@ +import icu_benchmarks.data.split_process_data +import icu_benchmarks.data.loader +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.ml_models + +# DATASET CONFIGURATION +include "configs/tasks/common/PredictionTaskVariables.gin" + +# CROSS-VALIDATION +include "configs/tasks/common/CrossValidation.gin" + +# MODE SETTINGS +Run.mode = "Regression" +NUM_CLASSES = 1 +HORIZON = 24 +train_common.weight = "balanced" +train_common.ram_cache = True + +# LOSS FUNCTION +DLPredictionWrapper.loss = @mse_loss +MLWrapper.loss = @mean_squared_error + +# SELECTING PREPROCESSOR +preprocess.preprocessor = @base_regression_preprocessor +preprocess.vars = %vars +preprocess.use_static = True + +# SPECIFYING REGRESSION OUTCOME SCALING +base_regression_preprocessor.outcome_min = 0 +base_regression_preprocessor.outcome_max = 15 + +# SELECTING DATASET +PredictionDataset.vars = %vars +PredictionDataset.ram_cache = True + diff --git a/configs/tasks/common/CrossValidation.gin b/configs/tasks/common/CrossValidation.gin new file mode 100644 index 00000000..f3efd9ef --- /dev/null +++ b/configs/tasks/common/CrossValidation.gin @@ -0,0 +1,3 @@ +# CROSS-VALIDATION SETTINGS +execute_repeated_cv.cv_repetitions = 5 +execute_repeated_cv.cv_folds = 5 \ No newline at end of file diff --git a/configs/tasks/common/PredictionTaskVariables.gin b/configs/tasks/common/PredictionTaskVariables.gin new file mode 100644 index 00000000..6e38638e --- /dev/null +++ b/configs/tasks/common/PredictionTaskVariables.gin @@ -0,0 +1,18 @@ +# DEFAULT DATASET CONFIGURATION + +preprocess.file_names = { + "DYNAMIC": "dyn.parquet", + "OUTCOME": "outc.parquet", + "STATIC": "sta.parquet", +} + +vars = { + "GROUP": "stay_id", + "SEQUENCE": "time", + "LABEL": "label", + "DYNAMIC": ["alb", "alp", "alt", "ast", "be", "bicar", "bili", "bili_dir", "bnd", "bun", "ca", "cai", "ck", "ckmb", "cl", + "crea", "crp", "dbp", "fgn", "fio2", "glu", "hgb", "hr", "inr_pt", "k", "lact", "lymph", "map", "mch", "mchc", "mcv", + "methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine", + "wbc"], + "STATIC": ["age", "sex", "height", "weight"], +} \ No newline at end of file diff --git a/demo_data/kf/eicu_demo/dyn.parquet b/demo_data/kf/eicu_demo/dyn.parquet new file mode 100644 index 00000000..e69de29b diff --git a/demo_data/kf/eicu_demo/outc.parquet b/demo_data/kf/eicu_demo/outc.parquet new file mode 100644 index 00000000..e69de29b diff --git a/demo_data/kf/eicu_demo/sta.parquet b/demo_data/kf/eicu_demo/sta.parquet new file mode 100644 index 00000000..e69de29b diff --git a/demo_data/kf/mimic_demo/dyn.parquet b/demo_data/kf/mimic_demo/dyn.parquet new file mode 100644 index 00000000..e69de29b diff --git a/demo_data/kf/mimic_demo/outc.parquet b/demo_data/kf/mimic_demo/outc.parquet new file mode 100644 index 00000000..e69de29b diff --git a/demo_data/kf/mimic_demo/sta.parquet b/demo_data/kf/mimic_demo/sta.parquet new file mode 100644 index 00000000..e69de29b diff --git a/demo_data/los/eicu_demo/dyn.parquet b/demo_data/los/eicu_demo/dyn.parquet new file mode 100644 index 00000000..745ce043 Binary files /dev/null and b/demo_data/los/eicu_demo/dyn.parquet differ diff --git a/demo_data/los/eicu_demo/outc.parquet b/demo_data/los/eicu_demo/outc.parquet new file mode 100644 index 00000000..d6d9f16f Binary files /dev/null and b/demo_data/los/eicu_demo/outc.parquet differ diff --git a/demo_data/los/eicu_demo/sta.parquet b/demo_data/los/eicu_demo/sta.parquet new file mode 100644 index 00000000..83821214 Binary files /dev/null and b/demo_data/los/eicu_demo/sta.parquet differ diff --git a/demo_data/los/mimic_demo/dyn.parquet b/demo_data/los/mimic_demo/dyn.parquet new file mode 100644 index 00000000..d623a6c2 Binary files /dev/null and b/demo_data/los/mimic_demo/dyn.parquet differ diff --git a/demo_data/los/mimic_demo/outc.parquet b/demo_data/los/mimic_demo/outc.parquet new file mode 100644 index 00000000..d1bf3741 Binary files /dev/null and b/demo_data/los/mimic_demo/outc.parquet differ diff --git a/demo_data/los/mimic_demo/sta.parquet b/demo_data/los/mimic_demo/sta.parquet new file mode 100644 index 00000000..b7250c2f Binary files /dev/null and b/demo_data/los/mimic_demo/sta.parquet differ diff --git a/docs/development.md b/docs/development.md new file mode 100644 index 00000000..0ff1c738 --- /dev/null +++ b/docs/development.md @@ -0,0 +1,41 @@ +# Development + +YAIB is in active development. The following sections could be relevant for adding new code to our repository + +## Libraries + +The following libraries are important to the operation of YAIB: + +- [Pandas](https://github.com/pandas-dev/pandas): Popular data structure framework. +- [ReciPys](https://github.com/rvandewater/recipys): A modular preprocessing package for Pandas dataframes. +- [Pytorch](https://pytorch.org/): An open source machine learning framework for deep learning applications. +- [Pytorch Lightning](https://www.pytorchlightning.ai/): A lightweight Pytorch wrapper for AI research. +- [Pytorch Ignite](https://github.com/pytorch/ignite): Library for training and evaluating neural networks in Pytorch. +- [Cuda Toolkit](https://developer.nvidia.com/cuda-toolkit): GPU acceleration used for deep learning models. +- [Scikit-learn](https://github.com/scikit-learn/scikit-learn): Machine learning library. +- [Scikit-optimize](https://scikit-optimize.github.io/stable/): Used for Bayesian optimization. +- [LightGBM](https://github.com/microsoft/LightGBM): Gradient boosting framework. +- [GIN](https://github.com/google/gin-config): Provides a lightweight configuration framework for Python. +- [Wandb](https://wandb.ai/): A tool for visualizing and tracking machine learning experiments. +- [Pytest](https://docs.pytest.org/en/stable/): A testing framework for Python. +### Imputation +- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute): Imputation library for MissForest and GAIN. +- [PyPOTS](https://github.com/WenjieDu/PyPOTS): Imputation library. +## Run Tests + +``` +python -m pytest ./tests/recipes +coverage run -m pytest ./tests/recipes +# then use either of the following +coverage report +coverage html +``` + +## Autoformat and lint + +For development purposes, we use the `Black` package to autoformat our code and a `Flake8` Linting/CI check: + +``` +black . -l 127 +flake8 . --count --max-complexity=14 --max-line-length=127 --statistics +``` diff --git a/docs/figures/results_yaib.png b/docs/figures/results_yaib.png new file mode 100644 index 00000000..a170a75e Binary files /dev/null and b/docs/figures/results_yaib.png differ diff --git a/docs/figures/yaib_flow_combined.svg b/docs/figures/yaib_flow_combined.svg new file mode 100644 index 00000000..6c61f2c8 --- /dev/null +++ b/docs/figures/yaib_flow_combined.svg @@ -0,0 +1 @@ +Training PredictorBayesianHyperparameter OptimizationModelsLRLGBMRFRNNLSTMGRUTCNTransformerSelectYet Another ICU BenchmarkTaskDefinitionModel Settings & Hyperparameters Pre-processorConfigurationPre-processingScalingMissing IndicatorsImputationFeature ExtractionSplittingFold splittingCachingCache data for reuseMetricsClassificationROCAUROCPRCAUPRCAccuracyCalibration CurveRegressionMAER2RMSEConfig(ML)ResearcherDefineAUMCdbMIMIC-III/IVeICUHiRIDCompatible DataHarmonizationClinical ExpertDataset XDefineHarmonizeCohort and Variable SelectionSepsisAKIMortalityLoSKFVariable MappingArtifact RemovalUnit Harmonization \ No newline at end of file diff --git a/docs/imputation_methods.md b/docs/imputation_methods.md new file mode 100644 index 00000000..2eead00b --- /dev/null +++ b/docs/imputation_methods.md @@ -0,0 +1,89 @@ + +# Adding new Imputation Models + +To add another imputation model, you have to create a class that inherits from `ImputationWrapper` in `icu_benchmarks.models.wrappers`. Your model class should look like this: + +```python +from icu_benchmarks.models.wrappers import ImputationWrapper +import gin + +@gin.configurable("newmethod") +class New_Method(ImputationWrapper): + + # adjust this accordingly + needs_training = False # if true, the method is trained iteratively (like a deep learning model) + needs_fit = True # if true, it receives the complete training data to perform a fit on + + def __init__(self, *args, model_arg1, model_arg2, **kwargs): + super().__init__(*args, **kwargs) + # define your new model here + self.model = ... + + # the following method has to be implemented for all methods + def forward(self, amputated_values, amputation_mask): + imputated_values = amputated_values + ... + return imputated_values + + # implement this, if needs_fit is true, otherwise you can leave it out. + # this method receives the complete input training data to perform a fit on. + def fit(self, train_data): + ... +``` + +You also need to create a gin configuration file in the `configs/imputation` directory, +named `newmethod.gin` after the name that was entered into the `gin.configurable` decorator call. + +Your `.gin` file should look like this: + +```python +import gin.torch.external_configurables +import icu_benchmarks.models.wrappers +import icu_benchmarks.models.dl_models +import icu_benchmarks.models.utils +import icu_benchmarks.data.split_process_data +# import here the file you created your New_Method class in +import icu_benchmarks.imputation.new_model + +# Train params +train_common.model = + + +@newmethod # change this into the name of the gin configuration file + +# here you can set some training parameters + + +train_common.epochs = 1000 +train_common.batch_size = 64 +train_common.patience = 10 +train_common.min_delta = 1e-4 +train_common.use_wandb = True + +ImputationWrapper.optimizer = + + +@Adam + + +ImputationWrapper.lr_scheduler = "cosine" + +# Optimizer params +Adam.lr = 3e-4 +Adam.weight_decay = 1e-6 + +# here you can set the model parameters you want to configure +newmethod.model_arg1 = 20 +newmethod.model_arg2 = 15 +``` + +You can find further configurations in the `Dataset_Imputation.gin` file in the `configs/tasks/` directory. +To start a training of an imputation method with the newly created imputation method, use the following command: + +```bash +python run.py train -d path/to/preprocessed/data/files -n dataset_name -t Dataset_Imputation -m newmethod +``` + +For the dataset path please enter the path to the directory where the preprocessed `dyn.parquet`, `outc.parquet` and `sta.parquet` are stored. The `dataset_name` is only for logging purposes and breaks nothing if not set correctly. Keep in mind to use the name of the `.gin` config file created for the imputation method as model name for the `-m` parameter. + +For reference for a deep learning based imputation method you can take a look at how the `MLPImputation` method is implemented in `icu_benchmarks/imputation/mlp.py` with its `MLP.gin` configuration file. For reference regarding methods with `needs_fit=True`, take a look at the `icu_benchmarks/imputation/baselines.py` file with several baseline implementations and their corresponding config files in `configs/imputation/`. diff --git a/docs/wanddb.md b/docs/wanddb.md new file mode 100644 index 00000000..223c061f --- /dev/null +++ b/docs/wanddb.md @@ -0,0 +1,30 @@ +# Hyperparameter Optimization using Weights and Biases Sweeps + +[This sweep file](wandb_sweep.yaml) shows an example on how to run a hyperparameter sweep with W&B. The general structure of the YAML should look like this: +``` yaml +program: icu-benchmarks +command: + - ${env} + - ${program} + - "train" +# .... other program parameters .... + - "--wandb-sweep" +method: grid +parameters: + # gin config parameter name: + # values: [a, b, etc...] + # example: + ImputationDataset.mask_method: + values: ["MCAR", "MAR", "MNAR"] +``` + +You can then create a sweep with +``` bash +wandb sweep path/to/sweep_file.yaml +``` +which will give you a sweep id. + +and start an agent to perform the optimization using the following command: +``` bash +wandb agent YOUR_SWEEP_ID +``` \ No newline at end of file diff --git a/environment.yml b/environment.yml index 75065175..849095ad 100644 --- a/environment.yml +++ b/environment.yml @@ -1,25 +1,36 @@ name: yaib channels: - pytorch + - nvidia - conda-forge - anaconda dependencies: - - black=22.10.0 - - coverage=6.5.0 + - python=3.10 + - black=23.3.0 + - coverage=7.2.3 - flake8=5.0.4 - gin-config=0.5.0 - - ignite=0.4.10 - - lightgbm=3.3.3 - - numpy=1.23.4 - - pandas=1.5.1 - - pyarrow=9.0.0 - - pytest=7.2.0 - - pytorch=1.12.1 - - cudatoolkit=11.6.0 - - scikit-learn=1.1.3 - - tensorboard=2.11.0 + - ignite=0.4.11 + # Versioning of Pytorch (2.0) dependent on compatible CUDA version + - pytorch + - pytorch-cuda=11.8 + - lightgbm=3.3.5 + - numpy=1.24.3 + - pandas=2.0.0 + - pyarrow=11.0.0 + - pytest=7.3.1 + - scikit-learn=1.2.2 + - tensorboard=2.12.2 - tqdm=4.64.1 - - pip=22.3.1 + - pytorch-lightning=2.0.1 + - wandb=0.15.0 + - pip=23.1 - scikit-optimize=0.9.0 + - einops=0.6.1 + - hydra-core=1.3 - pip: - git+https://github.com/rvandewater/recipys.git + - hyperimpute==0.1.16 + - pypots==0.0.10 + - hydra-submitit-launcher==1.2.0 + diff --git a/environment_mps.yml b/environment_mps.yml index db41d15c..eac240c8 100644 --- a/environment_mps.yml +++ b/environment_mps.yml @@ -1,24 +1,36 @@ name: yaib channels: - pytorch + - nvidia - conda-forge - anaconda dependencies: - - black=22.10.0 - - coverage=6.5.0 + - python=3.10 + - black=23.3.0 + - coverage=7.2.3 - flake8=5.0.4 - gin-config=0.5.0 - - ignite=0.4.10 - - lightgbm=3.3.3 - - numpy=1.23.4 - - pandas=1.5.1 - - pyarrow=9.0.0 - - pytest=7.2.0 - - pytorch=1.13.0 - - scikit-learn=1.1.3 - - tensorboard=2.11.0 + - ignite=0.4.11 + # Versioning of Pytorch (2.0) dependent on compatible CUDA version + - pytorch + - pytorch-cuda=11.8 + - lightgbm=3.3.5 + - numpy=1.24.3 + - pandas=2.0.0 + - pyarrow=11.0.0 + - pytest=7.3.1 + - scikit-learn=1.2.2 + - tensorboard=2.12.2 - tqdm=4.64.1 - - pip=22.3.1 + - mkl < 2022 + - pytorch-lightning=2.0.1 + - wandb=0.15.0 + - pip=23.1 - scikit-optimize=0.9.0 + - einops=0.6.1 + - hydra-core=1.3 - pip: - git+https://github.com/rvandewater/recipys.git + - hyperimpute==0.1.16 + - pypots==0.0.10 + - hydra-submitit-launcher==1.2.0 diff --git a/experiments/benchmark_classification.yml b/experiments/benchmark_classification.yml new file mode 100644 index 00000000..61763ffa --- /dev/null +++ b/experiments/benchmark_classification.yml @@ -0,0 +1,46 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t + - BinaryClassification + - --log-dir + - ../yaib_logs + - --tune + - --wandb-sweep + - -gc + - -lc +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: + - ../data/mortality24/miiv + - ../data/mortality24/hirid + - ../data/mortality24/eicu + - ../data/mortality24/aumc + - ../data/aki/miiv + - ../data/aki/hirid + - ../data/aki/eicu + - ../data/aki/aumc + - ../data/sepsis/miiv + - ../data/sepsis/hirid + - ../data/sepsis/eicu + - ../data/sepsis/aumc + model: + values: + - LogisticRegression + - LGBMClassifier + - GRU + - LSTM + - TCN + - Transformer + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_regression.yml b/experiments/benchmark_regression.yml new file mode 100644 index 00000000..8aa8d13e --- /dev/null +++ b/experiments/benchmark_regression.yml @@ -0,0 +1,42 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t + - Regression + - --log-dir + - ../yaib_logs + - --tune + - --wandb-sweep + - -gc + - -lc +method: grid +name: yaib_regression_benchmark +parameters: + data_dir: + values: + - ../data/los/miiv + - ../data/los/hirid + - ../data/los/eicu + - ../data/los/aumc + - ../data/kf/miiv + - ../data/kf/hirid + - ../data/kf/eicu + - ../data/kf/aumc + model: + values: + - ElasticNet + - LGBMRegressor + - GRU + - LSTM + - TCN + - Transformer + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/demo_benchmark_classification.yml b/experiments/demo_benchmark_classification.yml new file mode 100644 index 00000000..32516df2 --- /dev/null +++ b/experiments/demo_benchmark_classification.yml @@ -0,0 +1,40 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t + - BinaryClassification + - --log-dir + - ../yaib_logs + - --tune + - --wandb-sweep + - -gc + - -lc +method: grid +name: yaib_demo_classification_benchmark +parameters: + data_dir: + values: + - demo_data/mortality24/eicu_demo + - demo_data/mortality24/mimic_demo + - demo_data/aki/eicu_demo + - demo_data/aki/mimic_demo + - demo_data/sepsis/eicu_demo + - demo_data/sepsis/mimic_demo + model: + values: + - LogisticRegression + - LGBMClassifier + - GRU + - LSTM + - TCN + - Transformer + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/demo_benchmark_regression.yml b/experiments/demo_benchmark_regression.yml new file mode 100644 index 00000000..3b678371 --- /dev/null +++ b/experiments/demo_benchmark_regression.yml @@ -0,0 +1,38 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t + - Regression + - --log-dir + - ../yaib_logs + - --tune + - --wandb-sweep + - -gc + - -lc +method: grid +name: yaib_demo_regression_benchmark +parameters: + data_dir: + values: + - demo_data/los/eicu_demo + - demo_data/los/mimic_demo + - demo_data/kf/eicu_demo + - demo_data/kf/mimic_demo + model: + values: + - ElasticNet + - LGBMRegressor + - GRU + - LSTM + - TCN + - Transformer + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/icu_benchmarks/__init__.py b/icu_benchmarks/__init__.py index c51b8f69..38c4e50d 100644 --- a/icu_benchmarks/__init__.py +++ b/icu_benchmarks/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - """Top-level package for YAIB.""" __author__ = "Robin van de Water" diff --git a/icu_benchmarks/contants.py b/icu_benchmarks/contants.py new file mode 100644 index 00000000..da698a63 --- /dev/null +++ b/icu_benchmarks/contants.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class RunMode(str, Enum): + classification = "Classification" + imputation = "Imputation" + regression = "Regression" diff --git a/icu_benchmarks/cross_validation.py b/icu_benchmarks/cross_validation.py index abf99c40..89a98864 100644 --- a/icu_benchmarks/cross_validation.py +++ b/icu_benchmarks/cross_validation.py @@ -3,11 +3,15 @@ import logging import gin from pathlib import Path +from pytorch_lightning import seed_everything -from icu_benchmarks.data.preprocess import preprocess_data +from icu_benchmarks.wandb_utils import wandb_log +from icu_benchmarks.run_utils import aggregate_results +from icu_benchmarks.data.split_process_data import preprocess_data from icu_benchmarks.models.train import train_common from icu_benchmarks.models.utils import JsonResultLoggingEncoder from icu_benchmarks.run_utils import log_full_line +from icu_benchmarks.contants import RunMode @gin.configurable @@ -26,10 +30,16 @@ def execute_repeated_cv( generate_cache: bool = False, load_cache: bool = False, test_on: str = "test", + mode: str = RunMode.classification, + pretrained_imputation_model: object = None, + cpu: bool = False, + verbose: bool = False, + wandb: bool = False, ) -> float: """Preprocesses data and trains a model for each fold. Args: + data_dir: Path to the data directory. log_dir: Path to the log directory. seed: Random seed. @@ -37,11 +47,17 @@ def execute_repeated_cv( source_dir: Path to the source directory. cv_folds: Number of folds for cross validation. cv_folds_to_train: Number of folds to use during training. If None, all folds are trained on. + cv_repetitions: Amount of cross validation repetitions. + cv_repetitions_to_train: Amount of training repetitions. If None, all repetitions are trained on. reproducible: Whether to make torch reproducible. debug: Whether to load less data and enable more logging. generate_cache: Whether to generate and save cache. load_cache: Whether to load previously cached data. test_on: Dataset to test on. Can be "test" or "val" (e.g. for hyperparameter tuning). + mode: Run mode. Can be one of the values of RunMode + pretrained_imputation_model: Use a pretrained imputation model. + cpu: Whether to run on CPU. + verbose: Enable detailed logging. Returns: The average loss of all folds. """ @@ -50,6 +66,8 @@ def execute_repeated_cv( if not cv_folds_to_train: cv_folds_to_train = cv_folds agg_loss = 0 + + seed_everything(seed, reproducible) for repetition in range(cv_repetitions_to_train): for fold_index in range(cv_folds_to_train): start_time = datetime.now() @@ -63,6 +81,8 @@ def execute_repeated_cv( repetition_index=repetition, cv_folds=cv_folds, fold_index=fold_index, + pretrained_imputation_model=pretrained_imputation_model, + runmode=mode, ) repetition_fold_dir = log_dir / f"repetition_{repetition}" / f"fold_{fold_index}" @@ -74,9 +94,12 @@ def execute_repeated_cv( log_dir=repetition_fold_dir, load_weights=load_weights, source_dir=source_dir, - seed=seed, reproducible=reproducible, test_on=test_on, + mode=mode, + cpu=cpu, + verbose=verbose, + use_wandb=wandb, ) train_time = datetime.now() - start_time @@ -88,6 +111,10 @@ def execute_repeated_cv( with open(repetition_fold_dir / "durations.json", "w") as f: json.dump(durations, f, cls=JsonResultLoggingEncoder) + if wandb: + wandb_log({"Iteration": repetition * cv_folds_to_train + fold_index}) + if repetition * cv_folds_to_train + fold_index > 1: + aggregate_results(log_dir) log_full_line(f"FINISHED CV REPETITION {repetition}", level=logging.INFO, char="=", num_newlines=3) return agg_loss / (cv_repetitions_to_train * cv_folds_to_train) diff --git a/icu_benchmarks/data/constants.py b/icu_benchmarks/data/constants.py index 05984945..031f35ca 100644 --- a/icu_benchmarks/data/constants.py +++ b/icu_benchmarks/data/constants.py @@ -9,3 +9,9 @@ class DataSegment: dynamic = "DYNAMIC" outcome = "OUTCOME" # Labels features = "FEATURES" # Combined features from static and dynamic data. + + +class VarType: + group = "GROUP" + sequence = "SEQUENCE" + label = "LABEL" diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index af75f1b4..9b831aa7 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -1,35 +1,49 @@ +from typing import List +from pandas import DataFrame import gin import numpy as np -import torch -from torch import Tensor +from torch import Tensor, cat, from_numpy, float32 from torch.utils.data import Dataset +import logging +from typing import Dict, Tuple -from .constants import DataSegment as Segment, DataSplit as Split +from icu_benchmarks.imputation.amputations import ampute_data +from .constants import DataSegment as Segment +from .constants import DataSplit as Split -@gin.configurable("Dataset") -class SICUDataset(Dataset): - """Standardized ICU Dataset: subclass of Torch Dataset that represents the data to learn on. +class CommonDataset(Dataset): + """Common dataset: subclass of Torch Dataset that represents the data to learn on. - Args: - data: Dict of the different splits of the data. - split: Either 'train','val' or 'test'. - vars: Contains the names of columns in the data. + Args: data: Dict of the different splits of the data. split: Either 'train','val' or 'test'. vars: Contains the names of + columns in the data. grouping_segment: str, optional: The segment of the data contains the grouping column with only + unique values. Defaults to Segment.outcome. Is used to calculate the number of stays in the data. """ - def __init__(self, data: dict, split: str = Split.train, vars: dict[str] = gin.REQUIRED): + def __init__( + self, + data: dict, + split: str = Split.train, + vars: Dict[str, str] = gin.REQUIRED, + grouping_segment: str = Segment.outcome, + ): self.split = split self.vars = vars - self.outcome_df = data[split][Segment.outcome].set_index(self.vars["GROUP"]) + self.grouping_df = data[split][grouping_segment].set_index(self.vars["GROUP"]) self.features_df = ( data[split][Segment.features].set_index(self.vars["GROUP"]).drop(labels=self.vars["SEQUENCE"], axis=1) ) # calculate basic info for the data - self.num_stays = self.outcome_df.index.unique().shape[0] - self.num_measurements = self.features_df.shape[0] + self.num_stays = self.grouping_df.index.unique().shape[0] self.maxlen = self.features_df.groupby([self.vars["GROUP"]]).size().max() + def ram_cache(self, cache: bool = True): + self._cached_dataset = None + if cache: + logging.info("Caching dataset in ram.") + self._cached_dataset = [self[i] for i in range(len(self))] + def __len__(self) -> int: """Returns number of stays in the data. @@ -38,10 +52,34 @@ def __len__(self) -> int: """ return self.num_stays - def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: - """Function to sample from the data split of choice. + def get_feature_names(self): + return self.features_df.columns - Used for deep learning implementations. + def to_tensor(self): + values = [] + for entry in self: + for i, value in enumerate(entry): + if len(values) <= i: + values.append([]) + values[i].append(value.unsqueeze(0)) + return [cat(value, dim=0) for value in values] + + +@gin.configurable("PredictionDataset") +class PredictionDataset(CommonDataset): + """Subclass of common dataset for prediction tasks. + + Args: + ram_cache (bool, optional): Whether the complete dataset should be stored in ram. Defaults to True. + """ + + def __init__(self, *args, ram_cache: bool = True, **kwargs): + super().__init__(*args, grouping_segment=Segment.outcome, **kwargs) + self.outcome_df = self.grouping_df + self.ram_cache(ram_cache) + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. Used for deep learning implementations. Args: idx: A specific row index to sample. @@ -49,12 +87,15 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: Returns: A sample from the data, consisting of data, labels and padding mask. """ + if self._cached_dataset is not None: + return self._cached_dataset[idx] + pad_value = 0.0 stay_id = self.outcome_df.index.unique()[idx] # [self.vars["GROUP"]] # slice to make sure to always return a DF window = self.features_df.loc[stay_id:stay_id].to_numpy() - labels = self.outcome_df.loc[stay_id:stay_id]["label"].to_numpy(dtype=float) + labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float) if len(labels) == 1: # only one label per stay, align with window @@ -80,7 +121,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: labels = labels.astype(np.float32) data = window.astype(np.float32) - return torch.from_numpy(data), torch.from_numpy(labels), torch.from_numpy(pad_mask) + return from_numpy(data), from_numpy(labels), from_numpy(pad_mask) def get_balance(self) -> list: """Return the weight balance for the split of interest. @@ -88,22 +129,159 @@ def get_balance(self) -> list: Returns: Weights for each label. """ - counts = self.outcome_df["label"].value_counts() + counts = self.outcome_df[self.vars["LABEL"]].value_counts() return list((1 / counts) * np.sum(counts) / counts.shape[0]) - def get_data_and_labels(self) -> tuple[np.array, np.array]: + def get_data_and_labels(self) -> Tuple[np.array, np.array]: """Function to return all the data and labels aligned at once. We use this function for the ML methods which don't require an iterator. Returns: - A tuple containing data points and label for the split. + A Tuple containing data points and label for the split. """ - labels = self.outcome_df["label"].to_numpy().astype(float) + labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) rep = self.features_df if len(labels) == self.num_stays: # order of groups could be random, we make sure not to change it rep = rep.groupby(level=self.vars["GROUP"], sort=False).last() - rep = rep.to_numpy() + rep = rep.to_numpy().astype(float) return rep, labels + + def to_tensor(self): + data, labels = self.get_data_and_labels() + return from_numpy(data), from_numpy(labels) + + +@gin.configurable("ImputationDataset") +class ImputationDataset(CommonDataset): + """Subclass of Common Dataset that contains data for imputation models.""" + + def __init__( + self, + data: Dict[str, DataFrame], + split: str = Split.train, + vars: Dict[str, str] = gin.REQUIRED, + mask_proportion=0.3, + mask_method="MCAR", + mask_observation_proportion=0.3, + ram_cache: bool = True, + ): + """ + Args: + data (Dict[str, DataFrame]): data to use + split (str, optional): split to apply. Defaults to Split.train. + vars (Dict[str, str], optional): contains names of columns in the data. Defaults to gin.REQUIRED. + mask_proportion (float, optional): proportion to artificially mask for amputation. Defaults to 0.3. + mask_method (str, optional): masking mechanism. Defaults to "MCAR". + mask_observation_proportion (float, optional): poportion of the observed data to be masked. Defaults to 0.3. + ram_cache (bool, optional): if the dataset should be completely stored in ram and not generated on the fly during + training. Defaults to True. + """ + super().__init__(data, split, vars, grouping_segment=Segment.static) + self.amputated_values, self.amputation_mask = ampute_data( + self.features_df, mask_method, mask_proportion, mask_observation_proportion + ) + self.amputation_mask = (self.amputation_mask + self.features_df.isna().values).bool() + self.amputation_mask = DataFrame(self.amputation_mask, columns=self.vars[Segment.dynamic]) + self.amputation_mask[self.vars["GROUP"]] = self.features_df.index + self.amputation_mask.set_index(self.vars["GROUP"], inplace=True) + + self.target_missingness_mask = self.features_df.isna() + self.features_df.fillna(0, inplace=True) + self.ram_cache(ram_cache) + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. + + Used for deep learning implementations. + + Args: + idx: A specific row index to sample. + + Returns: + A sample from the data, consisting of data, labels and padding mask. + """ + if self._cached_dataset is not None: + return self._cached_dataset[idx] + stay_id = self.grouping_df.iloc[idx].name + + # slice to make sure to always return a DF + window = self.features_df.loc[stay_id:stay_id, self.vars[Segment.dynamic]] + window_missingness_mask = self.target_missingness_mask.loc[stay_id:stay_id, self.vars[Segment.dynamic]] + amputated_window = self.amputated_values.loc[stay_id:stay_id, self.vars[Segment.dynamic]] + amputation_mask = self.amputation_mask.loc[stay_id:stay_id, self.vars[Segment.dynamic]] + + return ( + from_numpy(amputated_window.values).to(float32), + from_numpy(amputation_mask.values).to(float32), + from_numpy(window.values).to(float32), + from_numpy(window_missingness_mask.values).to(float32), + ) + + +@gin.configurable("ImputationPredictionDataset") +class ImputationPredictionDataset(Dataset): + """Subclass of torch dataset that represents data with missingness for imputation. + + Args: + data (DataFrame): dict of the different splits of the data + grouping_column (str, optional): column that is used for grouping. Defaults to "stay_id". + select_columns (List[str], optional): the columns to serve as input for the imputation model. Defaults to None. + ram_cache (bool, optional): wether the dataset should be stored in ram. Defaults to True. + """ + + def __init__( + self, + data: DataFrame, + grouping_column: str = "stay_id", + select_columns: List[str] = None, + ram_cache: bool = True, + ): + self.dyn_df = data + + if select_columns is not None: + self.dyn_df = self.dyn_df[list(select_columns) + grouping_column] + + if grouping_column is not None: + self.dyn_df = self.dyn_df.set_index(grouping_column) + else: + self.dyn_df = data + + # calculate basic info for the data + self.group_indices = self.dyn_df.index.unique() + self.maxlen = self.dyn_df.groupby(grouping_column).size().max() + + self._cached_dataset = None + if ram_cache: + logging.info("Caching dataset in ram.") + self._cached_dataset = [self[i] for i in range(len(self))] + + def __len__(self) -> int: + """Returns number of stays in the data. + + Returns: + number of stays in the data + """ + return self.group_indices.shape[0] + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. + + Used for deep learning implementations. + + Args: + idx: A specific row index to sample. + + Returns: + A sample from the data, consisting of data, labels and padding mask. + """ + if self._cached_dataset is not None: + return self._cached_dataset[idx] + stay_id = self.group_indices[idx] + + # slice to make sure to always return a DF + window = self.dyn_df.loc[stay_id:stay_id, :] + + return from_numpy(window.values).to(float32) diff --git a/icu_benchmarks/data/preprocessor.py b/icu_benchmarks/data/preprocessor.py index 68b7a66a..e61ff06d 100644 --- a/icu_benchmarks/data/preprocessor.py +++ b/icu_benchmarks/data/preprocessor.py @@ -1,20 +1,21 @@ +import torch import logging import gin import pandas as pd from recipys.recipe import Recipe -from recipys.selector import all_numeric_predictors, has_type, all_of -from recipys.step import StepScale, StepImputeFill, StepSklearn, StepHistorical, Accumulator +from recipys.selector import all_numeric_predictors, all_outcomes, has_type, all_of +from recipys.step import StepScale, StepImputeFill, StepSklearn, StepHistorical, Accumulator, StepImputeModel from sklearn.impute import SimpleImputer, MissingIndicator -from sklearn.preprocessing import LabelEncoder +from sklearn.preprocessing import LabelEncoder, FunctionTransformer, MinMaxScaler + +from icu_benchmarks.wandb_utils import update_wandb_config +from icu_benchmarks.data.loader import ImputationPredictionDataset from .constants import DataSplit as Split, DataSegment as Segment import abc class Preprocessor: - def __init__(self): - pass - @abc.abstractmethod def apply(self, data, vars): return data @@ -23,16 +24,15 @@ def apply(self, data, vars): def to_cache_string(self): return f"{self.__class__.__name__}" - @abc.abstractmethod - def calculate_input_dim(self): - pass + def set_imputation_model(self, imputation_model): + self.imputation_model = imputation_model + if self.imputation_model is not None: + update_wandb_config({"imputation_model": self.imputation_model.__class__.__name__}) -@gin.configurable("default_preprocessor") -class DefaultPreprocessor(Preprocessor): - def __init__( - self, generate_features: bool = True, scaling: bool = True, use_static_features: bool = True, vars: dict = None - ): +@gin.configurable("base_classification_preprocessor") +class DefaultClassificationPreprocessor(Preprocessor): + def __init__(self, generate_features: bool = True, scaling: bool = True, use_static_features: bool = True): """ Args: generate_features: Generate features for dynamic data. @@ -44,9 +44,9 @@ def __init__( self.generate_features = generate_features self.scaling = scaling self.use_static_features = use_static_features - self.vars = vars + self.imputation_model = None - def apply(self, data, vars): + def apply(self, data, vars) -> dict[dict[pd.DataFrame]]: """ Args: data: Train, validation and test data dictionary. Further divided in static, dynamic, and outcome. @@ -55,10 +55,10 @@ def apply(self, data, vars): Preprocessed data. """ logging.info("Preprocessing dynamic features.") - data = self.process_dynamic(data, vars) + data = self._process_dynamic(data, vars) if self.use_static_features: logging.info("Preprocessing static features.") - data = self.process_static(data, vars) + data = self._process_static(data, vars) # Set index to grouping variable data[Split.train][Segment.static] = data[Split.train][Segment.static].set_index(vars["GROUP"]) @@ -87,7 +87,7 @@ def apply(self, data, vars): data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic) return data - def process_static(self, data, vars): + def _process_static(self, data, vars): sta_rec = Recipe(data[Split.train][Segment.static], [], vars[Segment.static]) if self.scaling: sta_rec.add_step(StepScale()) @@ -96,23 +96,41 @@ def process_static(self, data, vars): sta_rec.add_step(StepSklearn(SimpleImputer(missing_values=None, strategy="most_frequent"), sel=has_type("object"))) sta_rec.add_step(StepSklearn(LabelEncoder(), sel=has_type("object"), columnwise=True)) - data = self.apply_recipe_to_Splits(sta_rec, data, Segment.static) + data = apply_recipe_to_splits(sta_rec, data, Segment.static) + + return data + def _model_impute(self, data, group=None): + dataset = ImputationPredictionDataset(data, group, self.imputation_model.trained_columns) + input_data = torch.cat([data_point.unsqueeze(0) for data_point in dataset], dim=0) + self.imputation_model.eval() + with torch.no_grad(): + logging.info(f"Imputing with {self.imputation_model.__class__.__name__}.") + imputation = self.imputation_model.predict(input_data) + logging.info("Imputation done.") + assert imputation.isnan().sum() == 0 + data = data.copy() + data.loc[:, self.imputation_model.trained_columns] = imputation.flatten(end_dim=1).to("cpu") + if group is not None: + data.drop(columns=group, inplace=True) return data - def process_dynamic(self, data, vars): + def _process_dynamic(self, data, vars): dyn_rec = Recipe(data[Split.train][Segment.dynamic], [], vars[Segment.dynamic], vars["GROUP"], vars["SEQUENCE"]) if self.scaling: dyn_rec.add_step(StepScale()) + if self.imputation_model is not None: + dyn_rec.add_step(StepImputeModel(model=self.model_impute, sel=all_of(vars[Segment.dynamic]))) dyn_rec.add_step(StepSklearn(MissingIndicator(), sel=all_of(vars[Segment.dynamic]), in_place=False)) dyn_rec.add_step(StepImputeFill(method="ffill")) dyn_rec.add_step(StepImputeFill(value=0)) if self.generate_features: - dyn_rec = self.dynamic_feature_generation(dyn_rec, all_of(vars[Segment.dynamic])) - data = self.apply_recipe_to_Splits(dyn_rec, data, Segment.dynamic) + dyn_rec = self._dynamic_feature_generation(dyn_rec, all_of(vars[Segment.dynamic])) + data = apply_recipe_to_splits(dyn_rec, data, Segment.dynamic) return data - def dynamic_feature_generation(self, data, dynamic_vars): + def _dynamic_feature_generation(self, data, dynamic_vars): + logging.debug("Adding dynamic feature generation.") data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MIN, suffix="min_hist")) data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MAX, suffix="max_hist")) data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.COUNT, suffix="count_hist")) @@ -120,33 +138,142 @@ def dynamic_feature_generation(self, data, dynamic_vars): return data def to_cache_string(self): - return super().to_cache_string() + f"_{self.generate_features}_{self.scaling}" + return ( + super().to_cache_string() + + f"_classification_{self.generate_features}_{self.scaling}_{self.imputation_model.__class__.__name__}" + ) - def calculate_input_dim(self): - if self.generate_features: - len_dynamic = len(self.vars[Segment.dynamic]) * 6 - else: - len_dynamic = len(self.vars[Segment.dynamic]) * 2 - if self.use_static_features: - len_static = len(self.vars[Segment.static]) + +@gin.configurable("base_regression_preprocessor") +class DefaultRegressionPreprocessor(DefaultClassificationPreprocessor): + # Override base classification preprocessor + def __init__( + self, + generate_features: bool = True, + scaling: bool = True, + use_static_features: bool = True, + outcome_max=None, + outcome_min=None, + ): + """ + Args: + generate_features: Generate features for dynamic data. + scaling: Scaling of dynamic and static data. + use_static_features: Use static features. + max_range: Maximum value in outcome. + min_range: Minimum value in outcome. + Returns: + Preprocessed data. + """ + super().__init__(generate_features, scaling, use_static_features) + self.outcome_max = outcome_max + self.outcome_min = outcome_min + + def apply(self, data, vars): + """ + Args: + data: Train, validation and test data dictionary. Further divided in static, dynamic, and outcome. + vars: Variables for static, dynamic, outcome. + Returns: + Preprocessed data. + """ + for split in [Split.train, Split.val, Split.test]: + data = self._process_outcome(data, vars, split) + + data = super().apply(data, vars) + return data + + def _process_outcome(self, data, vars, split): + logging.debug(f"Processing {split} outcome values.") + outcome_rec = Recipe(data[split][Segment.outcome], vars["LABEL"], [], vars["GROUP"]) + # If the range is predefined, use predefined transformation function + if self.outcome_max is not None and self.outcome_min is not None: + outcome_rec.add_step( + StepSklearn( + sklearn_transformer=FunctionTransformer( + func=lambda x: ((x + abs(self.outcome_min)) / (abs(self.outcome_min) + self.outcome_max)) + ), + sel=all_outcomes(), + ) + ) else: - len_static = 0 + # If the range is not predefined, use MinMaxScaler + outcome_rec.add_step(StepSklearn(MinMaxScaler(), sel=all_outcomes())) + outcome_rec.prep() + data[split][Segment.outcome] = outcome_rec.bake() + return data - return len_dynamic + len_static - @staticmethod - def apply_recipe_to_Splits(recipe: Recipe, data: dict[dict[pd.DataFrame]], type: str) -> dict[dict[pd.DataFrame]]: - """Fits and transforms the training features, then transforms the validation and test features with the recipe. +@gin.configurable("base_imputation_preprocessor") +class DefaultImputationPreprocessor(Preprocessor): + def __init__( + self, + scaling: bool = True, + use_static_features: bool = True, + filter_missing_values: bool = True, + ): + """Preprocesses data for imputation. Args: - recipe: Object containing info about the features and steps. - data: Dict containing 'train', 'val', and 'test' and types of features per Split. - type: Whether to apply recipe to dynamic features, static features or outcomes. + scaling (bool, optional): If the values in each column should be normalized. Defaults to True. + use_static_features (bool, optional): If static features should be included in the dataset. Defaults to True. + """ + self.scaling = scaling + self.use_static_features = use_static_features + self.filter_missing_values = filter_missing_values + def apply(self, data, vars): + """ + Args: + data: Train, validation and test data dictionary. Further divided in static, dynamic, and outcome. + vars: Variables for static, dynamic, outcome. Returns: - Transformed features divided into 'train', 'val', and 'test'. + Preprocessed data. """ - data[Split.train][type] = recipe.prep() - data[Split.val][type] = recipe.bake(data[Split.val][type]) - data[Split.test][type] = recipe.bake(data[Split.test][type]) + logging.info("Preprocessor static features.") + data = {step: self._process_dynamic_data(data[step], vars) for step in data} + + dyn_rec = Recipe(data[Split.train][Segment.dynamic], [], vars[Segment.dynamic], vars["GROUP"], vars["SEQUENCE"]) + if self.scaling: + dyn_rec.add_step(StepScale()) + data = apply_recipe_to_splits(dyn_rec, data, Segment.dynamic) + + data[Split.train][Segment.features] = ( + data[Split.train].pop(Segment.dynamic).loc[:, vars[Segment.dynamic] + [vars["GROUP"], vars["SEQUENCE"]]] + ) + data[Split.val][Segment.features] = ( + data[Split.val].pop(Segment.dynamic).loc[:, vars[Segment.dynamic] + [vars["GROUP"], vars["SEQUENCE"]]] + ) + data[Split.test][Segment.features] = ( + data[Split.test].pop(Segment.dynamic).loc[:, vars[Segment.dynamic] + [vars["GROUP"], vars["SEQUENCE"]]] + ) return data + + def to_cache_string(self): + return super().to_cache_string() + f"_imputation_{self.use_static_features}_{self.scaling}" + + def _process_dynamic_data(self, data, vars): + if self.filter_missing_values: + rows_to_remove = data[Segment.dynamic][vars[Segment.dynamic]].isna().sum(axis=1) != 0 + ids_to_remove = data[Segment.dynamic].loc[rows_to_remove][vars["GROUP"]].unique() + data = {table_name: table.loc[~table[vars["GROUP"]].isin(ids_to_remove)] for table_name, table in data.items()} + logging.info(f"Removed {len(ids_to_remove)} stays with missing values.") + return data + + +@staticmethod +def apply_recipe_to_splits(recipe: Recipe, data: dict[dict[pd.DataFrame]], type: str) -> dict[dict[pd.DataFrame]]: + """Fits and transforms the training features, then transforms the validation and test features with the recipe. + + Args: + recipe: Object containing info about the features and steps. + data: Dict containing 'train', 'val', and 'test' and types of features per split. + type: Whether to apply recipe to dynamic features, static features or outcomes. + + Returns: + Transformed features divided into 'train', 'val', and 'test'. + """ + data[Split.train][type] = recipe.prep() + data[Split.val][type] = recipe.bake(data[Split.val][type]) + data[Split.test][type] = recipe.bake(data[Split.test][type]) + return data diff --git a/icu_benchmarks/data/preprocess.py b/icu_benchmarks/data/split_process_data.py similarity index 61% rename from icu_benchmarks/data/preprocess.py rename to icu_benchmarks/data/split_process_data.py index b790b040..08db9d75 100644 --- a/icu_benchmarks/data/preprocess.py +++ b/icu_benchmarks/data/split_process_data.py @@ -6,74 +6,19 @@ import pyarrow.parquet as pq from pathlib import Path import pickle -from sklearn.model_selection import StratifiedKFold -from icu_benchmarks.data.preprocessor import Preprocessor, DefaultPreprocessor -from .constants import DataSplit as Split, DataSegment as Segment +from sklearn.model_selection import StratifiedKFold, KFold - -def make_single_split( - data: dict[pd.DataFrame], - vars: dict[str], - cv_repetitions: int, - repetition_index: int, - cv_folds: int, - fold_index: int, - seed: int = 42, - debug: bool = False, -) -> dict[dict[pd.DataFrame]]: - """Randomly split the data into training, validation, and test set. - - Args: - data: dictionary containing data divided int OUTCOME, STATIC, and DYNAMIC. - vars: Contains the names of columns in the data. - cv_repetitions: Number of times to repeat cross validation. - repetition_index: Index of the repetition to return. - cv_folds: Number of folds for cross validation. - fold_index: Index of the fold to return. - seed: Random seed. - debug: Load less data if true. - - Returns: - Input data divided into 'train', 'val', and 'test'. - """ - id = vars["GROUP"] - stays = data[Segment.outcome][id] - if debug: - # Only use 1% of the data - stays = stays.sample(frac=0.01, random_state=seed) - labels = data[Segment.outcome][vars["LABEL"]].loc[stays.index] - - outer_CV = StratifiedKFold(cv_repetitions, shuffle=True, random_state=seed) - inner_CV = StratifiedKFold(cv_folds, shuffle=True, random_state=seed) - - dev, test = list(outer_CV.split(stays, labels))[repetition_index] - dev_stays = stays.iloc[dev] - dev_labels = labels.iloc[dev] - train, val = list(inner_CV.split(dev_stays, dev_labels))[fold_index] - - split = { - Split.train: dev_stays.iloc[train], - Split.val: dev_stays.iloc[val], - Split.test: stays.iloc[test], - } - data_split = {} - - for fold in split.keys(): # Loop through splits (train / val / test) - # Loop through segments (DYNAMIC / STATIC / OUTCOME) - # set sort to true to make sure that IDs are reordered after scrambling earlier - data_split[fold] = { - data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() - } - - return data_split +from icu_benchmarks.data.preprocessor import Preprocessor, DefaultClassificationPreprocessor +from icu_benchmarks.contants import RunMode +from .constants import DataSplit as Split, DataSegment as Segment, VarType as Var @gin.configurable("preprocess") def preprocess_data( data_dir: Path, file_names: dict[str] = gin.REQUIRED, - preprocessor: Preprocessor = DefaultPreprocessor, + preprocessor: Preprocessor = DefaultClassificationPreprocessor, use_static: bool = True, vars: dict[str] = gin.REQUIRED, seed: int = 42, @@ -84,6 +29,8 @@ def preprocess_data( load_cache: bool = False, generate_cache: bool = False, fold_index: int = 0, + pretrained_imputation_model: str = None, + runmode: RunMode = RunMode.classification, ) -> dict[dict[pd.DataFrame]]: """Perform loading, splitting, imputing and normalising of task data. @@ -100,6 +47,7 @@ def preprocess_data( load_cache: Use cached preprocessed data if true. generate_cache: Generate cached preprocessed data if true. fold_index: Index of the fold to return. + pretrained_imputation_model: pretrained imputation model to use. if None, standard imputation is used. Returns: Preprocessed data as DataFrame in a hierarchical dict with features type (STATIC) / DYNAMIC/ OUTCOME @@ -115,12 +63,13 @@ def preprocess_data( dumped_file_names = json.dumps(file_names, sort_keys=True) dumped_vars = json.dumps(vars, sort_keys=True) - if preprocessor is not DefaultPreprocessor: - logging.log(logging.INFO, "Using user-supplied preprocessor.") + logging.log(logging.INFO, f"Using preprocessor: {preprocessor.__name__}") preprocessor = preprocessor(use_static_features=use_static) + if isinstance(preprocessor, DefaultClassificationPreprocessor): + preprocessor.set_imputation_model(pretrained_imputation_model) hash_config = f"{preprocessor.to_cache_string()}{dumped_file_names}{dumped_vars}{debug}".encode("utf-8") - cache_filename = f"s_{seed}_r_{repetition_index}_f_{fold_index}_{hashlib.md5(hash_config).hexdigest()}" + cache_filename = f"s_{seed}_r_{repetition_index}_f_{fold_index}_d_{debug}_{hashlib.md5(hash_config).hexdigest()}" cache_file = cache_dir / cache_filename if load_cache: @@ -131,27 +80,110 @@ def preprocess_data( else: logging.info(f"No cached data found in {cache_file}, loading raw features.") - data = {f: pq.read_table(data_dir / file_names[f]).to_pandas() for f in file_names.keys()} + # Read parquet files into pandas dataframes and remove the parquet file from memory + logging.info(f"Loading data from directory {data_dir.absolute()}") + data = {f: pq.read_table(data_dir / file_names[f]).to_pandas(self_destruct=True) for f in file_names.keys()} + # Generate the splits logging.info("Generating splits.") - data = make_single_split(data, vars, cv_repetitions, repetition_index, cv_folds, fold_index, seed=seed, debug=debug) + data = make_single_split( + data, vars, cv_repetitions, repetition_index, cv_folds, fold_index, seed=seed, debug=debug, runmode=runmode + ) + # Apply preprocessing data = preprocessor.apply(data, vars) + # Generate cache if generate_cache: caching(cache_dir, cache_file, data, load_cache) else: logging.info("Cache will not be saved.") - gin.bind_parameter("model/hyperparameter.input_dim", data[Split.train][Segment.features].shape[1] - 2) - logging.info("Finished preprocessing.") return data -def caching(cache_dir, cache_file, data, use_cache): - if use_cache and not cache_file.exists(): +def make_single_split( + data: dict[pd.DataFrame], + vars: dict[str], + cv_repetitions: int, + repetition_index: int, + cv_folds: int, + fold_index: int, + seed: int = 42, + debug: bool = False, + runmode: RunMode = RunMode.classification, +) -> dict[dict[pd.DataFrame]]: + """Randomly split the data into training, validation, and test set. + + Args: + runmode: Run mode. Can be one of the values of RunMode + data: dictionary containing data divided int OUTCOME, STATIC, and DYNAMIC. + vars: Contains the names of columns in the data. + cv_repetitions: Number of times to repeat cross validation. + repetition_index: Index of the repetition to return. + cv_folds: Number of folds for cross validation. + fold_index: Index of the fold to return. + seed: Random seed. + debug: Load less data if true. + + Returns: + Input data divided into 'train', 'val', and 'test'. + """ + # ID variable + id = vars[Var.group] + + # Get stay IDs from outcome segment + stays = pd.Series(data[Segment.outcome][id].unique(), name=id) + + if debug: + # Only use 1% of the data + stays = stays.sample(frac=0.01, random_state=seed) + + # If there are labels, and the task is classification, use stratified k-fold + if Var.label in vars and runmode is RunMode.classification: + # Get labels from outcome data (takes the highest value (or True) in case seq2seq classification) + labels = data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) + if labels.value_counts().min() < cv_folds: + raise Exception( + f"The smallest amount of samples in a class is: {labels.value_counts().min()}, " + f"but {cv_folds} folds are requested. Reduce the number of folds or use more data." + ) + outer_cv = StratifiedKFold(cv_repetitions, shuffle=True, random_state=seed) + inner_cv = StratifiedKFold(cv_folds, shuffle=True, random_state=seed) + + dev, test = list(outer_cv.split(stays, labels))[repetition_index] + dev_stays = stays.iloc[dev] + train, val = list(inner_cv.split(dev_stays, labels.iloc[dev]))[fold_index] + else: + # If there are no labels, or the task is regression, use regular k-fold. + outer_cv = KFold(cv_repetitions, shuffle=True, random_state=seed) + inner_cv = KFold(cv_folds, shuffle=True, random_state=seed) + + dev, test = list(outer_cv.split(stays))[repetition_index] + dev_stays = stays.iloc[dev] + train, val = list(inner_cv.split(dev_stays))[fold_index] + + split = { + Split.train: dev_stays.iloc[train], + Split.val: dev_stays.iloc[val], + Split.test: stays.iloc[test], + } + data_split = {} + + for fold in split.keys(): # Loop through splits (train / val / test) + # Loop through segments (DYNAMIC / STATIC / OUTCOME) + # set sort to true to make sure that IDs are reordered after scrambling earlier + data_split[fold] = { + data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() + } + + return data_split + + +def caching(cache_dir, cache_file, data, use_cache, overwrite=True): + if use_cache and (not overwrite or not cache_file.exists()): if not cache_dir.exists(): cache_dir.mkdir() cache_file.touch() diff --git a/icu_benchmarks/imputation/__init__.py b/icu_benchmarks/imputation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/icu_benchmarks/imputation/amputations.py b/icu_benchmarks/imputation/amputations.py new file mode 100644 index 00000000..704bef5d --- /dev/null +++ b/icu_benchmarks/imputation/amputations.py @@ -0,0 +1,234 @@ +"""This file implements amputation mechanisms (MCAR, MAR (logisitc) and MNAR (logistic)) for missing data generation. +It was inspired from: https://rmisstastic.netlify.app/how-to/python/generate_html/how%20to%20generate%20missing%20values +Original code: https://github.com/BorisMuzellec/MissingDataOT/blob/master/utils.py +""" + +import gin +import torch +import logging +import numpy as np +from scipy import optimize + + +def MCAR_mask(X, p): + """ + Missing completely at random mechanism. + + Parameters + ---------- + X : torch.FloatTensor, shape (n, d) + Data for which missing values will be simulated. + p : float + Proportion of missing values to generate for variables which will have missing values. + + Returns + ------- + mask : torch.BoolTensor + Mask of generated missing values (True if the value is missing). + """ + + n, d = X.shape + mask = np.zeros((n, d)) + + ber = torch.rand(n, d) + mask = ber < p + + return mask + + +def BO_mask(X, p): + """ + Black out missing mechanism. Removes values across dimensions. + + Parameters + ---------- + X : torch.FloatTensor, shape (n, d) + Data for which missing values will be simulated. + p : float + Proportion of missing values to generate for variables which will have missing values. + + Returns + ------- + mask : torch.BoolTensor + Mask of generated missing values (True if the value is missing). + """ + + n, d = X.shape + + indices = torch.randperm(n)[: int(n * p)] + mask = torch.zeros(n, d).bool() + mask[indices, :] = True + + return mask + + +def MAR_logistic_mask(X, p, p_obs): + """ + Missing at random mechanism with a logistic masking model. First, a subset of variables with *no* missing values is + randomly selected. The remaining variables have missing values according to a logistic model with random weights, + re-scaled so as to attain the desired proportion of missing values on those variables. + + Parameters + ---------- + X : torch.FloatTensor, shape (n, d) + Data for which missing values will be simulated. + p : float + Proportion of missing values to generate for variables which will have missing values. + p_obs : float + Proportion of variables with *no* missing values that will be used for the logistic masking model. + + Returns + ------- + mask : torch.BoolTensor + Mask of generated missing values (True if the value is missing). + """ + + n, d = X.shape + mask = torch.zeros(n, d).bool() + + # number of variables that will have no missing values (at least one variable) + d_obs = max(int(p_obs * d), 1) + # number of variables that will have missing values + d_na = d - d_obs + + # Sample variables that will all be observed, and those with missing values + idxs_obs = np.random.choice(d, d_obs, replace=False) + idxs_nas = np.array([i for i in range(d) if i not in idxs_obs]) + + # Other variables will have NA proportions that depend on those observed variables, through a logistic model + # The parameters of this logistic model are random + + # Pick coefficients so that W^Tx has unit variance (avoids shrinking) + coeffs = pick_coeffs(X, idxs_obs, idxs_nas) + # Pick the intercepts to have a desired amount of missing values + intercepts = fit_intercepts(X[:, idxs_obs], coeffs, p) + + ps = torch.sigmoid(X[:, idxs_obs].mm(coeffs) + intercepts) + + ber = torch.rand(n, d_na) + mask[:, idxs_nas] = ber < ps + + return mask + + +def MNAR_logistic_mask(X, p, p_params=0.3, exclude_inputs=True): + """ + Missing not at random mechanism with a logistic masking model. It implements two mechanisms: + (i) Missing probabilities are selected with a logistic model, taking all variables as inputs. Hence, values that are + inputs can also be missing. + (ii) Variables are split into a set of intputs for a logistic model, and a set whose missing probabilities are + determined by the logistic model. Then inputs are then masked MCAR (hence, missing values from the second set will + depend on masked values. + In either case, weights are random and the intercept is selected to attain the desired proportion of missing values. + + Parameters + ---------- + X : torch.FloatTensor, shape (n, d) + Data for which missing values will be simulated. + p : float + Proportion of missing values to generate for variables which will have missing values. + p_params : float + Proportion of variables that will be used for the logistic masking model (only if exclude_inputs). + exclude_inputs : boolean, default=True + True: mechanism (ii) is used, False: (i) + + Returns + ------- + mask : torch.BoolTensor + Mask of generated missing values (True if the value is missing). + """ + + n, d = X.shape + mask = torch.zeros(n, d).bool() + + # number of variables used as inputs (at least 1) + d_params = max(int(p_params * d), 1) if exclude_inputs else d + # number of variables masked with the logistic model + d_na = d - d_params if exclude_inputs else d + + # Sample variables that will be parameters for the logistic regression: + idxs_params = np.random.choice(d, d_params, replace=False) if exclude_inputs else np.arange(d) + idxs_nas = np.array([i for i in range(d) if i not in idxs_params]) if exclude_inputs else np.arange(d) + + # Other variables will have NA proportions selected by a logistic model + # The parameters of this logistic model are random. + + # Pick coefficients so that W^Tx has unit variance (avoids shrinking) + coeffs = pick_coeffs(X, idxs_params, idxs_nas) + # Pick the intercepts to have a desired amount of missing values + intercepts = fit_intercepts(X[:, idxs_params], coeffs, p) + + ps = torch.sigmoid(X[:, idxs_params].mm(coeffs) + intercepts) + + ber = torch.rand(n, d_na) + mask[:, idxs_nas] = ber < ps + + # If the inputs of the logistic model are excluded from MNAR missingness, mask some + # values used in the logistic model at random + # This makes the missingness of other variables potentially dependent on masked values + + if exclude_inputs: + mask[:, idxs_params] = torch.rand(n, d_params) < p + + return mask + + +def pick_coeffs(X, idxs_obs=None, idxs_nas=None): + d_obs = len(idxs_obs) + d_na = len(idxs_nas) + coeffs = torch.randn(d_obs, d_na) + Wx = X[:, idxs_obs].mm(coeffs) + coeffs /= torch.std(Wx, 0, keepdim=True) + return coeffs + + +def fit_intercepts(X, coeffs, p): + d_obs, d_na = coeffs.shape + intercepts = torch.zeros(d_na) + for j in range(d_na): + + def f(x): + return torch.sigmoid(X.mv(coeffs[:, j]) + x).mean().item() - p + + intercepts[j] = optimize.bisect(f, -50, 50) + return intercepts + + +@gin.configurable("amputation") +def ampute_data(data, mechanism, p_miss, p_obs=0.3): + """ + Generate missing values for specifics missing-data mechanism and proportion of missing values. + + Parameters + ---------- + data : DataFrame + Data for which missing values will be simulated. + mechanism : str, + Indicates the missing-data mechanism to be used. ("MCAR", "MAR" or "MNAR") + p_miss : float + Proportion of missing values to generate for variables which will have missing values. + p_obs : float + If mecha = "MAR" or "MNAR", proportion of variables with *no* missing values + that will be used for the logistic masking model. + + Returns + ---------- + imputed_data: DataFrame + The data with the generated missing values. + """ + logging.info(f"Applying {mechanism} amputation.") + X = torch.tensor(data.values.astype(np.float32)) + + if mechanism == "MAR": + mask = MAR_logistic_mask(X, p_miss, p_obs) + elif mechanism == "MNAR": + mask = MNAR_logistic_mask(X, p_miss, p_obs) + elif mechanism == "MCAR": + mask = MCAR_mask(X, p_miss) + elif mechanism == "BO": + mask = BO_mask(X, p_miss) + else: + logging.error("Not a valid amputation mechanism. Missing-data mechanisms to be used are MCAR, MAR or MNAR.") + + amputed_data = data.mask(mask) + return amputed_data, mask diff --git a/icu_benchmarks/imputation/baselines.py b/icu_benchmarks/imputation/baselines.py new file mode 100644 index 00000000..27ddf307 --- /dev/null +++ b/icu_benchmarks/imputation/baselines.py @@ -0,0 +1,330 @@ +"""Baseline imputation methods. These methods imported from other frameworks and are used as baselines for comparison.""" +import torch +from hyperimpute.plugins.imputers import Imputers as HyperImpute +from sklearn.experimental import enable_iterative_imputer # noqa: F401 +from sklearn.impute import KNNImputer, SimpleImputer, IterativeImputer +from sklearn.linear_model import LinearRegression +from typing import Type + +from icu_benchmarks.models.wrappers import ImputationWrapper +from pypots.imputation import BRITS, SAITS, Transformer +import gin + + +@gin.configurable("KNN") +class KNNImputation(ImputationWrapper): + """Imputation using Scikit-Learn K-Nearest Neighbour.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, n_neighbors=2, **kwargs) -> None: + super().__init__(*args, n_neighbors=n_neighbors, **kwargs) + self.imputer = KNNImputer(n_neighbors=n_neighbors) + + def fit(self, train_dataset, val_dataset): + self.imputer.fit(train_dataset.amputated_values.values) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.reshape((-1, amputated_values.shape[-1])) + debatched_values = debatched_values.to("cpu") + output = torch.Tensor(self.imputer.transform(debatched_values)).to(amputated_values.device) + + output = output.reshape(amputated_values.shape) + return output + + +@gin.configurable("MICE") +class MICEImputation(ImputationWrapper): + """Imputation using Scikit-Learn MICE.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, max_iter=100, verbose=2, imputation_order="random", random_state=0, **kwargs) -> None: + super().__init__( + *args, max_iter=max_iter, verbose=verbose, imputation_order=imputation_order, random_state=random_state, **kwargs + ) + self.imputer = IterativeImputer( + estimator=LinearRegression(), + max_iter=max_iter, + verbose=verbose, + imputation_order=imputation_order, + random_state=random_state, + ) + + def fit(self, train_dataset, val_dataset): + self.imputer.fit(train_dataset.amputated_values.values) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.reshape((-1, amputated_values.shape[-1])) + debatched_values = debatched_values.to("cpu") + output = torch.Tensor(self.imputer.transform(debatched_values)).to(amputated_values.device) + + output = output.reshape(amputated_values.shape) + return output + + +@gin.configurable("Mean") +class MeanImputation(ImputationWrapper): + """Mean imputation using Scikit-Learn SimpleImputer.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.imputer = SimpleImputer(strategy="mean") + + def fit(self, train_dataset, val_dataset): + self.imputer.fit(train_dataset.amputated_values.values) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.reshape((-1, amputated_values.shape[-1])) + debatched_values = debatched_values.to("cpu") + output = torch.Tensor(self.imputer.transform(debatched_values)).to(amputated_values.device) + + output = output.reshape(amputated_values.shape) + return output + + +@gin.configurable("Median") +class MedianImputation(ImputationWrapper): + needs_training = False + needs_fit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.imputer = SimpleImputer(strategy="median") + + def fit(self, train_dataset, val_dataset): + self.imputer.fit(train_dataset.amputated_values.values) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.reshape((-1, amputated_values.shape[-1])) + debatched_values = debatched_values.to("cpu") + output = torch.Tensor(self.imputer.transform(debatched_values)).to(amputated_values.device) + + output = output.reshape(amputated_values.shape) + return output + + +@gin.configurable("Zero") +class ZeroImputation(ImputationWrapper): + """Zero imputation using Scikit-Learn SimpleImputer.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.imputer = SimpleImputer(strategy="constant", fill_value=0.0) + + def fit(self, train_dataset, val_dataset): + self.imputer.fit(train_dataset.amputated_values.values) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.reshape((-1, amputated_values.shape[-1])) + debatched_values = debatched_values.to("cpu") + output = torch.Tensor(self.imputer.transform(debatched_values)).to(amputated_values.device) + + output = output.reshape(amputated_values.shape) + return output + + +@gin.configurable("MostFrequent") +class MostFrequentImputation(ImputationWrapper): + """Most frequent imputation using Scikit-Learn SimpleImputer.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.imputer = SimpleImputer(strategy="most_frequent") + + def fit(self, train_dataset, val_dataset): + self.imputer.fit(train_dataset.amputated_values.values) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.reshape((-1, amputated_values.shape[-1])) + debatched_values = debatched_values.to("cpu") + output = torch.Tensor(self.imputer.transform(debatched_values)).to(amputated_values.device) + + output = output.reshape(amputated_values.shape) + return output + + +def wrap_hyperimpute_model(methodName: str, configName: str) -> Type: + class HyperImputeImputation(ImputationWrapper): + """Imputation using HyperImpute package.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.imputer = HyperImpute().get(methodName, **self.model_params()) + + @gin.configurable(module=configName) + def model_params(self, **kwargs): + return kwargs + + def fit(self, train_dataset, val_dataset): + self.imputer.fit(train_dataset.amputated_values.values) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.reshape((-1, amputated_values.shape[-1])) + debatched_values = debatched_values.to(float).to("cpu").numpy() + with torch.inference_mode(mode=False): + output = torch.Tensor(self.imputer.transform(debatched_values).values).to(amputated_values.device) + output = output.reshape(amputated_values.shape) + return output + + return gin.configurable(configName)(HyperImputeImputation) + + +GAINImputation = wrap_hyperimpute_model("gain", "GAIN") +MissForestImputation = wrap_hyperimpute_model("sklearn_missforest", "MissForest") +ICEImputation = wrap_hyperimpute_model("ice", "ICE") +SoftImputeImputation = wrap_hyperimpute_model("softimpute", "SoftImpute") +SinkhornImputation = wrap_hyperimpute_model("sinkhorn", "Sinkhorn") +MiracleImputation = wrap_hyperimpute_model("miracle", "Miracle") +MiwaeImputation = wrap_hyperimpute_model("miwae", "Miwae") +HyperImputation = wrap_hyperimpute_model("hyperimpute", "HyperImpute") + + +@gin.configurable("BRITS") +class BRITSImputation(ImputationWrapper): + """Bidirectional Recurrent Imputation for Time Series (BRITS) imputation using PyPots package.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, input_size, epochs=1, rnn_hidden_size=64, batch_size=256, **kwargs) -> None: + super().__init__( + *args, input_size=input_size, epochs=epochs, rnn_hidden_size=rnn_hidden_size, batch_size=batch_size, **kwargs + ) + self.imputer = BRITS( + n_steps=input_size[1], + n_features=input_size[2], + rnn_hidden_size=rnn_hidden_size, + batch_size=batch_size, + epochs=epochs, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + def fit(self, train_dataset, val_dataset): + self.imputer.fit( + torch.Tensor( + train_dataset.amputated_values.values.reshape(-1, train_dataset.maxlen, train_dataset.features_df.shape[1]) + ) + ) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.to(self.imputer.device).squeeze() + self.imputer.model = self.imputer.model.to(self.imputer.device) + output = torch.Tensor(self.imputer.impute(debatched_values)).to(self.device) + + output = output.reshape(amputated_values.shape) + return output + + +@gin.configurable("SAITS") +class SAITSImputation(ImputationWrapper): + """Self-Attention based Imputation for Time Series (SAITS) imputation using PyPots package.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, input_size, epochs, n_layers, d_model, d_inner, n_head, d_k, d_v, dropout, **kwargs) -> None: + super().__init__( + *args, + input_size=input_size, + epochs=epochs, + n_layers=n_layers, + d_model=d_model, + d_inner=d_inner, + n_head=n_head, + d_k=d_k, + d_v=d_v, + dropout=dropout, + **kwargs + ) + self.imputer = SAITS( + n_steps=input_size[1], + n_features=input_size[2], + n_layers=n_layers, + d_model=d_model, + d_inner=d_inner, + n_head=n_head, + d_k=d_k, + d_v=d_v, + dropout=dropout, + epochs=epochs, + ) + + def fit(self, train_dataset, val_dataset): + self.imputer.fit( + torch.Tensor( + train_dataset.amputated_values.values.reshape(-1, train_dataset.maxlen, train_dataset.features_df.shape[1]) + ) + ) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.to(self.imputer.device).squeeze() + self.imputer.model = self.imputer.model.to(self.imputer.device) + output = torch.Tensor(self.imputer.impute(debatched_values)).to(self.device) + + output = output.reshape(amputated_values.shape) + return output + + +@gin.configurable("Attention") +class AttentionImputation(ImputationWrapper): + """Attention based Imputation (Transformer) imputation using PyPots package.""" + + needs_training = False + needs_fit = True + + def __init__(self, *args, input_size, epochs, n_layers, d_model, d_inner, n_head, d_k, d_v, dropout, **kwargs) -> None: + super().__init__( + *args, + input_size=input_size, + epochs=epochs, + n_layers=n_layers, + d_model=d_model, + d_inner=d_inner, + n_head=n_head, + d_k=d_k, + d_v=d_v, + dropout=dropout, + **kwargs + ) + self.imputer = Transformer( + n_steps=input_size[1], + n_features=input_size[2], + n_layers=n_layers, + d_model=d_model, + d_inner=d_inner, + n_head=n_head, + d_k=d_k, + d_v=d_v, + dropout=dropout, + epochs=epochs, + ) + + def fit(self, train_dataset, val_dataset): + self.imputer.fit( + torch.Tensor( + train_dataset.amputated_values.values.reshape(-1, train_dataset.maxlen, train_dataset.features_df.shape[1]) + ) + ) + + def forward(self, amputated_values, amputation_mask): + debatched_values = amputated_values.to(self.imputer.device).squeeze() + self.imputer.model = self.imputer.model.to(self.imputer.device) + output = torch.Tensor(self.imputer.impute(debatched_values)).to(self.device) + + output = output.reshape(amputated_values.shape) + return output diff --git a/icu_benchmarks/imputation/csdi.py b/icu_benchmarks/imputation/csdi.py new file mode 100644 index 00000000..ba56f9c8 --- /dev/null +++ b/icu_benchmarks/imputation/csdi.py @@ -0,0 +1,431 @@ +# Source: https://github.com/ermongroup/CSDI + +import gin +from icu_benchmarks.models.wrappers import ImputationWrapper +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +@gin.configurable("CSDI") +class CSDI(ImputationWrapper): + """Conditional Score-based Diffusion Models for Imputation (CSDI) of Time Series. See https://arxiv.org/abs/2107.03502 + for details.""" + + def __init__( + self, + input_size, + time_step_embedding_size, + feature_embedding_size, + unconditional, + target_strategy, + num_diffusion_steps, + diffusion_step_embedding_dim, + n_attention_heads, + num_residual_layers, + noise_schedule, + beta_start, + beta_end, + n_samples, + conv_channels, + *args, + **kwargs, + ): + super().__init__( + input_size=input_size, + time_step_embedding_size=time_step_embedding_size, + feature_embedding_size=feature_embedding_size, + unconditional=unconditional, + target_strategy=target_strategy, + num_diffusion_steps=num_diffusion_steps, + diffusion_step_embedding_dim=diffusion_step_embedding_dim, + n_attention_heads=n_attention_heads, + num_residual_layers=num_residual_layers, + noise_schedule=noise_schedule, + beta_start=beta_start, + beta_end=beta_end, + n_samples=n_samples, + conv_channels=conv_channels, + *args, + **kwargs, + ) + self.target_dim = input_size[2] + self.n_samples = n_samples + + self.emb_time_dim = time_step_embedding_size + self.emb_feature_dim = feature_embedding_size + self.is_unconditional = unconditional + self.target_strategy = target_strategy + + self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim + if not self.is_unconditional: + self.emb_total_dim += 1 # for conditional mask + self.embed_layer = nn.Embedding(num_embeddings=self.target_dim, embedding_dim=self.emb_feature_dim) + + input_dim = 1 if self.is_unconditional else 2 + self.diffmodel = diff_CSDI( + conv_channels, + num_diffusion_steps, + diffusion_step_embedding_dim, + self.emb_total_dim, + n_attention_heads, + num_residual_layers, + input_dim, + ) + + # parameters for diffusion models + self.num_steps = num_diffusion_steps + if noise_schedule == "quad": + self.beta = np.linspace(beta_start**0.5, beta_end**0.5, self.num_steps) ** 2 + elif noise_schedule == "linear": + self.beta = np.linspace(beta_start, beta_end, self.num_steps) + + self.alpha_hat = 1 - self.beta + self.alpha = np.cumprod(self.alpha_hat) + self.alpha_torch = torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1) + + def on_fit_start(self) -> None: + self.alpha_torch = self.alpha_torch.to(self.device) + self.alpha_hat = torch.from_numpy(self.alpha_hat).to(self.device) + self.beta = torch.from_numpy(self.beta).to(self.device) + self.alpha = torch.from_numpy(self.alpha).to(self.device) + return super().on_fit_start() + + def time_embedding(self, pos, d_model=128): + pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device) + position = pos.unsqueeze(2) + div_term = 1 / torch.pow(10000.0, torch.arange(0, d_model, 2).to(self.device) / d_model) + pe[:, :, 0::2] = torch.sin(position * div_term) + pe[:, :, 1::2] = torch.cos(position * div_term) + return pe + + def get_randmask(self, observed_mask): + rand_for_mask = torch.rand_like(observed_mask) * observed_mask + rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1) + sample_ratios = torch.rand((len(observed_mask),), device=self.device) + for i in range(len(observed_mask)): + num_observed = observed_mask[i].sum().item() + num_masked = round(num_observed * sample_ratios[i].item()) + rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1 + cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float() + return cond_mask + + def get_hist_mask(self, observed_mask, for_pattern_mask=None): + if for_pattern_mask is None: + for_pattern_mask = observed_mask + if self.target_strategy == "mix": + rand_mask = self.get_randmask(observed_mask) + + cond_mask = observed_mask.clone() + random_tensor = torch.rand((len(cond_mask),), device=self.device) + for i in range(len(cond_mask)): + mask_choice = random_tensor[i] + if self.target_strategy == "mix" and mask_choice > 0.5: + cond_mask[i] = rand_mask[i] + else: # draw another sample for histmask (i-1 corresponds to another sample) + cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1] + return cond_mask + + def get_side_info(self, observed_tp, cond_mask): + B, K, L = cond_mask.shape + + time_embed = self.time_embedding(observed_tp, self.emb_time_dim) # (B,L,emb) + time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) + feature_embed = self.embed_layer(torch.arange(self.target_dim).to(self.device)) # (K,emb) + feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1) + + side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*) + side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L) + + if not self.is_unconditional: + side_mask = cond_mask.unsqueeze(1) # (B,1,K,L) + side_info = torch.cat([side_info, side_mask], dim=1) + + return side_info + + def calc_loss_valid(self, observed_data, cond_mask, observed_mask, side_info, is_train): + loss_sum = 0 + for t in range(self.num_steps): # calculate loss for all t + loss = self.calc_loss(observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t) + loss_sum += loss.detach() + return loss_sum / self.num_steps + + def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): + if self.is_unconditional: + total_input = noisy_data.unsqueeze(1) # (B,1,K,L) + else: + cond_obs = (cond_mask * observed_data).unsqueeze(1) + noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1) + total_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) + + return total_input + + def impute(self, amputated_data, cond_mask, side_info, n_samples): + B, K, L = amputated_data.shape + + imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device) + + for i in range(n_samples): + # generate noisy observation for unconditional model + if self.is_unconditional: + noisy_obs = amputated_data + noisy_cond_history = [] + for t in range(self.num_steps): + noise = torch.randn_like(noisy_obs) + noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise + noisy_cond_history.append(noisy_obs * cond_mask) + + current_sample = torch.randn_like(amputated_data) + + for t in range(self.num_steps - 1, -1, -1): + if self.is_unconditional: + diff_input = noisy_cond_history[t] + (1.0 - cond_mask) * current_sample + diff_input = diff_input.unsqueeze(1) # (B,1,K,L) + else: + cond_obs = amputated_data.unsqueeze(1) + noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) + diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) + predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device)) + + coeff1 = 1 / self.alpha_hat[t] ** 0.5 + coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5 + current_sample = coeff1 * (current_sample - coeff2 * predicted) + + if t > 0: + noise = torch.randn_like(current_sample) + sigma = ((1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]) ** 0.5 + current_sample += sigma * noise + + imputed_samples[:, i] = current_sample.detach() + return imputed_samples + + def get_conditional_mask(self, observed_mask): + if self.target_strategy == "random": + return self.get_randmask(observed_mask) + return self.get_hist_mask(observed_mask) + + def forward(self, amputated_data, amputation_mask): + amputated_data = amputated_data.permute(0, 2, 1) + amputation_mask = amputation_mask.permute(0, 2, 1) + observed_mask = torch.ones_like(amputation_mask) - amputation_mask + B, K, L = amputated_data.shape + + observed_time_points = torch.arange(0, L, 1, device=self.device).expand(B, L) + + cond_mask = self.get_conditional_mask(observed_mask) + + side_info = self.get_side_info(observed_time_points, cond_mask) + + # return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train) + t = torch.randint(0, self.num_steps, [B]).to(self.device) + current_alpha = self.alpha_torch[t] # (B,1,1) + noise = torch.randn_like(amputated_data) + noisy_data = (current_alpha**0.5) * amputated_data + (1.0 - current_alpha) ** 0.5 * noise + + total_input = self.set_input_to_diffmodel(noisy_data, amputated_data, cond_mask) + + predicted = self.diffmodel(total_input, side_info, t) # (B,K,L) + + target_mask = observed_mask - cond_mask + return noise * target_mask, predicted * target_mask + + def step_fn(self, batch, step_prefix): + amputated_data, amputation_mask, target, target_missingness = batch + amputated_data = amputated_data.nan_to_num() + + if step_prefix == "test": + prediction = self.evaluate(amputated_data, amputation_mask, self.n_samples) + amputated_data[amputation_mask > 0] = prediction[amputation_mask > 0] + amputated_data[target_missingness > 0] = target[target_missingness > 0] + loss = self.loss(target, amputated_data) + for metric in self.metrics[step_prefix].values(): + metric.update( + ( + torch.flatten(amputated_data.detach(), start_dim=1).clone(), + torch.flatten(target.detach(), start_dim=1).clone(), + ) + ) + else: + noise, prediction = self(amputated_data, amputation_mask) + loss = self.loss(noise, prediction) + + self.log(f"{step_prefix}/loss", loss.item(), prog_bar=True) + return loss + + def predict_step(self, data, amputation_mask): + data = data.nan_to_num() + prediction = self.evaluate(data, amputation_mask, self.n_samples) + data[amputation_mask > 0] = prediction[amputation_mask > 0] + return data + + def evaluate(self, amputated_data, amputation_mask, n_samples): + amputated_data = amputated_data.permute(0, 2, 1) + amputation_mask = amputation_mask.permute(0, 2, 1) + B, K, L = amputated_data.shape + + observed_time_points = torch.arange(0, L, 1, device=self.device).expand(B, L) + + cond_mask = torch.ones_like(amputation_mask) - amputation_mask + + side_info = self.get_side_info(observed_time_points, cond_mask) + + samples = self.impute(amputated_data, cond_mask, side_info, n_samples) + + previous_deterministic_setting = torch.are_deterministic_algorithms_enabled() + torch.use_deterministic_algorithms(False) + samples = samples.median(dim=1)[0].permute(0, 2, 1) + torch.use_deterministic_algorithms(previous_deterministic_setting) + + return samples + + +def get_torch_trans(heads=8, layers=1, channels=64): + encoder_layer = nn.TransformerEncoderLayer(d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu") + return nn.TransformerEncoder(encoder_layer, num_layers=layers) + + +def Conv1d_with_init(in_channels, out_channels, kernel_size): + layer = nn.Conv1d(in_channels, out_channels, kernel_size) + nn.init.kaiming_normal_(layer.weight) + return layer + + +class DiffusionStepEmbedding(nn.Module): + def __init__(self, num_steps, embedding_dim=128, projection_dim=None): + super().__init__() + if projection_dim is None: + projection_dim = embedding_dim + self.register_buffer( + "embedding", + self._build_embedding(num_steps, embedding_dim / 2), + persistent=False, + ) + self.projection1 = nn.Linear(embedding_dim, projection_dim) + self.projection2 = nn.Linear(projection_dim, projection_dim) + + def forward(self, diffusion_step): + x = self.embedding[diffusion_step] + x = self.projection1(x) + x = F.silu(x) + x = self.projection2(x) + x = F.silu(x) + return x + + def _build_embedding(self, num_steps, dim=64): + steps = torch.arange(num_steps).unsqueeze(1) # (T,1) + frequencies = 10.0 ** (torch.arange(dim) / (dim - 1) * 4.0).unsqueeze(0) # (1,dim) + table = steps * frequencies # (T,dim) + table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2) + return table + + +class diff_CSDI(nn.Module): + def __init__( + self, channels, num_diffusion_steps, diffusion_step_embedding_dim, side_dim, nheads, num_residual_blocks, inputdim=2 + ): + super().__init__() + self.channels = channels + + self.diffusion_step_embedding = DiffusionStepEmbedding( + num_steps=num_diffusion_steps, + embedding_dim=diffusion_step_embedding_dim, + ) + + self.input_projection = Conv1d_with_init(inputdim, self.channels, 1) + self.output_projection1 = Conv1d_with_init(self.channels, self.channels, 1) + self.output_projection2 = Conv1d_with_init(self.channels, 1, 1) + nn.init.zeros_(self.output_projection2.weight) + + self.residual_layers = nn.ModuleList( + [ + ResidualBlock( + side_dim=side_dim, + channels=self.channels, + diffusion_embedding_dim=diffusion_step_embedding_dim, + nheads=nheads, + ) + for _ in range(num_residual_blocks) + ] + ) + + def forward(self, x, cond_info, diffusion_step): + B, inputdim, K, L = x.shape + + x = x.reshape(B, inputdim, K * L) + x = self.input_projection(x) + x = F.relu(x) + x = x.reshape(B, self.channels, K, L) + + diffusion_emb = self.diffusion_step_embedding(diffusion_step) + + skip = [] + for layer in self.residual_layers: + x, skip_connection = layer(x, cond_info, diffusion_emb) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) + x = x.reshape(B, self.channels, K * L) + x = self.output_projection1(x) # (B,channel,K*L) + x = F.relu(x) + x = self.output_projection2(x) # (B,1,K*L) + x = x.reshape(B, K, L) + return x + + +class ResidualBlock(nn.Module): + def __init__(self, side_dim, channels, diffusion_embedding_dim, nheads): + super().__init__() + self.diffusion_projection = nn.Linear(diffusion_embedding_dim, channels) + self.cond_projection = Conv1d_with_init(side_dim, 2 * channels, 1) + self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1) + self.output_projection = Conv1d_with_init(channels, 2 * channels, 1) + + self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=channels) + self.feature_layer = get_torch_trans(heads=nheads, layers=1, channels=channels) + + def forward_time(self, y, base_shape): + B, channel, K, L = base_shape + if L == 1: + return y + y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L) + y = self.time_layer(y.permute(2, 0, 1)).permute(1, 2, 0) + y = y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K * L) + return y + + def forward_feature(self, y, base_shape): + B, channel, K, L = base_shape + if K == 1: + return y + y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K) + y = self.feature_layer(y.permute(2, 0, 1)).permute(1, 2, 0) + y = y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K * L) + return y + + def forward(self, x, cond_info, diffusion_emb): + B, channel, K, L = x.shape + base_shape = x.shape + x = x.reshape(B, channel, K * L) + + diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(-1) # (B,channel,1) + y = x + diffusion_emb + + y = self.forward_time(y, base_shape) + y = self.forward_feature(y, base_shape) # (B,channel,K*L) + y = self.mid_projection(y) # (B,2*channel,K*L) + + _, cond_dim, _, _ = cond_info.shape + cond_info = cond_info.reshape(B, cond_dim, K * L) + cond_info = self.cond_projection(cond_info) # (B,2*channel,K*L) + y = y + cond_info + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) # (B,channel,K*L) + y = self.output_projection(y) + + residual, skip = torch.chunk(y, 2, dim=1) + x = x.reshape(base_shape) + residual = residual.reshape(base_shape) + skip = skip.reshape(base_shape) + return (x + residual) / math.sqrt(2.0), skip diff --git a/icu_benchmarks/imputation/diffusion.py b/icu_benchmarks/imputation/diffusion.py new file mode 100644 index 00000000..2f270bd3 --- /dev/null +++ b/icu_benchmarks/imputation/diffusion.py @@ -0,0 +1,306 @@ +# Source: +# https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing#scrollTo=qWw50ui9IZ5q +# https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb#scrollTo=290edb0b +# Tutorial: +# https://m.youtube.com/watch?v=a4Yfz2FxXiY +# Source Paper: +# https://arxiv.org/abs/2006.11239 +# Paper for Cosine Schedule: +# https://arxiv.org/abs/2102.09672 + +from icu_benchmarks.models.wrappers import ImputationWrapper +import gin +import math +import torch +from torch import nn +import torch.nn.functional as F + + +@gin.configurable("Diffusion") +class SimpleDiffusionModel(ImputationWrapper): + """Simple Diffusion Model for Imputation. See https://arxiv.org/abs/2006.11239 for more details.""" + + needs_training = True + needs_fit = False + + input_size = [] + + def __init__(self, input_size, n_onedirectional_conv, T, min_noise, max_noise, noise_scheduler, *args, **kwargs): + super().__init__( + n_onedirectional_conv=n_onedirectional_conv, + T=T, + min_noise=min_noise, + max_noise=max_noise, + noise_scheduler=noise_scheduler, + *args, + **kwargs + ) + + self.n_onedirectional_conv = n_onedirectional_conv + self.T = T + self.min_noise = min_noise + self.max_noise = max_noise + self.noise_scheduler = noise_scheduler + + # == Noise Schedulers == # + # Linear + if self.noise_scheduler == "linear": + self.betas = torch.linspace(self.min_noise, self.max_noise, self.T) + # Quadratic + elif self.noise_scheduler == "quadratic": + self.betas = torch.linspace(self.min_noise**0.5, self.max_noise**0.5, self.T) ** 2 + # Cosine + elif self.noise_scheduler == "cosine": + x = torch.linspace(0, self.T, self.T + 1) + alphas_cumprod = torch.cos(((x / self.T) + 0.008) / (1 + 0.008) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + self.betas = torch.clip(betas, 0.0001, 0.9999) + # Sigmoid + elif self.noise_scheduler == "sigmoid": + betas = torch.linspace(-6, 6, self.T) + self.betas = torch.sigmoid(betas) * (self.max_noise - self.min_noise) + self.min_noise + # Error + else: + raise NotImplementedError( + "Noise Scheduler must be linear, quadratic, cosine or sigmoid.\n Your Entry: [%s] is not implemented" + % self.noise_scheduler + ) + + # Helper Values + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) + self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) + self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + + # Store Input Size + self.input_size = input_size + + # Time embedding + self.time_mlp = nn.Sequential( + SinusoidalPositionEmbeddings(input_size[2]), nn.Linear(input_size[2], input_size[2]), nn.ReLU() + ) + + # Blocks + self.downs = nn.ModuleList() + self.ups = nn.ModuleList() + + for i in range(self.n_onedirectional_conv): + self.downs.append(Block(input_size, i)) + self.ups.append(Block(input_size, (self.n_onedirectional_conv - i), up=True)) + + def forward(self, amputated, timestep): + amputated = torch.nan_to_num(amputated, nan=0.0) + amputated = amputated[:, None, :, :] + x = amputated + + # Embedd time + t = self.time_mlp(timestep) + + # Residual Connections + residuals = [] + + for down in self.downs: + x = down(x, t) + residuals.append(x) + for up in self.ups: + residual = residuals.pop() + x = torch.cat((x, residual), dim=1) + x = up(x, t) + return x.squeeze() + + def training_step(self, batch): + amputated, amputation_mask, target, target_missingness = batch + amputated = torch.nan_to_num(amputated, nan=0.0) + + self.input_size = amputated.shape + + # # Context / Target Split (Credits @Allie) + # # Take an inverse of the amputed mask - to get the observation mask + # observed_mask = (~(amputation_mask > 0)).float() + + # # Generate a random tensor with the same dimensions as mask and multiply mask by it + # # This removes all missing values from the following calculations + # rand_for_mask = torch.rand_like(observed_mask) * observed_mask + + # # Create a context mask - the selection of the elements is so that only 50% of all observed values are selected + # context_mask = (rand_for_mask > 0.5).float() + + # # Create a target mask - the selection of the elements is so that all values not selected by the context mask + # # but are still observed are selected + # target_mask = (~(rand_for_mask > 0.5)).float() * observed_mask + + # context = amputated * context_mask + # target = amputated * target_mask + + # x_0 = context + + x_0 = amputated + + # Take a random timestep + t = torch.randint(0, self.T, (self.input_size[0],)).long() + + # Introduce Noise into the samples according + x_t, noise = self.forward_diffusion_sample(x_0, t) + + # Let the model predict the noise in the noised sample + noise_pred = self(x_t, t) + + # Calculate Loss: Difference between actual noise and noise prediction + loss = F.l1_loss(noise, noise_pred) + + self.log("train/loss", loss.item(), prog_bar=True) + + return loss + + def validation_step(self, batch, batch_idx): + amputated, amputation_mask, target, target_missingness = batch + amputated = torch.nan_to_num(amputated, nan=0.0) + + self.input_size = amputated.shape + + # TODO: - Context / Target Split + x_0 = amputated + + # Take a random timestep + t = torch.randint(0, self.T, (self.input_size[0],)).long() + + # Introduce Noise into the samples according + x_t, noise = self.forward_diffusion_sample(x_0, t) + + # Let the model predict the noise in the noised sample + noise_pred = self(x_t, t) + + # Calculate Loss: Difference between actual noise and noise prediction + loss = F.l1_loss(noise, noise_pred) + + self.log("val/loss", loss.item(), prog_bar=True) + + def test_step(self, batch, batch_idx): + amputated, amputation_mask, target, target_missingness = batch + amputated = torch.nan_to_num(amputated, nan=0.0) + + self.input_size = amputated.shape + + x_0 = amputated + + # Take the last timestep + t = torch.full((self.input_size[0],), self.T - 1) + + # Let the Model predict the noise + noise_pred = self(x_0, t) + + # Calculate the forward sample + x_t, _ = self.forward_diffusion_sample(x_0, t) + + # Calculate the backward sample for timestep 0 replacing the original x_0 and having imputed data + x_0 = self.backward_diffusion_sample(noise_pred, x_t, t) + + # Use the prediction only where the original data is missing + x_0 = amputated.masked_scatter_(amputation_mask.bool(), x_0) + + # Calculate Loss: Difference between imputed and target + loss = self.loss(x_0, target) + + self.log("test/loss", loss.item(), prog_bar=True) + + x_0[target_missingness > 0] = target[target_missingness > 0] + # Update Metrics + for metric in self.metrics["test"].values(): + metric.update((torch.flatten(x_0, start_dim=1), torch.flatten(target, start_dim=1))) + + # Helper function to return a value for a specific timestep t from a list reformatted for the current input size + def get_index_from_list(self, values, t, x_shape): + batch_size = t.shape[0] + out = values.gather(-1, t) + return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) + + # Function that takes an original sample x_0 and introduces random noise for some timestep t + def forward_diffusion_sample(self, x_0, t): + # Random Noise + noise = torch.randn_like(x_0) + + sqrt_alphas_cumprod_t = self.get_index_from_list(self.sqrt_alphas_cumprod, t, x_0.shape) + sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_cumprod, t, x_0.shape) + + # Mean + mean = sqrt_alphas_cumprod_t * x_0 + + # Variance + variance = sqrt_one_minus_alphas_cumprod_t * noise + + # Forward Sample + forward_sample = mean + variance + + return forward_sample, noise + + # + def backward_diffusion_sample(self, noise_pred, x_t, t, t_index=0): + """Function that takes a noised image at some timestep t and the noise prediction and tries to compute the original + sample. The t needs to be one specific timestamp -> always the same value. It does not have to be like + this in the forward diffusion sample function""" + + betas_t = self.get_index_from_list(self.betas, t, x_t.shape) + sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + sqrt_recip_alphas_t = self.get_index_from_list(self.sqrt_recip_alphas, t, x_t.shape) + posterior_variance_t = self.get_index_from_list(self.posterior_variance, t, x_t.shape) + + model_mean = sqrt_recip_alphas_t * (x_t - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t) + + if t_index == 0: + return model_mean + else: + noise = torch.randn_like(x_t) + return model_mean + torch.sqrt(posterior_variance_t) * noise + + +class Block(nn.Module): + def __init__(self, input_size, i, up=False): + super().__init__() + + n_timestamps = input_size[1] - 3 * i + + self.time_mlp = nn.Linear(input_size[2], n_timestamps) + if up: + # take 2 times the number of input channels because residuals were added in the upsampling process + self.conv1 = nn.ConvTranspose2d(2, 1, 3, padding=1) + self.transform = nn.ConvTranspose2d(1, 1, (4, 2)) + else: + self.conv1 = nn.Conv2d(1, 1, 3, padding=1) + self.transform = nn.Conv2d(1, 1, (4, 2)) + self.conv2 = nn.Conv2d(1, 1, 3, padding=1) + self.bnorm1 = nn.BatchNorm2d(1) + self.bnorm2 = nn.BatchNorm2d(1) + self.relu = nn.ReLU() + + def forward(self, x, t): + # First Convolution + h = self.bnorm1(self.relu(self.conv1(x))) + # TODO: - Add Attention Layer before Time Embedding + + # Time Embedding + time_emb = self.relu(self.time_mlp(t[:, None, :])) + # Extend last dimension + time_emb = time_emb[(...,) + (None,)] + # Add time + h += time_emb + # Second Convolution + h = self.bnorm2(self.relu(self.conv2(h))) + return self.transform(h) + + +class SinusoidalPositionEmbeddings(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, time): + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1 + 0.05) + embeddings = torch.exp(torch.arange(half_dim) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + return embeddings diff --git a/icu_benchmarks/imputation/diffwave.py b/icu_benchmarks/imputation/diffwave.py new file mode 100644 index 00000000..0945c0e6 --- /dev/null +++ b/icu_benchmarks/imputation/diffwave.py @@ -0,0 +1,370 @@ +import math + +import gin +import numpy as np +import torch +import torch.nn as nn + +from icu_benchmarks.models.wrappers import ImputationWrapper + + +@gin.configurable("DiffWave") +class DiffWaveImputer(ImputationWrapper): + """Imputation model based on DiffWave (https://arxiv.org/abs/2009.09761). Adapted from + https://github.com/AI4HealthUOL/SSSD/blob/main/src/imputers/DiffWaveImputer.py""" + + def __init__( + self, + in_channels, + res_channels, + skip_channels, + out_channels, + num_res_layers, + dilation_cycle, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + diffusion_time_steps, + beta_0, + beta_T, + *args, + **kwargs, + ): + super(DiffWaveImputer, self).__init__( + in_channels=in_channels, + res_channels=res_channels, + skip_channels=skip_channels, + out_channels=out_channels, + num_res_layers=num_res_layers, + dilation_cycle=dilation_cycle, + diffusion_step_embed_dim_in=diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + diffusion_time_steps=diffusion_time_steps, + beta_0=beta_0, + beta_T=beta_T, + *args, + **kwargs, + ) + + self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU()) + + self.residual_layer = Residual_group( + res_channels=res_channels, + skip_channels=skip_channels, + num_res_layers=num_res_layers, + dilation_cycle=dilation_cycle, + diffusion_step_embed_dim_in=diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + in_channels=in_channels, + ) + + self.final_conv = nn.Sequential( + Conv(skip_channels, skip_channels, kernel_size=1), nn.ReLU(), ZeroConv1d(skip_channels, out_channels) + ) + + self.diffusion_parameters = calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T) + + def on_fit_start(self) -> None: + self.diffusion_parameters = { + k: v.to(self.device) for k, v in self.diffusion_parameters.items() if isinstance(v, torch.Tensor) + } + return super().on_fit_start() + + def forward(self, input_data): + noise, conditional, mask, diffusion_steps = input_data + + conditional = conditional * mask + conditional = torch.cat([conditional, mask.float()], dim=1) + + x = noise + x = self.init_conv(x) + x = self.residual_layer((x, conditional, diffusion_steps)) + y = self.final_conv(x) + + return y + + def step_fn(self, batch, step_prefix=""): + amputated_data, amputation_mask, target, target_missingness = batch + + amputated_data = torch.nan_to_num(amputated_data).permute(0, 2, 1) + amputation_mask = amputation_mask.permute(0, 2, 1).bool() + observed_mask = 1 - amputation_mask.float() + + if step_prefix in ["train", "val"]: + T, Alpha_bar = self.hparams.diffusion_time_steps, self.diffusion_parameters["Alpha_bar"] + + B, C, L = amputated_data.shape # B is batchsize, C=1, L is audio length + diffusion_steps = torch.randint(T, size=(B, 1, 1)).to(self.device) # randomly sample diffusion steps from 1~T + + z = std_normal(amputated_data.shape, self.device) + z = amputated_data * observed_mask.float() + z * (1 - observed_mask).float() + transformed_X = ( + torch.sqrt(Alpha_bar[diffusion_steps]) * amputated_data + torch.sqrt(1 - Alpha_bar[diffusion_steps]) * z + ) # compute x_t from q(x_t|x_0) + epsilon_theta = self( + ( + transformed_X, + amputated_data, + observed_mask, + diffusion_steps.view(B, 1), + ) + ) # predict \epsilon according to \epsilon_\theta + + loss = self.loss(epsilon_theta[amputation_mask.bool()], z[amputation_mask.bool()]) + else: + target = target.permute(0, 2, 1) + target_missingness = target_missingness.permute(0, 2, 1) + imputed_data = self.sampling(amputated_data, observed_mask) + amputated_data[amputation_mask.bool()] = imputed_data[amputation_mask.bool()] + amputated_data[target_missingness > 0] = target[target_missingness > 0] + loss = self.loss(amputated_data, target) + for metric in self.metrics[step_prefix].values(): + metric.update((torch.flatten(amputated_data, start_dim=1).clone(), torch.flatten(target, start_dim=1).clone())) + + self.log(f"{step_prefix}/loss", loss.item(), prog_bar=True) + return loss + + def sampling(self, cond, mask): + """ + Perform the complete sampling step according to p(x_0|x_T) = prod_{t=1}^T p_{\theta}(x_{t-1}|x_t) + + Parameters: + net (torch network): the wavenet model + size (tuple): size of tensor to be generated, + usually is (number of audios to generate, channels=1, length of audio) + diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams + note, the tensors need to be cuda tensors + + Returns: + the generated audio(s) in torch.tensor, shape=size + """ + + Alpha, Alpha_bar, Sigma = ( + self.diffusion_parameters["Alpha"], + self.diffusion_parameters["Alpha_bar"], + self.diffusion_parameters["Sigma"], + ) + + T = self.hparams.diffusion_time_steps + assert len(Alpha) == T + assert len(Alpha_bar) == T + assert len(Sigma) == T + + B, _, _ = cond.shape + x = std_normal(cond.shape, self.device) + + for t in range(T - 1, -1, -1): + x = x * (1 - mask).float() + cond * mask.float() + diffusion_steps = (t * torch.ones((B, 1))).to(self.device) # use the corresponding reverse step + epsilon_theta = self( + ( + x, + cond, + mask, + diffusion_steps, + ) + ) # predict \epsilon according to \epsilon_\theta + # update x_{t-1} to \mu_\theta(x_t) + x = (x - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t]) + if t > 0: + x = x + Sigma[t] * std_normal(cond.shape, self.device) # add the variance term to x_{t-1} + + return x + + +def swish(x): + return x * torch.sigmoid(x) + + +def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in, device): + """ + Embed a diffusion step $t$ into a higher dimensional space + E.g. the embedding vector in the 128-dimensional space is + [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] + Parameters: + diffusion_steps (torch.long tensor, shape=(batchsize, 1)): + diffusion steps for batch data + diffusion_step_embed_dim_in (int, default=128): + dimensionality of the embedding space for discrete diffusion steps + + Returns: + the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)): + """ + + assert diffusion_step_embed_dim_in % 2 == 0 + + half_dim = diffusion_step_embed_dim_in // 2 + _embed = np.log(10000) / (half_dim - 1) + _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device) + _embed = diffusion_steps * _embed + diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1) + + return diffusion_step_embed + + +def calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T): + """ + Compute diffusion process hyperparameters + + Parameters: + T (int): number of diffusion steps + beta_0 and beta_T (float): beta schedule start/end value, + where any beta_t in the middle is linearly interpolated + + Returns: + a dictionary of diffusion hyperparameters including: + T (int), Beta/Alpha/Alpha_bar/Sigma (torch.tensor on cpu, shape=(T, )) + These cpu tensors are changed to cuda tensors on each individual gpu + """ + + Beta = torch.linspace(beta_0, beta_T, diffusion_time_steps) # Linear schedule + Alpha = 1 - Beta + Alpha_bar = Alpha + 0 + Beta_tilde = Beta + 0 + for t in range(1, diffusion_time_steps): + Alpha_bar[t] *= Alpha_bar[t - 1] # \bar{\alpha}_t = \prod_{s=1}^t \alpha_s + Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (1 - Alpha_bar[t]) # \tilde{\beta}_t = \beta_t * (1-\bar{\alpha}_{t-1}) + # / (1-\bar{\alpha}_t) + Sigma = torch.sqrt(Beta_tilde) # \sigma_t^2 = \tilde{\beta}_t + + _dh = {} + _dh["T"], _dh["Beta"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"] = diffusion_time_steps, Beta, Alpha, Alpha_bar, Sigma + diffusion_hyperparams = _dh + return diffusion_hyperparams + + +class Conv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): + super(Conv, self).__init__() + self.padding = dilation * (kernel_size - 1) // 2 + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) + self.conv = nn.utils.weight_norm(self.conv) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, x): + out = self.conv(x) + return out + + +class ZeroConv1d(nn.Module): + def __init__(self, in_channel, out_channel): + super(ZeroConv1d, self).__init__() + self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0) + self.conv.weight.data.zero_() + self.conv.bias.data.zero_() + + def forward(self, x): + out = self.conv(x) + return out + + +class Residual_block(nn.Module): + def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out, in_channels): + super(Residual_block, self).__init__() + + self.res_channels = res_channels + # the layer-specific fc for diffusion step embedding + self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels) + + # dilated conv layer + self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation) + + # add mel spectrogram upsampler and conditioner conv1x1 layer (In adapted to S4 output) + self.cond_conv = Conv(2 * in_channels, 2 * self.res_channels, kernel_size=1) # 80 is mel bands + + # residual conv1x1 layer, connect to next residual layer + self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1) + self.res_conv = nn.utils.weight_norm(self.res_conv) + nn.init.kaiming_normal_(self.res_conv.weight) + + # skip conv1x1 layer, add to all skip outputs through skip connections + self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1) + self.skip_conv = nn.utils.weight_norm(self.skip_conv) + nn.init.kaiming_normal_(self.skip_conv.weight) + + def forward(self, input_data): + x, cond, diffusion_step_embed = input_data + h = x + B, C, L = x.shape + assert C == self.res_channels + + part_t = self.fc_t(diffusion_step_embed) + part_t = part_t.view([B, self.res_channels, 1]) + h = h + part_t + + h = self.dilated_conv_layer(h) + # add (local) conditioner + assert cond is not None + + cond = self.cond_conv(cond) + h += cond + + out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) + + res = self.res_conv(out) + assert x.shape == res.shape + skip = self.skip_conv(out) + + return (x + res) * math.sqrt(0.5), skip + + +class Residual_group(nn.Module): + def __init__( + self, + res_channels, + skip_channels, + num_res_layers, + dilation_cycle, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + in_channels, + ): + super(Residual_group, self).__init__() + self.num_res_layers = num_res_layers + self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in + + self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) + self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) + + self.residual_blocks = nn.ModuleList() + for n in range(self.num_res_layers): + self.residual_blocks.append( + Residual_block( + res_channels, + skip_channels, + dilation=2 ** (n % dilation_cycle), + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + in_channels=in_channels, + ) + ) + + def get_device(self): + return next(self.parameters()).device + + def forward(self, input_data): + noise, conditional, diffusion_steps = input_data + + diffusion_step_embed = calc_diffusion_step_embedding( + diffusion_steps, self.diffusion_step_embed_dim_in, self.get_device() + ) + diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) + diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) + + h = noise + skip = 0 + for n in range(self.num_res_layers): + h, skip_n = self.residual_blocks[n]((noise, conditional, diffusion_step_embed)) + skip += skip_n + + return skip * math.sqrt(1.0 / self.num_res_layers) # normalize for training stability + + +def std_normal(size, device): + """ + Generate the standard Gaussian variable of a certain size + """ + + return torch.normal(0, 1, size=size).to(device) diff --git a/icu_benchmarks/imputation/layers/s4layer.py b/icu_benchmarks/imputation/layers/s4layer.py new file mode 100644 index 00000000..0691c7ef --- /dev/null +++ b/icu_benchmarks/imputation/layers/s4layer.py @@ -0,0 +1,1179 @@ +# Source: https://github.com/AI4HealthUOL/SSSD +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from functools import partial +from scipy import special as ss +from einops import rearrange, repeat +import opt_einsum as oe + +contract = oe.contract +contract_expression = oe.contract_expression + +""" Standalone CSDI + S4 imputer for random missing, non-random missing and black-out missing. +The notebook contains CSDI and S4 functions and utilities. However the imputer is located in the last Class of +the notebook, please see more documentation of use there. Additional at this file can be added for CUDA multiplication +the cauchy kernel.""" + +""" Cauchy kernel """ + +try: # Try pykeops + from pykeops.torch import Genred + + has_pykeops = True + + def cauchy_conj(v, z, w): + """Pykeops version""" + expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))" + expr_denom = "ComplexMult(z-w, z-Conj(w))" + + cauchy_mult = Genred( + f"ComplexDivide({expr_num}, {expr_denom})", + # expr_num, + # expr_denom, + [ + "v = Vj(2)", + "z = Vi(2)", + "w = Vj(2)", + ], + reduction_op="Sum", + axis=1, + dtype="float32" if v.dtype == torch.cfloat else "float64", + ) + + v, z, w = _broadcast_dims(v, z, w) + v = _c2r(v) + z = _c2r(z) + w = _c2r(w) + + r = 2 * cauchy_mult(v, z, w, backend="GPU") + return _r2c(r) + +except ImportError: + has_pykeops = False + + def cauchy_slow(v, z, w): + """ + v, w: (..., N) + z: (..., L) + returns: (..., L) + """ + cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L) + return torch.sum(cauchy_matrix, dim=-2) + + +def _broadcast_dims(*tensors): + max_dim = max([len(tensor.shape) for tensor in tensors]) + tensors = [tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape) for tensor in tensors] + return tensors + + +def _c2r(input): + return torch.view_as_real(input) + + +def _r2c(input): + return torch.view_as_complex(input) + + +def _conj(x): + return torch.cat([x, x.conj()], dim=-1) + + +if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): + + def _resolve_conj(x): + return x.conj().resolve_conj() + +else: + + def _resolve_conj(x): + return x.conj() + +""" simple nn.Module components """ + + +def Activation(activation=None, dim=-1): + if activation in [None, "id", "identity", "linear"]: + return nn.Identity() + elif activation == "tanh": + return nn.Tanh() + elif activation == "relu": + return nn.ReLU() + elif activation == "gelu": + return nn.GELU() + elif activation in ["swish", "silu"]: + return nn.SiLU() + elif activation == "glu": + return nn.GLU(dim=dim) + elif activation == "sigmoid": + return nn.Sigmoid() + else: + raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) + + +def get_initializer(name, activation=None): + if activation in [None, "id", "identity", "linear", "modrelu"]: + nonlinearity = "linear" + elif activation in ["relu", "tanh", "sigmoid"]: + nonlinearity = activation + elif activation in ["gelu", "swish", "silu"]: + nonlinearity = "relu" # Close to ReLU so approximate with ReLU's gain + else: + raise NotImplementedError(f"get_initializer: activation {activation} not supported") + + if name == "uniform": + initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) + elif name == "normal": + initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) + elif name == "xavier": + initializer = torch.nn.init.xavier_normal_ + elif name == "zero": + initializer = partial(torch.nn.init.constant_, val=0) + elif name == "one": + initializer = partial(torch.nn.init.constant_, val=1) + else: + raise NotImplementedError(f"get_initializer: initializer type {name} not supported") + + return initializer + + +class TransposedLinear(nn.Module): + """Linear module on the second-to-last dimension""" + + def __init__(self, d_input, d_output, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.empty(d_output, d_input)) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init + # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent + + if bias: + self.bias = nn.Parameter(torch.empty(d_output, 1)) + bound = 1 / math.sqrt(d_input) + nn.init.uniform_(self.bias, -bound, bound) + else: + self.bias = 0.0 + + def forward(self, x): + return contract("... u l, v u -> ... v l", x, self.weight) + self.bias + + +def LinearActivation( + d_input, + d_output, + bias=True, + zero_bias_init=False, + transposed=False, + initializer=None, + activation=None, + activate=False, # Apply activation as part of this module + weight_norm=False, + **kwargs, +): + """Returns a linear nn.Module with control over axes order, initialization, and activation""" + + # Construct core module + linear_cls = TransposedLinear if transposed else nn.Linear + if activation == "glu": + d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + # Initialize weight + if initializer is not None: + get_initializer(initializer, activation)(linear.weight) + + # Initialize bias + if bias and zero_bias_init: + nn.init.zeros_(linear.bias) + + # Weight norm + if weight_norm: + linear = nn.utils.weight_norm(linear) + + if activate and activation is not None: + activation = Activation(activation, dim=-2 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + + +""" Misc functional utilities """ + + +def krylov(L, A, b, c=None, return_power=False): + """ + Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. + + If return_power=True, return A^{L-1} as well + """ + # TODO: There is an edge case if L=1 where output doesn't get broadcast, + # which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises + + x = b.unsqueeze(-1) # (..., N, 1) + A_ = A + + AL = None + if return_power: + AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) + _L = L - 1 + + done = L == 1 + # loop invariant: _L represents how many indices left to compute + while not done: + if return_power: + if _L % 2 == 1: + AL = A_ @ AL + _L //= 2 + + # Save memory on last iteration + _l = x.shape[-1] + if L - _l <= _l: + done = True + _x = x[..., : L - _l] + else: + _x = x + + _x = A_ @ _x + x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes + if not done: + A_ = A_ @ A_ + + assert x.shape[-1] == L + + if c is not None: + x = torch.einsum("...nl, ...n -> ...l", x, c) + x = x.contiguous() # WOW!! + if return_power: + return x, AL + else: + return x + + +def power(L, A, v=None): + """Compute A^L and the scan sum_i A^i v_i + + A: (..., N, N) + v: (..., N, L) + """ + + _I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + + powers = [A] + _l = 1 + while True: + if L % 2 == 1: + _I = powers[-1] @ _I + L //= 2 + if L == 0: + break + _l *= 2 + powers.append(powers[-1] @ powers[-1]) + + if v is None: + return _I + + # Invariants: + # powers[-1] := A^l + # l := largest po2 at most L + + # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop + # without caching intermediate powers of A. + # We do this reverse divide-and-conquer for efficiency reasons: + # 1) it involves fewer padding steps for non-po2 L + # 2) it involves more contiguous arrays + + # Take care of edge case for non-po2 arrays + # Note that this initial step is a no-op for the case of power of 2 (l == L) + k = v.size(-1) - _l + v_ = powers.pop() @ v[..., _l:] + v = v[..., :_l] + v[..., :k] = v[..., :k] + v_ + + # Handle reduction for power of 2 + while v.size(-1) > 1: + v = rearrange(v, "... (z l) -> ... z l", z=2) + v = v[..., 0, :] + powers.pop() @ v[..., 1, :] + return _I, v.squeeze(-1) + + +""" HiPPO utilities """ + + +def embed_c2r(A): + A = rearrange(A, "... m n -> ... m () n ()") + A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + np.pad(A, ((0, 0), (1, 0), (0, 0), (1, 0))) + return rearrange(A, "m x n y -> (m x) (n y)") + + +def transition(measure, N, **measure_args): + """A, B transition matrices for different measures + + measure: the type of measure + legt - Legendre (translated) + legs - Legendre (scaled) + glagt - generalized Laguerre (translated) + lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization + """ + # Laguerre (translated) + if measure == "lagt": + b = measure_args.get("beta", 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + # Generalized Laguerre + # alpha 0, beta small is most stable (limits to the 'lagt' measure) + # alpha 0, beta 1 has transition matrix A = [lower triangular 1] + elif measure == "glagt": + alpha = measure_args.get("alpha", 0.0) + beta = measure_args.get("beta", 0.01) + A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) + B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] + + L = np.exp(0.5 * (ss.gammaln(np.arange(N) + alpha + 1) - ss.gammaln(np.arange(N) + 1))) + A = (1.0 / L[:, None]) * A * L[None, :] + B = (1.0 / L[:, None]) * B * np.exp(-0.5 * ss.gammaln(1 - alpha)) * beta ** ((1 - alpha) / 2) + # Legendre (translated) + elif measure == "legt": + Q = np.arange(N, dtype=np.float64) + R = (2 * Q + 1) ** 0.5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :] + B = R[:, None] + A = -A + # Legendre (scaled) + elif measure == "legs": + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure == "fourier": + freqs = np.arange(N // 2) + d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1] + A = 2 * np.pi * (np.diag(d, 1) - np.diag(d, -1)) + A = A - embed_c2r(np.ones((N // 2, N // 2))) + B = embed_c2r(np.ones((N // 2, 1)))[..., :1] + elif measure == "random": + A = np.random.randn(N, N) / N + B = np.random.randn(N, 1) + elif measure == "diagonal": + A = -np.diag(np.exp(np.random.randn(N))) + B = np.random.randn(N, 1) + else: + raise NotImplementedError + + return A, B + + +def rank_correction(measure, N, rank=1, dtype=torch.float): + """Return low-rank matrix L such that A + L is normal""" + + if measure == "legs": + assert rank >= 1 + P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) + elif measure == "legt": + assert rank >= 2 + P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N) + P0 = P.clone() + P0[0::2] = 0.0 + P1 = P.clone() + P1[1::2] = 0.0 + P = torch.stack([P0, P1], dim=0) # (2 N) + elif measure == "lagt": + assert rank >= 1 + P = 0.5 ** 0.5 * torch.ones(1, N, dtype=dtype) + elif measure == "fourier": + P = torch.ones(N, dtype=dtype) # (N) + P0 = P.clone() + P0[0::2] = 0.0 + P1 = P.clone() + P1[1::2] = 0.0 + P = torch.stack([P0, P1], dim=0) # (2 N) + else: + raise NotImplementedError + + d = P.size(0) + if rank > d: + P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N) + return P + + +def nplr(measure, N, rank=1, dtype=torch.float): + """Return w, p, q, V, B such that + (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V + i.e. A = V[w - p q^*]V^*, B = V B + """ + assert dtype == torch.float or torch.cfloat + if measure == "random": + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + # w = torch.randn(N//2, dtype=dtype) + w = -torch.exp(torch.randn(N // 2)) + 1j * torch.randn(N // 2) + P = torch.randn(rank, N // 2, dtype=dtype) + B = torch.randn(N // 2, dtype=dtype) + V = torch.eye(N, dtype=dtype)[..., : N // 2] # Only used in testing + return w, P, B, V + + A, B = transition(measure, N) + A = torch.as_tensor(A, dtype=dtype) # (N, N) + B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) + + P = rank_correction(measure, N, rank=rank, dtype=dtype) + AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3) + w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) + # V w V^{-1} = A + + # Only keep one of the conjugate pairs + w = w[..., 0::2].contiguous() + V = V[..., 0::2].contiguous() + + V_inv = V.conj().transpose(-1, -2) + + B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B + P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P + + return w, P, B, V + + +def bilinear(dt, A, B=None): + """ + dt: (...) timescales + A: (... N N) + B: (... N) + """ + N = A.shape[-1] + _I = torch.eye(N).to(A) + A_backwards = _I - dt[:, None, None] / 2 * A + A_forwards = _I + dt[:, None, None] / 2 * A + + if B is None: + dB = None + else: + dB = dt[..., None] * torch.linalg.solve(A_backwards, B.unsqueeze(-1)).squeeze(-1) # (... N) + + dA = torch.linalg.solve(A_backwards, A_forwards) # (... N N) + return dA, dB + + +class SSKernelNPLR(nn.Module): + """Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state s + pace, where A is Normal + Low Rank (NPLR). The class name stands for 'State-Space SSKernel for Normal Plus Low-Rank'. + The parameters of this function are as follows. + + Args: + A: (... N N) the state matrix + B: (... N) input matrix + C: (... N) output matrix + dt: (...) timescales / discretization step size + p, q: (... P N) low-rank correction to A, such that Ap=A+pq^T is a normal matrix + + The forward pass of this Module returns: + (... L) that represents represents FFT SSKernel_L(A^dt, B^dt, C) + + """ + + @torch.no_grad() + def _setup_C(self, double_length=False): + """Construct C~ from C + + double_length: current C is for length L, convert it to length 2L + """ + C = _r2c(self.C) + self._setup_state() + dA_L = power(self.L, self.dA) + # Multiply C by I - dA_L + C_ = _conj(C) + prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) + if double_length: + prod = -prod # Multiply by I + dA_L instead + C_ = C_ - prod + C_ = C_[..., : self.N] # Take conjugate pairs again + self.C.copy_(_c2r(C_)) + + if double_length: + self.L *= 2 + self._omega(self.L, dtype=C.dtype, device=C.device, cache=True) + + def _omega(self, L, dtype, device, cache=True): + """Calculate (and cache) FFT nodes and their "unprocessed" them with the bilinear transform + This should be called everytime the internal length self.L changes""" + omega = torch.tensor(np.exp(-2j * np.pi / (L)), dtype=dtype, device=device) # \omega_{2L} + omega = omega ** torch.arange(0, L // 2 + 1, device=device) + z = 2 * (1 - omega) / (1 + omega) + if cache: + self.register_buffer("omega", _c2r(omega)) + self.register_buffer("z", _c2r(z)) + return omega, z + + def __init__( + self, + L, + w, + P, + B, + C, + log_dt, + hurwitz=False, + trainable=None, + lr=None, + tie_state=False, + length_correction=True, + ): + """ + L: Maximum length; this module computes an SSM kernel of length L + w: (N) + p: (r, N) low-rank correction to A + q: (r, N) + A represented by diag(w) - pq^* + + B: (N) + dt: (H) timescale per feature + C: (H, C, N) system is 1-D to c-D (channels) + + hurwitz: tie pq and ensure w has negative real part + trainable: toggle which of the parameters is trainable + lr: add hook to set lr of hippo parameters specially (everything besides C) + tie_state: tie all state parameters across the H hidden features + length_correction: multiply C by (I - dA^L) - can be turned off when L is large for slight speedup at initialization + (only relevant when N large as well) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + super().__init__() + self.hurwitz = hurwitz + self.tie_state = tie_state + + # Rank of low-rank correction + self.rank = P.shape[-2] + assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = w.size(-1) + + # Broadcast everything to correct shapes + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))).clone() # (H, C, N) + H = 1 if self.tie_state else self.H + B = repeat(B, "n -> 1 h n", h=H).clone() + P = repeat(P, "r n -> r h n", h=H).clone() + w = repeat(w, "n -> h n", h=H).clone() + + # Cache Fourier nodes every time we set up a desired length + self.L = L + if self.L is not None: + self._omega(self.L, dtype=C.dtype, device=C.device, cache=True) + + # Register parameters + # C is a regular parameter, not state + # self.C = nn.Parameter(_c2r(C.conj().resolve_conj())) + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + train = False + if trainable is None: + trainable = {} + if not trainable: + trainable = {} + if trainable: + trainable, train = {}, True + self.register("log_dt", log_dt, trainable.get("dt", train), lr, 0.0) + self.register("B", _c2r(B), trainable.get("B", train), lr, 0.0) + self.register("P", _c2r(P), trainable.get("P", train), lr, 0.0) + if self.hurwitz: + log_w_real = torch.log(-w.real + 1e-3) # Some of the HiPPO methods have real part 0 + w_imag = w.imag + self.register("log_w_real", log_w_real, trainable.get("A", 0), lr, 0.0) + self.register("w_imag", w_imag, trainable.get("A", train), lr, 0.0) + self.Q = None + else: + self.register("w", _c2r(w), trainable.get("A", train), lr, 0.0) + # self.register("Q", _c2r(P.clone().conj().resolve_conj()), trainable.get('P', train), lr, 0.0) + Q = _resolve_conj(P.clone()) + self.register("Q", _c2r(Q), trainable.get("P", train), lr, 0.0) + + if length_correction: + self._setup_C() + + def _w(self): + # Get the internal w (diagonal) parameter + if self.hurwitz: + w_real = -torch.exp(self.log_w_real) + w_imag = self.w_imag + w = w_real + 1j * w_imag + else: + w = _r2c(self.w) # (..., N) + return w + + def forward(self, state=None, rate=1.0, L=None): + """ + Args: + state: (..., s, N) extra tensor that augments B + rate: sampling rate factor + + returns: (..., c+s, L) + """ + # Handle sampling rate logic + # The idea is that this kernel's length (in continuous units) is self.L, + # while we are asked to provide a kernel of length L at (relative) sampling rate. + # If either are not passed in, assume we're not asked to change the scale of our kernel + assert not (rate is None and L is None) + if rate is None: + rate = self.L / L + if L is None: + L = int(self.L / rate) + + # Increase the internal length if needed + while rate * L > self.L: + self.double_length() + + dt = torch.exp(self.log_dt) * rate + B = _r2c(self.B) + C = _r2c(self.C) + P = _r2c(self.P) + Q = P.conj() if self.Q is None else _r2c(self.Q) + w = self._w() + + if rate == 1.0: + # Use cached FFT nodes + omega, z = _r2c(self.omega), _r2c(self.z) # (..., L) + else: + omega, z = self._omega(int(self.L / rate), dtype=w.dtype, device=w.device, cache=False) + + if self.tie_state: + B = repeat(B, "... 1 n -> ... h n", h=self.H) + P = repeat(P, "... 1 n -> ... h n", h=self.H) + Q = repeat(Q, "... 1 n -> ... h n", h=self.H) + + # Augment B + if state is not None: + # Have to "unbilinear" the state to put it into the same "type" as B + # Compute 1/dt * (I + dt/2 A) @ state + + # Can do this without expanding (maybe minor speedup using conj symmetry in theory), + # but it's easier to read this way + s = _conj(state) if state.size(-1) == self.N else state # (B H N) + sA = s * _conj(w) - contract("bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P)) # (B H N) + s = s / dt.unsqueeze(-1) + sA / 2 + s = s[..., : self.N] + + B = torch.cat([s, B], dim=-3) # (s+1, H, N) + + # Incorporate dt into A + w = w * dt.unsqueeze(-1) # (H N) + + # Stack B and p, C and q for convenient batching + B = torch.cat([B, P], dim=-3) # (s+1+r, H, N) + C = torch.cat([C, Q], dim=-3) # (c+r, H, N) + + # Incorporate B and C batch dimensions + v = B.unsqueeze(-3) * C.unsqueeze(-4) # (s+1+r, c+r, H, N) + # w = w[None, None, ...] # (1, 1, H, N) + # z = z[None, None, None, ...] # (1, 1, 1, L) + + # Calculate resolvent at omega + # if has_cauchy_extension and z.dtype == torch.cfloat: + # r = cauchy_mult(v, z, w, symmetric=True) + if has_pykeops: + r = cauchy_conj(v, z, w) + else: + r = cauchy_slow(v, z, w) + r = r * dt[None, None, :, None] # (S+1+R, C+R, H, L) + + # Low-rank Woodbury correction + if self.rank == 1: + k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (1 + r[-1:, -1:, :, :]) + elif self.rank == 2: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank:, :, :] + r10 = r[-self.rank:, : -self.rank, :, :] + r11 = r[-self.rank:, -self.rank:, :, :] + det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[:1, 1:, :, :] * r11[1:, :1, :, :] + s = ( + r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] + - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] + - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] + ) + s = s / det + k_f = r00 - s + else: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank:, :, :] + r10 = r[-self.rank:, : -self.rank, :, :] + r11 = r[-self.rank:, -self.rank:, :, :] + r11 = rearrange(r11, "a b h n -> h n a b") + r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) + r11 = rearrange(r11, "h n a b -> a b h n") + k_f = r00 - torch.einsum("i j h n, j k h n, k l h n -> i l h n", r01, r11, r10) + + # Final correction for the bilinear transform + k_f = k_f * 2 / (1 + omega) + + # Move from frequency to coefficients + k = torch.fft.irfft(k_f) # (S+1, C, H, L) + + # Truncate to target length + k = k[..., :L] + + if state is not None: + k_state = k[:-1, :, :, :] # (S, C, H, L) + else: + k_state = None + k_B = k[-1, :, :, :] # (C H L) + return k_B, k_state + + @torch.no_grad() + def double_length(self): + self._setup_C(double_length=True) + + def _setup_linear(self): + """Create parameters that allow fast linear stepping of state""" + w = self._w() + B = _r2c(self.B) # (H N) + P = _r2c(self.P) + Q = P.conj() if self.Q is None else _r2c(self.Q) + + # Prepare Linear stepping + dt = torch.exp(self.log_dt) + D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) + R = ( + torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real + ) # (H r r) + Q_D = rearrange(Q * D, "r h n -> h r n") + R = torch.linalg.solve(R.to(Q_D), Q_D) # (H r N) + R = rearrange(R, "h r n -> r h n") + + self.step_params = { + "D": D, # (H N) + "R": R, # (r H N) + "P": P, # (r H N) + "Q": Q, # (r H N) + "B": B, # (1 H N) + "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) + } + + def _step_state_linear(self, u=None, state=None): + """ + Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form + and bilinear discretization. + + Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps + a fused CUDA kernel implementation would be much faster + + u: (H) input + state: (H, N/2) state with conjugate pairs + Optionally, the state can have last dimension N + Returns: same shape as state + """ + C = _r2c(self.C) # View used for dtype/device + + if u is None: # Special case used to find dA + u = torch.zeros(self.H, dtype=C.dtype, device=C.device) + if state is None: # Special case used to find dB + state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) + + step_params = self.step_params.copy() + if state.size(-1) == self.N: # Only store half of the conjugate pairs; should be true by default + # There should be a slightly faster way using conjugate symmetry + + def contract_fn(p, x, y): + return contract("r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y))[ + ..., : self.N + ] # inner outer product + + else: + assert state.size(-1) == 2 * self.N + step_params = {k: _conj(v) for k, v in step_params.items()} + + # TODO worth setting up a contract_expression in default_state if we want to use this at inference + # time for stepping + + def contract_fn(p, x, y): + return contract("r h n, r h m, ... h m -> ... h n", p, x, y) # inner outer product + + D = step_params["D"] # (H N) + E = step_params["E"] # (H N) + R = step_params["R"] # (r H N) + P = step_params["P"] # (r H N) + Q = step_params["Q"] # (r H N) + B = step_params["B"] # (1 H N) + + new_state = E * state - contract_fn(P, Q, state) # (B H N) + new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) + new_state = D * (new_state - contract_fn(P, R, new_state)) + + return new_state + + def _setup_state(self): + """Construct dA and dB for discretized state equation""" + + # Construct dA and dB by using the stepping + self._setup_linear() + C = _r2c(self.C) # Just returns a view that we use for finding dtype/device + + state = torch.eye(2 * self.N, dtype=C.dtype, device=C.device).unsqueeze(-2) # (N 1 N) + dA = self._step_state_linear(state=state) + dA = rearrange(dA, "n h m -> h m n") + self.dA = dA # (H N N) + + u = C.new_ones(self.H) + dB = self._step_state_linear(u=u) + dB = _conj(dB) + self.dB = rearrange(dB, "1 h n -> h n") # (H N) + + def _step_state(self, u, state): + """Must be called after self.default_state() is used to construct an initial state!""" + next_state = self.state_contraction(self.dA, state) + self.input_contraction(self.dB, u) + return next_state + + def setup_step(self, mode="dense"): + """Set up dA, dB, dC discretized parameters for stepping""" + self._setup_state() + + # Calculate original C + dA_L = power(self.L, self.dA) + _I = torch.eye(self.dA.size(-1)).to(dA_L) + C = _conj(_r2c(self.C)) # (H C N) + + dC = torch.linalg.solve( + _I - dA_L.transpose(-1, -2), + C.unsqueeze(-1), + ).squeeze(-1) + self.dC = dC + + # Do special preprocessing for different step modes + + self._step_mode = mode + if mode == "linear": + # Linear case: special step function for the state, we need to handle output + # use conjugate symmetry by default, which affects the output projection + self.dC = 2 * self.dC[:, :, : self.N] + elif mode == "diagonal": + # Eigendecomposition of the A matrix + L, V = torch.linalg.eig(self.dA) + V_inv = torch.linalg.inv(V) + + # Change the parameterization to diagonalize + self.dA = L + self.dB = contract("h n m, h m -> h n", V_inv, self.dB) + self.dC = contract("h n m, c h n -> c h m", V, self.dC) + + elif mode == "dense": + pass + else: + raise NotImplementedError("NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}") + + def default_state(self, *batch_shape): + C = _r2c(self.C) + N = C.size(-1) + H = C.size(-2) + + # Cache the tensor contractions we will later do, for efficiency + # These are put in this function because they depend on the batch size + if self._step_mode != "linear": + N *= 2 + + if self._step_mode == "diagonal": + self.state_contraction = contract_expression( + "h n, ... h n -> ... h n", + (H, N), + batch_shape + (H, N), + ) + else: + # Dense (quadratic) case: expand all terms + self.state_contraction = contract_expression( + "h m n, ... h n -> ... h m", + (H, N, N), + batch_shape + (H, N), + ) + + self.input_contraction = contract_expression( + "h n, ... h -> ... h n", + (H, N), # self.dB.shape + batch_shape + (H,), + ) + + self.output_contraction = contract_expression( + "c h n, ... h n -> ... c h", + (C.shape[0], H, N), # self.dC.shape + batch_shape + (H, N), + ) + + state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + """Must have called self.setup_step() and created state with self.default_state() before calling this""" + + if self._step_mode == "linear": + new_state = self._step_state_linear(u, state) + else: + new_state = self._step_state(u, state) + y = self.output_contraction(self.dC, new_state) + return y, new_state + + def register(self, name, tensor, trainable=False, lr=None, wd=None): + """Utility method: register a tensor as a buffer or trainable parameter""" + + if trainable: + self.register_parameter(name, nn.Parameter(tensor)) + else: + self.register_buffer(name, tensor) + + optim = {} + if trainable and lr is not None: + optim["lr"] = lr + if trainable and wd is not None: + optim["weight_decay"] = wd + if len(optim) > 0: + setattr(getattr(self, name), "_optim", optim) + + +class HippoSSKernel(nn.Module): + """Wrapper around SSKernel that generates A, B, C, dt according to HiPPO arguments. + + The SSKernel is expected to support the interface + forward() + default_state() + setup_step() + step() + """ + + def __init__( + self, + H, + N=64, + L=1, + measure="legs", + rank=1, + channels=1, # 1-dim to C-dim map; can think of C as having separate "heads" + dt_min=0.001, + dt_max=0.1, + trainable=None, # Dictionary of options to train various HiPPO parameters + lr=None, # Hook to set LR of hippo parameters differently + length_correction=True, # Multiply by I-A|^L after initialization; can be turned off for initialization speed + hurwitz=False, + tie_state=False, # Tie parameters of HiPPO ODE across the H features + precision=1, # 1 (single) or 2 (double) for the kernel + resample=False, # If given inputs of different lengths, adjust the sampling rate. + # Note that L should always be provided in this case, as it assumes that L is the true underlying + # length of the continuous signal + ): + super().__init__() + self.N = N + self.H = H + L = L or 1 + self.precision = precision + dtype = torch.double if self.precision == 2 else torch.float + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble + self.rate = None if resample else 1.0 + self.channels = channels + + # Generate dt + log_dt = torch.rand(self.H, dtype=dtype) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + + w, p, B, _ = nplr(measure, self.N, rank, dtype=dtype) + C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) + self.kernel = SSKernelNPLR( + L, + w, + p, + B, + C, + log_dt, + hurwitz=hurwitz, + trainable=trainable, + lr=lr, + tie_state=tie_state, + length_correction=length_correction, + ) + + def forward(self, L=None): + k, _ = self.kernel(rate=self.rate, L=L) + return k.float() + + def step(self, u, state, **kwargs): + u, state = self.kernel.step(u, state, **kwargs) + return u.float(), state + + def default_state(self, *args, **kwargs): + return self.kernel.default_state(*args, **kwargs) + + +def get_torch_trans(heads=8, layers=1, channels=64): + encoder_layer = nn.TransformerEncoderLayer(d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu") + return nn.TransformerEncoder(encoder_layer, num_layers=layers) + + +class S4(nn.Module): + def __init__( + self, + d_model, + d_state=64, + l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer + # than sequence. However, this can be marginally slower if the true length is not a power of 2 + channels=1, # maps 1-dim to C-dim + bidirectional=False, + # Arguments for FF + activation="gelu", # activation in between SS and FF + postact=None, # activation after FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + hyper_act=None, # Use a "hypernetwork" multiplication + dropout=0.0, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum sequence length, also denoted by L + if this is not known at model creation, set l_max=1 + channels: can be interpreted as a number of "heads" + bidirectional: bidirectional + dropout: standard dropout argument + transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) + [B=batch size, L=sequence length, H=hidden dimension]. + + Other options are all experimental and should not need to be configured + """ + + super().__init__() + + self.h = d_model + self.n = d_state + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + + # optional multiplicative modulation GLU-style + # https://arxiv.org/abs/2002.05202 + self.hyper = hyper_act is not None + if self.hyper: + channels *= 2 + self.hyper_activation = Activation(hyper_act) + + self.D = nn.Parameter(torch.randn(channels, self.h)) + + if self.bidirectional: + channels *= 2 + + # SSM Kernel + self.kernel = HippoSSKernel(self.h, N=self.n, L=l_max, channels=channels, **kernel_args) + + # Pointwise + self.activation = Activation(activation) + dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + + # position-wise output transform to mix features + self.output_linear = LinearActivation( + self.h * self.channels, + self.h, + transposed=self.transposed, + initializer=initializer, + activation=postact, + activate=True, + weight_norm=weight_norm, + ) + + # self.time_transformer = get_torch_trans(heads=8, layers=1, channels=self.h) + + def forward(self, u, **kwargs): # absorbs return_output and transformer src mask + """ + u: (B H L) if self.transposed else (B L H) + state: (H N) never needed unless you know what you're doing + + Returns: same shape as u + """ + if not self.transposed: + u = u.transpose(-1, -2) + L = u.size(-1) + + # Compute SS Kernel + k = self.kernel(L=L) # (C H L) (B C H L) + + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) + k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) + k_f = torch.fft.rfft(k, n=2 * L) # (C H L) + u_f = torch.fft.rfft(u, n=2 * L) # (B H L) + y_f = contract("bhl,chl->bchl", u_f, k_f) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) + y = torch.fft.irfft(y_f, n=2 * L)[..., :L] # (B C H L) + + # Compute D term in state space equation - essentially a skip connection + y = y + contract("bhl,ch->bchl", u, self.D) # u.unsqueeze(-3) * self.D.unsqueeze(-1) + + # Optional hyper-network multiplication + if self.hyper: + y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2) + y = self.hyper_activation(yh) * y + + # Reshape to flatten channels + y = rearrange(y, "... c h l -> ... (c h) l") + + y = self.dropout(self.activation(y)) + + if not self.transposed: + y = y.transpose(-1, -2) + + y = self.output_linear(y) + + # ysize = b, k, l, requieres l, b, k + # y = self.time_transformer(y.permute(2,0,1)).permute(1,2,0) + + return y, None + + def step(self, u, state): + """Step one time step as a recurrent model. Intended to be used during validation. + + u: (B H) + state: (B H N) + Returns: output (B H), state (B H N) + """ + assert not self.training + + y, next_state = self.kernel.step(u, state) # (B C H) + y = y + u.unsqueeze(-2) * self.D + y = rearrange(y, "... c h -> ... (c h)") + y = self.activation(y) + if self.transposed: + y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) + else: + y = self.output_linear(y) + return y, next_state + + def default_state(self, *batch_shape, device=None): + return self.kernel.default_state(*batch_shape) + + @property + def d_state(self): + return self.h * self.n + + @property + def d_output(self): + return self.h + + @property + def state_to_tensor(self): + return lambda state: rearrange("... h n -> ... (h n)", state) + + +class S4Layer(nn.Module): + # S4 Layer that can be used as a drop-in replacement for a TransformerEncoder + def __init__(self, features, lmax, N=64, dropout=0.0, bidirectional=True, layer_norm=True): + super().__init__() + self.s4_layer = S4(d_model=features, d_state=N, l_max=lmax, bidirectional=bidirectional) + + self.norm_layer = nn.LayerNorm(features) if layer_norm else nn.Identity() + self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() + + def forward(self, x): + # x has shape seq, batch, feature + x = x.permute((1, 2, 0)) # batch, feature, seq (as expected from S4 with transposed=True) + xout, _ = self.s4_layer(x) # batch, feature, seq + xout = self.dropout(xout) + xout = xout + x # skip connection # batch, feature, seq + xout = xout.permute((2, 0, 1)) # seq, batch, feature + return self.norm_layer(xout) diff --git a/icu_benchmarks/imputation/mlp.py b/icu_benchmarks/imputation/mlp.py new file mode 100644 index 00000000..68d98e23 --- /dev/null +++ b/icu_benchmarks/imputation/mlp.py @@ -0,0 +1,36 @@ +from icu_benchmarks.models.wrappers import ImputationWrapper +from torch.nn import Linear, ReLU, BatchNorm1d, Sequential, Sigmoid, Flatten +import torch +import gin + + +@gin.configurable("MLP") +class MLPImputation(ImputationWrapper): + """Imputation model based on a Multi-Layer Perceptron (MLP).""" + + needs_training = True + needs_fit = False + + def __init__(self, *args, input_size, num_hidden_layers=3, hidden_layer_size=10, **kwargs) -> None: + super().__init__( + *args, input_size=input_size, num_hidden_layers=num_hidden_layers, hidden_layer_size=hidden_layer_size, **kwargs + ) + self.model = [ + Flatten(), + Linear(input_size[1] * input_size[2], hidden_layer_size), + ReLU(), + BatchNorm1d(hidden_layer_size), + ] + for _ in range(num_hidden_layers): + self.model += [Linear(hidden_layer_size, hidden_layer_size), ReLU(), BatchNorm1d(hidden_layer_size)] + self.model += [Linear(hidden_layer_size, input_size[1] * input_size[2]), Sigmoid()] + + self.model = Sequential(*self.model) + + def forward(self, amputated, amputation_mask): + amputated = torch.nan_to_num(amputated, nan=0.0) + + output = self.model(amputated) + output = output.reshape(amputated.shape) + + return output diff --git a/icu_benchmarks/imputation/np.py b/icu_benchmarks/imputation/np.py new file mode 100644 index 00000000..6a974fe5 --- /dev/null +++ b/icu_benchmarks/imputation/np.py @@ -0,0 +1,432 @@ +import gin + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Normal + +from icu_benchmarks.models.wrappers import ImputationWrapper + + +@gin.configurable("NP") +class NPImputation(ImputationWrapper): + """Imputation using Neural Processes. Implementation adapted from https://github.com/EmilienDupont/neural-processes/. + Provides imputation wrapper for NeuralProcess class.""" + + needs_training = True + needs_fit = False + + def __init__( + self, + input_size, + encoder_layers, + encoder_h_dim, + decoder_layers, + decoder_h_dim, + r_dim, + z_dim, + train_sample_times, + val_sample_times, + test_sample_times, + predict_sample_times, + *args, + **kwargs + ) -> None: + super().__init__( + input_size=input_size, + encoder_layers=encoder_layers, + encoder_h_dim=encoder_h_dim, + decoder_layers=decoder_layers, + decoder_h_dim=decoder_h_dim, + r_dim=r_dim, + z_dim=z_dim, + train_sample_times=train_sample_times, + val_sample_times=val_sample_times, + test_sample_times=test_sample_times, + predict_sample_times=predict_sample_times, + *args, + **kwargs + ) + + self.x_dim = input_size[2] + self.y_dim = input_size[2] + self.z_dim = z_dim + + self.train_sample_times = train_sample_times + self.val_sample_times = val_sample_times + self.test_sample_times = test_sample_times + self.predict_sample_times = predict_sample_times + + self.model = NeuralProcess( + self.x_dim, + self.y_dim, + encoder_layers, + encoder_h_dim, + decoder_layers, + decoder_h_dim, + r_dim, + z_dim, + ) + + def forward(self, x_context, y_context, x_target, y_target=None): + return self.model(x_context, y_context, x_target, y_target) + + # Override the training step - needed for the custom loss calculation + def training_step(self, batch, _): + self.model.train(True) + + # Unpack batch into three values + amputed, mask, _, _ = batch + batch_size, num_timesteps, num_obs_var = amputed.shape + + amputed = torch.nan_to_num(amputed, nan=0.0).to(self.device) + + # Create and rearrange x to be the same shape as variables (x is timesteps) + x = torch.arange(0, num_timesteps, device=self.device) + x = x.repeat(batch_size) + x = x.repeat(num_obs_var) + x = x.reshape(batch_size, num_timesteps, num_obs_var) + # Resulting size is [batch size, number of timesteps, number of observed variables] + + # Do a context/target split with mask - see CSDI implemnetation - line 56 + # https://github.com/ermongroup/CSDI/blob/main/main_model.py + x_context, y_context, x_target, y_target = self._context_target_split(x, amputed, mask) + + # Get the predicted probability distribution + p_y_pred, _, _ = self(x_context, y_context, x_target, y_target) + + # Sample K times to ensure that we select the best sample for gradient descent + best_loss = self.loss(p_y_pred.rsample(), y_target) + + for _ in range(0, self.train_sample_times): + loss = self.loss(p_y_pred.rsample(), y_target) + if best_loss < loss: + best_loss = loss + + self.log("train/loss", best_loss.item(), prog_bar=True) + return best_loss + + # Override the validation step - needed for the custom loss calculation + @torch.no_grad() + def validation_step(self, batch, _): + self.model.eval() + # Unpack batch into three values + amputed, mask, complete, complete_missingness_mask = batch + batch_size, num_timesteps, num_obs_var = amputed.shape + + amputed = torch.nan_to_num(amputed, nan=0.0).to(self.device) + + # Create and rearrange x to be the same shape as variables (x is timesteps) + x = torch.arange(0, num_timesteps, device=self.device) + x = x.repeat(batch_size) + x = x.repeat(num_obs_var) + x = x.reshape(batch_size, num_timesteps, num_obs_var) + # Resulting size is [batch size, number of timesteps, number of observed variables] + + # Do a context/target split with mask - see CSDI implemnetation - line 56 + # https://github.com/ermongroup/CSDI/blob/main/main_model.py + x_context, y_context, x_target, y_target = self._context_target_split(x, amputed, mask) + + # Get the predicted probability distribution + p_y_pred, _, _ = self(x_context, y_context, x_target, y_target) + + # Sample K times to ensure that we select the best sample for gradient descent + best_loss = self.loss(p_y_pred.rsample(), y_target) + + for _ in range(0, self.val_sample_times): + loss = self.loss(p_y_pred.rsample(), y_target) + if best_loss < loss: + best_loss = loss + + self.log("val/loss", best_loss.item(), prog_bar=True) + + # Do metric calculations - take x_target to be the full size now + x_target = x + + # Get the predicted probability distribution + p_y_pred = self(x_context, y_context, x_target) + + # Sample the distribution K times to put the values from it into the amputed dataset + generated_list = [] + for _ in range(0, self.val_sample_times): + generated = p_y_pred.sample() + generated_list.append(generated) + + # Calculate mean of all K samples - dim = 0 is required to do a element-wise mean + # calculation on multidimensional tensor stack + generated = torch.mean(torch.stack(generated_list), dim=0).to(self.device) + # Use the indexing functionality of tensor to impute values into the indicies + # specified by the mask + amputed[mask > 0] = generated[mask > 0] + amputed[complete_missingness_mask > 0] = complete[complete_missingness_mask > 0] + + # Update the metrics + for metric in self.metrics["val"].values(): + metric.update( + ( + torch.flatten(amputed, start_dim=1), + torch.flatten(complete, start_dim=1), + ) + ) + + @torch.no_grad() + def test_step(self, batch, _): + self.model.eval() + # Unpack batch into three values + amputed, mask, complete, complete_missingness_mask = batch + batch_size, num_timesteps, num_obs_var = amputed.shape + + # Create and rearrange x to be the same shape as variables (x is timesteps) + x = torch.arange(0, num_timesteps, device=self.device) + x = x.repeat(batch_size) + x = x.repeat(num_obs_var) + x = x.reshape(batch_size, num_timesteps, num_obs_var) + # Resulting size is [batch size, number of timesteps, number of observed variables] + + # For now, do the most basic thing - put 0s instead of nans + amputed = torch.nan_to_num(amputed, nan=0.0).to(self.device) + + x_context, y_context, _, _ = self._context_target_split(x, amputed, mask) + + x_target = x + + # Get the predicted probability distribution + p_y_pred = self(x_context, y_context, x_target) + + # Sample the distribution K times to put the values from it into the amputed dataset + generated_list = [] + for _ in range(0, self.test_sample_times): + generated = p_y_pred.sample() + generated_list.append(generated) + + # Calculate mean of all K samples - dim = 0 is required to do a element-wise mean + # calculation on multidimensional tensor stack + generated = torch.mean(torch.stack(generated_list), dim=0).to(self.device) + # Use the indexing functionality of tensor to impute values into the indicies + # specified by the mask + amputed[mask > 0] = generated[mask > 0] + + # In val/test loops, use the MSE loss - KL divergence can't be calculated + # without target distribution + loss = self.loss(amputed, complete) + + self.log("test/loss", loss.item(), prog_bar=True) + + amputed[complete_missingness_mask > 0] = complete[complete_missingness_mask > 0] + # Update the metrics + for metric in self.metrics["test"].values(): + metric.update( + ( + torch.flatten(amputed, start_dim=1), + torch.flatten(complete, start_dim=1), + ) + ) + + def predict(self, data): + self.model.eval() + + data = data.to(self.device) + batch_size, num_timesteps, num_obs_var = data.shape + + # Take an inverse of missingness mask for a mask of observed values + observation_mask = ~(torch.isnan(data)) + + # Create and rearrange x to be the same shape as variables (x is timesteps) + x = torch.arange(0, num_timesteps, device=self.device) + x = x.repeat(batch_size) + x = x.repeat(num_obs_var) + x = x.reshape(batch_size, num_timesteps, num_obs_var) + + x_context = x * observation_mask + y_context = torch.nan_to_num(data, nan=0.0).to(self.device) + + x_target = x.to(self.device) + + p_y_pred = self(x_context, y_context, x_target) + + # Sample the distribution K times to put the values from it into the amputed dataset + generated_list = [] + for _ in range(0, self.predict_sample_times): + generated = p_y_pred.sample() + generated_list.append(generated) + + # Calculate mean of all K samples - dim = 0 is required to do a element-wise mean calculation on + # multidimensional tensor stack + generated = torch.mean(torch.stack(generated_list), dim=0).to(self.device) + data[observation_mask == 0] = generated[observation_mask == 0] + + return data + + def _context_target_split(self, x, y, amputed_mask): + # Take an inverse of the amputed mask - to get the observation mask + observed_mask = (~(amputed_mask > 0)).float() + + # Generate a random tensor with the same dimensions as mask and multiply mask by it + # This removes all missing values from the following calculations + rand_for_mask = torch.rand_like(observed_mask) * observed_mask + # Create a context mask - the selection of the elements is so that only + # 50% of all observed values are selected + context_mask = (rand_for_mask > 0.5).float() + # Create a target mask - the selection of the elements is so that all values + # not selected by the context mask but are still observed are selected + target_mask = (~(rand_for_mask > 0.5)).float() * observed_mask + + # Multiply x and y by masks to get the context/target split + x_context = x * context_mask + y_context = y * context_mask + + x_target = x * target_mask + y_target = y * target_mask + + return x_context, y_context, x_target, y_target + + +# Actual class that implements neural processes +class NeuralProcess(nn.Module): + """Class that implements neural processes.""" + + def __init__( + self, + x_dim, + y_dim, + encoder_layers, + encoder_h_dim, + decoder_layers, + decoder_h_dim, + r_dim, + z_dim, + ): + super().__init__() + + self.x_dim = x_dim + self.y_dim = y_dim + self.r_dim = r_dim + self.z_dim = z_dim + + # Initialize encoders/decoder + self.encoder = MLPEncoder(x_dim, y_dim, encoder_h_dim, encoder_layers, r_dim) + + self.latent_encoder = MuEncoder(r_dim=r_dim, z_dim=z_dim) + + self.decoder = Decoder(decoder_h_dim, decoder_layers, x_dim, y_dim, z_dim) + + def forward(self, x_context, y_context, x_target, y_target=None): + if y_target is not None: + # Encode target and context (context needs to be encoded to + # calculate kl term) + mu_target, sigma_target = self._encode(x_target, y_target) + mu_context, sigma_context = self._encode(x_context, y_context) + # Sample from encoded distribution using reparameterization trick + q_target = Normal(mu_target, sigma_target) + q_context = Normal(mu_context, sigma_context) + z_sample = q_target.rsample() + # Get parameters of output distribution + y_pred_mu, y_pred_sigma = self.decoder(x_target, z_sample) + p_y_pred = Normal(y_pred_mu, y_pred_sigma) + + return p_y_pred, q_target, q_context + else: + # At testing time, encode only context + mu_context, sigma_context = self._encode(x_context, y_context) + # Sample from distribution based on context + q_context = Normal(mu_context, sigma_context) + z_sample = q_context.rsample() + # Predict target points based on context + y_pred_mu, y_pred_sigma = self.decoder(x_target, z_sample) + p_y_pred = Normal(y_pred_mu, y_pred_sigma) + + return p_y_pred + + def _aggregate(self, r_i): + return torch.mean(r_i, dim=1) + + def _encode(self, x, y): + # Encode each point into a representation r_i + r_i = self.encoder(x, y) + # Aggregate representations r_i into a single representation r + r = self._aggregate(r_i) + # Return parameters of distribution + return self.latent_encoder(r) + + +# This class describes the deterministic encoder +# The encoding is (x_i, y_i) to representation r_i +class MLPEncoder(nn.Module): + def __init__(self, x_dim, y_dim, h_dim, h_layers, r_dim): + super().__init__() + + # Define the first input layer + layers = [nn.Linear(x_dim + y_dim, h_dim), nn.ReLU(inplace=True)] + # Define the multilayer structure + for _ in range(h_layers): + layers.append(nn.Linear(h_dim, h_dim)) + layers.append(nn.ReLU(inplace=True)) + # Add the final layer (without ReLU) + layers.append(nn.Linear(h_dim, r_dim)) + + self.model = nn.Sequential(*layers) + + def forward(self, x, y): + input_pairs = torch.cat((x, y), dim=2) + return self.model(input_pairs) + + +# This class describes the latent encoder +# The encoding is r_i to mu and sigma of the distribution from which to sample latent variable z +class MuEncoder(nn.Module): + def __init__(self, r_dim, z_dim): + super().__init__() + + self.model_hidden = nn.Linear(r_dim, r_dim) + self.model_mu = nn.Linear(r_dim, z_dim) + self.model_sigma = nn.Linear(r_dim, z_dim) + + def forward(self, r): + hidden = torch.relu(self.model_hidden(r)) + mu = self.model_mu(hidden) + # Define sigma following convention in "Empirical Evaluation of Neural + # Process Objectives" and "Attentive Neural Processes" + sigma = 0.1 + 0.9 * torch.sigmoid(self.model_sigma(hidden)) + return mu, sigma + + +# This class describes the decoder +# The encoding is from x_target and z to y_target (i.e. making a prediction of y) +class Decoder(nn.Module): + def __init__(self, h_dim, h_layers, x_dim, y_dim, z_dim): + super().__init__() + + self.x_dim = x_dim + self.y_dim = y_dim + self.z_dim = z_dim + + layers = [nn.Linear(x_dim + z_dim, h_dim), nn.ReLU(inplace=True)] + + for _ in range(h_layers): + layers.append(nn.Linear(h_dim, h_dim)) + layers.append(nn.ReLU(inplace=True)) + self.model_hidden = nn.Sequential(*layers) + + self.model_mu = nn.Linear(h_dim, y_dim) + self.model_sigma = nn.Linear(h_dim, y_dim) + + def forward(self, x, z): + batch_size, num_points, _ = x.size() + # Repeat z, so it can be concatenated with every x. This changes shape + # from (batch_size, z_dim) to (batch_size, num_points, z_dim) + z = z.unsqueeze(1).repeat(1, num_points, 1) + # Flatten x and z to fit with linear layer + x_flat = x.view(batch_size * num_points, self.x_dim) + z_flat = z.view(batch_size * num_points, self.z_dim) + # Input is concatenation of z with every row of x + input_pairs = torch.cat((x_flat, z_flat), dim=1) + hidden = self.model_hidden(input_pairs) + mu = self.model_mu(hidden) + pre_sigma = self.model_sigma(hidden) + # Reshape output into expected shape + mu = mu.view(batch_size, num_points, self.y_dim) + pre_sigma = pre_sigma.view(batch_size, num_points, self.y_dim) + # Define sigma following convention in "Empirical Evaluation of Neural + # Process Objectives" and "Attentive Neural Processes" + sigma = 0.1 + 0.9 * F.softplus(pre_sigma) + return mu, sigma diff --git a/icu_benchmarks/imputation/rnn.py b/icu_benchmarks/imputation/rnn.py new file mode 100644 index 00000000..9f6a7213 --- /dev/null +++ b/icu_benchmarks/imputation/rnn.py @@ -0,0 +1,103 @@ +from icu_benchmarks.models.wrappers import ImputationWrapper +import torch +import torch.nn as nn +from torch.autograd import Variable +import gin + + +# Adapted from https://github.com/Graph-Machine-Learning-Group/grin/blob/main/lib/nn/models/rnn_imputers.py +@gin.configurable("RNN") +class RNNImputation(ImputationWrapper): + """Imputation model with Gated Recurrent Units (GRU) or Long-Short Term Memory Network (LSTM). Defaults to GRU.""" + + needs_training = True + needs_fit = False + + def __init__(self, *args, input_size, hidden_size=64, state_init="zero", cell="gru", **kwargs) -> None: + super().__init__(*args, input_size=input_size, hidden_size=hidden_size, state_init=state_init, cell=cell, **kwargs) + self.input_size = input_size + self.n_features = input_size[2] + self.hidden_size = hidden_size + self.state_init = state_init + self.cell = cell + + if cell == "gru": + cell = nn.GRUCell + elif cell == "lstm": + cell = nn.LSTMCell + else: + raise NotImplementedError(f'"{cell}" cell not implemented.') + + self.rnn = cell(self.n_features, self.hidden_size) + self.fn = nn.Linear(self.hidden_size, self.n_features) + + def init_hidden_state(self, x): + if self.state_init == "zero": + return torch.zeros((x.size(0), self.hidden_size), device=x.device, dtype=x.dtype) + if self.state_init == "noise": + return torch.randn(x.size(0), self.hidden_size, device=x.device, dtype=x.dtype) + + def forward(self, amputated, amputation_mask, return_hidden=False): + steps = amputated.size(1) + amputated = torch.where(amputation_mask.bool(), torch.zeros_like(amputated), amputated) + h = self.init_hidden_state(amputated) + c = self.init_hidden_state(amputated) + + output = self.fn(h) + hs = [h] + preds = [output] + for s in range(steps - 1): + x_t = torch.where(amputation_mask[:, s].bool(), output, amputated[:, s]) + if self.cell == "gru": + h = self.rnn(x_t, h) + elif self.cell == "lstm": + h, c = self.rnn(x_t, (h, c)) + output = self.fn(h) + hs.append(h) + preds.append(output) + + output = torch.stack(preds, 1) + h = torch.stack(hs, 1) + + if return_hidden: + return output, h + return output + + +@gin.configurable("BRNN") +class BRNNImputation(ImputationWrapper): + """Imputation model with Bidirectional Gated Recurrent Units (GRU) or Long-Short Term Memory Network (LSTM). Defaults to + GRU.""" + + needs_training = True + needs_fit = False + + def __init__(self, *args, input_size, hidden_size=64, state_init="zero", dropout=0.0, cell="gru", **kwargs) -> None: + super().__init__( + *args, input_size=input_size, hidden_size=hidden_size, state_init=state_init, dropout=dropout, cell=cell, **kwargs + ) + self.hidden_size = hidden_size + self.fwd_rnn = RNNImputation(input_size=input_size, hidden_size=hidden_size, state_init=state_init, cell=cell) + self.bwd_rnn = RNNImputation(input_size=input_size, hidden_size=hidden_size, state_init=state_init, cell=cell) + self.dropout = nn.Dropout(dropout) + self.fn = nn.Linear(2 * hidden_size, input_size[2]) + + def forward(self, amputated, amputation_mask): + _, h_fwd = self.fwd_rnn(amputated, amputation_mask, return_hidden=True) + _, h_bwd = self.bwd_rnn(self.reverse_tensor(amputated, 1), self.reverse_tensor(amputation_mask, 1), return_hidden=True) + h_bwd = self.reverse_tensor(h_bwd, 1) + + h = self.dropout(torch.cat([h_fwd, h_bwd], -1)) + output = self.fn(h) + + return output + + @staticmethod + def reverse_tensor(tensor=None, axis=-1): + if tensor is None: + return None + if tensor.dim() <= 1: + return tensor + indices = range(tensor.size()[axis])[::-1] + indices = Variable(torch.LongTensor(indices), requires_grad=False).to(tensor.device) + return tensor.index_select(axis, indices) diff --git a/icu_benchmarks/imputation/simple_diffusion.py b/icu_benchmarks/imputation/simple_diffusion.py new file mode 100644 index 00000000..3590a5a0 --- /dev/null +++ b/icu_benchmarks/imputation/simple_diffusion.py @@ -0,0 +1,264 @@ +from icu_benchmarks.models.wrappers import ImputationWrapper +import gin +import math +import torch +from torch import nn +import torch.nn.functional as F + + +@gin.configurable("Simple_Diffusion") +class SimpleDiffusionModel(ImputationWrapper): + """Imputation model based on a Simple Diffusion Model. + Adapted from https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL.""" + + needs_training = True + needs_fit = False + + input_size = [] + + def __init__(self, *args, input_size, **kwargs): + super().__init__(*args, input_size=input_size, **kwargs) + + down_channels = (25, 20, 18, 15) + up_channels = (15, 18, 20, 25) + time_emb_dim = 6 + + self.input_size = input_size + + # Time embedding + self.time_mlp = nn.Sequential( + SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU() + ) + + # Initial projection + self.conv0 = nn.Conv1d(input_size[1], down_channels[0], 2) + + # Downsample + self.downs = nn.ModuleList( + [Block(down_channels[i], down_channels[i + 1], time_emb_dim) for i in range(len(down_channels) - 1)] + ) + + # Upsample + self.ups = nn.ModuleList( + [Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True) for i in range(len(up_channels) - 1)] + ) + + # Final Output + self.output = nn.ConvTranspose1d(up_channels[-1], input_size[1], 2) + + def forward(self, amputated, timestep): + amputated = torch.nan_to_num(amputated, nan=0.0) + # model_input = torch.cat((amputated, amputation_mask), dim=1) + + # output = self.model(model_input) + # output = output.reshape(amputated.shape) + + # Embedd time + t = self.time_mlp(timestep) + + # Initial Convolution + x = self.conv0(amputated) + # Unet + residual_inputs = [] + for down in self.downs: + x = down(x, t) + residual_inputs.append(x) + for up in self.ups: + residual_x = residual_inputs.pop() + # Add residual x as additional channels + x = torch.cat((x, residual_x), dim=1) + x = up(x, t) + + # Output Layer + output = self.output(x) + + return output + + def linear_beta_schedule(timesteps, start=0.0001, end=0.02): + return torch.linspace(start, end, timesteps) + + def get_index_from_list(self, vals, t, x_shape): + """ + Returns a specific index t of a passed list of values vals + while considering the batch dimension. + """ + batch_size = t.shape[0] + out = vals.gather(-1, t) + return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) + + def forward_diffusion_sample(self, x_0, t): + """ + Takes an image and a timestep as input and + returns the noisy version of it + """ + noise = torch.randn_like(x_0) + sqrt_alphas_cumprod_t = self.get_index_from_list(self.sqrt_alphas_cumprod, t, x_0.shape) + sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_cumprod, t, x_0.shape) + # mean + variance + return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise + + # Define beta schedule + T = 300 + betas = linear_beta_schedule(timesteps=T) + + # Pre-calculate different terms for closed form + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + sqrt_recip_alphas = torch.sqrt(1.0 / alphas) + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + + def get_loss(self, x_0, t): + x_noisy, noise = self.forward_diffusion_sample(x_0, t) + noise_pred = self(x_noisy, t) + return F.l1_loss(noise, noise_pred) + + def on_fit_start(self) -> None: + self.betas = self.betas.to(self.device) + self.alphas = self.alphas.to(self.device) + self.alphas_cumprod = self.alphas_cumprod.to(self.device) + self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(self.device) + self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device) + self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device) + self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(self.device) + self.posterior_variance = self.posterior_variance.to(self.device) + super().on_fit_start() + + def training_step(self, batch): + amputated, amputation_mask, target, target_missingness = batch + amputated = torch.nan_to_num(amputated, nan=0.0) + + t = torch.randint(0, self.T, (self.input_size[0],), device=self.device).long() + loss = self.get_loss(target, t) + + self.log("train/loss", loss.item(), prog_bar=True) + + for metric in self.metrics["train"].values(): + metric.update((torch.flatten(target, start_dim=1), torch.flatten(target, start_dim=1))) + + return loss + + def validation_step(self, batch, batch_index): + amputated, amputation_mask, target, target_missingness = batch + amputated = torch.nan_to_num(amputated, nan=0.0) + # imputated = self(amputated, amputation_mask) + + t = torch.randint(0, self.T, (1,), device=self.device).long() + + betas_t = self.get_index_from_list(self.betas, t, amputated.shape) + sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_cumprod, t, amputated.shape) + sqrt_recip_alphas_t = self.get_index_from_list(self.sqrt_recip_alphas, t, amputated.shape) + + model_mean = sqrt_recip_alphas_t * (amputated - betas_t * self(amputated, t) / sqrt_one_minus_alphas_cumprod_t) + + posterior_variance_t = self.get_index_from_list(self.posterior_variance, t, amputated.shape) + + if t == 0: + imputated = model_mean + else: + noise = torch.randn_like(amputated) + imputated = model_mean + torch.sqrt(posterior_variance_t) * noise + + # imputated = amputated.masked_scatter_(amputation_mask.bool(), imputated) + + amputated[amputation_mask > 0] = imputated[amputation_mask > 0] + amputated[target_missingness > 0] = target[target_missingness > 0] + + loss = self.loss(amputated, target) + self.log("val/loss", loss.item(), prog_bar=True) + + for metric in self.metrics["val"].values(): + metric.update((torch.flatten(amputated, start_dim=1), torch.flatten(target, start_dim=1))) + + def test_step(self, batch, batch_index): + amputated, amputation_mask, target, target_missingness = batch + amputated = torch.nan_to_num(amputated, nan=0.0) + # imputated = self(amputated, amputation_mask) + + t = torch.randint(0, self.T, (1,), device=self.device).long() + + betas_t = self.get_index_from_list(self.betas, t, amputated.shape) + sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_cumprod, t, amputated.shape) + sqrt_recip_alphas_t = self.get_index_from_list(self.sqrt_recip_alphas, t, amputated.shape) + + model_mean = sqrt_recip_alphas_t * (amputated - betas_t * self(amputated, t) / sqrt_one_minus_alphas_cumprod_t) + + posterior_variance_t = self.get_index_from_list(self.posterior_variance, t, amputated.shape) + + if t == 0: + imputated = model_mean + else: + noise = torch.randn_like(amputated) + imputated = model_mean + torch.sqrt(posterior_variance_t) * noise + + # imputated = amputated.masked_scatter_(amputation_mask.bool(), imputated) + + amputated[amputation_mask > 0] = imputated[amputation_mask > 0] + amputated[target_missingness > 0] = target[target_missingness > 0] + + loss = self.loss(amputated, target) + self.log("test/loss", loss.item(), prog_bar=True) + + for metric in self.metrics["test"].values(): + metric.update((torch.flatten(amputated, start_dim=1), torch.flatten(target, start_dim=1))) + + +class Block(nn.Module): + def __init__(self, in_ch, out_ch, time_emb_dim, up=False): + super().__init__() + self.time_mlp = nn.Linear(time_emb_dim, out_ch) + time_dim = 5 if in_ch == 25 else 4 if in_ch == 20 else 3 if in_ch == 18 else 2 + if up: + # take 2 times the number of input channels because residuals were added in the upsampling process + in_ch *= 2 + self.conv1 = nn.ConvTranspose1d(in_ch, out_ch, 3, padding=1) + self.transform = nn.ConvTranspose1d(out_ch, out_ch, 2) + else: + self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1) + self.transform = nn.Conv1d(out_ch, out_ch, 2) + self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1) + self.bnorm1 = nn.BatchNorm1d(out_ch) + self.bnorm2 = nn.BatchNorm1d(out_ch) + self.relu = nn.ReLU() + + # Transformer Encoder for Feature Self-Attention + self.feature_layer = nn.TransformerEncoderLayer(d_model=in_ch, nhead=1, dim_feedforward=64, activation="gelu") + self.feature_transformer = nn.TransformerEncoder(self.feature_layer, num_layers=1) + + # Transformer Encoder for Time Self-Attention + self.time_layer = nn.TransformerEncoderLayer(d_model=time_dim, nhead=1, dim_feedforward=64, activation="gelu") + self.time_transformer = nn.TransformerEncoder(self.time_layer, num_layers=1) + + def forward(self, x, t): + # Apply Feature Self-Attention + h = self.feature_transformer(x.permute(0, 2, 1)).permute(0, 2, 1) + # Apply Time Self-Attention + h = self.time_transformer(h) + # First Convolution + h = self.bnorm1(self.relu(self.conv1(h))) + # Time Embedding + time_emb = self.relu(self.time_mlp(t)) + # Extend last dimension + time_emb = time_emb[(...,) + (None,)] + # Add time + h += time_emb + # Second Convolution + h = self.bnorm2(self.relu(self.conv2(h))) + return self.transform(h) + + +class SinusoidalPositionEmbeddings(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, time): + device = time.device + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1 + 0.05) + embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + return embeddings diff --git a/icu_benchmarks/imputation/sssds4.py b/icu_benchmarks/imputation/sssds4.py new file mode 100644 index 00000000..ce6e8c0d --- /dev/null +++ b/icu_benchmarks/imputation/sssds4.py @@ -0,0 +1,424 @@ +# Source: https://github.com/AI4HealthUOL/SSSD +import gin +import math +import numpy as np +import torch +from torch import nn +from icu_benchmarks.models.wrappers import ImputationWrapper +from icu_benchmarks.imputation.layers.s4layer import S4Layer + + +@gin.configurable("SSSDS4") +class SSSDS4(ImputationWrapper): + """Implements the SSSD model from https://arxiv.org/abs/2208.09399.""" + + def __init__( + self, + input_size, + res_channels, + skip_channels, + num_res_layers, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + s4_lmax, + s4_d_state, + s4_dropout, + s4_bidirectional, + s4_layernorm, + diffusion_time_steps, + beta_0, + beta_T, + *args, + **kwargs: str, + ): + super(SSSDS4, self).__init__( + input_size=input_size, + res_channels=res_channels, + skip_channels=skip_channels, + num_res_layers=num_res_layers, + diffusion_step_embed_dim_in=diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + s4_lmax=s4_lmax, + s4_d_state=s4_d_state, + s4_dropout=s4_dropout, + s4_bidirectional=s4_bidirectional, + s4_layernorm=s4_layernorm, + diffusion_time_steps=diffusion_time_steps, + beta_0=beta_0, + beta_T=beta_T, + *args, + **kwargs, + ) + + num_channels = input_size[2] + self.init_conv = nn.Sequential(Conv(num_channels, res_channels, kernel_size=1), nn.ReLU()) + + self.residual_layer = Residual_group( + res_channels=res_channels, + skip_channels=skip_channels, + num_res_layers=num_res_layers, + diffusion_step_embed_dim_in=diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + in_channels=num_channels, + s4_lmax=s4_lmax, + s4_d_state=s4_d_state, + s4_dropout=s4_dropout, + s4_bidirectional=s4_bidirectional, + s4_layernorm=s4_layernorm, + ) + + self.final_conv = nn.Sequential( + Conv(skip_channels, skip_channels, kernel_size=1), nn.ReLU(), ZeroConv1d(skip_channels, num_channels) + ) + + self.diffusion_parameters = calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T) + + def on_fit_start(self) -> None: + self.diffusion_parameters = { + k: v.to(self.device) for k, v in self.diffusion_parameters.items() if isinstance(v, torch.Tensor) + } + return super().on_fit_start() + + def forward(self, input_data): + noise, conditional, mask, diffusion_steps = input_data + + conditional = torch.cat([conditional, mask.float()], dim=1) + + x = noise + x = self.init_conv(x) + x = self.residual_layer((x, conditional, diffusion_steps)) + y = self.final_conv(x) + + return y + + def step_fn(self, batch, step_prefix=""): + amputated_data, amputation_mask, target, target_missingness = batch + + amputated_data = torch.nan_to_num(amputated_data).permute(0, 2, 1) + amputation_mask = amputation_mask.permute(0, 2, 1).bool() + observed_mask = 1 - amputation_mask.float() + amputation_mask = amputation_mask.bool() + + if step_prefix in ["train", "val"]: + T, Alpha_bar = self.hparams.diffusion_time_steps, self.diffusion_parameters["Alpha_bar"] + + B, C, L = amputated_data.shape # B is batchsize, C=1, L is audio length + diffusion_steps = torch.randint(T, size=(B, 1, 1)).to(self.device) # randomly sample diffusion steps from 1~T + + z = std_normal(amputated_data.shape, self.device) + z = amputated_data * observed_mask.float() + z * (1 - observed_mask).float() + transformed_X = ( + torch.sqrt(Alpha_bar[diffusion_steps]) * amputated_data + torch.sqrt(1 - Alpha_bar[diffusion_steps]) * z + ) # compute x_t from q(x_t|x_0) + epsilon_theta = self( + ( + transformed_X, + amputated_data, + observed_mask, + diffusion_steps.view(B, 1), + ) + ) # predict \epsilon according to \epsilon_\theta + + loss = self.loss(epsilon_theta[amputation_mask], z[amputation_mask]) + else: + target = target.permute(0, 2, 1) + target_missingness = target_missingness.permute(0, 2, 1) + imputed_data = self.sampling(amputated_data, observed_mask) + amputated_data[amputation_mask] = imputed_data[amputation_mask] + amputated_data[target_missingness > 0] = target[target_missingness > 0] + loss = self.loss(amputated_data, target) + for metric in self.metrics[step_prefix].values(): + metric.update((torch.flatten(amputated_data, start_dim=1).clone(), torch.flatten(target, start_dim=1).clone())) + + self.log(f"{step_prefix}/loss", loss.item(), prog_bar=True) + return loss + + def predict_step(self, amputated_data, amputation_mask): + amputated_data = torch.nan_to_num(amputated_data).permute(0, 2, 1) + amputation_mask = amputation_mask.permute(0, 2, 1) + observed_mask = 1 - amputation_mask.float() + amputation_mask = amputation_mask.bool() + imputed_data = self.sampling(amputated_data, observed_mask) + amputated_data[amputation_mask] = imputed_data[amputation_mask] + amputated_data = amputated_data.permute(0, 2, 1) + return amputated_data + + def sampling(self, cond, mask): + """ + Perform the complete sampling step according to p(x_0|x_T) = prod_{t=1}^T p_{\theta}(x_{t-1}|x_t) + + Parameters: + net (torch network): the wavenet model + size (tuple): size of tensor to be generated, + usually is (number of audios to generate, channels=1, length of audio) + diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams + note, the tensors need to be cuda tensors + + Returns: + the generated audio(s) in torch.tensor, shape=size + """ + + Alpha, Alpha_bar, Sigma = ( + self.diffusion_parameters["Alpha"], + self.diffusion_parameters["Alpha_bar"], + self.diffusion_parameters["Sigma"], + ) + + T = self.hparams.diffusion_time_steps + assert len(Alpha) == T + assert len(Alpha_bar) == T + assert len(Sigma) == T + + B, _, _ = cond.shape + x = std_normal(cond.shape, self.device) + + for t in range(T - 1, -1, -1): + x = x * (1 - mask).float() + cond * mask.float() + diffusion_steps = (t * torch.ones((B, 1))).to(self.device) # use the corresponding reverse step + epsilon_theta = self( + ( + x, + cond, + mask, + diffusion_steps, + ) + ) # predict \epsilon according to \epsilon_\theta + # update x_{t-1} to \mu_\theta(x_t) + x = (x - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t]) + if t > 0: + x = x + Sigma[t] * std_normal(cond.shape, self.device) # add the variance term to x_{t-1} + + return x + + +def std_normal(size, device): + """ + Generate the standard Gaussian variable of a certain size + """ + + return torch.normal(0, 1, size=size).to(device) + + +def swish(x): + return x * torch.sigmoid(x) + + +class Conv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): + super(Conv, self).__init__() + self.padding = dilation * (kernel_size - 1) // 2 + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) + self.conv = nn.utils.weight_norm(self.conv) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, x): + out = self.conv(x) + return out + + +class ZeroConv1d(nn.Module): + def __init__(self, in_channel, out_channel): + super(ZeroConv1d, self).__init__() + self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0) + self.conv.weight.data.zero_() + self.conv.bias.data.zero_() + + def forward(self, x): + out = self.conv(x) + return out + + +class Residual_block(nn.Module): + def __init__( + self, + res_channels, + skip_channels, + diffusion_step_embed_dim_out, + in_channels, + s4_lmax, + s4_d_state, + s4_dropout, + s4_bidirectional, + s4_layernorm, + ): + super(Residual_block, self).__init__() + self.res_channels = res_channels + + self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels) + + self.S41 = S4Layer( + features=2 * self.res_channels, + lmax=s4_lmax, + N=s4_d_state, + dropout=s4_dropout, + bidirectional=s4_bidirectional, + layer_norm=s4_layernorm, + ) + + self.conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3) + + self.S42 = S4Layer( + features=2 * self.res_channels, + lmax=s4_lmax, + N=s4_d_state, + dropout=s4_dropout, + bidirectional=s4_bidirectional, + layer_norm=s4_layernorm, + ) + + self.cond_conv = Conv(2 * in_channels, 2 * self.res_channels, kernel_size=1) + + self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1) + self.res_conv = nn.utils.weight_norm(self.res_conv) + nn.init.kaiming_normal_(self.res_conv.weight) + + self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1) + self.skip_conv = nn.utils.weight_norm(self.skip_conv) + nn.init.kaiming_normal_(self.skip_conv.weight) + + def forward(self, input_data): + x, cond, diffusion_step_embed = input_data + h = x + B, C, L = x.shape + assert C == self.res_channels + + part_t = self.fc_t(diffusion_step_embed) + part_t = part_t.view([B, self.res_channels, 1]) + h = h + part_t + + h = self.conv_layer(h) + h = self.S41(h.permute(2, 0, 1)).permute(1, 2, 0) + + assert cond is not None + cond = self.cond_conv(cond) + h += cond + + h = self.S42(h.permute(2, 0, 1)).permute(1, 2, 0) + + out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) + + res = self.res_conv(out) + assert x.shape == res.shape + skip = self.skip_conv(out) + + return (x + res) * math.sqrt(0.5), skip # normalize for training stability + + +class Residual_group(nn.Module): + def __init__( + self, + res_channels, + skip_channels, + num_res_layers, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + in_channels, + s4_lmax, + s4_d_state, + s4_dropout, + s4_bidirectional, + s4_layernorm, + ): + super(Residual_group, self).__init__() + self.num_res_layers = num_res_layers + self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in + + self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) + self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) + + self.residual_blocks = nn.ModuleList() + for n in range(self.num_res_layers): + self.residual_blocks.append( + Residual_block( + res_channels, + skip_channels, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + in_channels=in_channels, + s4_lmax=s4_lmax, + s4_d_state=s4_d_state, + s4_dropout=s4_dropout, + s4_bidirectional=s4_bidirectional, + s4_layernorm=s4_layernorm, + ) + ) + + def get_device(self): + return next(self.parameters()).device + + def forward(self, input_data): + noise, conditional, diffusion_steps = input_data + + diffusion_step_embed = calc_diffusion_step_embedding( + diffusion_steps, self.diffusion_step_embed_dim_in, self.get_device() + ) + diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) + diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) + + h = noise + skip = 0 + for n in range(self.num_res_layers): + h, skip_n = self.residual_blocks[n]((h, conditional, diffusion_step_embed)) + skip += skip_n + + return skip * math.sqrt(1.0 / self.num_res_layers) + + +def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in, device): + """ + Embed a diffusion step $t$ into a higher dimensional space + E.g. the embedding vector in the 128-dimensional space is + [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] + Parameters: + diffusion_steps (torch.long tensor, shape=(batchsize, 1)): + diffusion steps for batch data + diffusion_step_embed_dim_in (int, default=128): + dimensionality of the embedding space for discrete diffusion steps + + Returns: + the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)): + """ + + assert diffusion_step_embed_dim_in % 2 == 0 + + half_dim = diffusion_step_embed_dim_in // 2 + _embed = np.log(10000) / (half_dim - 1) + _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device) + _embed = diffusion_steps * _embed + diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1) + + return diffusion_step_embed + + +def calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T): + """ + Compute diffusion process hyperparameters + + Parameters: + T (int): number of diffusion steps + beta_0 and beta_T (float): beta schedule start/end value, + where any beta_t in the middle is linearly interpolated + + Returns: + a dictionary of diffusion hyperparameters including: + T (int), Beta/Alpha/Alpha_bar/Sigma (torch.tensor on cpu, shape=(T, )) + These cpu tensors are changed to cuda tensors on each individual gpu + """ + + Beta = torch.linspace(beta_0, beta_T, diffusion_time_steps) # Linear schedule + Alpha = 1 - Beta + Alpha_bar = Alpha + 0 + Beta_tilde = Beta + 0 + for t in range(1, diffusion_time_steps): + Alpha_bar[t] *= Alpha_bar[t - 1] # \bar{\alpha}_t = \prod_{s=1}^t \alpha_s + Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (1 - Alpha_bar[t]) # \tilde{\beta}_t = \beta_t * (1-\bar{\alpha}_{t-1}) + # / (1-\bar{\alpha}_t) + Sigma = torch.sqrt(Beta_tilde) # \sigma_t^2 = \tilde{\beta}_t + + _dh = {} + _dh["T"], _dh["Beta"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"] = diffusion_time_steps, Beta, Alpha, Alpha_bar, Sigma + diffusion_hyperparams = _dh + return diffusion_hyperparams diff --git a/icu_benchmarks/imputation/sssdsa.py b/icu_benchmarks/imputation/sssdsa.py new file mode 100644 index 00000000..fb384efe --- /dev/null +++ b/icu_benchmarks/imputation/sssdsa.py @@ -0,0 +1,788 @@ +# Source: https://github.com/AI4HealthUOL/SSSD/blob/main/src/imputers/SSSDSAImputer.py +import numpy as np +import gin +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from icu_benchmarks.imputation.layers.s4layer import S4, LinearActivation +from icu_benchmarks.models.wrappers import ImputationWrapper + + +@gin.configurable("SSSDSA") +class SSSDSA(ImputationWrapper): + """ "SaShiMi model backbone. Adapted from https://github.com/AI4HealthUOL/SSSD/blob/main/src/imputers/SSSDSAImputer.py""" + + def __init__( + self, + d_model, + n_layers, + pool, + expand, + ff, + glu, + unet, + dropout, + in_channels, + out_channels, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + label_embed_dim, + label_embed_classes, + bidirectional, + s4_lmax, + s4_d_state, + s4_dropout, + s4_bidirectional, + diffusion_time_steps, + beta_0, + beta_T, + input_size, + *args, + **kwargs, + ): + """ + SaShiMi model backbone. + + Args: + d_model: dimension of the model. We generally use 64 for all our experiments. + n_layers: number of (Residual (S4) --> Residual (FF)) blocks at each pooling level. + We use 8 layers for our experiments, although we found that increasing layers even further generally + improves performance at the expense of training / inference speed. + pool: pooling factor at each level. Pooling shrinks the sequence length at lower levels. + We experimented with a pooling factor of 4 with 1 to 4 tiers of pooling and found 2 tiers to be best. + It's possible that a different combination of pooling factors and number of tiers may perform better. + expand: expansion factor when pooling. Features are expanded (i.e. the model becomes wider) at lower levels of the + architecture.We generally found 2 to perform best (among 2, 4). + ff: expansion factor for the FF inverted bottleneck. We generally found 2 to perform best (among 2, 4). + bidirectional: use bidirectional S4 layers. Bidirectional layers are suitable for use with non-causal models + such as diffusion models like DiffWave. + glu: use gated linear unit in the S4 layers. Adds parameters and generally improves performance. + unet: use a unet-like architecture, adding (Residual (S4) --> Residual (FF)) layers before downpooling. + All else fixed, this slows down inference (and slightly slows training), but generally improves performance. + We use this variant when dropping in SaShiMi into diffusion models, and this should generally be preferred + for non-autoregressive models. + dropout: dropout rate. Default to 0.0, since we haven't found settings where SaShiMi overfits. + """ + super().__init__( + d_model=d_model, + n_layers=n_layers, + pool=pool, + expand=expand, + ff=ff, + glu=glu, + unet=unet, + dropout=dropout, + in_channels=in_channels, + out_channels=out_channels, + diffusion_step_embed_dim_in=diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + label_embed_dim=label_embed_dim, + label_embed_classes=label_embed_classes, + bidirectional=bidirectional, + s4_lmax=s4_lmax, + s4_d_state=s4_d_state, + s4_dropout=s4_dropout, + s4_bidirectional=s4_bidirectional, + diffusion_time_steps=diffusion_time_steps, + beta_0=beta_0, + beta_T=beta_T, + input_size=input_size, + *args, + **kwargs, + ) + self.diffusion_parameters = calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T) + self.d_model = H = d_model + self.unet = unet + + def s4_block(dim, stride): + layer = S4( + d_model=dim, + l_max=s4_lmax, + d_state=s4_d_state, + bidirectional=s4_bidirectional, + postact="glu" if glu else None, + dropout=dropout, + transposed=True, + # hurwitz=True, # use the Hurwitz parameterization for stability + # tie_state=True, # tie SSM parameters across d_state in the S4 layer + trainable={ + "dt": True, + "A": True, + "P": True, + "B": True, + }, # train all internal S4 parameters + ) + + return ResidualBlock( + d_model=dim, + layer=layer, + dropout=dropout, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + in_channels=in_channels, + label_embed_dim=label_embed_dim, + stride=stride, + ) + + def ff_block(dim, stride): + layer = FFBlock( + d_model=dim, + expand=ff, + dropout=dropout, + ) + return ResidualBlock( + d_model=dim, + layer=layer, + dropout=dropout, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + in_channels=in_channels, + label_embed_dim=label_embed_dim, + stride=stride, + ) + + # Down blocks + d_layers, H = self.init_down_blocks(pool, unet, n_layers, ff, H, expand, s4_block, ff_block) + + # Center block + c_layers = self.init_center_blocks(pool, n_layers, ff, H, s4_block, ff_block) + + # Up blocks + u_layers, H = self.init_up_blocks(pool, n_layers, ff, H, expand, bidirectional, s4_block, ff_block) + + self.d_layers = nn.ModuleList(d_layers) + self.c_layers = nn.ModuleList(c_layers) + self.u_layers = nn.ModuleList(u_layers) + self.norm = nn.LayerNorm(H) + + self.init_conv = nn.Sequential(nn.Conv1d(in_channels, d_model, kernel_size=1), nn.ReLU()) + self.final_conv = nn.Sequential( + nn.Conv1d(d_model, d_model, kernel_size=1), nn.ReLU(), nn.Conv1d(d_model, out_channels, kernel_size=1) + ) + self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) + self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) + self.cond_embedding = nn.Embedding(label_embed_classes, label_embed_dim) + self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in + + assert H == d_model + + def init_down_blocks(self, pool, unet, n_layers, ff, H, expand, s4_block, ff_block): + d_layers = [] + for i, p in enumerate(pool): + if unet: + # Add blocks in the down layers + for _ in range(n_layers): + if i == 0: + d_layers.append(s4_block(H, 1)) + if ff > 0: + d_layers.append(ff_block(H, 1)) + elif i == 1: + d_layers.append(s4_block(H, p)) + if ff > 0: + d_layers.append(ff_block(H, p)) + # Add sequence downsampling and feature expanding + d_layers.append(DownPool(H, expand, p)) + H *= expand + return d_layers, H + + def init_up_blocks(self, pool, n_layers, ff, H, expand, bidirectional, s4_block, ff_block): + u_layers = [] + for i, p in enumerate(pool[::-1]): + block = [] + H //= expand + block.append(UpPool(H * expand, expand, p, causal=not bidirectional)) + + for _ in range(n_layers): + if i == 0: + block.append(s4_block(H, pool[0])) + if ff > 0: + block.append(ff_block(H, pool[0])) + + elif i == 1: + block.append(s4_block(H, 1)) + if ff > 0: + block.append(ff_block(H, 1)) + + u_layers.append(nn.ModuleList(block)) + return u_layers, H + + def init_center_blocks(self, pool, n_layers, ff, H, s4_block, ff_block): + c_layers = [] + for _ in range(n_layers): + c_layers.append(s4_block(H, pool[1] * 2)) + if ff > 0: + c_layers.append(ff_block(H, pool[1] * 2)) + return c_layers + + def on_fit_start(self) -> None: + self.diffusion_parameters = { + k: v.to(self.device) for k, v in self.diffusion_parameters.items() if isinstance(v, torch.Tensor) + } + return super().on_fit_start() + + def get_device(self): + return next(self.parameters()).device + + def forward(self, input_data): + # (transformed_X, cond, mask, diffusion_steps.view(B,1),)) + # audio_cond: same shape as audio, audio_mask: same shape as audio but binary to be imputed where zero + noise, conditional, mask, diffusion_steps = input_data + + conditional = conditional * mask + conditional = torch.cat([conditional, mask.float()], dim=1) + + diffusion_step_embed = calc_diffusion_step_embedding( + diffusion_steps, self.diffusion_step_embed_dim_in, self.get_device() + ) + diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) + diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) + + x = noise + x = self.init_conv(x) + + # Down blocks + outputs = [] + outputs.append(x) + for layer in self.d_layers: + if isinstance(layer, ResidualBlock): + x = layer((x, conditional, diffusion_step_embed)) + else: + x = layer(x) + outputs.append(x) + + # Center block + for layer in self.c_layers: + if isinstance(layer, ResidualBlock): + x = layer((x, conditional, diffusion_step_embed)) + else: + x = layer(x) + x = x + outputs.pop() # add a skip connection to the last output of the down block + + # Up blocks + for block in self.u_layers: + if self.unet: + for layer in block: + if isinstance(layer, ResidualBlock): + x = layer((x, conditional, diffusion_step_embed)) + else: + x = layer(x) + x = x + outputs.pop() # skip connection + else: + for layer in block: + if isinstance(layer, ResidualBlock): + x = layer((x, conditional, diffusion_step_embed)) + else: + x = layer(x) + if isinstance(layer, UpPool): + # Before modeling layer in the block + x = x + outputs.pop() + outputs.append(x) + x = x + outputs.pop() # add a skip connection from the input of the modeling part of this up block + + # feature projection + x = x.transpose(1, 2) # (batch, length, expand) + x = self.norm(x).transpose(1, 2) # (batch, expand, length) + + x = self.final_conv(x) # 128 to 12 + return x + + def step_fn(self, batch, step_prefix=""): + amputated_data, amputation_mask, target, target_missingness = batch + + amputated_data = torch.nan_to_num(amputated_data).permute(0, 2, 1) + amputation_mask = amputation_mask.permute(0, 2, 1).bool() + + padding_size = next_power(amputated_data.shape[2]) - amputated_data.shape[2] + amputated_data = torch.cat( + [ + amputated_data, + torch.zeros((amputated_data.shape[0], amputated_data.shape[1], padding_size), device=self.device), + ], + dim=2, + ) + amputation_mask = torch.cat( + [ + amputation_mask, + torch.zeros( + (amputation_mask.shape[0], amputation_mask.shape[1], padding_size), device=self.device, dtype=bool + ), + ], + dim=2, + ) + + observed_mask = 1 - amputation_mask.float() + amputation_mask = amputation_mask.bool() + + if step_prefix in ["train", "val"]: + T, Alpha_bar = self.hparams.diffusion_time_steps, self.diffusion_parameters["Alpha_bar"] + + B, C, L = amputated_data.shape # B is batchsize, C=1, L is audio length + diffusion_steps = torch.randint(T, size=(B, 1, 1)).to(self.device) # randomly sample diffusion steps from 1~T + + z = std_normal(amputated_data.shape, self.device) + z = amputated_data * observed_mask.float() + z * (1 - observed_mask).float() + transformed_X = ( + torch.sqrt(Alpha_bar[diffusion_steps]) * amputated_data + torch.sqrt(1 - Alpha_bar[diffusion_steps]) * z + ) # compute x_t from q(x_t|x_0) + epsilon_theta = self( + ( + transformed_X, + amputated_data, + observed_mask, + diffusion_steps.view(B, 1), + ) + ) # predict \epsilon according to \epsilon_\theta + + loss = self.loss(epsilon_theta[amputation_mask], z[amputation_mask]) + else: + target = target.permute(0, 2, 1) + target_missingness = target_missingness.permute(0, 2, 1) + target = torch.cat( + [target, torch.zeros((target.shape[0], target.shape[1], padding_size), device=self.device)], dim=2 + ) + target_missingness = torch.cat( + [ + target_missingness, + torch.zeros((target_missingness.shape[0], target_missingness.shape[1], padding_size), device=self.device), + ], + dim=2, + ) + imputed_data = self.sampling(amputated_data, observed_mask) + amputated_data[amputation_mask] = imputed_data[amputation_mask] + amputated_data[target_missingness > 0] = target[target_missingness > 0] + loss = self.loss(amputated_data, target) + for metric in self.metrics[step_prefix].values(): + metric.update((torch.flatten(amputated_data, start_dim=1).clone(), torch.flatten(target, start_dim=1).clone())) + + self.log(f"{step_prefix}/loss", loss.item(), prog_bar=True) + return loss + + def default_state(self, *args, **kwargs): + layers = list(self.d_layers) + list(self.c_layers) + [layer for block in self.u_layers for layer in block] + return [layer.default_state(*args, **kwargs) for layer in layers] + + def step(self, x, state, **kwargs): + """ + input: (batch, d_input) + output: (batch, d_output) + """ + # States will be popped in reverse order for convenience + state = state[::-1] + + # Down blocks + outputs = [] # Store all layers for SaShiMi + next_state = [] + for layer in self.d_layers: + outputs.append(x) + x, _next_state = layer.step(x, state=state.pop(), **kwargs) + next_state.append(_next_state) + if x is None: + break + + # Center block + if x is None: + # Skip computations since we've downsized + skipped = len(self.d_layers) - len(outputs) + for _ in range(skipped + len(self.c_layers)): + next_state.append(state.pop()) + if self.unet: + for i in range(skipped): + next_state.append(state.pop()) + u_layers = list(self.u_layers)[skipped // 3:] + else: + for i in range(skipped): + for _ in range(len(self.u_layers[i])): + next_state.append(state.pop()) + u_layers = list(self.u_layers)[skipped:] + else: + outputs.append(x) + for layer in self.c_layers: + x, _next_state = layer.step(x, state=state.pop(), **kwargs) + next_state.append(_next_state) + x = x + outputs.pop() + u_layers = self.u_layers + + x, next_state = self.up_blocks_loop(x, u_layers, next_state, state, outputs, **kwargs) + + # feature projection + x = self.norm(x) + return x, next_state + + def up_blocks_loop(self, x, u_layers, next_state, state, outputs, **kwargs): + for block in u_layers: + if self.unet: + for layer in block: + x, _next_state = layer.step(x, state=state.pop(), **kwargs) + next_state.append(_next_state) + x = x + outputs.pop() + else: + for layer in block: + x, _next_state = layer.step(x, state=state.pop(), **kwargs) + next_state.append(_next_state) + if isinstance(layer, UpPool): + # Before modeling layer in the block + x = x + outputs.pop() + outputs.append(x) + x = x + outputs.pop() + return x, next_state + + def setup_rnn(self, mode="dense"): + """ + Convert the SaShiMi model to a RNN for autoregressive generation. + + Args: + mode: S4 recurrence mode. Using `diagonal` can speed up generation by 10-20%. + `linear` should be faster theoretically but is slow in practice since it + dispatches more operations (could benefit from fused operations). + Note that `diagonal` could potentially be unstable if the diagonalization is numerically unstable + (although we haven't encountered this case in practice), while `dense` should always be stable. + """ + assert mode in ["dense", "diagonal", "linear"] + for module in self.modules(): + if hasattr(module, "setup_step"): + module.setup_step(mode) + + def sampling(self, cond, mask): + """ + Perform the complete sampling step according to p(x_0|x_T) = prod_{t=1}^T p_{\theta}(x_{t-1}|x_t) + + Parameters: + net (torch network): the wavenet model + size (tuple): size of tensor to be generated, + usually is (number of audios to generate, channels=1, length of audio) + diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams + note, the tensors need to be cuda tensors + + Returns: + the generated audio(s) in torch.tensor, shape=size + """ + + Alpha, Alpha_bar, Sigma = ( + self.diffusion_parameters["Alpha"], + self.diffusion_parameters["Alpha_bar"], + self.diffusion_parameters["Sigma"], + ) + + T = self.hparams.diffusion_time_steps + assert len(Alpha) == T + assert len(Alpha_bar) == T + assert len(Sigma) == T + + B, _, _ = cond.shape + x = std_normal(cond.shape, self.device) + + for t in range(T - 1, -1, -1): + x = x * (1 - mask).float() + cond * mask.float() + diffusion_steps = (t * torch.ones((B, 1))).to(self.device) # use the corresponding reverse step + epsilon_theta = self( + ( + x, + cond, + mask, + diffusion_steps, + ) + ) # predict \epsilon according to \epsilon_\theta + # update x_{t-1} to \mu_\theta(x_t) + x = (x - (1 - Alpha[t]) / torch.sqrt(1 - Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t]) + if t > 0: + x = x + Sigma[t] * std_normal(cond.shape, self.device) # add the variance term to x_{t-1} + + return x + + +def swish(x): + return x * torch.sigmoid(x) + + +def std_normal(size, device): + """ + Generate the standard Gaussian variable of a certain size + """ + + return torch.normal(0, 1, size=size).to(device) + + +def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in, device): + """ + Embed a diffusion step $t$ into a higher dimensional space + E.g. the embedding vector in the 128-dimensional space is + [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] + Parameters: + diffusion_steps (torch.long tensor, shape=(batchsize, 1)): + diffusion steps for batch data + diffusion_step_embed_dim_in (int, default=128): + dimensionality of the embedding space for discrete diffusion steps + + Returns: + the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)): + """ + + assert diffusion_step_embed_dim_in % 2 == 0 + + half_dim = diffusion_step_embed_dim_in // 2 + _embed = np.log(10000) / (half_dim - 1) + _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device) + _embed = diffusion_steps * _embed + diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1) + + return diffusion_step_embed + + +class Conv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1): + super(Conv, self).__init__() + self.padding = dilation * (kernel_size - 1) // 2 + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding, stride=stride) + + self.conv = nn.utils.weight_norm(self.conv) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, x): + out = self.conv(x) + return out + + +class DownPool(nn.Module): + def __init__(self, d_input, expand, pool): + super().__init__() + self.d_input = d_input + self.d_output = d_input * expand + self.pool = pool + + self.linear = LinearActivation( + d_input * pool, + self.d_output, + transposed=True, + weight_norm=True, + ) + + def forward(self, x): + x = rearrange(x, "... h (l s) -> ... (h s) l", s=self.pool) + x = self.linear(x) + return x + + def step(self, x, state, **kwargs): + """ + x: (..., H) + """ + + if x is None: + return None, state + state.append(x) + if len(state) == self.pool: + x = rearrange(torch.stack(state, dim=-1), "... h s -> ... (h s)") + x = x.unsqueeze(-1) + x = self.linear(x) + x = x.squeeze(-1) + return x, [] + else: + return None, state + + def default_state(self, *args, **kwargs): + return [] + + +class UpPool(nn.Module): + def __init__(self, d_input, expand, pool, causal=True): + super().__init__() + self.d_output = d_input // expand + self.pool = pool + self.causal = causal + + self.linear = LinearActivation( + d_input, + self.d_output * pool, + transposed=True, + weight_norm=True, + ) + + def forward(self, x): + x = self.linear(x) + + if self.causal: + x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality + x = rearrange(x, "... (h s) l -> ... h (l s)", s=self.pool) + + return x + + def step(self, x, state, **kwargs): + """ + x: (..., H) + """ + assert len(state) > 0 + y, state = state[0], state[1:] + if len(state) == 0: + assert x is not None + x = x.unsqueeze(-1) + x = self.linear(x) + x = x.squeeze(-1) + x = rearrange(x, "... (h s) -> ... h s", s=self.pool) + state = list(torch.unbind(x, dim=-1)) + else: + assert x is None + return y, state + + def default_state(self, *batch_shape, device=None): + state = torch.zeros(batch_shape + (self.d_output, self.pool), device=device) # (batch, h, s) + state = list(torch.unbind(state, dim=-1)) # List of (..., H) + return state + + +class FFBlock(nn.Module): + def __init__(self, d_model, expand=2, dropout=0.0): + """ + Feed-forward block. + + Args: + d_model: dimension of input + expand: expansion factor for inverted bottleneck + dropout: dropout rate + """ + super().__init__() + + input_linear = LinearActivation( + d_model, + d_model * expand, + transposed=True, + activation="gelu", + activate=True, + ) + dropout = nn.Dropout2d(dropout) if dropout > 0.0 else nn.Identity() + output_linear = LinearActivation( + d_model * expand, + d_model, + transposed=True, + activation=None, + activate=False, + ) + + self.ff = nn.Sequential( + input_linear, + dropout, + output_linear, + ) + + def forward(self, x): + return self.ff(x), None + + def default_state(self, *args, **kwargs): + return None + + def step(self, x, state, **kwargs): + # expects: (B, D, L) + return self.ff(x.unsqueeze(-1)).squeeze(-1), state + + +class ResidualBlock(nn.Module): + def __init__(self, d_model, layer, dropout, diffusion_step_embed_dim_out, in_channels, label_embed_dim, stride): + """ + Residual S4 block. + + Args: + d_model: dimension of the model + bidirectional: use bidirectional S4 layer + glu: use gated linear unit in the S4 layer + dropout: dropout rate + """ + super().__init__() + + self.layer = layer + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout2d(dropout) if dropout > 0.0 else nn.Identity() + + self.fc_t = nn.Linear(diffusion_step_embed_dim_out, d_model) + self.cond_conv = Conv(2 * in_channels, d_model, kernel_size=stride, stride=stride) + self.fc_label = nn.Linear(label_embed_dim, d_model) if label_embed_dim is not None else None + + def forward(self, input_data): + """ + Input x is shape (B, d_input, L) + """ + x, cond, diffusion_step_embed = input_data + + # add in diffusion step embedding + part_t = self.fc_t(diffusion_step_embed).unsqueeze(2) + z = x + part_t + + # Prenorm + z = self.norm(z.transpose(-1, -2)).transpose(-1, -2) + + z, _ = self.layer(z) + + cond = self.cond_conv(cond) + # cond = self.fc_label(cond) + + z = z + cond + + # Dropout on the output of the layer + z = self.dropout(z) + + # Residual connection + x = z + x + + return x + + def default_state(self, *args, **kwargs): + return self.layer.default_state(*args, **kwargs) + + def step(self, x, state, **kwargs): + z = x + + # Prenorm + z = self.norm(z) + + # Apply layer + z, state = self.layer.step(z, state, **kwargs) + + # Residual connection + x = z + x + + return x, state + + +def largets_component(number): + """ + returns the prime number that divides number the most times. + """ + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + return i + return number + + +def next_power(number): + """ + returns the next power of 2. + """ + return 1 << (number - 1).bit_length() + + +def calc_diffusion_hyperparams(diffusion_time_steps, beta_0, beta_T): + """ + Compute diffusion process hyperparameters + + Params: + T (int): number of diffusion steps. + beta_0 beta_T (float) :beta schedule start/end value, where any beta_t in the middle is linearly interpolated + + Returns: + a dictionary of diffusion hyperparameters including: + T (int), Beta/Alpha/Alpha_bar/Sigma (torch.tensor on cpu, shape=(T, )) + These cpu tensors are changed to cuda tensors on each individual gpu + """ + + Beta = torch.linspace(beta_0, beta_T, diffusion_time_steps) # Linear schedule + Alpha = 1 - Beta + Alpha_bar = Alpha + 0 + Beta_tilde = Beta + 0 + for t in range(1, diffusion_time_steps): + Alpha_bar[t] *= Alpha_bar[t - 1] # \bar{\alpha}_t = \prod_{s=1}^t \alpha_s + Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (1 - Alpha_bar[t]) # \tilde{\beta}_t = \beta_t * (1-\bar{\alpha}_{t-1}) + # / (1-\bar{\alpha}_t) + Sigma = torch.sqrt(Beta_tilde) # \sigma_t^2 = \tilde{\beta}_t + + _dh = {} + _dh["T"], _dh["Beta"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"] = diffusion_time_steps, Beta, Alpha, Alpha_bar, Sigma + diffusion_hyperparams = _dh + return diffusion_hyperparams diff --git a/icu_benchmarks/models/metric_constants.py b/icu_benchmarks/models/constants.py similarity index 65% rename from icu_benchmarks/models/metric_constants.py rename to icu_benchmarks/models/constants.py index 6fcb9268..70ba5b65 100644 --- a/icu_benchmarks/models/metric_constants.py +++ b/icu_benchmarks/models/constants.py @@ -1,5 +1,5 @@ from ignite.contrib.metrics import AveragePrecision, ROC_AUC, PrecisionRecallCurve, RocCurve -from ignite.metrics import MeanAbsoluteError, Accuracy # , ConfusionMatrix +from ignite.metrics import Accuracy, RootMeanSquaredError # , ConfusionMatrix from sklearn.calibration import calibration_curve from sklearn.metrics import ( average_precision_score, @@ -14,8 +14,9 @@ mean_squared_error, # f1_score, ) +from enum import Enum -from icu_benchmarks.models.metrics import CalibrationCurve, BalancedAccuracy +from icu_benchmarks.models.metrics import CalibrationCurve, BalancedAccuracy, MAE, JSD # TODO: revise transformation for metrics in wrappers.py in order to handle metrics that can not handle a mix of binary and @@ -27,8 +28,8 @@ class MLMetrics: # "Confusion_Matrix": confusion_matrix, # "F1": f1_score, "PR": average_precision_score, - "PRC": precision_recall_curve, - "ROC": roc_curve, + "PR_Curve": precision_recall_curve, + "RO_Curve": roc_curve, } MULTICLASS_CLASSIFICATION = { @@ -50,19 +51,35 @@ class MLMetrics: # TODO: add support for confusion matrix class DLMetrics: BINARY_CLASSIFICATION = { - "AUC": ROC_AUC(), - "Calibration_Curve": CalibrationCurve(), + "AUC": ROC_AUC, + "Calibration_Curve": CalibrationCurve, # "Confusion_Matrix": ConfusionMatrix(num_classes=2), - "PR": AveragePrecision(), - "PRC": PrecisionRecallCurve(), - "ROC": RocCurve(), + "PR": AveragePrecision, + "PR_Curve": PrecisionRecallCurve, + "RO_Curve": RocCurve, } MULTICLASS_CLASSIFICATION = { - "Accuracy": Accuracy(), - "BalancedAccuracy": BalancedAccuracy(), + "Accuracy": Accuracy, + "BalancedAccuracy": BalancedAccuracy, } REGRESSION = { - "MAE": MeanAbsoluteError(), + "MAE": MAE, } + + IMPUTATION = { + "rmse": RootMeanSquaredError, + "mae": MAE, + "jsd": JSD, + } + + +class ImputationInit(str, Enum): + """Type of initialization to use for the imputation model.""" + + NORMAL = "normal" + UNIFORM = "uniform" + XAVIER = "xavier" + KAIMING = "kaiming" + ORTHOGONAL = "orthogonal" diff --git a/icu_benchmarks/models/encoders.py b/icu_benchmarks/models/dl_models.py similarity index 57% rename from icu_benchmarks/models/encoders.py rename to icu_benchmarks/models/dl_models.py index 8abe058b..5756e1af 100644 --- a/icu_benchmarks/models/encoders.py +++ b/icu_benchmarks/models/dl_models.py @@ -2,17 +2,45 @@ from numbers import Integral import numpy as np import torch.nn as nn - from icu_benchmarks.models.layers import TransformerBlock, LocalBlock, TemporalBlock, PositionalEncoding +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +@gin.configurable +class RNNet(DLPredictionWrapper): + """Torch standard RNN model""" + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred @gin.configurable -class LSTMNet(nn.Module): - def __init__(self, input_dim, hidden_dim, layer_dim, num_classes): - super().__init__() +class LSTMNet(DLPredictionWrapper): + """Torch standard LSTM model.""" + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + ) self.hidden_dim = hidden_dim self.layer_dim = layer_dim - self.rnn = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True) + self.rnn = nn.LSTM(input_size[2], hidden_dim, layer_dim, batch_first=True) self.logit = nn.Linear(hidden_dim, num_classes) def init_hidden(self, x): @@ -28,12 +56,16 @@ def forward(self, x): @gin.configurable -class GRUNet(nn.Module): - def __init__(self, input_dim, hidden_dim, layer_dim, num_classes): - super().__init__() +class GRUNet(DLPredictionWrapper): + """Torch standard GRU model.""" + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + ) self.hidden_dim = hidden_dim self.layer_dim = layer_dim - self.rnn = nn.GRU(input_dim, hidden_dim, layer_dim, batch_first=True) + self.rnn = nn.GRU(input_size[2], hidden_dim, layer_dim, batch_first=True) self.logit = nn.Linear(hidden_dim, num_classes) def init_hidden(self, x): @@ -49,23 +81,40 @@ def forward(self, x): @gin.configurable -class Transformer(nn.Module): +class Transformer(DLPredictionWrapper): + """Transformer model as defined by the HiRID-Benchmark (https://github.com/ratschlab/HIRID-ICU-Benchmark).""" + def __init__( self, - input_dim, + input_size, hidden, heads, ff_hidden_mult, depth, num_classes, + *args, dropout=0.0, l1_reg=0, pos_encoding=True, dropout_att=0.0, + **kwargs, ): - super().__init__() + super().__init__( + input_size=input_size, + hidden=hidden, + heads=heads, + ff_hidden_mult=ff_hidden_mult, + depth=depth, + num_classes=num_classes, + *args, + dropout=dropout, + l1_reg=l1_reg, + pos_encoding=pos_encoding, + dropout_att=dropout_att, + **kwargs, + ) hidden = hidden if hidden % 2 == 0 else hidden + 1 # Make sure hidden is even - self.input_embedding = nn.Linear(input_dim, hidden) # This acts as a time-distributed layer by defaults + self.input_embedding = nn.Linear(input_size[2], hidden) # This acts as a time-distributed layer by defaults if pos_encoding: self.pos_encoder = PositionalEncoding(hidden) else: @@ -100,25 +149,41 @@ def forward(self, x): @gin.configurable -class LocalTransformer(nn.Module): +class LocalTransformer(DLPredictionWrapper): def __init__( self, - input_dim, + input_size, hidden, heads, ff_hidden_mult, depth, num_classes, + *args, dropout=0.0, l1_reg=0, pos_encoding=True, local_context=1, dropout_att=0.0, + **kwargs, ): - super().__init__() + super().__init__( + input_size=input_size, + hidden=hidden, + heads=heads, + ff_hidden_mult=ff_hidden_mult, + depth=depth, + num_classes=num_classes, + *args, + dropout=dropout, + l1_reg=l1_reg, + pos_encoding=pos_encoding, + local_context=local_context, + dropout_att=dropout_att, + **kwargs, + ) hidden = hidden if hidden % 2 == 0 else hidden + 1 # Make sure hidden is even - self.input_embedding = nn.Linear(input_dim, hidden) # This acts as a time-distributed layer by defaults + self.input_embedding = nn.Linear(input_size[2], hidden) # This acts as a time-distributed layer by defaults if pos_encoding: self.pos_encoder = PositionalEncoding(hidden) else: @@ -153,11 +218,21 @@ def forward(self, x): return pred -# From TCN original paper https://github.com/locuslab/TCN @gin.configurable -class TemporalConvNet(nn.Module): - def __init__(self, input_dim, num_channels, num_classes, max_seq_length=0, kernel_size=2, dropout=0.0): - super(TemporalConvNet, self).__init__() +class TemporalConvNet(DLPredictionWrapper): + """Temporal Convolutional Network. Adapted from TCN original paper https://github.com/locuslab/TCN""" + + def __init__(self, input_size, num_channels, num_classes, *args, max_seq_length=0, kernel_size=2, dropout=0.0, **kwargs): + super().__init__( + input_size=input_size, + num_channels=num_channels, + num_classes=num_classes, + *args, + max_seq_length=max_seq_length, + kernel_size=kernel_size, + dropout=dropout, + **kwargs, + ) layers = [] # We compute automatically the depth based on the desired seq_length. @@ -169,7 +244,7 @@ def __init__(self, input_dim, num_channels, num_classes, max_seq_length=0, kerne num_levels = len(num_channels) for i in range(num_levels): dilation_size = 2**i - in_channels = input_dim if i == 0 else num_channels[i - 1] + in_channels = input_size[2] if i == 0 else num_channels[i - 1] out_channels = num_channels[i] layers += [ TemporalBlock( diff --git a/icu_benchmarks/models/layers.py b/icu_benchmarks/models/layers.py index da93a419..c08623bd 100644 --- a/icu_benchmarks/models/layers.py +++ b/icu_benchmarks/models/layers.py @@ -35,7 +35,7 @@ def parallel_recomb(q_t, kv_t, att_type="all", local_context=3, bin_size=None): class PositionalEncoding(nn.Module): - "Positional Encoding, mostly from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html" + """Positional Encoding, mostly from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html""" def __init__(self, emb, max_len=3000): super().__init__() @@ -53,25 +53,18 @@ def forward(self, x): class SelfAttention(nn.Module): - """Multi Head Attention block from Attention is All You Need. - Input has shape (batch_size, n_timestamps, emb). - - ---------- - emb: - Dimension of the input vector. - hidden: - Dimension of query, key, value matrices. - heads: - Number of heads. - - mask: - Mask the future timestamps - """ + """Multi Head Attention block from Attention is All You Need (https://arxiv.org/abs/1706.03762). Input has shape + (batch_size, n_timestamps, emb).""" def __init__( self, emb, hidden, heads=8, mask=True, att_type="all", local_context=None, mask_aggregation="union", dropout_att=0.0 ): - """Initialize the Multi Head Block.""" + """Initialize the Multi Head Block. + Args: + emb: Dimension of the input vector. + hidden: Hidden dimension of query, key, value matrices. + heads: Number of heads. + mask: Whether to mask the attention matrix.""" super().__init__() self.emb = emb @@ -95,13 +88,10 @@ def __init__( def forward(self, x): """ - x: - Input data tensor with shape (batch_size, n_timestemps, emb) - hidden: - Hidden dim (dimension of query, key, value matrixes) - - Returns - Self attention tensor with shape (batch_size, n_timestemps, emb) + Args: + x: Input data tensor with shape (batch_size, n_timestamps, emb) + Returns: + Self attention tensor with shape (batch_size, n_timestamps, emb) """ # bs - batch_size, n - vectors number, emb - embedding dimensionality bs, n, emb = x.size() @@ -139,7 +129,6 @@ def forward(self, x): dot = torch.where(mask_tensor.bool(), dot, torch.tensor(float("-inf")).to(dot.device)).view(bs * h, n, n) elif self.mask_aggregation == "split": - dot_list = list(torch.split(dot, dot.shape[0] // len(self.att_type), dim=0)) for i, att_type in enumerate(self.att_type): mask_tensor = parallel_recomb( @@ -280,8 +269,9 @@ def forward(self, x): return x -# From TCN original paper https://github.com/locuslab/TCN class Chomp1d(nn.Module): + """From TCN original paper https://github.com/locuslab/TCN""" + def __init__(self, chomp_size): super(Chomp1d, self).__init__() self.chomp_size = chomp_size diff --git a/icu_benchmarks/models/metrics.py b/icu_benchmarks/models/metrics.py index 8f81f3e6..3592277a 100644 --- a/icu_benchmarks/models/metrics.py +++ b/icu_benchmarks/models/metrics.py @@ -2,6 +2,13 @@ from typing import Callable import numpy as np from ignite.metrics import EpochMetric +from sklearn.metrics import balanced_accuracy_score, mean_absolute_error +from sklearn.calibration import calibration_curve +from scipy.spatial.distance import jensenshannon +"""" +This file contains metrics that are not available in ignite.metrics. Specifically, it adds transformation capabilities to some +metrics. +""" def accuracy(output, target, topk=(1,)): @@ -22,38 +29,27 @@ def accuracy(output, target, topk=(1,)): def balanced_accuracy_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: - try: - from sklearn.metrics import balanced_accuracy_score - except ImportError: - raise RuntimeError("This contrib module requires sklearn to be installed.") - y_true = y_targets.numpy() y_pred = np.argmax(y_preds.numpy(), axis=-1) return balanced_accuracy_score(y_true, y_pred) def ece_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: - try: - from sklearn.calibration import calibration_curve - except ImportError: - raise RuntimeError("This contrib module requires sklearn to be installed.") - y_true = y_targets.numpy() y_pred = y_preds.numpy() return calibration_curve(y_true, y_pred, n_bins=10) def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: - try: - from sklearn.metrics import mean_absolute_error - except ImportError: - raise RuntimeError("This contrib module requires sklearn to be installed.") - y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] return mean_absolute_error(y_true, y_pred) +def JSD_fn(y_preds: torch.Tensor, y_targets: torch.Tensor): + return jensenshannon(abs(y_preds).flatten(), abs(y_targets).flatten()) ** 2 + + class BalancedAccuracy(EpochMetric): def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: super(BalancedAccuracy, self).__init__( @@ -80,3 +76,16 @@ def __init__( output_transform=output_transform, check_compute_fn=check_compute_fn, ) + + +class JSD(EpochMetric): + def __init__( + self, + output_transform: Callable = lambda x: x, + check_compute_fn: bool = False, + ) -> None: + super(JSD, self).__init__( + lambda x, y: JSD_fn(x, y), + output_transform=output_transform, + check_compute_fn=check_compute_fn, + ) diff --git a/icu_benchmarks/models/ml_models.py b/icu_benchmarks/models/ml_models.py new file mode 100644 index 00000000..e06d3fe7 --- /dev/null +++ b/icu_benchmarks/models/ml_models.py @@ -0,0 +1,124 @@ +import gin +import lightgbm +from sklearn import linear_model +from sklearn import ensemble +from sklearn import neural_network +from sklearn import svm +from icu_benchmarks.models.wrappers import MLWrapper +from icu_benchmarks.contants import RunMode + + +class LGBMWrapper(MLWrapper): + def fit_model(self, train_data, train_labels, val_data, val_labels): + """Fitting function for LGBM models.""" + self.model.fit( + train_data, + train_labels, + eval_set=(val_data, val_labels), + verbose=True, + callbacks=[ + lightgbm.early_stopping(self.hparams.patience, verbose=False), + lightgbm.log_evaluation(period=-1, show_stdv=False), + ], + ) + val_loss = list(self.model.best_score_["valid_0"].values())[0] + return val_loss + + +@gin.configurable +class LGBMClassifier(LGBMWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(lightgbm.LGBMClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable +class LGBMRegressor(LGBMWrapper): + _supported_run_modes = [RunMode.regression] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(lightgbm.LGBMRegressor, *args, **kwargs) + super().__init__(*args, **kwargs) + + +# Scikit-learn models +@gin.configurable +class LogisticRegression(MLWrapper): + __supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(linear_model.LogisticRegression, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable() +class LinearRegression(MLWrapper): + _supported_run_modes = [RunMode.regression] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(linear_model.LinearRegression, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable() +class ElasticNet(MLWrapper): + _supported_run_modes = [RunMode.regression] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(linear_model.ElasticNet, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable +class RFClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(ensemble.RandomForestClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable +class SVMClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.model_args(svm.SVC, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable +class SVMRegressor(MLWrapper): + _supported_run_modes = [RunMode.regression] + + def __init__(self, *args, **kwargs): + self.model = self.model_args(svm.SVR, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable +class PerceptronClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(neural_network.MLPClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable +class MLPClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(neural_network.MLPClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) + + +class MLPRegressor(MLWrapper): + _supported_run_modes = [RunMode.regression] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(neural_network.MLPRegressor, *args, **kwargs) + super().__init__(*args, **kwargs) diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index 7e4d730f..1c2a91cd 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -1,16 +1,26 @@ import os -import random -import sys import gin import torch import logging -import numpy as np import pandas as pd +from torch.optim import Adam +from torch.utils.data import DataLoader +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pathlib import Path +from icu_benchmarks.data.loader import PredictionDataset, ImputationDataset +from icu_benchmarks.models.utils import save_config_file, JSONMetricsLogger +from icu_benchmarks.contants import RunMode +from icu_benchmarks.data.constants import DataSplit as Split -from icu_benchmarks.data.loader import SICUDataset -from icu_benchmarks.models.wrappers import MLWrapper -from icu_benchmarks.models.utils import save_config_file +cpu_core_count = len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else os.cpu_count() + + +def assure_minimum_length(dataset): + if len(dataset) < 2: + return [dataset[0], dataset[0]] + return dataset @gin.configurable("train_common") @@ -19,11 +29,22 @@ def train_common( log_dir: Path, load_weights: bool = False, source_dir: Path = None, - seed: int = 1234, reproducible: bool = True, - model: object = MLWrapper, + mode: str = RunMode.classification, + model: object = gin.REQUIRED, weight: str = None, - test_on: str = "test", + optimizer: type = Adam, + precision=32, + batch_size=64, + epochs=1000, + patience=20, + min_delta=1e-5, + test_on: str = Split.test, + use_wandb: bool = False, + cpu: bool = False, + verbose=False, + ram_cache=False, + num_workers: int = min(cpu_core_count, torch.cuda.device_count() * 4 * int(torch.cuda.is_available()), 32), ): """Common wrapper to train all benchmarked models. @@ -32,48 +53,120 @@ def train_common( log_dir: Path to directory where model output should be saved. load_weights: If set to true, skip training and load weights from source_dir instead. source_dir: If set to load weights, path to directory containing trained weights. - seed: Common seed used for any random operation. reproducible: If set to true, set torch to run reproducibly. + mode: Mode of the model. Can be one of the values of RunMode. + model: Model to be trained. + weight: Weight to be used for the loss function. + optimizer: Optimizer to be used for training. + precision: Pytorch precision to be used for training. Can be 16 or 32. + batch_size: Batch size to be used for training. + epochs: Number of epochs to train for. + patience: Number of epochs to wait before early stopping. + min_delta: Minimum change in loss to be considered an improvement. + test_on: If set to "test", evaluate the model on the test set. If set to "val", evaluate on the validation set. + use_wandb: If set to true, log to wandb. + cpu: If set to true, run on cpu. + verbose: Enable detailed logging. + ram_cache: Whether to cache the data in RAM. + num_workers: Number of workers to use for data loading. """ - # Setting the seed before gin parsing - os.environ["PYTHONHASHSEED"] = str(seed) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) + logging.info(f"Training model: {model.__name__}.") + dataset_class = ImputationDataset if mode == RunMode.imputation else PredictionDataset - if reproducible: - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - torch.use_deterministic_algorithms(True) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + logging.info(f"Logging to directory: {log_dir}.") + save_config_file(log_dir) # We save the operative config before and also after training - model.set_log_dir(log_dir) - save_config_file(log_dir) + train_dataset = dataset_class(data, split=Split.train, ram_cache=ram_cache) + val_dataset = dataset_class(data, split=Split.val, ram_cache=ram_cache) + train_dataset, val_dataset = assure_minimum_length(train_dataset), assure_minimum_length(val_dataset) + batch_size = min(batch_size, len(train_dataset), len(val_dataset)) + + logging.debug(f"Training on {len(train_dataset)} samples and validating on {len(val_dataset)} samples.") + logging.info(f"Using {num_workers} workers for data loading.") + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + ) - dataset = SICUDataset(data, split="train") - val_dataset = SICUDataset(data, split="val") + data_shape = next(iter(train_loader))[0].shape + model = model(optimizer=optimizer, input_size=data_shape, epochs=epochs, run_mode=mode) + model.set_weight(weight, train_dataset) if load_weights: - if (source_dir / "model.torch").is_file(): - model.load_weights(source_dir / "model.torch") - elif (source_dir / "model.txt").is_file(): - model.load_weights(source_dir / "model.txt") - elif (source_dir / "model.joblib").is_file(): - model.load_weights(source_dir / "model.joblib") + if source_dir.exists(): + # if not model.needs_training: + checkpoint = torch.load(source_dir / "model.ckpt") + # else: + model = model.load_state_dict(checkpoint["state_dict"]) else: - raise Exception("No weights to load at path : {}".format(source_dir / "model.*")) + raise Exception(f"No weights to load at path : {source_dir}") + + model.set_trained_columns(train_dataset.get_feature_names()) + + loggers = [TensorBoardLogger(log_dir), JSONMetricsLogger(log_dir)] + + callbacks = [ + EarlyStopping(monitor="val/loss", min_delta=min_delta, patience=patience, strict=False), + ModelCheckpoint(log_dir, filename="model", save_top_k=1, save_last=True), + ] + if verbose: + callbacks.append(TQDMProgressBar(refresh_rate=min(100, len(train_loader) // 2))) + if precision == 16 or "16-mixed": + torch.set_float32_matmul_precision("medium") + trainer = Trainer( + max_epochs=epochs if model.needs_training else 1, + callbacks=callbacks, + precision=precision, + accelerator="auto" if not cpu else "cpu", + devices=max(torch.cuda.device_count(), 1), + deterministic=reproducible, + benchmark=not reproducible, + enable_progress_bar=verbose, + logger=loggers, + num_sanity_val_steps=0, + ) + + if model.needs_fit: + logging.info("Fitting model to data.") + model.fit(train_dataset, val_dataset) + model.save_model(log_dir, "last") + logging.info("Fitting complete.") - else: - try: - model.train(dataset, val_dataset, weight, seed) - except ValueError as e: - logging.exception(e) - sys.exit(1) + if model.needs_training: + logging.info("Training model.") + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + logging.info("Training complete.") - test_dataset = SICUDataset(data, split=test_on) - weight = dataset.get_balance() + test_dataset = dataset_class(data, split=test_on) + test_dataset = assure_minimum_length(test_dataset) + test_loader = ( + DataLoader( + test_dataset, + batch_size=min(batch_size * 4, len(test_dataset)), + shuffle=False, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + ) + if model.needs_training + else DataLoader([test_dataset.to_tensor()], batch_size=1) + ) - # save config file again to capture missing gin parameters + model.set_weight("balanced", train_dataset) + test_loss = trainer.test(model, dataloaders=test_loader, verbose=verbose)[0]["test/loss"] save_config_file(log_dir) - return model.test(test_dataset, weight, seed) + return test_loss diff --git a/icu_benchmarks/models/utils.py b/icu_benchmarks/models/utils.py index 563abae5..53bfcae2 100644 --- a/icu_benchmarks/models/utils.py +++ b/icu_benchmarks/models/utils.py @@ -1,3 +1,6 @@ +import json +from typing import Dict +from pathlib import Path from datetime import timedelta from enum import Enum from json import JSONEncoder @@ -6,23 +9,12 @@ import numpy as np import torch - -def save_model(model, optimizer, epoch, save_file): - state = { - "model": model.state_dict(), - "optimizer": optimizer.state_dict(), - "epoch": epoch, - } - torch.save(state, save_file) - del state - - -def load_model_state(filepath, model, optimizer=None): - state = torch.load(filepath) - model.load_state_dict(state["model"]) - if optimizer is not None: - optimizer.load_state_dict(state["optimizer"]) - logging.info("Loaded model and optimizer") +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.utilities import rank_zero_only +from torch.nn import Module +from torch.optim import Optimizer, Adam, SGD, RAdam +from typing import Optional, Union +from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, MultiStepLR, ExponentialLR def save_config_file(log_dir): @@ -31,6 +23,69 @@ def save_config_file(log_dir): f.write(gin.operative_config_str()) +def create_optimizer(name: str, model: Module, lr: float, momentum: float) -> Optimizer: + """creates the specified optimizer with the given parameters + + Args: + name (str): str name of optimizer + model (Module): the model used for training + lr (float): learning rate + momentum (float): momentum (only for sgd optimizer) + + Raises: + ValueError: thrown if optimizer name not known + + Returns: + Optimizer: the model optimizer + """ + name = name.lower() + if name == "adam": + return Adam(params=model.parameters(), lr=lr) + elif name == "sgd": + return SGD(params=model.parameters(), lr=lr, momentum=momentum) + elif name == "radam": + return RAdam(params=model.parameters(), lr=lr) + else: + raise ValueError(f"No optimizer with name {name} found!") + + +def create_scheduler( + scheduler_name: Optional[str], + optimizer: Optimizer, + lr_factor: float, + lr_steps: Optional[list], + epochs: int, +) -> Union[_LRScheduler, None]: + """creates a learning rate scheduler with the given parameters + + Args: + scheduler_name (Optional[str]): str name of scheduler or None, in which case None will be returned + optimizer (Optimizer): the learning optimizer + lr_factor (float): the learning rate factor + lr_steps (Optional[list]): learning rate steps for the scheduler to take (only supported for step scheduler) + epochs (int): number of scheduler epochs (only supported for cosine scheduler) + + Raises: + ValueError: thrown if step scheduler was chosen but no steps were passed + ValueError: thrown if scheduler name not known and not None + + Returns: + Union[_LRScheduler, None]: either the learning rate scheduler object or None if scheduler_name was None + """ + if scheduler_name == "step": + if not lr_steps: + raise ValueError("step scheduler chosen but no lr steps passed!") + return MultiStepLR(optimizer, lr_steps, lr_factor) + elif scheduler_name == "exponential": + return ExponentialLR(optimizer, lr_factor) + elif scheduler_name == "cosine": + return CosineAnnealingLR(optimizer, epochs) + elif not scheduler_name: + return None + else: + raise ValueError(f"no scheduler with name {scheduler_name} found!") + + class JsonResultLoggingEncoder(JSONEncoder): """JSON converter for objects that are not serializable by default.""" @@ -52,7 +107,7 @@ def default(self, obj): return JSONEncoder.default(self, obj) -class Align(Enum): +class Align(str, Enum): LEFT = "<" CENTER = "^" RIGHT = ">" @@ -87,3 +142,49 @@ def log_table_row( if highlight: table_row = f"\x1b[31;32m{table_row}\x1b[0m" logging.log(level, table_row) + + +class JSONMetricsLogger(Logger): + def __init__(self, output_dir=None, **kwargs): + super().__init__(**kwargs) + if output_dir is None: + output_dir = Path.cwd() / "metrics" + logging.info(f"logging metrics to file: {str(output_dir.resolve())}") + self.output_dir = output_dir + self.output_dir.mkdir(parents=True, exist_ok=True) + + @property + def name(self): + return "json_metrics_logger" + + @rank_zero_only + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + old_metrics = {} + stage_metrics = { + "train": {"/".join(key.split("/")[1:]): value for key, value in metrics.items() if key.startswith("train/")}, + "val": {"/".join(key.split("/")[1:]): value for key, value in metrics.items() if key.startswith("val/")}, + "test": {"/".join(key.split("/")[1:]): value for key, value in metrics.items() if key.startswith("test/")}, + } + for stage, metrics in stage_metrics.items(): + if metrics: + output_file = self.output_dir / f"{stage}_metrics.json" + old_metrics = {} + if output_file.exists(): + try: + with output_file.open("r") as f: + old_metrics = json.load(f) + logging.debug(f"updating {stage} metrics file...") + except json.decoder.JSONDecodeError: + logging.warning("could not decode json file, overwriting...") + + old_metrics.update(metrics) + with output_file.open("w") as f: + json.dump(old_metrics, f, cls=JsonResultLoggingEncoder, indent=4) + + @property + def version(self): + return "0.1" + + @rank_zero_only + def log_hyperparams(self, params): + pass diff --git a/icu_benchmarks/models/wrappers.py b/icu_benchmarks/models/wrappers.py index 1e5580f9..5fc94a12 100644 --- a/icu_benchmarks/models/wrappers.py +++ b/icu_benchmarks/models/wrappers.py @@ -1,82 +1,201 @@ -import inspect -import json import logging -import os -from pathlib import Path +from abc import ABC +from typing import Dict, Any +from typing import List, Optional, Union +import sklearn.metrics +from sklearn.metrics import log_loss +from torch.nn import MSELoss, CrossEntropyLoss +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +import inspect import gin -import joblib -import lightgbm import numpy as np import torch -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import log_loss +from ignite.exceptions import NotComputableError +from icu_benchmarks.models.constants import ImputationInit +from icu_benchmarks.models.utils import create_optimizer, create_scheduler +from joblib import dump +from pytorch_lightning import LightningModule -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm, trange - -from icu_benchmarks.models.metric_constants import MLMetrics, DLMetrics -from icu_benchmarks.models.encoders import LSTMNet -from icu_benchmarks.models.metrics import MAE -from icu_benchmarks.models.utils import save_model, load_model_state, log_table_row, JsonResultLoggingEncoder +from icu_benchmarks.models.constants import MLMetrics, DLMetrics +from icu_benchmarks.contants import RunMode gin.config.external_configurable(torch.nn.functional.nll_loss, module="torch.nn.functional") gin.config.external_configurable(torch.nn.functional.cross_entropy, module="torch.nn.functional") gin.config.external_configurable(torch.nn.functional.mse_loss, module="torch.nn.functional") -gin.config.external_configurable(lightgbm.LGBMClassifier, module="lightgbm") -gin.config.external_configurable(lightgbm.LGBMRegressor, module="lightgbm") -gin.config.external_configurable(LogisticRegression) +gin.config.external_configurable(sklearn.metrics.mean_squared_error, module="sklearn.metrics") +gin.config.external_configurable(sklearn.metrics.log_loss, module="sklearn.metrics") + + +@gin.configurable("BaseModule") +class BaseModule(LightningModule): + needs_training = False + needs_fit = False + + weight = None + metrics = {} + trained_columns = None + run_mode = None + + def forward(self, *args, **kwargs): + raise NotImplementedError() + + def step_fn(self, batch, step_prefix=""): + raise NotImplementedError() + + def finalize_step(self, step_prefix=""): + pass + + def set_metrics(self, *args, **kwargs): + self.metrics = {} + + def set_trained_columns(self, columns: List[str]): + self.trained_columns = columns + def set_weight(self, weight, *args, **kwargs): + pass -def pick_device_config(hint=None): - if (hint == "cuda" or hint is None) and torch.cuda.is_available(): - device = torch.device("cuda:0") - pin_memory = True - n_worker = 1 - elif (hint == "mps" or hint is None) and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - device = torch.device("mps") - pin_memory = True - n_worker = 1 - else: - device = torch.device("cpu") - pin_memory = False - n_worker = os.cpu_count() - return device, pin_memory, n_worker + def training_step(self, batch, batch_idx): + return self.step_fn(batch, "train") + + def validation_step(self, batch, batch_idx): + return self.step_fn(batch, "val") + + def test_step(self, batch, batch_idx): + return self.step_fn(batch, "test") + + def on_train_epoch_end(self) -> None: + self.finalize_step("train") + + def on_validation_epoch_end(self) -> None: + self.finalize_step("val") + + def on_test_epoch_end(self) -> None: + self.finalize_step("test") + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + checkpoint["class"] = self.__class__ + checkpoint["trained_columns"] = self.trained_columns + return super().on_save_checkpoint(checkpoint) + + def save_model(self, save_path, file_name, file_extension): + raise NotImplementedError() + + def check_supported_runmode(self, runmode: RunMode): + if runmode not in self._supported_run_modes: + raise ValueError(f"Runmode {runmode} not supported for {self.__class__.__name__}") + return True @gin.configurable("DLWrapper") -class DLWrapper(object): +class DLWrapper(BaseModule, ABC): + needs_training = True + needs_fit = False + _metrics_warning_printed = set() + _supported_run_modes = [RunMode.classification, RunMode.regression, RunMode.imputation] + def __init__( self, - encoder=LSTMNet, - loss=torch.nn.functional.cross_entropy, - optimizer_fn=torch.optim.Adam, - device=None, - verbose_logging=True, + loss=CrossEntropyLoss(), + optimizer=torch.optim.Adam, + run_mode: RunMode = RunMode.classification, + input_shape=None, + lr: float = 0.002, + momentum: float = 0.9, + lr_scheduler: Optional[str] = None, + lr_factor: float = 0.99, + lr_steps: Optional[List[int]] = None, + epochs: int = 100, + input_size: torch.Tensor = None, + initialization_method: str = "normal", + **kwargs, ): - device, pin_memory, n_worker = pick_device_config(device) - - self.device = device - logging.info(f"Model will be trained using {device}") - self.pin_memory = pin_memory - self.n_worker = n_worker - - self.encoder = encoder - self.encoder.to(device) + """Interface for Deep Learning models.""" + super().__init__() + self.save_hyperparameters(ignore=["loss", "optimizer"]) self.loss = loss - self.optimizer = optimizer_fn(self.encoder.parameters()) + self.optimizer = optimizer self.scaler = None - self.verbose_logging = verbose_logging + self.check_supported_runmode(run_mode) + self.run_mode = run_mode + + def on_fit_start(self): + self.metrics = { + step_name: { + metric_name: (metric() if isinstance(metric, type) else metric) + for metric_name, metric in self.set_metrics().items() + } + for step_name in ["train", "val", "test"] + } + return super().on_fit_start() + + def finalize_step(self, step_prefix=""): + try: + self.log_dict( + { + f"{step_prefix}/{name}": metric.compute() + for name, metric in self.metrics[step_prefix].items() + if "_Curve" not in name + }, + sync_dist=True, + ) + for metric in self.metrics[step_prefix].values(): + metric.reset() + except (NotComputableError, ValueError): + if step_prefix not in self._metrics_warning_printed: + self._metrics_warning_printed.add(step_prefix) + logging.warning(f"Metrics for {step_prefix} not computable") + pass + + def configure_optimizers(self): + if isinstance(self.optimizer, str): + optimizer = create_optimizer(self.optimizer, self, self.hparams.lr, self.hparams.momentum) + else: + optimizer = self.optimizer(self.parameters()) - def set_log_dir(self, log_dir: Path): - self.log_dir = log_dir + if self.hparams.lr_scheduler is None or self.hparams.lr_scheduler == "": + return optimizer + scheduler = create_scheduler( + self.hparams.lr_scheduler, optimizer, self.hparams.lr_factor, self.hparams.lr_steps, self.hparams.epochs + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} - def set_scaler(self, scaler): - self.scaler = scaler + def on_test_epoch_start(self) -> None: + self.metrics = { + step_name: {metric_name: metric() for metric_name, metric in self.set_metrics().items()} + for step_name in ["train", "val", "test"] + } + return super().on_test_epoch_start() + + def save_model(self, save_path, file_name, file_extension=".ckpt"): + path = save_path / (file_name + file_extension) + try: + torch.save(self, path) + logging.info(f"Model saved to {str(path.resolve())}.") + except Exception as e: + logging.error(f"Cannot save model to path {str(path.resolve())}: {e}.") + + +@gin.configurable("DLPredictionWrapper") +class DLPredictionWrapper(DLWrapper): + """Interface for Deep Learning models.""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def set_weight(self, weight, dataset): + """Set the weight for the loss function.""" + + if isinstance(weight, list): + weight = torch.FloatTensor(weight).to(self.device) + elif weight == "balanced": + weight = torch.FloatTensor(dataset.get_balance()).to(self.device) + self.loss_weights = weight + + def set_metrics(self, *args): + """Set the evaluation metrics for the prediction model.""" - def set_metrics(self): def softmax_binary_output_transform(output): with torch.no_grad(): y_pred, y = output @@ -89,26 +208,26 @@ def softmax_multi_output_transform(output): y_pred = torch.softmax(y_pred, dim=1) return y_pred, y - # Binary classification - # output transform is not applied for contrib metrics so we do our own. - if self.encoder.logit.out_features == 2: - self.output_transform = softmax_binary_output_transform - self.metrics = DLMetrics.BINARY_CLASSIFICATION - + # Output transform is not applied for contrib metrics, so we do our own. + if self.run_mode == RunMode.classification: + # Binary classification + if self.logit.out_features == 2: + self.output_transform = softmax_binary_output_transform + metrics = DLMetrics.BINARY_CLASSIFICATION + else: + # Multiclass classification + self.output_transform = softmax_multi_output_transform + metrics = DLMetrics.MULTICLASS_CLASSIFICATION # Regression - elif self.encoder.logit.out_features == 1: + elif self.run_mode == RunMode.regression: self.output_transform = lambda x: x - if self.scaler is not None: - self.metrics = {"MAE": MAE(invert_transform=self.scaler.inverse_transform)} - else: - self.metrics = DLMetrics.REGRESSION - - # Multiclass classification + metrics = DLMetrics.REGRESSION else: - self.output_transform = softmax_multi_output_transform - self.metrics = DLMetrics.MULTICLASS_CLASSIFICATION + raise ValueError(f"Run mode {self.run_mode} not supported.") + return metrics - def step_fn(self, element, loss_weight=None): + def step_fn(self, element, step_prefix=""): + """Perform a step in the training loop.""" if len(element) == 2: data, labels = element[0], element[1].to(self.device) @@ -128,202 +247,62 @@ def step_fn(self, element, loss_weight=None): data = data.float().to(self.device) else: raise Exception("Loader should return either (data, label) or (data, label, mask)") - out = self.encoder(data) + out = self(data) if len(out) == 2 and isinstance(out, tuple): out, aux_loss = out else: aux_loss = 0 - out_flat = torch.masked_select(out, mask.unsqueeze(-1)).reshape(-1, out.shape[-1]) - label_flat = torch.masked_select(labels, mask) - if out_flat.shape[-1] > 1: - loss = self.loss(out_flat, label_flat.long(), weight=loss_weight) + aux_loss # torch.long because NLL + prediction = torch.masked_select(out, mask.unsqueeze(-1)).reshape(-1, out.shape[-1]).to(self.device) + target = torch.masked_select(labels, mask).to(self.device) + if prediction.shape[-1] > 1 and self.run_mode == RunMode.classification: + # Classification task + loss = self.loss(prediction, target.long(), weight=self.loss_weights.to(self.device)) + aux_loss + # torch.long because NLL + elif self.run_mode == RunMode.regression: + # Regression task + loss = self.loss(prediction[:, 0], target.float()) + aux_loss else: - loss = self.loss(out_flat[:, 0], label_flat.float()) + aux_loss # Regression task - - return loss, out_flat, label_flat - - def _do_training(self, train_loader, weight, metrics): - # Training epoch - self.encoder.train() - agg_train_loss = 0 - for elem in tqdm(train_loader, leave=False, disable=not self.verbose_logging): - loss, preds, target = self.step_fn(elem, weight) - loss.backward() - self.optimizer.step() - self.optimizer.zero_grad() - agg_train_loss += loss - for name, metric in metrics.items(): - metric.update(self.output_transform((preds, target))) - - train_metric_results = {} - for name, metric in metrics.items(): - train_metric_results[name] = metric.compute() - metric.reset() - train_loss = float(agg_train_loss / len(train_loader)) - return train_loss, train_metric_results - - @gin.configurable(module="DLWrapper") - def train( - self, - train_dataset, - val_dataset, - weight, - seed, - epochs=1000, - batch_size=64, - patience=10, - min_delta=1e-4, - ): - self.set_metrics() - metrics = self.metrics - - torch.autograd.set_detect_anomaly(True) # Check for any nans in gradients - - train_loader = DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=True, - num_workers=self.n_worker, - pin_memory=self.pin_memory, - prefetch_factor=2, - ) - val_loader = DataLoader( - val_dataset, - batch_size=batch_size, - shuffle=False, - num_workers=self.n_worker, - pin_memory=self.pin_memory, - prefetch_factor=2, - ) - - if isinstance(weight, list): - weight = torch.FloatTensor(weight).to(self.device) - elif weight == "balanced": - weight = torch.FloatTensor(train_dataset.get_balance()).to(self.device) - - best_loss = float("inf") - epoch_no_improvement = 0 - train_writer = SummaryWriter(self.log_dir / "tensorboard" / "train") - val_writer = SummaryWriter(self.log_dir / "tensorboard" / "val") - - table_header = ["EPOCH", "SPLIT", "METRICS", "COMMENT"] - widths = [5, 5, 25, 50] - log_table_row(table_header, widths=widths) - disable_tqdm = logging.getLogger().isEnabledFor(logging.INFO) - for epoch in trange(epochs, leave=False, disable=not self.verbose_logging or disable_tqdm): - # Train step - train_loss, train_metric_results = self._do_training(train_loader, weight, metrics) - - # Validation step - val_loss, val_metric_results = self.evaluate(val_loader, metrics, weight) - - # Early stopping - if val_loss <= best_loss - min_delta: - best_metrics = val_metric_results - epoch_no_improvement = 0 - self.save_weights(epoch, self.log_dir / "model.torch") - best_loss = val_loss - comment = "Validation loss improved to {:.4f} ".format(val_loss) - else: - epoch_no_improvement += 1 - comment = "No improvement on loss for {} epochs".format(epoch_no_improvement) - if epoch_no_improvement >= patience: - logging.info("No improvement on loss for more than {} epochs. We stop training".format(patience)) - break - - # Logging - test_metric_strings = [] - for name, value in train_metric_results.items(): - if isinstance(value, np.float): - test_metric_strings.append(f"{name}: {value:.4f}") - train_writer.add_scalar(name, value, epoch) - train_writer.add_scalar("Loss", train_loss, epoch) - - val_metric_strings = [] - for name, value in val_metric_results.items(): - if isinstance(value, np.float): - val_metric_strings.append(f"{name}: {value:.4f}") - val_writer.add_scalar(name, value, epoch) - val_writer.add_scalar("Loss", val_loss, epoch) - - log_table_row([epoch, "Train", ", ".join(test_metric_strings), ""], widths=widths) - log_table_row([epoch, "Val", ", ".join(val_metric_strings), comment], widths=widths) - - best_metrics["loss"] = best_loss - - with open(self.log_dir / "best_metrics.json", "w") as f: - json.dump(best_metrics, f, cls=JsonResultLoggingEncoder) - - self.load_weights(self.log_dir / "model.torch") # We load back the best iteration - - def test(self, dataset, weight, seed): - self.set_metrics() - test_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.n_worker, pin_memory=self.pin_memory) - if isinstance(weight, list): - weight = torch.FloatTensor(weight).to(self.device) - test_loss, test_metrics = self.evaluate(test_loader, self.metrics, weight) - - test_metrics["loss"] = test_loss - with open(self.log_dir / "test_metrics.json", "w") as f: - json.dump(test_metrics, f, cls=JsonResultLoggingEncoder) - - for key, value in test_metrics.items(): - if isinstance(value, float): - logging.info("Test {}: {}".format(key, value)) - - return test_loss + raise ValueError(f"Run mode {self.run_mode} not supported.") + transformed_output = self.output_transform((prediction, target)) + for metric in self.metrics[step_prefix].values(): + metric.update(transformed_output) + self.log(f"{step_prefix}/loss", loss, on_step=False, on_epoch=True, sync_dist=True) + return loss - def evaluate(self, eval_loader, metrics, weight): - self.encoder.eval() - agg_eval_loss = 0 - with torch.no_grad(): - for elem in eval_loader: - loss, preds, target = self.step_fn(elem, weight) - agg_eval_loss += loss - for name, metric in metrics.items(): - metric.update(self.output_transform((preds, target))) - - eval_metric_results = {} - for name, metric in metrics.items(): - eval_metric_results[name] = metric.compute() - metric.reset() - eval_loss = float(agg_eval_loss / len(eval_loader)) - return eval_loss, eval_metric_results - - def save_weights(self, epoch, save_path): - save_model(self.encoder, self.optimizer, epoch, save_path) - - def load_weights(self, load_path): - load_model_state(load_path, self.encoder, optimizer=self.optimizer) +@gin.configurable("MLWrapper") +class MLWrapper(BaseModule, ABC): + """Interface for prediction with traditional Scikit-learn-like Machine Learning models.""" + needs_training = False + needs_fit = True + _supported_run_modes = [RunMode.classification, RunMode.regression] -@gin.configurable("MLWrapper") -class MLWrapper(object): - def __init__(self, model=lightgbm.LGBMClassifier): - self.model = model + def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patience=10, **kwargs): + super().__init__() + self.save_hyperparameters() self.scaler = None - - def set_log_dir(self, log_dir: Path): - self.log_dir = log_dir + self.check_supported_runmode(run_mode) + self.run_mode = run_mode + self.loss = loss + self.patience = patience def set_metrics(self, labels): - - # Binary classification - if len(np.unique(labels)) == 2: - if isinstance(self.model, lightgbm.basic.Booster): + if self.run_mode == RunMode.classification: + # Binary classification + if len(np.unique(labels)) == 2: + # if isinstance(self.model, lightgbm.basic.Booster): self.output_transform = lambda x: x - else: - self.output_transform = lambda x: x[:, 1] - self.label_transform = lambda x: x - - self.metrics = MLMetrics.BINARY_CLASSIFICATION + # self.output_transform = lambda x: x[:, 1] + self.label_transform = lambda x: x - # Multiclass classification - elif np.all(labels[:10].astype(int) == labels[:10]): - self.output_transform = lambda x: np.argmax(x, axis=-1) - self.label_transform = lambda x: x - self.metrics = MLMetrics.MULTICLASS_CLASSIFICATION + self.metrics = MLMetrics.BINARY_CLASSIFICATION + # Multiclass classification + else: + # Todo: verify multiclass classification + self.output_transform = lambda x: np.argmax(x, axis=-1) + self.label_transform = lambda x: x + self.metrics = MLMetrics.MULTICLASS_CLASSIFICATION # Regression else: @@ -335,97 +314,188 @@ def set_metrics(self, labels): self.label_transform = lambda x: x self.metrics = MLMetrics.REGRESSION - def set_scaler(self, scaler): - self.scaler = scaler - - @gin.configurable(module="MLWrapper") - def train(self, train_dataset, val_dataset, weight, seed, patience=10): + def fit(self, train_dataset, val_dataset): + """Fit the model to the training data.""" train_rep, train_label = train_dataset.get_data_and_labels() val_rep, val_label = val_dataset.get_data_and_labels() + self.set_metrics(train_label) - metrics = self.metrics - - if "class_weight" in self.model.get_params().keys(): # Set class weights - self.model.set_params(class_weight=weight) - - if "eval_set" in inspect.getfullargspec(self.model.fit).args: # This is lightgbm - model_type = "lgbm" - self.model.set_params(random_state=seed) - self.model.fit( - train_rep, - train_label, - eval_set=(val_rep, val_label), - callbacks=[ - lightgbm.early_stopping(patience, verbose=False), - lightgbm.log_evaluation(period=-1, show_stdv=False), - ], - ) - val_loss = list(self.model.best_score_["valid_0"].values())[0] - else: - model_type = "sklearn" - self.model.fit(train_rep, train_label) - val_loss = 0.0 - if "MAE" in self.metrics.keys(): - val_pred = self.model.predict(val_rep) - train_pred = self.model.predict(train_rep) - else: - val_pred = self.model.predict_proba(val_rep) - train_pred = self.model.predict_proba(train_rep) - - train_metric_results = {} - train_string = "" - train_values = [] - val_string = "Val Results: loss: {:.4f}" - val_values = [val_loss] - val_metric_results = {"loss": val_loss} - for name, metric in metrics.items(): - train_metric_results[name] = metric(self.label_transform(train_label), self.output_transform(train_pred)) - val_metric_results[name] = metric(self.label_transform(val_label), self.output_transform(val_pred)) - if isinstance(train_metric_results[name], np.float): - train_string += "Train Results: " if len(train_string) == 0 else ", " - train_string += name + ":{:.4f}" - val_string += ", " + name + ":{:.4f}" - train_values.append(train_metric_results[name]) - val_values.append(val_metric_results[name]) - logging.info(train_string.format(*train_values)) - logging.info(val_string.format(*val_values)) - - model_file = "model.txt" if model_type == "lgbm" else "model.joblib" - self.save_weights(save_path=(self.log_dir / model_file), model_type=model_type) - with open(self.log_dir / "val_metrics.json", "w") as f: - json.dump(val_metric_results, f, cls=JsonResultLoggingEncoder) - - def test(self, dataset, weight, seed): - test_rep, test_label = dataset.get_data_and_labels() + # if "class_weight" in self.model.get_params().keys(): # Set class weights + # self.model.set_params(class_weight=self.weight) + + val_loss = self.fit_model(train_rep, train_label, val_rep, val_label) + + train_pred = self.predict(train_rep) + + logging.debug(f"Model:{self.model}") + self.log("train/loss", self.loss(train_label, train_pred), sync_dist=True) + logging.debug(f"Train loss: {self.loss(train_label, train_pred)}") + self.log("val/loss", val_loss, sync_dist=True) + logging.debug(f"Val loss: {val_loss}") + self.log_metrics(train_label, train_pred, "train") + + def fit_model(self, train_data, train_labels, val_data, val_labels): + """Fit the model to the training data (default SKlearn syntax)""" + self.model.fit(train_data, train_labels) + val_loss = 0.0 + return val_loss + + def validation_step(self, val_dataset, _): + val_rep, val_label = val_dataset.get_data_and_labels() + val_rep, val_label = torch.from_numpy(val_rep).to(self.device), torch.from_numpy(val_label).to(self.device) + self.set_metrics(val_label) + + val_pred = self.predict(val_rep) + + self.log_metrics("val/loss", self.loss(val_label, val_pred), sync_dist=True) + logging.info(f"Val loss: {self.loss(val_label, val_pred)}") + self.log_metrics(val_label, val_pred, "val") + + def test_step(self, dataset, _): + test_rep, test_label = dataset + test_rep, test_label = test_rep.squeeze().cpu().numpy(), test_label.squeeze().cpu().numpy() self.set_metrics(test_label) - if "MAE" in self.metrics.keys() or isinstance(self.model, lightgbm.basic.Booster): # If we reload a LGBM classifier - test_pred = self.model.predict(test_rep) + test_pred = self.predict(test_rep) + + self.log("test/loss", self.loss(test_label, test_pred), sync_dist=True) + logging.debug(f"Test loss: {self.loss(test_label, test_pred)}") + self.log_metrics(test_label, test_pred, "test") + + def predict(self, features): + if self.run_mode == RunMode.regression: + return self.model.predict(features) else: - test_pred = self.model.predict_proba(test_rep) + return self.model.predict(features) + + def log_metrics(self, label, pred, metric_type): + """Log metrics to the PL logs.""" + + self.log_dict( + { + f"{metric_type}/{name}": metric(self.label_transform(label), self.output_transform(pred)) + for name, metric in self.metrics.items() + # Filter out metrics that return a tuple (e.g. precision_recall_curve) + if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) + }, + sync_dist=True, + ) - test_metric_results = {} - for name, metric in self.metrics.items(): - value = metric(self.label_transform(test_label), self.output_transform(test_pred)) - test_metric_results[name] = value - # Only log float values - if isinstance(value, np.float): - logging.info("Test {}: {}".format(name, value)) + def configure_optimizers(self): + return None + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + del state["label_transform"] + del state["output_transform"] + return state + + def save_model(self, save_path, file_name, file_extension=".joblib"): + path = save_path / (file_name + file_extension) + try: + dump(self.model, path) + logging.info(f"Model saved to {str(path.resolve())}.") + except Exception as e: + logging.error(f"Cannot save model to path {str(path.resolve())}: {e}.") + + def set_model_args(self, model, *args, **kwargs): + """Set hyperparameters of the model if they are supported by the model.""" + signature = inspect.signature(model.__init__).parameters + possible_hps = list(signature.keys()) + # Get passed keyword arguments + arguments = locals()["kwargs"] + # Get valid hyperparameters + hyperparams = {key: value for key, value in arguments.items() if key in possible_hps} + logging.debug(f"Creating model with: {hyperparams}.") + return model(**hyperparams) + + +@gin.configurable("ImputationWrapper") +class ImputationWrapper(DLWrapper): + """Interface for imputation models.""" + + needs_training = True + needs_fit = False + _supported_run_modes = [RunMode.imputation] - with open(self.log_dir / "test_metrics.json", "w") as f: - json.dump(test_metric_results, f, cls=JsonResultLoggingEncoder) + def __init__( + self, + loss: _Loss = MSELoss(), + optimizer: Union[str, Optimizer] = "adam", + runmode: RunMode = RunMode.imputation, + lr: float = 0.002, + momentum: float = 0.9, + lr_scheduler: Optional[str] = None, + lr_factor: float = 0.99, + lr_steps: Optional[List[int]] = None, + input_size: torch.Tensor = None, + initialization_method: ImputationInit = ImputationInit.NORMAL, + **kwargs: str, + ) -> None: + super().__init__() + self.check_supported_runmode(runmode) + self.run_mode = runmode + self.save_hyperparameters(ignore=["loss", "optimizer"]) + self.loss = loss + self.optimizer = optimizer - return log_loss(test_label, test_pred) + def set_metrics(self): + return DLMetrics.IMPUTATION + + def init_weights(self, init_type="normal", gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): + if init_type == ImputationInit.NORMAL: + torch.nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == ImputationInit.XAVIER: + torch.nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == ImputationInit.KAIMING: + torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out") + elif init_type == ImputationInit.ORTHOGONAL: + torch.nn.init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError(f"Initialization method {init_type} is not implemented") + if hasattr(m, "bias") and m.bias is not None: + torch.nn.init.constant_(m.bias.data, 0.0) + elif classname.find("BatchNorm2d") != -1: + torch.nn.init.normal_(m.weight.data, 1.0, gain) + torch.nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + def on_fit_start(self) -> None: + self.init_weights(self.hparams.initialization_method) + for metrics in self.metrics.values(): + for metric in metrics.values(): + metric.reset() + return super().on_fit_start() - def save_weights(self, save_path, model_type="lgbm"): - if model_type == "lgbm": - self.model.booster_.save_model(save_path) - else: - joblib.dump(self.model, save_path) + def step_fn(self, batch, step_prefix=""): + amputated, amputation_mask, target, target_missingness = batch + imputated = self(amputated, amputation_mask) + amputated[amputation_mask > 0] = imputated[amputation_mask > 0] + amputated[target_missingness > 0] = target[target_missingness > 0] - def load_weights(self, load_path): - if load_path.suffix == ".txt": - self.model = lightgbm.Booster(model_file=load_path) - else: - with open(load_path, "rb") as f: - self.model = joblib.load(f) + loss = self.loss(amputated, target) + self.log(f"{step_prefix}/loss", loss.item(), prog_bar=True) + + for metric in self.metrics[step_prefix].values(): + metric.update( + (torch.flatten(amputated.detach(), start_dim=1).clone(), torch.flatten(target.detach(), start_dim=1).clone()) + ) + return loss + + def fit(self, train_dataset, val_dataset): + raise NotImplementedError() + + def predict_step(self, data, amputation_mask=None): + return self(data, amputation_mask) + + def predict(self, data): + self.eval() + data = data.to(self.device) + data_missingness = torch.isnan(data).to(torch.float32) + prediction = self.predict_step(data, data_missingness) + data[data_missingness.bool()] = prediction[data_missingness.bool()] + return data diff --git a/icu_benchmarks/run.py b/icu_benchmarks/run.py index 5d870b27..0f8b7859 100644 --- a/icu_benchmarks/run.py +++ b/icu_benchmarks/run.py @@ -7,6 +7,9 @@ from pathlib import Path import importlib.util +import torch.cuda + +from icu_benchmarks.wandb_utils import update_wandb_config, apply_wandb_sweep, set_wandb_run_name from icu_benchmarks.tuning.hyperparameters import choose_and_bind_hyperparameters from scripts.plotting.utils import plot_aggregated_results from icu_benchmarks.cross_validation import execute_repeated_cv @@ -15,32 +18,89 @@ create_run_dir, aggregate_results, log_full_line, + load_pretrained_imputation_model, + setup_logging, ) +from icu_benchmarks.contants import RunMode + + +@gin.configurable("Run") +def get_mode(mode: gin.REQUIRED): + # Check if enum is mode. + assert RunMode(mode) + return RunMode(mode) def main(my_args=tuple(sys.argv[1:])): - args = build_parser().parse_args(my_args) + args, _ = build_parser().parse_known_args(my_args) - log_fmt = "%(asctime)s - %(levelname)s: %(message)s" - logging.basicConfig(format=log_fmt) - logging.getLogger().setLevel(logging.INFO) - # Whether to enable verbose logging. If disabled reduces log output desired for running compute cluster jobs. - gin.bind_parameter("DLWrapper.verbose_logging", args.verbose) + # Set arguments for wandb sweep + if args.wandb_sweep: + args = apply_wandb_sweep(args) + # Initialize loggers + log_format = "%(asctime)s - %(levelname)s - %(name)s : %(message)s" + date_format = "%Y-%m-%d %H:%M:%S" + verbose = args.verbose + setup_logging(date_format, log_format, verbose) + + # Load weights if in evaluation mode load_weights = args.command == "evaluate" + data_dir = Path(args.data_dir) + + # Get arguments name = args.name task = args.task model = args.model + reproducible = args.reproducible + + # Set experiment name + if name is None: + name = data_dir.name + logging.info(f"Running experiment {name}.") + + # Load task config + gin.parse_config_file(f"configs/tasks/{task}.gin") + + mode = get_mode() + + if args.wandb_sweep: + run_name = f"{mode}_{model}_{name}" + set_wandb_run_name(run_name) + + logging.info(f"Task mode: {mode}.") experiment = args.experiment + + pretrained_imputation_model = load_pretrained_imputation_model(args.pretrained_imputation) + + # Log imputation model to wandb + update_wandb_config( + { + "pretrained_imputation_model": pretrained_imputation_model.__class__.__name__ + if pretrained_imputation_model is not None + else "None" + } + ) source_dir = None - # todo:check if this is correct - reproducible = False log_dir_name = args.log_dir / name - log_dir = (log_dir_name / experiment) if experiment else (log_dir_name / args.task_name / model) - train_on_cpu = args.cpu + log_dir = ( + (log_dir_name / experiment) + if experiment + else (log_dir_name / (args.task_name if args.task_name is not None else args.task) / model) + ) + if torch.cuda.is_available(): + for name in range(0, torch.cuda.device_count()): + log_full_line(f"Available GPU {name}: {torch.cuda.get_device_name(name)}", level=logging.INFO) + else: + log_full_line( + "No GPUs available: please check your device and Torch,Cuda installation if unintended.", level=logging.WARNING + ) + + log_full_line(f"Logging to {log_dir}.", logging.INFO) if args.preprocessor: # Import custom supplied preprocessor + log_full_line(f"Importing custom preprocessor from {args.preprocessor}.", logging.INFO) try: spec = importlib.util.spec_from_file_location("CustomPreprocessor", args.preprocessor) module = importlib.util.module_from_spec(spec) @@ -49,61 +109,58 @@ def main(my_args=tuple(sys.argv[1:])): gin.bind_parameter("preprocess.preprocessor", module.CustomPreprocessor) except Exception as e: logging.error(f"Could not import custom preprocessor from {args.preprocessor}: {e}") - else: - from icu_benchmarks.data.preprocessor import DefaultPreprocessor as preprocessor - if train_on_cpu: - gin.bind_parameter("DLWrapper.device", "cpu") if load_weights: # Evaluate log_dir /= f"from_{args.source_name}" run_dir = create_run_dir(log_dir) source_dir = args.source_dir gin.parse_config_file(source_dir / "train_config.gin") - if gin.query_parameter("train_common.model").selector == "DLWrapper": - # Calculate input dimensions for deep learning models based on preprocessing operations - gin.bind_parameter("model/hyperparameter.input_dim", preprocessor().calculate_input_dim()) - else: # Train - reproducible = args.reproducible checkpoint = log_dir / args.checkpoint if args.checkpoint else None + model_path = ( + Path("configs") / ("imputation_models" if mode == RunMode.imputation else "prediction_models") / f"{model}.gin" + ) gin_config_files = ( [Path(f"configs/experiments/{args.experiment}.gin")] if args.experiment - else [Path(f"configs/models/{model}.gin"), Path(f"configs/tasks/{task}.gin")] + else [model_path, Path(f"configs/tasks/{task}.gin")] ) gin.parse_config_files_and_bindings(gin_config_files, args.hyperparams, finalize_config=False) - - if gin.query_parameter("train_common.model").selector == "DLWrapper": - # Calculate input dimensions for deep learning models based on preprocessing operations - gin.bind_parameter("model/hyperparameter.input_dim", preprocessor().calculate_input_dim()) - + log_full_line(f"Data directory: {data_dir.resolve()}", level=logging.INFO) run_dir = create_run_dir(log_dir) choose_and_bind_hyperparameters( args.tune, - args.data_dir, + data_dir, run_dir, args.seed, + run_mode=mode, checkpoint=checkpoint, debug=args.debug, generate_cache=args.generate_cache, load_cache=args.load_cache, + verbose=args.verbose, ) - logging.info(f"Logging to {run_dir.resolve()}") + log_full_line(f"Logging to {run_dir.resolve()}", level=logging.INFO) log_full_line("STARTING TRAINING", level=logging.INFO, char="=", num_newlines=3) start_time = datetime.now() execute_repeated_cv( - args.data_dir, + data_dir, run_dir, args.seed, load_weights=load_weights, source_dir=source_dir, reproducible=reproducible, debug=args.debug, + verbose=args.verbose, load_cache=args.load_cache, generate_cache=args.generate_cache, + mode=mode, + pretrained_imputation_model=pretrained_imputation_model, + cpu=args.cpu, + wandb=args.wandb_sweep, ) log_full_line("FINISHED TRAINING", level=logging.INFO, char="=", num_newlines=3) diff --git a/icu_benchmarks/run_utils.py b/icu_benchmarks/run_utils.py index 6ea569c7..179ab1f5 100644 --- a/icu_benchmarks/run_utils.py +++ b/icu_benchmarks/run_utils.py @@ -1,3 +1,7 @@ +import warnings +from math import sqrt + +import torch import json from argparse import ArgumentParser, BooleanOptionalAction from datetime import datetime, timedelta @@ -5,8 +9,9 @@ from pathlib import Path import scipy.stats as stats import shutil -from statistics import mean, stdev +from statistics import mean, pstdev from icu_benchmarks.models.utils import JsonResultLoggingEncoder +from icu_benchmarks.wandb_utils import wandb_log def build_parser() -> ArgumentParser: @@ -23,37 +28,62 @@ def build_parser() -> ArgumentParser: # ARGUMENTS FOR ALL COMMANDS general_args = parent_parser.add_argument_group("General arguments") general_args.add_argument("-d", "--data-dir", required=True, type=Path, help="Path to the parquet data directory.") - general_args.add_argument("-n", "--name", required=True, help="Name of the (target) dataset.") - general_args.add_argument("-t", "--task", default="BinaryClassification", help="Name of the task gin.") - general_args.add_argument("-tn", "--task-name", help="Name of the task.") - general_args.add_argument("-m", "--model", default="LGBMClassifier", help="Name of the model gin.") - general_args.add_argument("-e", "--experiment", help="Name of the experiment gin.") - general_args.add_argument("-l", "--log-dir", required=True, type=Path, help="Log directory with model weights.") - general_args.add_argument("-s", "--seed", default=1111, type=int, help="Random seed for processing, tuning and training.") + general_args.add_argument("-t", "--task", default="BinaryClassification", required=True, help="Name of the task gin.") + general_args.add_argument("-n", "--name", required=False, help="Name of the (target) dataset.") + general_args.add_argument("-tn", "--task-name", required=False, help="Name of the task, used for naming experiments.") + general_args.add_argument("-m", "--model", default="LGBMClassifier", required=False, help="Name of the model gin.") + general_args.add_argument("-e", "--experiment", required=False, help="Name of the experiment gin.") + general_args.add_argument( + "-l", "--log-dir", required=False, default=Path("../yaib_logs/"), type=Path, help="Log directory with model weights." + ) + general_args.add_argument( + "-s", "--seed", required=False, default=1234, type=int, help="Random seed for processing, tuning and training." + ) general_args.add_argument( "-v", "--verbose", - default=True, + default=False, + required=False, action=BooleanOptionalAction, help="Whether to use verbose logging. Disable for clean logs.", ) - general_args.add_argument("--cpu", default=False, action=BooleanOptionalAction, help="Set to use CPU.") - general_args.add_argument("-db", "--debug", default=False, action=BooleanOptionalAction, help="Set to load less data.") + general_args.add_argument("--cpu", default=False, required=False, action=BooleanOptionalAction, help="Set to use CPU.") general_args.add_argument( - "-lc", "--load_cache", default=False, action=BooleanOptionalAction, help="Set to load generated data cache." + "-db", "--debug", required=False, default=False, action=BooleanOptionalAction, help="Set to load less data." ) general_args.add_argument( - "-gc", "--generate_cache", default=False, action=BooleanOptionalAction, help="Set to generate data cache." + "-lc", + "--load_cache", + required=False, + default=False, + action=BooleanOptionalAction, + help="Set to load generated data cache.", + ) + general_args.add_argument( + "-gc", + "--generate_cache", + required=False, + default=False, + action=BooleanOptionalAction, + help="Set to generate data cache.", + ) + general_args.add_argument("-p", "--preprocessor", required=False, type=Path, help="Load custom preprocessor from file.") + general_args.add_argument("-pl", "--plot", required=False, action=BooleanOptionalAction, help="Generate common plots.") + general_args.add_argument( + "-wd", "--wandb-sweep", required=False, action="store_true", help="Activates wandb hyper parameter sweep." + ) + general_args.add_argument( + "-imp", "--pretrained-imputation", required=False, type=str, help="Path to pretrained imputation model." ) - general_args.add_argument("-p", "--preprocessor", type=Path, help="Load custom preprocessor from file.") - general_args.add_argument("-pl", "--plot", default=False, action=BooleanOptionalAction, help="Generate common plots.") # MODEL TRAINING ARGUMENTS prep_and_train = subparsers.add_parser("train", help="Preprocess features and train model.", parents=[parent_parser]) - prep_and_train.add_argument("--reproducible", default=True, action=BooleanOptionalAction, help="Make torch reproducible.") - prep_and_train.add_argument("-hp", "--hyperparams", nargs="+", help="Hyperparameters for model.") + prep_and_train.add_argument( + "--reproducible", required=False, default=True, action=BooleanOptionalAction, help="Make torch reproducible." + ) + prep_and_train.add_argument("-hp", "--hyperparams", required=False, nargs="+", help="Hyperparameters for model.") prep_and_train.add_argument("--tune", default=False, action=BooleanOptionalAction, help="Find best hyperparameters.") - prep_and_train.add_argument("--checkpoint", type=Path, help="Use previous checkpoint.") + prep_and_train.add_argument("--checkpoint", required=False, type=Path, help="Use previous checkpoint.") # EVALUATION PARSER evaluate = subparsers.add_parser("evaluate", help="Evaluate trained model on data.", parents=[parent_parser]) @@ -83,7 +113,7 @@ def create_run_dir(log_dir: Path, randomly_searched_params: str = None) -> Path: return log_dir_run -def aggregate_results(log_dir: Path, execution_time: timedelta = -1): +def aggregate_results(log_dir: Path, execution_time: timedelta = None): """Aggregates results from all folds and writes to JSON file. Args: @@ -95,10 +125,15 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = -1): if repetition.is_dir(): aggregated[repetition.name] = {} for fold_iter in repetition.iterdir(): + aggregated[repetition.name][fold_iter.name] = {} if (fold_iter / "test_metrics.json").is_file(): with open(fold_iter / "test_metrics.json", "r") as f: result = json.load(f) - aggregated[repetition.name][fold_iter.name] = result + aggregated[repetition.name][fold_iter.name].update(result) + elif (fold_iter / "val_metrics.csv").is_file(): + with open(fold_iter / "val_metrics.csv", "r") as f: + result = json.load(f) + aggregated[repetition.name][fold_iter.name].update(result) # Add durations to metrics if (fold_iter / "durations.json").is_file(): with open(fold_iter / "durations.json", "r") as f: @@ -116,7 +151,11 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = -1): # Compute statistical metric over aggregated results averaged_scores = {metric: (mean(list)) for metric, list in list_scores.items()} - std_scores = {metric: (stdev(list)) for metric, list in list_scores.items()} + + # Calculate the population standard deviation over aggregated results over folds/iterations + # Divide by sqrt(n) to get standard deviation. + std_scores = {metric: (pstdev(list) / sqrt(len(list))) for metric, list in list_scores.items()} + confidence_interval = { metric: (stats.t.interval(0.95, len(list) - 1, loc=mean(list), scale=stats.sem(list))) for metric, list in list_scores.items() @@ -126,7 +165,7 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = -1): "avg": averaged_scores, "std": std_scores, "CI_0.95": confidence_interval, - "execution_time": execution_time, + "execution_time": execution_time.total_seconds() if execution_time is not None else 0.0, } with open(log_dir / "aggregated_test_metrics.json", "w") as f: @@ -137,6 +176,8 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = -1): logging.info(f"Accumulated results: {accumulated_metrics}") + wandb_log(json.loads(json.dumps(accumulated_metrics, cls=JsonResultLoggingEncoder))) + def log_full_line(msg: str, level: int = logging.INFO, char: str = "-", num_newlines: int = 0): """Logs a full line of a given character with a message centered. @@ -153,3 +194,61 @@ def log_full_line(msg: str, level: int = logging.INFO, char: str = "-", num_newl level, "{0:{char}^{width}}{1}".format(msg, "\n" * num_newlines, char=char, width=terminal_size.columns - reserved_chars), ) + + +def load_pretrained_imputation_model(use_pretrained_imputation): + """Loads a pretrained imputation model. + + Args: + use_pretrained_imputation: Path to the pretrained imputation model. + """ + if use_pretrained_imputation is not None and not Path(use_pretrained_imputation).exists(): + logging.warning("The specified pretrained imputation model does not exist.") + use_pretrained_imputation = None + + if use_pretrained_imputation is not None: + logging.info("Using pretrained imputation from" + str(use_pretrained_imputation)) + pretrained_imputation_model_checkpoint = torch.load(use_pretrained_imputation, map_location=torch.device("cpu")) + if isinstance(pretrained_imputation_model_checkpoint, dict): + imputation_model_class = pretrained_imputation_model_checkpoint["class"] + pretrained_imputation_model = imputation_model_class(**pretrained_imputation_model_checkpoint["hyper_parameters"]) + pretrained_imputation_model.set_trained_columns(pretrained_imputation_model_checkpoint["trained_columns"]) + pretrained_imputation_model.load_state_dict(pretrained_imputation_model_checkpoint["state_dict"]) + else: + pretrained_imputation_model = pretrained_imputation_model_checkpoint + pretrained_imputation_model = pretrained_imputation_model.to("cuda" if torch.cuda.is_available() else "cpu") + try: + logging.info(f"imputation model device: {next(pretrained_imputation_model.parameters()).device}") + pretrained_imputation_model.device = next(pretrained_imputation_model.parameters()).device + except Exception as e: + logging.debug(f"Could not set device of imputation model: {e}") + else: + pretrained_imputation_model = None + + return pretrained_imputation_model + + +def setup_logging(date_format, log_format, verbose): + """ + Set up all loggers to use the same format and date format. + + Args: + date_format: Format for the date. + log_format: Format for the log. + verbose: Whether to log debug messages. + """ + logging.basicConfig(format=log_format, datefmt=date_format) + loggers = ["pytorch_lightning", "lightning_fabric"] + for logger in loggers: + logging.getLogger(logger).handlers[0].setFormatter(logging.Formatter(log_format, datefmt=date_format)) + + if not verbose: + logging.getLogger().setLevel(logging.INFO) + for logger in loggers: + logging.getLogger(logger).setLevel(logging.INFO) + warnings.filterwarnings("ignore") + else: + logging.getLogger().setLevel(logging.DEBUG) + for logger in loggers: + logging.getLogger(logger).setLevel(logging.DEBUG) + warnings.filterwarnings("default") diff --git a/icu_benchmarks/tuning/gin_utils.py b/icu_benchmarks/tuning/gin_utils.py index b43cb493..c88d3a0f 100644 --- a/icu_benchmarks/tuning/gin_utils.py +++ b/icu_benchmarks/tuning/gin_utils.py @@ -1,5 +1,5 @@ import logging - +from ..wandb_utils import wandb_log import gin @@ -39,6 +39,8 @@ def bind_gin_params(hyperparams_names: list[str], hyperparams_values: list): hyperparams_names: List of hyperparameter names. hyperparams_values: List of hyperparameter values. """ + logging.info("Binding Hyperparameters:") for param, value in zip(hyperparams_names, hyperparams_values): gin.bind_parameter(param, value) logging.info(f"{param} = {value}") + wandb_log({param: value}) diff --git a/icu_benchmarks/tuning/hyperparameters.py b/icu_benchmarks/tuning/hyperparameters.py index 3fbcd186..212a7753 100644 --- a/icu_benchmarks/tuning/hyperparameters.py +++ b/icu_benchmarks/tuning/hyperparameters.py @@ -1,7 +1,7 @@ import json import gin import logging -from logging import INFO, NOTSET +from logging import NOTSET import numpy as np from pathlib import Path from skopt import gp_minimize @@ -11,6 +11,8 @@ from icu_benchmarks.cross_validation import execute_repeated_cv from icu_benchmarks.run_utils import log_full_line from icu_benchmarks.tuning.gin_utils import get_gin_hyperparameters, bind_gin_params +from icu_benchmarks.contants import RunMode +from icu_benchmarks.wandb_utils import wandb_log TUNE = 25 logging.addLevelName(25, "TUNE") @@ -22,36 +24,48 @@ def choose_and_bind_hyperparameters( data_dir: Path, log_dir: Path, seed: int, + run_mode: RunMode = RunMode.classification, checkpoint: str = None, - scopes: list[str] = gin.REQUIRED, + scopes: list[str] = [], n_initial_points: int = 3, n_calls: int = 20, - folds_to_tune_on: int = gin.REQUIRED, + folds_to_tune_on: int = None, checkpoint_file: str = "hyperparameter_tuning_logs.json", generate_cache: bool = False, load_cache: bool = False, debug: bool = False, + verbose: bool = False, + wandb: bool = False, ): """Choose hyperparameters to tune and bind them to gin. Args: + wandb: Whether we use wandb or not. + load_cache: Load cached data if available. + generate_cache: Generate cache data. do_tune: Whether to tune hyperparameters or not. data_dir: Path to the data directory. log_dir: Path to the log directory. seed: Random seed. + run_mode: The run mode of the experiment. checkpoint: Name of the checkpoint run to load previously explored hyperparameters from. scopes: List of gin scopes to search for hyperparameters to tune. n_initial_points: Number of initial points to explore. n_calls: Number of iterations to optimize the hyperparameters. folds_to_tune_on: Number of folds to tune on. checkpoint_file: Name of the checkpoint file. - debug: Whether to load less data and enable more logging. + debug: Whether to load less data. + verbose: Set to true to increase log output. Raises: ValueError: If checkpoint is not None and the checkpoint does not exist. """ hyperparams = {} + if len(scopes) == 0 or folds_to_tune_on is None: + logging.warning("No scopes and/or folds to tune on, skipping tuning.") + return + # Collect hyperparameters. hyperparams_bounds, hyperparams_names = collect_bound_hyperparameters(hyperparams, scopes) @@ -79,9 +93,9 @@ def choose_and_bind_hyperparameters( else: logging.warning("No checkpoint file found, starting from scratch.") - # Function to + # Function that trains the model with the given hyperparameters. def bind_params_and_train(hyperparams): - with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.TemporaryDirectory(dir=log_dir) as temp_dir: bind_gin_params(hyperparams_names, hyperparams) if not do_tune: return 0 @@ -89,12 +103,15 @@ def bind_params_and_train(hyperparams): data_dir, Path(temp_dir), seed, + mode=run_mode, cv_repetitions_to_train=1, cv_folds_to_train=folds_to_tune_on, generate_cache=generate_cache, load_cache=load_cache, test_on="val", debug=debug, + verbose=verbose, + wandb=wandb, ) header = ["ITERATION"] + hyperparams_names + ["LOSS AT ITERATION"] @@ -108,11 +125,17 @@ def tune_step_callback(res): f.write(json.dumps(data, cls=JsonResultLoggingEncoder)) table_cells = [len(res.x_iters)] + res.x_iters[-1] + [res.func_vals[-1]] highlight = res.x_iters[-1] == res.x # highlight if best so far + log_table_row(header, TUNE) log_table_row(table_cells, TUNE, align=Align.RIGHT, header=header, highlight=highlight) + wandb_log({"hp-iteration": len(res.x_iters)}) if do_tune: log_full_line("STARTING TUNING", level=TUNE, char="=") - logging.log(TUNE, f"Tuning from {n_initial_points} points in {n_calls} iterations on {folds_to_tune_on} folds.") + logging.log( + TUNE, + f"Applying Bayesian Optimization from {n_initial_points} points in {n_calls} " + f"iterations on {folds_to_tune_on} folds.", + ) log_table_row(header, TUNE) else: logging.log(TUNE, "Hyperparameter tuning disabled") @@ -124,8 +147,6 @@ def tune_step_callback(res): logging.log(TUNE, "Choosing hyperparameters randomly from bounds.") n_initial_points = 1 n_calls = 1 - if not debug: - logging.disable(level=INFO) # Call gaussian process. To choose a random set of hyperparameters this functions is also called. res = gp_minimize( diff --git a/icu_benchmarks/wandb_utils.py b/icu_benchmarks/wandb_utils.py new file mode 100644 index 00000000..be8bac7d --- /dev/null +++ b/icu_benchmarks/wandb_utils.py @@ -0,0 +1,61 @@ +from argparse import Namespace +import logging +import wandb + + +def wandb_running() -> bool: + """Check if wandb is running.""" + return wandb.run is not None + + +def update_wandb_config(config: dict) -> None: + """updates wandb config if wandb is running + + Args: + config (dict): config to set + """ + logging.debug(f"Updating Wandb config: {config}") + if wandb_running(): + wandb.config.update(config) + + +def apply_wandb_sweep(args: Namespace) -> Namespace: + """applies the wandb sweep configuration to the namespace object + + Args: + args (Namespace): parsed arguments + + Returns: + Namespace: arguments with sweep configuration applied (some are applied via hyperparams) + """ + wandb.init() + sweep_config = wandb.config + args.__dict__.update(sweep_config) + if args.hyperparams is None: + args.hyperparams = [] + for key, value in sweep_config.items(): + args.hyperparams.append(f"{key}=" + (("'" + value + "'") if isinstance(value, str) else str(value))) + logging.info(f"hyperparams after loading sweep config: {args.hyperparams}") + return args + + +def wandb_log(log_dict): + """logs metrics to wandb + + Args: + log_dict (dict): metric dict to log + """ + if wandb_running(): + wandb.log(log_dict) + + +def set_wandb_run_name(run_name): + """stores the run name in wandb config + + Args: + run_name (str): name of the current run + """ + if wandb_running(): + wandb.config.update({"run-name": run_name}) + wandb.run.name = run_name + wandb.run.save() diff --git a/scripts/sweep_configs/01_experiment_1.yaml b/scripts/sweep_configs/01_experiment_1.yaml new file mode 100644 index 00000000..919e9dd5 --- /dev/null +++ b/scripts/sweep_configs/01_experiment_1.yaml @@ -0,0 +1,24 @@ +program: icu-benchmarks +command: + - ${env} + - ${program} + - "train" + - "-d" + - "../data/miiv" + - "-t" + - "DatasetImputation" + - "-c" + - "--wandb-sweep" + - "--wandb" +method: grid +parameters: + model: + values: ["NP", "CSDI", "Mean", "Median", "Zero", "MostFrequent", "KNN", "MICE", "GAIN", "BRITS", "SAITS", "Attention", "MissForest", "MLP", "BRNN", "RNN", "Simple_Diffusion"] + ImputationDataset.mask_method: + values: ["MCAR", "MAR", "MNAR"] + seed: + values: [2222] + ImputationDataset.mask_proportion: + values: [0.3, 0.5, 0.7] + data_dir: + values: ["../data/miiv", "../data/eicu", "../data/hirid"] diff --git a/scripts/sweep_configs/hyperparameter_sweeps/attention_sweep.yml b/scripts/sweep_configs/hyperparameter_sweeps/attention_sweep.yml new file mode 100644 index 00000000..c7b3d2c5 --- /dev/null +++ b/scripts/sweep_configs/hyperparameter_sweeps/attention_sweep.yml @@ -0,0 +1,48 @@ +program: icu-benchmarks +command: + - ${env} + - ${program} + - "train" + - "-d" + - "../data/miiv" + - "-t" + - "DatasetImputation" + - "-c" + - "--wandb-sweep" + - "--wandb" +method: bayes +metric: + name: avg.jsd + goal: minimize +name: Attention hyperparameter sweep +parameters: + model: + values: ["Attention"] # DiffWave not working yet + ImputationDataset.mask_method: + values: ["MCAR"] + seed: + values: [1111] + data_dir: + values: ["../data/hirid"] + execute_repeated_cv.cv_repetitions: + values: [2] + execute_repeated_cv.cv_folds: + values: [2] + ImputationWrapper.lr_scheduler: + values: ["cosine", ""] + Adam.lr: + values: [0.001, 0.01, 0.1] + Attention.n_layers: + values: [2, 4, 6] + Attention.d_model: + values: [64, 128, 256] + Attention.d_inner: + values: [64, 128, 256] + Attention.n_head: + values: [4, 8, 16] + Attention.d_k: + values: [32, 64, 128] + Attention.d_v: + values: [32, 64, 128] + Attention.dropout: + values: [0.0, 0.1, 0.3] diff --git a/scripts/sweep_configs/hyperparameter_sweeps/brnn_sweep.yml b/scripts/sweep_configs/hyperparameter_sweeps/brnn_sweep.yml new file mode 100644 index 00000000..e072cfb6 --- /dev/null +++ b/scripts/sweep_configs/hyperparameter_sweeps/brnn_sweep.yml @@ -0,0 +1,42 @@ +program: icu-benchmarks +command: + - ${env} + - ${program} + - "train" + - "-d" + - "../data/miiv" + - "-t" + - "DatasetImputation" + - "-c" + - "--wandb-sweep" + - "--wandb" +method: bayes +metric: + name: avg.jsd + goal: minimize +name: BRNN hyperparameter sweep +parameters: + model: + values: ["BRNN"] # DiffWave not working yet + ImputationDataset.mask_method: + values: ["MCAR"] + seed: + values: [1111] + data_dir: + values: ["../data/hirid"] + execute_repeated_cv.cv_repetitions: + values: [2] + execute_repeated_cv.cv_folds: + values: [2] + ImputationWrapper.lr_scheduler: + values: ["cosine", ""] + Adam.lr: + values: [0.001, 0.01, 0.1] + BRNN.cell: + values: ['gru', "lstm"] + BRNN.hidden_size: + values: [32, 64, 128] + BRNN.state_init: + values: ['zero', "noise"] + BRNN.dropout: + values: [0.0, 0.1, 0.3] \ No newline at end of file diff --git a/scripts/sweep_configs/hyperparameter_sweeps/csdi_sweep.yml b/scripts/sweep_configs/hyperparameter_sweeps/csdi_sweep.yml new file mode 100644 index 00000000..e143867b --- /dev/null +++ b/scripts/sweep_configs/hyperparameter_sweeps/csdi_sweep.yml @@ -0,0 +1,48 @@ +program: icu-benchmarks +command: + - ${env} + - ${program} + - "train" + - "-d" + - "../data/miiv" + - "-t" + - "DatasetImputation" + - "-c" + - "--wandb-sweep" + - "--wandb" +method: bayes +metric: + name: avg.jsd + goal: minimize +name: csdi hyperparameter sweep +parameters: + model: + values: ["CSDI"] # DiffWave not working yet + ImputationDataset.mask_method: + values: ["MCAR"] + seed: + values: [1111] + data_dir: + values: ["../data/hirid"] + execute_repeated_cv.cv_repetitions: + values: [2] + execute_repeated_cv.cv_folds: + values: [2] + CSDI.time_step_embedding_size: + values: [64, 128, 256] + CSDI.feature_embedding_size: + values: [8, 16, 32, 64] + CSDI.target_strategy: + values: ["random", "hist", "mix"] + CSDI.num_diffusion_steps: + values: [50, 100, 200] + CSDI.diffusion_step_embedding_dim: + values: [64, 128, 256] + CSDI.n_attention_heads: + values: [6, 8, 10] + CSDI.num_residual_layers: + values: [2, 4, 6, 8] + CSDI.noise_schedule: + values: ["quad", "linear"] + CSDI.n_samples: + values: [5, 10, 15] diff --git a/scripts/sweep_configs/hyperparameter_sweeps/np_sweep.yml b/scripts/sweep_configs/hyperparameter_sweeps/np_sweep.yml new file mode 100644 index 00000000..b8142d60 --- /dev/null +++ b/scripts/sweep_configs/hyperparameter_sweeps/np_sweep.yml @@ -0,0 +1,48 @@ +program: icu-benchmarks +command: + - ${env} + - ${program} + - "train" + - "-d" + - "../data/miiv" + - "-t" + - "DatasetImputation" + - "-c" + - "--wandb-sweep" + - "--wandb" +method: bayes +metric: + name: avg.jsd + goal: minimize +name: np hyperparameter sweep +parameters: + model: + values: ["NP"] # DiffWave not working yet + ImputationDataset.mask_method: + values: ["MCAR"] + seed: + values: [1111] + data_dir: + values: ["../data/hirid"] + execute_repeated_cv.cv_repetitions: + values: [2] + execute_repeated_cv.cv_folds: + values: [2] + + ImputationWrapper.lr_scheduler: + values: ["cosine", ""] + Adam.lr: + values: [0.001, 0.01, 0.1] + + NP.encoder_layers: + values: [3, 6, 12] + NP.encoder_h_dim: + values: [24, 36, 72] + NP.decoder_layers: + values: [3, 6, 12] + NP.decoder_h_dim: + values: [24, 36, 72] + NP.r_dim: + values: [3, 6, 12] + NP.z_dim: + values: [3, 6, 12] \ No newline at end of file diff --git a/scripts/sweep_configs/hyperparameter_sweeps/saits_sweep.yml b/scripts/sweep_configs/hyperparameter_sweeps/saits_sweep.yml new file mode 100644 index 00000000..ab9adb4c --- /dev/null +++ b/scripts/sweep_configs/hyperparameter_sweeps/saits_sweep.yml @@ -0,0 +1,48 @@ +program: icu-benchmarks +command: + - ${env} + - ${program} + - "train" + - "-d" + - "../data/miiv" + - "-t" + - "DatasetImputation" + - "-c" + - "--wandb-sweep" + - "--wandb" +method: bayes +metric: + name: avg.jsd + goal: minimize +name: SAITS hyperparameter sweep +parameters: + model: + values: ["SAITS"] # DiffWave not working yet + ImputationDataset.mask_method: + values: ["MCAR"] + seed: + values: [1111] + data_dir: + values: ["../data/hirid"] + execute_repeated_cv.cv_repetitions: + values: [2] + execute_repeated_cv.cv_folds: + values: [2] + ImputationWrapper.lr_scheduler: + values: ["cosine", ""] + Adam.lr: + values: [0.001, 0.01, 0.1] + SAITS.n_layers: + values: [2, 4, 6] + SAITS.d_model: + values: [64, 128, 256] + SAITS.d_inner: + values: [64, 128, 256] + SAITS.n_head: + values: [4, 8, 16] + SAITS.d_k: + values: [32, 64, 128] + SAITS.d_v: + values: [32, 64, 128] + SAITS.dropout: + values: [0.0, 0.1, 0.3] \ No newline at end of file diff --git a/scripts/sweep_configs/hyperparameter_sweeps/ssds4_sweep.yml b/scripts/sweep_configs/hyperparameter_sweeps/ssds4_sweep.yml new file mode 100644 index 00000000..60768d25 --- /dev/null +++ b/scripts/sweep_configs/hyperparameter_sweeps/ssds4_sweep.yml @@ -0,0 +1,59 @@ +program: icu-benchmarks +command: + - ${env} + - ${program} + - "train" + - "-d" + - "../data/miiv" + - "-t" + - "DatasetImputation" + - "-c" + - "--wandb-sweep" + - "--wandb" +method: bayes +metric: + name: avg.jsd + goal: minimize +name: SSSDS4 hyperparameter sweep +parameters: + model: + values: ["SSSDS4"] # DiffWave not working yet + ImputationDataset.mask_method: + values: ["MCAR"] + seed: + values: [1111] + data_dir: + values: ["../data/hirid"] + execute_repeated_cv.cv_repetitions: + values: [2] + execute_repeated_cv.cv_folds: + values: [2] + ImputationWrapper.lr_scheduler: + values: ["cosine", ""] + Adam.lr: + values: [0.001, 0.01, 0.1] + + SSSDS4.res_channels: + values: [64, 128, 256] + SSSDS4.skip_channels: + values: [64, 128, 256] + SSSDS4.num_res_layers: + values: [12, 24, 36] + SSSDS4.diffusion_step_embed_dim_in: + values: [64, 128, 256] + SSSDS4.diffusion_step_embed_dim_mid: + values: [128, 256, 512] + SSSDS4.diffusion_step_embed_dim_out: + values: [128, 256, 512] + SSSDS4.s4_lmax: + values: [50, 100, 200] + SSSDS4.s4_d_state: + values: [32, 64, 128] + SSSDS4.s4_dropout: + values: [0.0, 0.1, 0.3] + SSSDS4.s4_bidirectional: + values: [True, False] + SSSDS4.s4_layernorm: + values: [True, False] + SSSDS4.diffusion_time_steps: + values: [500, 1000, 2000] diff --git a/setup.cfg b/setup.cfg index 1e319f3e..d7ef5fa7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,7 +15,8 @@ replace = __version__ = '{new_version}' universal = 1 [flake8] -exclude = docs +exclude = docs, venv*, tests +max-line-length = 120 [aliases] # Define setup.py command aliases here diff --git a/setup.py b/setup.py index 88cea3ce..491fa2b9 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,49 @@ # -*- coding: utf-8 -*- """The setup script.""" - +from pathlib import Path from setuptools import setup, find_packages + +root_path = Path(__file__).resolve().parent + + +def parse_environment_yml(): + """Parse the environment.yml file and extract the package names.""" + # here we cannot use pyyaml because it is not installed yet + with open(root_path / "environment.yml") as f: + lines = f.readlines() + + lines = [line.strip() for line in lines] + dependencies = [] + inside_dependencies = False + for entry in lines: + if entry == "dependencies:": + inside_dependencies = True + continue + if inside_dependencies: + if not entry.startswith("-"): + break + dependency_name = entry.strip().split(" ")[-1] + if dependency_name != "pip:" and "python=" not in dependency_name: + dependencies.append(dependency_name) + + sanitized_dependencies = [] + for dependency in dependencies: + # conda package ignite is named pytorch-ignite on pypi + if "ignite" in dependency: + dependency = "pytorch-" + dependency + if dependency.startswith("pytorch="): + dependency = dependency.replace("pytorch", "torch") + if "=" in dependency and "==" not in dependency: + dependency = "==".join(dependency.split("=")) + if "http://" in dependency or "https://" in dependency: + package_name = dependency.split("/")[-1].split(".")[0] + dependency = package_name + "@" + dependency + sanitized_dependencies.append(dependency) + return sanitized_dependencies + + with open("README.md") as readme_file: readme = readme_file.read() @@ -23,7 +63,8 @@ description="Yet Another ICU Benchmark is a holistic framework for the automation of clinical prediction models " "on ICU data. Users can create custom datasets, cohorts, prediction tasks, endpoints, and models.", entry_points={"console_scripts": ["icu-benchmarks = icu_benchmarks.run:main"]}, - install_requires=[], # dependencies managed via conda for the moment + install_requires=parse_environment_yml(), + extras_require={"mps": ["mkl < 2022"]}, license="MIT license", long_description=readme, include_package_data=True, @@ -34,6 +75,6 @@ test_suite="tests", tests_require=[], url="https://github.com/rvandewater/YAIB", - version="0.1.0", + version="0.2.0", zip_safe=False, )