Skip to content

Commit

Permalink
added option to disable test-time augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Feb 1, 2021
1 parent 41a129a commit 60cbbbb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions hippunfold/config/snakebids.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ parse_args:
default: False
action: 'store_true'

--nnunet_disable_tta:
help: 'Disable test-time augmentation for nnU-net inference, speeds up inference by 8x, at expense of accuracy (default: %(default)s)'
default: False
action: 'store_true'


#--- workflow specific configuration --

Expand Down
3 changes: 2 additions & 1 deletion hippunfold/workflow/rules/nnunet.smk
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ rule run_inference:
out_folder = 'templbl',
task = parse_task_from_tar,
chkpnt = parse_chkpnt_from_tar,
disable_tta = '' if config['nnunet_disable_tta'] else '--disable_tta'
output:
nnunet_seg = bids(root='work',datatype='seg_{modality}',**config['subj_wildcards'],suffix='dseg.nii.gz',desc='nnunet',space='corobl',hemi='{hemi,Lflip|R}')
shadow: 'minimal'
Expand All @@ -64,7 +65,7 @@ rule run_inference:
'tar -xvf {input.model_tar} -C {params.model_dir} && ' #extract model
'export RESULTS_FOLDER={params.model_dir} && ' #set nnunet env var to point to model
'export nnUNet_n_proc_DA={threads} && ' #set threads
'nnUNet_predict -i {params.in_folder} -o {params.out_folder} -t {params.task} -chk {params.chkpnt} && ' # run inference
'nnUNet_predict -i {params.in_folder} -o {params.out_folder} -t {params.task} -chk {params.chkpnt} {params.disable_tta} && ' # run inference
'cp -v {params.temp_lbl} {output.nnunet_seg}' #copy from temp output folder to final output


Expand Down

0 comments on commit 60cbbbb

Please sign in to comment.