Skip to content

Commit 6a46390

Browse files
committed
Publication
1 parent cb32977 commit 6a46390

File tree

137 files changed

+55862
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

137 files changed

+55862
-1
lines changed

README.md

+195-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,198 @@
11
# BrainSynth [![DOI](https://zenodo.org/badge/706209780.svg)](https://zenodo.org/doi/10.5281/zenodo.10014960)
22
Official implementation of "Realistic Morphology-preserving Generative Modelling of the Brain"
33

4-
# Code will soon be uploaded
4+
# Pretrained models
5+
6+
Experiments scripts have been provided in the experiments folder and are named based on the models they have trained.
7+
8+
Bellow you have some toy examples and how to use pretrained models.
9+
10+
To use the pretrained models you need to do the following:
11+
12+
0) Create a docker container based on the Dockerfile and requirements file found in the dcoker folder
13+
1) Create a folder similar with the following structure where you replace 'experiment_name' with the name of your experiment and you chose either baseline_vqvae or performer depending on which weights you want to use:
14+
```
15+
<<experiment_name>>
16+
├── baseline_vqvae/performer
17+
├── checkpoints
18+
├── logs
19+
└── outputs
20+
```
21+
2) Download the weights of the desired model from here (weights are being uploaded) and put it the checkpoints folder:
22+
3) Rename the file to 'checkpoint_epoch=0.pt'
23+
4) Use the corresponding script from the examples bellow and remember to:
24+
* Replace the training/validation subjects with paths towards either folder filled with .nii.gz files or towards csv/tsv files that have a path column with the full paths towards the files.
25+
* Replace the conditioning files with the correct one for the transformer training.
26+
* Replace the project_directory with the path were you created the folder from point 1
27+
* Replace the experiment_name with the name of the experiment you created from point 1
28+
5) Properly mount the paths towards the files and results folders and launch your docker container
29+
6) Use the appropriate script for the model from bellow and change the mode to the desired one
30+
31+
# VQ-VAE
32+
33+
To extract the quantized latent representations of the images you need to run the same command as you used for training and replace the `--mode=Training` parameter with `--mode=extracting`. For decoding, you need to replace it with `--mode=decoding`.
34+
35+
Training script example for VQ-VAE.
36+
```bash
37+
python /project/run_vqvae.py run \
38+
--training_subjects="/path/to/training/data/tsv/" \
39+
--validation_subjects="/path/to/validation/data/tsv/" \
40+
--load_nii_canonical=False \
41+
--project_directory="/results/" \
42+
--experiment_name="example_run" \
43+
--mode='training' \
44+
--device='ddp' \
45+
--distributed_port=29500 \
46+
--amp=True \
47+
--deterministic=False \
48+
--cuda_benchmark=True \
49+
--seed=4 \
50+
--epochs=500 \
51+
--learning_rate=0.000165 \
52+
--gamma=0.99999 \
53+
--log_every=1 \
54+
--checkpoint_every=1 \
55+
--eval_every=1 \
56+
--loss='jukebox_perceptual' \
57+
--adversarial_component=True \
58+
--discriminator_network='baseline_discriminator' \
59+
--discriminator_learning_rate=5e-05 \
60+
--discriminator_loss='least_square' \
61+
--generator_loss='least_square' \
62+
--initial_factor_value=0 \
63+
--initial_factor_steps=25 \
64+
--max_factor_steps=50 \
65+
--max_factor_value=5 \
66+
--batch_size=8 \
67+
--normalize=True \
68+
--roi='((16,176), (16,240),(96,256))' \
69+
--eval_batch_size=8 \
70+
--num_workers=8 \
71+
--prefetch_factor=8 \
72+
--starting_epoch=172 \
73+
--network='baseline_vqvae' \
74+
--use_subpixel_conv=False \
75+
--use_slim_residual=True \
76+
--no_levels=4 \
77+
--downsample_parameters='((4,2,1,1),(4,2,1,1),(4,2,1,1),(4,2,1,1))' \
78+
--upsample_parameters='((4,2,1,0,1),(4,2,1,0,1),(4,2,1,0,1),(4,2,1,0,1))' \
79+
--no_res_layers=3 \
80+
--no_channels=256 \
81+
--codebook_type='ema' \
82+
--num_embeddings='(2048,)' \
83+
--embedding_dim='(32,)' \
84+
--decay='(0.5,)' \
85+
--commitment_cost='(0.25,)' \
86+
--max_decay_epochs=100 \
87+
--dropout=0.0 \
88+
--act='RELU'
89+
```
90+
91+
# Transformer
92+
93+
To sample new images from the trained model you need to run the same command as you used for training and replace the `--mode=training` parameter with `--mode=inference`.
94+
95+
Best performance was found by equalising normalised continuous conditioning variables.
96+
97+
Training script example for Transformer based on the UKB one.
98+
```bash
99+
python3 /project/run_transformer.py run \
100+
--training_subjects='/path/to/training/data/tsv/' \
101+
--validation_subjects='/path/to/validation/data/tsv/' \
102+
--conditioning_path='/path/to/continuous/equalised/tsv/' \
103+
--conditionings='(\"used\", \"conditioning\", \"columns\")' \
104+
--project_directory='/results/' \
105+
--experiment_name='example_run' \
106+
--mode='training' \
107+
--deterministic=False \
108+
--cuda_benchmark=False \
109+
--cuda_enable=True \
110+
--use_zero=True \
111+
--device='ddp' \
112+
--seed=4 \
113+
--epochs=500 \
114+
--learning_rate=0.0005 \
115+
--gamma='auto' \
116+
--log_every=1 \
117+
--checkpoint_every=1 \
118+
--eval_every=0 \
119+
--weighted_sampling=True \
120+
--batch_size=2 \
121+
--eval_batch_size=2 \
122+
--num_workers=16 \
123+
--prefetch_factor=16 \
124+
--vqvae_checkpoint='/path/to/vqvae/checkpoint/' \
125+
--vqvae_aug_conditionings='none' \
126+
--vqvae_aug_load_nii_canonical=False \
127+
--vqvae_aug_augmentation_probability=0.00 \
128+
--vqvae_aug_augmentation_strength=0.0 \
129+
--vqvae_aug_normalize=True \
130+
--vqvae_aug_roi='((16,176), (16,240),(96,256))' \
131+
--vqvae_network='baseline_vqvae' \
132+
--vqvae_net_level=0 \
133+
--vqvae_net_use_subpixel_conv=False \
134+
--vqvae_net_use_slim_residual=True \
135+
--vqvae_net_no_levels=4 \
136+
--vqvae_net_downsample_parameters='((4,2,1,1),(4,2,1,1),(4,2,1,1),(4,2,1,1))' \
137+
--vqvae_net_upsample_parameters='((4,2,1,0,1),(4,2,1,0,1),(4,2,1,0,1),(4,2,1,0,1))' \
138+
--vqvae_net_no_res_layers=3 \
139+
--vqvae_net_no_channels=256 \
140+
--vqvae_net_codebook_type='ema' \
141+
--vqvae_net_num_embeddings='(2048,)' \
142+
--vqvae_net_embedding_dim='(32,)' \
143+
--vqvae_net_embedding_init='(\"normal\",)' \
144+
--vqvae_net_commitment_cost='(0.25, )' \
145+
--vqvae_net_decay='(0.5,)' \
146+
--vqvae_net_dropout=0.0 \
147+
--vqvae_net_act='RELU'\
148+
--starting_epoch=0 \
149+
--ordering_type='raster_scan' \
150+
--transpositions_axes='((2, 0, 1),)' \
151+
--rot90_axes='((0, 1),)' \
152+
--transformation_order='(\"rotate_90\", \"transpose\")' \
153+
--network='xtransformer' \
154+
--vocab_size=2048 \
155+
--n_embd=1024 \
156+
--n_layers=36 \
157+
--n_head=16 \
158+
--tie_embedding=False \
159+
--ff_glu=False \
160+
--emb_dropout=0.001 \
161+
--ff_dropout=0.001 \
162+
--attn_dropout=0.001 \
163+
--use_rezero=False \
164+
--position_emb='rotary' \
165+
--conditioning_type='cross_attend' \
166+
--use_continuous_conditioning='(True, True, True, True)' \
167+
--local_attn_heads=8 \
168+
--local_window_size=420 \
169+
--feature_redraw_interval=1 \
170+
--generalized_attention=False \
171+
--use_rmsnorm=True \
172+
--attn_talking_heads=False \
173+
--attn_on_attn=False \
174+
--attn_gate_values=True \
175+
--sandwich_norm=False \
176+
--rel_pos_bias=False \
177+
--use_qk_norm_attn=False \
178+
--spatial_rel_pos_bias=True \
179+
--bucket_values=False \
180+
--shift_mem_down=1
181+
```
182+
183+
# Acknowledgements
184+
185+
Work done through the collaboration between NVIDIA and KCL.
186+
187+
The models in this work were trained on [NVIDIA Cambridge-1](https://www.nvidia.com/en-us/industries/healthcare-life-sciences/cambridge-1/), the UK’s largest supercomputer, aimed at accelerating digital biology.
188+
189+
# Funding
190+
- Jointly with UCL - Wellcome Flagship Programme (WT213038/Z/18/Z)
191+
- Wellcome/EPSRC Centre for Medical Engineering (WT203148/Z/16/Z)
192+
- EPSRC Research Council DTP (EP/R513064/1)
193+
- The London AI Center for Value-Based Healthcare
194+
- GE Healthcare
195+
- Intramural Research Program of the NIMH (ZIC-MH002960 and ZIC-MH002968).
196+
- European Union’s HORIZON 2020 Research
197+
- Innovation Programme under the Marie Sklodowska-Curie Grant Agreement No 814302
198+
- UCLH NIHR Biomedical Research Centre.

docker/Dockerfile

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
FROM nvcr.io/nvidia/pytorch:21.05-py3
2+
3+
# The commented lines are meant for internal use by KCL BMEIS staff
4+
#ARG USER_ID
5+
#ARG GROUP_ID
6+
#ARG USER
7+
8+
ENV TZ=Europe/London
9+
10+
#RUN addgroup --gid $GROUP_ID $USER
11+
#RUN adduser --disabled-password --gecos '' --uid $USER_ID --gid $GROUP_ID $USER
12+
13+
# This is required for cases where the sqsh file is run without internet connection
14+
RUN rm -rf ~/.cache
15+
RUN rm -rf /root/.cache
16+
RUN mkdir /cache_dir
17+
RUN ln -s /cache_dir ~/.cache
18+
RUN ln -s /cache_dir /root/.cache
19+
ENV XDG_CACHE_HOME=/cache_dir
20+
ENV TORCH_HOME=/cache_dir
21+
ENV MPLCONFIGDIR=/cache_dir
22+
23+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
24+
25+
RUN apt-get update && apt-get install -y sudo
26+
RUN pip3 install --upgrade pip
27+
RUN apt-get install -y ffmpeg
28+
29+
RUN pip install -U --no-cache-dir git+https://github.com/idiap/fast-transformers.git@c99d771cdff096ce44336e06d9fcf2fe163b7626
30+
COPY ./requirements.txt .
31+
RUN pip3 install -r requirements.txt
32+
33+
# This is required for cases where the sqsh file is run without internet connection
34+
RUN python3 -c "import lpips;lpips.LPIPS(net='alex')"
35+
36+
ENV BUILD_MONAI=1
37+
RUN wget https://github.com/Project-MONAI/MONAI/archive/0aa936f87a694a66d54e514ec823a37e999be862.zip && \
38+
unzip 0aa936f87a694a66d54e514ec823a37e999be862.zip && \
39+
rm 0aa936f87a694a66d54e514ec823a37e999be862.zip && \
40+
cd MONAI-0aa936f87a694a66d54e514ec823a37e999be862 && \
41+
python3 setup.py develop
42+
43+
#USER $USER
44+
45+
#RUN source /home/$USER/.bashrc
46+
#RUN source /home/$USER/.profile
47+
48+
ENTRYPOINT ["/bin/bash"]

docker/requirements.txt

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
pytorch-ignite==0.4.4
2+
fire==0.4.0
3+
lpips==0.1.4
4+
pytorch_msssim==0.2.1
5+
nilearn==0.9.0
6+
tensorboard==2.8.0
7+
pandas==1.1.5
8+
nibabel==3.2.2
9+
moviepy==1.0.3
10+
performer-pytorch==1.1.4
11+
deepspeed==0.4.2
12+
mpi4py==3.1.3
13+
x-transformers==0.25.4

0 commit comments

Comments
 (0)