Multiflow is a protein sequence and structure generative model based on our preprint: Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design.
Our codebase is developed on top of FrameFlow. The sequence generative model is adpated from Discrete Flow Models (DFM).
If you use this codebase, then please cite
@article{campbell2024generative,
title={Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design},
author={Campbell, Andrew and Yim, Jason and Barzilay, Regina and Rainforth, Tom and Jaakkola, Tommi},
journal={arXiv preprint arXiv:2402.04997},
year={2024}
}
Note
This codebase is very fresh. We expect there to be bugs and issues with other systems and environments. Please create a github issue or pull request and we will attempt to help.
LICENSE: MIT
We recommend using mamba.
If using mamba then use mamba
in place of conda
.
# Install environment with dependencies.
conda env create -f multiflow.yml
# Activate environment
conda activate multiflow
# Install local package.
# Current directory should have setup.py.
pip install -e .
Next you need to install torch-scatter manually depending on your torch version. (Unfortunately torch-scatter has some oddity that it can't be installed with the environment.) We use torch 2.0.1 and cuda 11.7 so we install the following
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu117.html
If you use a different torch then that can be found with the following.
# Find your installed version of torch
python
>>> import torch
>>> torch.__version__
# Example: torch 2.0.1+cu117
Warning
You will likely run into the follow error from DeepSpeed
ModuleNotFoundError: No module named 'torch._six'
If so, replace from torch._six import inf
with from torch import inf
.
/path/to/envs/site-packages/deepspeed/runtime/utils.py
/path/to/envs/site-packages/deepspeed/runtime/zero/stage_1_and_2.py
where /path/to/envs
is replaced with your path. We would appreciate a pull request to avoid this monkey patch!
Our training relies on logging with wandb. Log in to Wandb and make an account. Authorize Wandb here.
We host the datasets on Zenodo here. Download the following files,
real_train_set.tar.gz
(2.5 GB)synthetic_train_set.tar.gz
(220 MB)test_set.tar.gz
(347 MB) Next, untar the files
# Uncompress training data
mkdir train_set
tar -xzvf real_train_set.tar.gz -C train_set/
tar -xzvf synthetic_train_set.tar.gz -C train_set/
# Uncompress test data
mkdir test_set
tar -xzvf test_set.tar.gz -C test_set/
The resulting directory structure should look like
<current_dir>
├── train_set
│ ├── processed_pdb
| | ├── <subdir>
| | | └── <protein_id>.pkl
│ ├── processed_synthetic
| | └── <protein_id>.pkl
├── test_set
| └── processed
| | ├── <subdir>
| | | └── <protein_id>.pkl
...
Our experiments read the data by using relative paths. Keep the directory structure like this to avoid bugs.
The command to run co-design training is the following,
python -W ignore multiflow/experiments/train_se3_flows.py -cn pdb_codesign
We use Hydra to maintain our configs.
The training config is found here multiflow/configs/pdb_codesign.yaml
.
Most important fields:
experiment.num_devices
: Number of GPUs to use for training. Default is 2.data.sampler.max_batch_size
: Maximum batch size. We use dynamic batch sizes depending ondata.sampler.max_num_res_squared
. Both these parameters need to be tuned for your GPU memory. Our default settings are set for a 40GB Nvidia RTX card.data.sampler.max_num_res_squared
: See above.
We provide pre-trained model weights at this Zenodo link.
Run the following to unpack the weights
tar -xzvf weights.tar.gz
The following three tasks can be performed.
# Unconditional Co-Design
python -W ignore multiflow/experiments/inference_se3_flows.py -cn inference_unconditional
# Inverse Folding
python -W ignore multiflow/experiments/inference_se3_flows.py -cn inference_inverse_folding
# Forward Folding
python -W ignore multiflow/experiments/inference_se3_flows.py -cn inference_forward_folding
Config locations:
- configs/inference_unconditional.yaml: unconditional sampling config.
- configs/inference_inverse_folding.yaml: inverse folding config.
- configs/inference_forward_folding.yaml: forward folding config.
Most important fields:
-
inference.num_gpus: Number of GPUs to use. I typically use 2 or 4.
-
inference.{...}_ckpt_path: Checkpoint path for hallucination.
-
inference.interpolant.sampling.num_timesteps: Number of steps in the flow.
-
inference.folding.folding_model:
esmf
for ESMFold andaf2
for AlphaFold2.
[Only for hallucination]
- inference.samples.samples_per_length: Number of samples per length.
- inference.samples.min_length: Start of length range to sample.
- inference.samples.max_length: End of length range to sample.
- inference.samples.length_subset: Subset of lengths to sample. Will override min_length and max_length.