Code and report for my semester project on using rotational and translational equivariant graph neural networks to predict cardiac arrest from 3 dimensional reconstructed arteries. For more information about the experiments, please check out my write up!
The way I recommend navigating this repo is to:
- Look at the write up
- Look at the code structure below
- Look at the experiments results on wandb platform. For more information on how to navigate that, see the appendix of the write-up.
- If you want to run your own results. You will need an MI-proj/data folder, containing a patient_dict.pickle file as described below in datasets.py description, and a data folder, for example CoordToCnc with your mesh data.
- Generate your experiments by running python main_cross_val with appropriate hyperparameters in hyper_params.yaml.
Name | Description ---------------------------------------------------------------------------------------------------------- create_data.py | Data fetching and preprocessing. This should be run from the MI-proj directory. | The executed function is at the bottom of the file, note that our dataset is not | public, so you won't have access to the path and label_path directories. data_augmentation.py | Contains all the data augmentation schemes attempted. Used in create_data.py. datasets.py | Contains our custom DataSet object which is how we store the meshes. Also contains | custom split_data function which does the train, validation, and test splits at | the patient level. Note that you will need a file "MI-proj/data/patient_dict.pickle" | containing the dictionary with patients as keys and artery name list as value. hyper_params.yaml | File containing all hyperparameters of a given model. Used in evaluate.py and | main_cross_val.py. If you plan on using it for evaluate.py, there should be one | value per hyperparameter. main_cross_val.py | Runs a grid search with cross validation on all combinations of hyperparameters in | hyper_params.yaml. All experiments are recorded on the wandb platform. Make sure to | change and remmember MODEL_TYPE to be able to retrieve the experiment on the wandb | platform! This does not use the test set. This should be called from inside the | MI-proj/experiments directory. evaluate. py | Same as cross_validation, but evaluates the model on test set once it has finished | training. This should be run with only one value per hyperparameter in | hyper_params.yaml. It is crucial to use the same seed here as used when doing the | grid search. Also records all results on the wandb platform. This should be called | from inside the MI-proj/experiments directory. gnnexplainer.ipynb | Coming soon! Jupyter notebook for the GNNExplainer experiment and visualization.
GNNExplainer.py | Slightly modified code from the paper of [1]. | Code was obtained from the repo of [1]. egnn.py | Slightly modified code from the paper of [2]. | Code was obtained from the repo of [2]. models.py | Contains all different models used in experiments. train.py | Contains a custom GNN object definition. Main script used for training and | evaluating our models.
[1] paper: GNNExplainer: Generating Explanations for Graph Neural Networks, repo: GNNExplainer.
[2] paper: E(n) Equivariant Graph Neural Networks, repo: egnn.