This project implements a distributed training pipeline for learning representations over spatial cell graphs using PyTorch, PyTorch Geometric, and Distributed Data Parallel (DDP).
conda env create -f environment.yml
conda activate HEIST
- Place your dataset in the
data/pretraining/
directory. - The data should be preprocessed and saved as PyTorch objects using
torch.save
.
Run the the full HEIST model on all the available GPUs, use the following command to start distributed training:
python main_ddp.py --data_dir data/pretraining/ --pe --cross_message_passing
If you want to train only on single GPU, change main_ddp
with main
:
python main.py --data_dir data/pretraining/ --pe --cross_message_passing
Argument | Description | Default |
---|---|---|
--data_dir |
Path to preprocessed data | data/pretraining/ |
--pe_dim |
Positional Encoding Dim | 128 |
--init_dim |
Initial MLP Hidden Dim | 128 |
--hidden_dim |
Hidden Dimension | 128 |
--output_dim |
Output Dimension | 128 |
--num_layers |
Number of MLP Layers | 10 |
--num_heads |
Number of Transformer Heads | 8 |
--batch_size |
Batch Size | 128 |
--lr |
Learning Rate | 1e-3 |
--wd |
Weight Decay | 3e-3 |
--num_epochs |
Number of Training Epochs | 20 |
- Checkpoints are saved automatically under the
saved_models/
directory. - The best model is saved as
HEIST.pth
.
If you want to resume from a saved checkpoint, ensure that the model and optimizer state dictionaries are correctly loaded in the script.
After training, you can evaluate the model using the provided evaluation script.
bash eval.sh
This script will:
-
Initialize and activate the Conda environment
HEIST
. -
Run a series of evaluations across multiple datasets and tasks:
-
Representation Space Calculation
dfci
,upmc
,charville
,sea
,melanoma
,placenta
,lung
-
Tissue Classification
- Predict primary outcomes and recurrence for clinical datasets.
-
Melanoma and Cell Clustering Evaluations
-
Placenta Dataset Analysis
-
Gene Imputation Tasks
- Both standard and fine-tuned versions.
-
- Evaluation results will be stored in the corresponding directories or logged to the console.
- Ensure that the trained model checkpoint is available and named correctly (default:
HEIST
).