diff --git a/hippunfold/config/snakebids.yml b/hippunfold/config/snakebids.yml index 295b46d3..a350a176 100644 --- a/hippunfold/config/snakebids.yml +++ b/hippunfold/config/snakebids.yml @@ -147,8 +147,17 @@ parse_args: default: False action: 'store_true' + --use_gpu: + help: 'Enable gpu for inference by setting resource gpus=1 in run_inference rule (default: %(default)s)' + default: False + action: 'store_true' + + --nnunet_disable_tta: + help: 'Disable test-time augmentation for nnU-net inference, speeds up inference by 8x, at expense of accuracy (default: %(default)s)' + default: False + action: 'store_true' - + #--- workflow specific configuration -- singularity: diff --git a/hippunfold/workflow/rules/nnunet.smk b/hippunfold/workflow/rules/nnunet.smk index dee8ee59..01d37a96 100644 --- a/hippunfold/workflow/rules/nnunet.smk +++ b/hippunfold/workflow/rules/nnunet.smk @@ -59,21 +59,22 @@ rule run_inference: out_folder = 'templbl', task = parse_task_from_tar, chkpnt = parse_chkpnt_from_tar, + disable_tta = '' if config['nnunet_disable_tta'] else '--disable_tta' output: nnunet_seg = bids(root='work',datatype='seg_{modality}',**config['subj_wildcards'],suffix='dseg.nii.gz',desc='nnunet',space='corobl',hemi='{hemi,Lflip|R}') shadow: 'minimal' - threads: 16 + threads: 16 resources: - gpus = 1, + gpus = 1 if config['use_gpu'] else 0, mem_mb = 32000, - time = 30, + time = 30 if config['use_gpu'] else 60, group: 'subj' shell: 'mkdir -p {params.model_dir} {params.in_folder} {params.out_folder} && ' #create temp folders 'cp -v {input.in_img} {params.temp_img} && ' #cp input image to temp folder 'tar -xvf {input.model_tar} -C {params.model_dir} && ' #extract model 'export RESULTS_FOLDER={params.model_dir} && ' #set nnunet env var to point to model 'export nnUNet_n_proc_DA={threads} && ' #set threads - 'nnUNet_predict -i {params.in_folder} -o {params.out_folder} -t {params.task} -chk {params.chkpnt} && ' # run inference + 'nnUNet_predict -i {params.in_folder} -o {params.out_folder} -t {params.task} -chk {params.chkpnt} {params.disable_tta} && ' # run inference 'cp -v {params.temp_lbl} {output.nnunet_seg}' #copy from temp output folder to final output diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index fe4abfbb..00000000 --- a/requirements.txt +++ /dev/null @@ -1,44 +0,0 @@ -amply==0.1.4 -appdirs==1.4.4 -attrs==20.3.0 -bids-validator==1.5.8 -certifi==2020.12.5 -chardet==3.0.4 -click==7.1.2 -ConfigArgParse==1.2.3 -datrie==0.8.2 -docopt==0.6.2 -docutils==0.16 -gitdb==4.0.5 -GitPython==3.1.11 -idna==2.10 -ipython-genutils==0.2.0 -jsonschema==3.2.0 -jupyter-core==4.7.0 -nbformat==5.0.8 -nibabel==3.2.1 -num2words==0.5.10 -numpy==1.19.4 -packaging==20.8 -pandas==1.1.5 -patsy==0.5.1 -psutil==5.7.3 -PuLP==2.3.1 -pybids==0.12.3 -pyparsing==2.4.7 -pyrsistent==0.17.3 -python-dateutil==2.8.1 -pytz==2020.4 -PyYAML==5.3.1 -ratelimiter==1.2.0.post0 -requests==2.25.0 -scipy==1.5.4 -six==1.15.0 -smmap==3.0.4 -snakebids==0.2.0 -snakemake==5.30.1 -SQLAlchemy==1.3.20 -toposort==1.5 -traitlets==5.0.5 -urllib3==1.26.2 -wrapt==1.12.1 diff --git a/setup.py b/setup.py index 6bf9b982..f865abdc 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ install_requires=[ "snakebids>=0.2.0", "snakemake>=5.28.0", - "nnunet>=1.6.6", + "nnunet @ git+https://github.com/ylugithub/nnUNet.git@v1.6.6", "appdirs", "pandas", "nibabel",