Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TTA Batch Processing to Improve Inference Speed #2153

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

PengchengShi1220
Copy link

Summary:
Proposing the integration of Test Time Augmentation (TTA) with batch processing in nnUNet to enhance inference efficiency, particularly evident in larger 3D datasets. Demonstrated improvements of 5%-8% in speed with validated results on the AMOS2022 dataset.

Implementation Details:

  • Python Version: 3.11
  • PyTorch Version: 2.2.2+cu121
  • Model: [64, 160, 192] patch size on NVIDIA RTX 3090, 24GB VRAM.

Results:

  • Mirror Axes (0, 1, 2):

    ID TTA No Batch (s) TTA Batch Size=2 (s) TTA Batch Size=4 (s) TTA Batch Size=8 (s)
    amos_0247 51 49 48 49
    amos_0111 80 76 74 75
    amos_0173 129 122 119 120
  • Mirror Axes (0, 1):

    ID TTA No Batch (s) TTA Batch Size=2 (s) TTA Batch Size=4 (s)
    amos_0247 26 24 25
    amos_0111 40 38 38
    amos_0173 64 61 60
  • Mirror Axis (0):

    ID TTA No Batch (s) TTA Batch Size=2 (s)
    amos_0247 13 12
    amos_0111 20 19
    amos_0173 32 31

VRAM Usage:

  • Detailed VRAM Consumption by TTA Batch Size:
    Batch Size VRAM (GB) - Axes (0, 1, 2) VRAM (GB) - Axes (0, 1) VRAM (GB) - Axis (0)
    1 7.57 7.57 7.57
    2 9.24 9.24 9.24
    4 12.37 12.37 -
    8 17.04 - -

Recommendations:

  • Use TTA Batch Size=4 for configurations with three Mirror Axes (0, 1, 2).
  • Use TTA Batch Size=2 for configurations with fewer Mirror Axes.

The TTA batch processing approach has been thoroughly tested on the AMOS2022 dataset, showing consistent results with the original setup.

@FabianIsensee FabianIsensee self-assigned this May 2, 2024
@FabianIsensee
Copy link
Member

Hi,
thanks for the contribution + extensive benchmarking! That helps a lot in seeing the value!
If you would like us to include this, please make it an optional parameter people can set when calling nnUNetv2_predict. This should also (like all the other parameters) be set in the init of the nnUNetPredictor class.
The reason I want this to be optional is twofold:

  • sometimes we just don't have the VRAM to justify doing that
  • in case of limited VRAM, there are other VRAM-hungry features (perform_everything_on_device) that are more impactful for inference throughput and should be prioritized over batching TTA

Best,
Fabian

@PengchengShi1220
Copy link
Author

Hi Fabian,

Thanks for your feedback! Based on your suggestions, I have now made the "use_batch_tta" an optional parameter in the nnUNetPredictor class, which can be controlled via the parser argument "disable_batch_tta". This allows users to opt-in or out of batch TTA based on their VRAM capacity and priorities.

Please let me know if further adjustments are required.

Best,
Pengcheng

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants