This is the official codebase of the paper
Multi-Scale Representation Learning for Protein Fitness Prediction, NeurIPS'2024
[ArXiv] [OpenReview]
Zuobai Zhang*, Pascal Notin*, Yining Huang, Aurelie Lozano, Vijil Chenthamarakshan, Debora Marks, Payel Das, Jian Tang
Sequence-Structure-Surface Fitness Model (S3F) is a novel multimodal representation learning framework that integrates protein features across several scales. The model is pre-trained on the CATH dataset and evaluated by zero-shot protein fitness on ProteinGym. The summary of our results on ProteinGym can be found here. The datasets and model checkpoints for this project can be downloaded from this link.
This codebase is based on PyTorch and TorchDrug (TorchProtein).
You may install the dependencies via either conda or pip. Generally, GearNet works with Python 3.7/3.8 and PyTorch version >= 1.8.0.
conda install torchdrug pytorch=1.8.0 cudatoolkit=11.1 -c milagraph -c pytorch-lts -c pyg -c conda-forge
conda install easydict pyyaml -c conda-forge
pip install torch==1.8.0+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install torchdrug
pip install easydict pyyaml
To evaluate on ProteinGym benchmark, you need to first download datasets from the official ProteinGym website.
# Download ProteinGym benchmark
wget https://marks.hms.harvard.edu/proteingym/DMS_ProteinGym_substitutions.zip -O ./dataset
unzip DMS_ProteinGym_substitutions.zip -d ./dataset/DMS_ProteinGym_substitutions
wget https://marks.hms.harvard.edu/proteingym/ProteinGym_AF2_structures.zip -O ./dataset
unzip ProteinGym_AF2_structures.zip -d ./dataset/ProteinGym_AF2_structures
Then, running S3F requires to first generate the surfaces based on ProteinGym structures. You can choose either to download our pre-processed version of surface graphs at link or process surface graphs by yourself.
# Download the processed surface graphs
wget https://zenodo.org/records/14257708/files/processed_surface_proteingym.zip -P ./dataset
unzip ./dataset/processed_surface_proteingym.zip -d ./dataset
# Or process surface graphs by yourself
python script/preload_dataset.py -i ./dataset/ProteinGym_AF2_structures/ -o ./dataset/processed_proteingym/
python script/process_surface.py -i ./dataset/processed_proteingym/ -o ./dataset/processed_surface_gym/
As the model is based on the ESM-2-650M model, you need to first download the ESM model checkpoint.
mkdir -p ~/scratch/protein-model-weights/esm-model-weights/
wget https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt -P ~/scratch/protein-model-weights/esm-model-weights/
There is a task.model.sequence_model.path
argument in each config file to control where to automatically download ESM model weights. Please modify this to your customized path to the esm model weights.
We prodive both S2F and S3F model checkpoints for evaluation. You can download these checkpoints from link and then run the following commands for evaluation. Right now, we only support single-gpu evaluation, which takes around 20 hours for all 217 assays on one A100 GPU. The output files can be found at ~/scratch/proteingym_output
, which is specified by the output_dir
argument in the *.yaml
.
# Run evaluation for S2F
python script/evaluate.py -c config/evaluate/s2f.yaml --datadir ./dataset/DMS_ProteinGym_substitutions --structdir ./dataset/ProteinGym_AF2_structures --ckpt <path_to_ckpt>
# Run evaluation for S3F
python script/evaluate.py -c config/evaluate/s3f.yaml --datadir ./dataset/DMS_ProteinGym_substitutions --structdir ./dataset/ProteinGym_AF2_structures --surfdir ./dataset/processed_surface_gym/ --ckpt <path_to_ckpt>
We also provide the code for pre-training on the CATH dataset with residue type prediction objective. To get started, you need to download a raw version from this repo and preload it with our script.
# Download raw cath dataset and process it
wget https://huggingface.co/datasets/tyang816/cath/blob/main/dompdb.tar -P ./dataset
tar -xvf dompdb.tar -C ./dataset
python script/preload_dataset.py -i ./dataset/dompdb/ -o ./dataset/processed_cath/
Then, to pre-train S3F, you need to first process the surface graphs for the protein. However, as the files are too large (over 150G), we cannot provide the processed version. You need to run the following command to re-process the surface graphs by yourself.
# Process surface graphs (require one GPU to compute)
python script/process_surface.py -i ./dataset/processed_cath/ -o ./dataset/processed_surface_cath/
To pre-train the S2F or S3F models, please run the following commands. Here we use 4 A100 GPUs for training.
# Pre-train S2F model
python -m torch.distributed.launch --nproc_per_node=4 script/pretrain.py -c config/pretrain/s2f.yaml --datadir ./dataset/processed_cath
# Pre-train S3F model
python -m torch.distributed.launch --nproc_per_node=4 script/pretrain.py -c config/pretrain/s3f.yaml --datadir ./dataset/processed_cath --surfdir ./dataset/processed_surface_cath
To customize your pre-training setting, you need to adapt config/pretrain/s3f.yaml
for your setting. You can set engine.gpus
as the devices you want to use and set engine.batch_size
as the batch size per gpu.
If you find this codebase useful in your research, please cite the following papers.
@inproceedings{zhang2024s3f,
title={Multi-Scale Representation Learning for Protein Fitness Prediction},
author={Zhang, Zuobai and Notin, Pascal and Huang, Yining and Lozano, Aurelie and Chenthamarakshan, Vijil and Marks, Debora and Das, Payel and Tang, Jian},
booktitle={Advances in Neural Information Processing Systems},
year={2024}
}