|
1 | 1 | # BrainSynth [](https://zenodo.org/doi/10.5281/zenodo.10014960)
|
2 | 2 | Official implementation of "Realistic Morphology-preserving Generative Modelling of the Brain"
|
3 | 3 |
|
4 |
| -# Code will soon be uploaded |
| 4 | +# Pretrained models |
| 5 | + |
| 6 | +Experiments scripts have been provided in the experiments folder and are named based on the models they have trained. |
| 7 | + |
| 8 | +Bellow you have some toy examples and how to use pretrained models. |
| 9 | + |
| 10 | +To use the pretrained models you need to do the following: |
| 11 | + |
| 12 | +0) Create a docker container based on the Dockerfile and requirements file found in the dcoker folder |
| 13 | +1) Create a folder similar with the following structure where you replace 'experiment_name' with the name of your experiment and you chose either baseline_vqvae or performer depending on which weights you want to use: |
| 14 | +``` |
| 15 | +<<experiment_name>> |
| 16 | +├── baseline_vqvae/performer |
| 17 | + ├── checkpoints |
| 18 | + ├── logs |
| 19 | + └── outputs |
| 20 | +``` |
| 21 | +2) Download the weights of the desired model from here (weights are being uploaded) and put it the checkpoints folder: |
| 22 | +3) Rename the file to 'checkpoint_epoch=0.pt' |
| 23 | +4) Use the corresponding script from the examples bellow and remember to: |
| 24 | +* Replace the training/validation subjects with paths towards either folder filled with .nii.gz files or towards csv/tsv files that have a path column with the full paths towards the files. |
| 25 | +* Replace the conditioning files with the correct one for the transformer training. |
| 26 | +* Replace the project_directory with the path were you created the folder from point 1 |
| 27 | +* Replace the experiment_name with the name of the experiment you created from point 1 |
| 28 | +5) Properly mount the paths towards the files and results folders and launch your docker container |
| 29 | +6) Use the appropriate script for the model from bellow and change the mode to the desired one |
| 30 | + |
| 31 | +# VQ-VAE |
| 32 | + |
| 33 | +To extract the quantized latent representations of the images you need to run the same command as you used for training and replace the `--mode=Training` parameter with `--mode=extracting`. For decoding, you need to replace it with `--mode=decoding`. |
| 34 | + |
| 35 | +Training script example for VQ-VAE. |
| 36 | +```bash |
| 37 | +python /project/run_vqvae.py run \ |
| 38 | + --training_subjects="/path/to/training/data/tsv/" \ |
| 39 | + --validation_subjects="/path/to/validation/data/tsv/" \ |
| 40 | + --load_nii_canonical=False \ |
| 41 | + --project_directory="/results/" \ |
| 42 | + --experiment_name="example_run" \ |
| 43 | + --mode='training' \ |
| 44 | + --device='ddp' \ |
| 45 | + --distributed_port=29500 \ |
| 46 | + --amp=True \ |
| 47 | + --deterministic=False \ |
| 48 | + --cuda_benchmark=True \ |
| 49 | + --seed=4 \ |
| 50 | + --epochs=500 \ |
| 51 | + --learning_rate=0.000165 \ |
| 52 | + --gamma=0.99999 \ |
| 53 | + --log_every=1 \ |
| 54 | + --checkpoint_every=1 \ |
| 55 | + --eval_every=1 \ |
| 56 | + --loss='jukebox_perceptual' \ |
| 57 | + --adversarial_component=True \ |
| 58 | + --discriminator_network='baseline_discriminator' \ |
| 59 | + --discriminator_learning_rate=5e-05 \ |
| 60 | + --discriminator_loss='least_square' \ |
| 61 | + --generator_loss='least_square' \ |
| 62 | + --initial_factor_value=0 \ |
| 63 | + --initial_factor_steps=25 \ |
| 64 | + --max_factor_steps=50 \ |
| 65 | + --max_factor_value=5 \ |
| 66 | + --batch_size=8 \ |
| 67 | + --normalize=True \ |
| 68 | + --roi='((16,176), (16,240),(96,256))' \ |
| 69 | + --eval_batch_size=8 \ |
| 70 | + --num_workers=8 \ |
| 71 | + --prefetch_factor=8 \ |
| 72 | + --starting_epoch=172 \ |
| 73 | + --network='baseline_vqvae' \ |
| 74 | + --use_subpixel_conv=False \ |
| 75 | + --use_slim_residual=True \ |
| 76 | + --no_levels=4 \ |
| 77 | + --downsample_parameters='((4,2,1,1),(4,2,1,1),(4,2,1,1),(4,2,1,1))' \ |
| 78 | + --upsample_parameters='((4,2,1,0,1),(4,2,1,0,1),(4,2,1,0,1),(4,2,1,0,1))' \ |
| 79 | + --no_res_layers=3 \ |
| 80 | + --no_channels=256 \ |
| 81 | + --codebook_type='ema' \ |
| 82 | + --num_embeddings='(2048,)' \ |
| 83 | + --embedding_dim='(32,)' \ |
| 84 | + --decay='(0.5,)' \ |
| 85 | + --commitment_cost='(0.25,)' \ |
| 86 | + --max_decay_epochs=100 \ |
| 87 | + --dropout=0.0 \ |
| 88 | + --act='RELU' |
| 89 | +``` |
| 90 | + |
| 91 | +# Transformer |
| 92 | + |
| 93 | +To sample new images from the trained model you need to run the same command as you used for training and replace the `--mode=training` parameter with `--mode=inference`. |
| 94 | + |
| 95 | +Best performance was found by equalising normalised continuous conditioning variables. |
| 96 | + |
| 97 | +Training script example for Transformer based on the UKB one. |
| 98 | +```bash |
| 99 | +python3 /project/run_transformer.py run \ |
| 100 | + --training_subjects='/path/to/training/data/tsv/' \ |
| 101 | + --validation_subjects='/path/to/validation/data/tsv/' \ |
| 102 | + --conditioning_path='/path/to/continuous/equalised/tsv/' \ |
| 103 | + --conditionings='(\"used\", \"conditioning\", \"columns\")' \ |
| 104 | + --project_directory='/results/' \ |
| 105 | + --experiment_name='example_run' \ |
| 106 | + --mode='training' \ |
| 107 | + --deterministic=False \ |
| 108 | + --cuda_benchmark=False \ |
| 109 | + --cuda_enable=True \ |
| 110 | + --use_zero=True \ |
| 111 | + --device='ddp' \ |
| 112 | + --seed=4 \ |
| 113 | + --epochs=500 \ |
| 114 | + --learning_rate=0.0005 \ |
| 115 | + --gamma='auto' \ |
| 116 | + --log_every=1 \ |
| 117 | + --checkpoint_every=1 \ |
| 118 | + --eval_every=0 \ |
| 119 | + --weighted_sampling=True \ |
| 120 | + --batch_size=2 \ |
| 121 | + --eval_batch_size=2 \ |
| 122 | + --num_workers=16 \ |
| 123 | + --prefetch_factor=16 \ |
| 124 | + --vqvae_checkpoint='/path/to/vqvae/checkpoint/' \ |
| 125 | + --vqvae_aug_conditionings='none' \ |
| 126 | + --vqvae_aug_load_nii_canonical=False \ |
| 127 | + --vqvae_aug_augmentation_probability=0.00 \ |
| 128 | + --vqvae_aug_augmentation_strength=0.0 \ |
| 129 | + --vqvae_aug_normalize=True \ |
| 130 | + --vqvae_aug_roi='((16,176), (16,240),(96,256))' \ |
| 131 | + --vqvae_network='baseline_vqvae' \ |
| 132 | + --vqvae_net_level=0 \ |
| 133 | + --vqvae_net_use_subpixel_conv=False \ |
| 134 | + --vqvae_net_use_slim_residual=True \ |
| 135 | + --vqvae_net_no_levels=4 \ |
| 136 | + --vqvae_net_downsample_parameters='((4,2,1,1),(4,2,1,1),(4,2,1,1),(4,2,1,1))' \ |
| 137 | + --vqvae_net_upsample_parameters='((4,2,1,0,1),(4,2,1,0,1),(4,2,1,0,1),(4,2,1,0,1))' \ |
| 138 | + --vqvae_net_no_res_layers=3 \ |
| 139 | + --vqvae_net_no_channels=256 \ |
| 140 | + --vqvae_net_codebook_type='ema' \ |
| 141 | + --vqvae_net_num_embeddings='(2048,)' \ |
| 142 | + --vqvae_net_embedding_dim='(32,)' \ |
| 143 | + --vqvae_net_embedding_init='(\"normal\",)' \ |
| 144 | + --vqvae_net_commitment_cost='(0.25, )' \ |
| 145 | + --vqvae_net_decay='(0.5,)' \ |
| 146 | + --vqvae_net_dropout=0.0 \ |
| 147 | + --vqvae_net_act='RELU'\ |
| 148 | + --starting_epoch=0 \ |
| 149 | + --ordering_type='raster_scan' \ |
| 150 | + --transpositions_axes='((2, 0, 1),)' \ |
| 151 | + --rot90_axes='((0, 1),)' \ |
| 152 | + --transformation_order='(\"rotate_90\", \"transpose\")' \ |
| 153 | + --network='xtransformer' \ |
| 154 | + --vocab_size=2048 \ |
| 155 | + --n_embd=1024 \ |
| 156 | + --n_layers=36 \ |
| 157 | + --n_head=16 \ |
| 158 | + --tie_embedding=False \ |
| 159 | + --ff_glu=False \ |
| 160 | + --emb_dropout=0.001 \ |
| 161 | + --ff_dropout=0.001 \ |
| 162 | + --attn_dropout=0.001 \ |
| 163 | + --use_rezero=False \ |
| 164 | + --position_emb='rotary' \ |
| 165 | + --conditioning_type='cross_attend' \ |
| 166 | + --use_continuous_conditioning='(True, True, True, True)' \ |
| 167 | + --local_attn_heads=8 \ |
| 168 | + --local_window_size=420 \ |
| 169 | + --feature_redraw_interval=1 \ |
| 170 | + --generalized_attention=False \ |
| 171 | + --use_rmsnorm=True \ |
| 172 | + --attn_talking_heads=False \ |
| 173 | + --attn_on_attn=False \ |
| 174 | + --attn_gate_values=True \ |
| 175 | + --sandwich_norm=False \ |
| 176 | + --rel_pos_bias=False \ |
| 177 | + --use_qk_norm_attn=False \ |
| 178 | + --spatial_rel_pos_bias=True \ |
| 179 | + --bucket_values=False \ |
| 180 | + --shift_mem_down=1 |
| 181 | +``` |
| 182 | + |
| 183 | +# Acknowledgements |
| 184 | + |
| 185 | +Work done through the collaboration between NVIDIA and KCL. |
| 186 | + |
| 187 | +The models in this work were trained on [NVIDIA Cambridge-1](https://www.nvidia.com/en-us/industries/healthcare-life-sciences/cambridge-1/), the UK’s largest supercomputer, aimed at accelerating digital biology. |
| 188 | + |
| 189 | +# Funding |
| 190 | +- Jointly with UCL - Wellcome Flagship Programme (WT213038/Z/18/Z) |
| 191 | +- Wellcome/EPSRC Centre for Medical Engineering (WT203148/Z/16/Z) |
| 192 | +- EPSRC Research Council DTP (EP/R513064/1) |
| 193 | +- The London AI Center for Value-Based Healthcare |
| 194 | +- GE Healthcare |
| 195 | +- Intramural Research Program of the NIMH (ZIC-MH002960 and ZIC-MH002968). |
| 196 | +- European Union’s HORIZON 2020 Research |
| 197 | +- Innovation Programme under the Marie Sklodowska-Curie Grant Agreement No 814302 |
| 198 | +- UCLH NIHR Biomedical Research Centre. |
0 commit comments